service-internal's picture
Update main.py
a7201bc verified
import os
os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf-cache"
os.environ["HF_HOME"] = "/tmp/hf-home"
from fastapi import FastAPI, Request
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig
from scipy.special import softmax
import numpy as np
# ✅ Define app BEFORE any @app.route
app = FastAPI()
MODEL = "cardiffnlp/twitter-roberta-base-sentiment-latest"
tokenizer = AutoTokenizer.from_pretrained(MODEL)
config = AutoConfig.from_pretrained(MODEL)
model = AutoModelForSequenceClassification.from_pretrained(MODEL)
def preprocess(text):
tokens = []
for t in text.split():
if t.startswith("@") and len(t) > 1:
t = "@user"
elif t.startswith("http"):
t = "http"
tokens.append(t)
return " ".join(tokens)
@app.post("/analyze")
async def analyze(request: Request):
data = await request.json()
text = preprocess(data.get("text", ""))
if not text.strip():
return {"error": "Empty input"}
# Token length check
tokenized = tokenizer(text, return_tensors='pt', add_special_tokens=True)
if tokenized.input_ids.shape[1] <= 512:
encoded_input = tokenizer(text, return_tensors='pt', truncation=True, padding=True)
output = model(**encoded_input)
probs = softmax(output[0][0].detach().numpy())
else:
max_words = 500
words = text.split()
chunks = [" ".join(words[i:i + max_words]) for i in range(0, len(words), max_words)]
all_probs = []
for chunk in chunks:
encoded_input = tokenizer(chunk, return_tensors='pt', truncation=True, padding=True, max_length=512)
output = model(**encoded_input)
probs_chunk = softmax(output[0][0].detach().numpy())
all_probs.append(probs_chunk)
probs = np.mean(all_probs, axis=0)
# Define the fixed order
fixed_order = ["negative", "neutral", "positive"]
# Build the result using fixed order
result = [
{"label": label, "score": round(float(probs[config.label2id[label]]), 4)}
for label in fixed_order
]
return {"result": result}