embedding-model / app.py
snsynth's picture
add all files
8980288
raw
history blame
1.39 kB
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModel
import torch
app = FastAPI(
title="OpenAI-compatible Embedding API",
version="1.0.0",
)
# Load model from Hugging Face Hub
MODEL_NAME = "BAAI/bge-small-en-v1.5"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModel.from_pretrained(MODEL_NAME)
model.eval()
class EmbeddingRequest(BaseModel):
input: list[str]
model: str
@app.get("/")
def root():
return {"message": "API is working"}
@app.post("/v1/embeddings")
def create_embeddings(request: EmbeddingRequest):
with torch.no_grad():
tokens = tokenizer(request.input, return_tensors="pt", padding=True, truncation=True)
output = model(**tokens)
cls_embeddings = output.last_hidden_state[:, 0]
norm_embeddings = torch.nn.functional.normalize(cls_embeddings, p=2, dim=1)
data = [
{
"object": "embedding",
"embedding": e.tolist(),
"index": i
}
for i, e in enumerate(norm_embeddings)
]
return {
"object": "list",
"data": data,
"model": request.model,
"usage": {
"prompt_tokens": sum(len(tokenizer.encode(x)) for x in request.input),
"total_tokens": sum(len(tokenizer.encode(x)) for x in request.input),
}
}