pri2k
πŸ”§ Updated app.py to compute embeddings using MentalBERT
0774891
raw
history blame
1.51 kB
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModel
import torch
import os
app = FastAPI()
# Load Hugging Face Token
HF_TOKEN = os.getenv("HF_TOKEN")
if not HF_TOKEN:
raise ValueError("❌ Hugging Face API token not found! Set HF_TOKEN as an environment variable.")
# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("mental/mental-bert-base-uncased", token=HF_TOKEN)
model = AutoModel.from_pretrained("mental/mental-bert-base-uncased", token=HF_TOKEN)
model.eval() # Set model to evaluation mode
# Request body schema
class TextRequest(BaseModel):
text: str
# Helper function to compute embedding
def compute_embedding(text: str) -> list[float]:
"""Generate a sentence embedding using mean pooling on MentalBERT output."""
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
with torch.no_grad():
outputs = model(**inputs)
embedding = outputs.last_hidden_state.mean(dim=1).squeeze()
return embedding.tolist()
# POST endpoint to return embedding
@app.post("/embed")
def get_embedding(request: TextRequest):
text = request.text.strip()
if not text:
raise HTTPException(status_code=400, detail="Input text cannot be empty.")
try:
embedding = compute_embedding(text)
return {"embedding": embedding}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error computing embedding: {str(e)}")