Spaces:
Runtime error
Runtime error
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
from transformers import AutoTokenizer, AutoModel | |
import torch | |
import os | |
app = FastAPI() | |
# Load Hugging Face Token | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
if not HF_TOKEN: | |
raise ValueError("β Hugging Face API token not found! Set HF_TOKEN as an environment variable.") | |
# Load tokenizer and model | |
tokenizer = AutoTokenizer.from_pretrained("mental/mental-bert-base-uncased", use_auth_token=HF_TOKEN) | |
model = AutoModel.from_pretrained("mental/mental-bert-base-uncased", use_auth_token=HF_TOKEN) | |
model.eval() # Set model to evaluation mode | |
# Request body schema | |
class TextRequest(BaseModel): | |
text: str | |
# Helper function to compute embedding | |
def compute_embedding(text: str) -> list[float]: | |
"""Generate a sentence embedding using mean pooling on MentalBERT output.""" | |
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
embedding = outputs.last_hidden_state.mean(dim=1).squeeze() | |
return embedding.tolist() | |
# POST endpoint to return embedding | |
def get_embedding(request: TextRequest): | |
text = request.text.strip() | |
if not text: | |
raise HTTPException(status_code=400, detail="Input text cannot be empty.") | |
try: | |
embedding = compute_embedding(text) | |
return {"embedding": embedding} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Error computing embedding: {str(e)}") | |