Spaces:
Runtime error
Runtime error
import torch | |
from PIL import Image | |
from transformers import AutoModel, AutoProcessor, AutoTokenizer | |
from transformers import AutoModelForMaskedLM | |
from backend.interfaces import BaseEmbeddingModel | |
class DefaultDenseEmbeddingModel(BaseEmbeddingModel): | |
def __init__(self, config): | |
super().__init__(config) | |
# Initialize your dense embedding model here | |
self.model_name = "BAAI/BGE-VL-base" # or "BAAI/BGE-VL-large" | |
self.model = AutoModel.from_pretrained( | |
self.model_name, trust_remote_code=True | |
).to(self.device) | |
self.preprocessor = AutoProcessor.from_pretrained( | |
self.model_name, trust_remote_code=True | |
) | |
def encode_text(self, texts: list[str]) -> list[list[float]]: | |
if not texts: | |
return [] | |
inputs = self.preprocessor( | |
text=texts, return_tensors="pt", truncation=True, padding=True | |
).to(self.device) | |
return self.model.get_text_features(**inputs).cpu().tolist() | |
def encode_image(self, images: list[str] | list[Image.Image]) -> list[float]: | |
if not images: | |
return [] | |
if isinstance(images[0], str): | |
images = [Image.open(image_path).convert("RGB") for image_path in images] | |
inputs = self.preprocessor(images=images, return_tensors="pt").to(self.device) | |
return self.model.get_image_features(**inputs).cpu().tolist() | |
class DefaultSparseEmbeddingModel(BaseEmbeddingModel): | |
def __init__(self, config): | |
super().__init__(config) | |
# Initialize your sparse embedding model here | |
self.model_name = "naver/splade-v3" | |
self.model = AutoModelForMaskedLM.from_pretrained(self.model_name).to( | |
self.device | |
) | |
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) | |
def encode_text(self, texts: list[str]) -> list[dict]: | |
if not texts: | |
return [] | |
tokens = self.tokenizer( | |
texts, return_tensors="pt", truncation=True, padding=True | |
).to(self.device) | |
outputs = self.model(**tokens) | |
sparse_embedding = ( | |
torch.max( | |
torch.log(1 + torch.relu(outputs.logits)) | |
* tokens.attention_mask.unsqueeze(-1), | |
dim=1, | |
)[0] | |
.detach() | |
.cpu() | |
) | |
# convert to pinecone sparse format | |
res = [] | |
for i in range(len(sparse_embedding)): | |
indices = sparse_embedding[i].nonzero().squeeze().tolist() | |
values = sparse_embedding[i, indices].tolist() | |
res.append({"indices": indices, "values": values}) | |
return res | |