ALIGN-Sim / Models /SentenceTransformersModel.py
yzm0034's picture
Upload folder using huggingface_hub
4f08d2c verified
raw
history blame contribute delete
783 Bytes
import abc
import warnings
from pathlib import Path
from typing import List, Union
import torch
from numpy.typing import NDArray
from sentence_transformers import SentenceTransformer
class SentenceTransformerModels():
def __init__(self, model_id, device: bool = False):
self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = SentenceTransformer(model_id).eval()
def encode(self, sentences: List[str], batch_size: int = 32) -> NDArray:
with torch.no_grad():
embeddings = self.model.encode(
sentences, batch_size=batch_size, device=self.device
)
if isinstance(embeddings, torch.Tensor):
return embeddings.cpu().numpy()
return embeddings