service-internal commited on
Commit
0191339
·
verified ·
1 Parent(s): 0a4ba60

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +46 -60
main.py CHANGED
@@ -1,63 +1,49 @@
1
- import os
2
- os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf-cache"
3
- os.environ["HF_HOME"] = "/tmp/hf-home"
4
-
5
- from fastapi import FastAPI, Request
6
- from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig
7
- from scipy.special import softmax
8
- import numpy as np
9
-
10
- app = FastAPI()
11
-
12
- MODEL = "cardiffnlp/twitter-roberta-base-sentiment-latest"
13
-
14
- # Load model and tokenizer
15
- tokenizer = AutoTokenizer.from_pretrained(MODEL)
16
- config = AutoConfig.from_pretrained(MODEL)
17
- model = AutoModelForSequenceClassification.from_pretrained(MODEL)
18
-
19
- # Preprocessing step for Twitter-style input
20
- def preprocess(text):
21
- tokens = []
22
- for t in text.split():
23
- if t.startswith("@") and len(t) > 1:
24
- t = "@user"
25
- elif t.startswith("http"):
26
- t = "http"
27
- tokens.append(t)
28
- return " ".join(tokens)
29
-
30
  @app.post("/analyze")
31
  async def analyze(request: Request):
32
  data = await request.json()
33
- raw_text = data.get("text", "")
34
-
35
- # Logging for debugging
36
- print(f"Raw input: {raw_text}")
37
-
38
- if not raw_text.strip():
39
- return {"error": "Empty input text."}
40
-
41
- text = preprocess(raw_text)
42
- print(f"Preprocessed: {text}")
43
-
44
- encoded_input = tokenizer(text, return_tensors='pt', truncation=True, padding=True)
45
- print(f"Encoded input: {encoded_input.input_ids}")
46
-
47
- output = model(**encoded_input)
48
- scores = output[0][0].detach().numpy()
49
- probs = softmax(scores)
50
-
51
- # Logging output
52
- print(f"Raw scores: {scores}")
53
- print(f"Softmax probs: {probs}")
54
-
55
- result = [
56
- {"label": config.id2label[i], "score": round(float(probs[i]), 4)}
57
- for i in probs.argsort()[::-1]
58
- ]
59
-
60
- print(f"Result: {result}")
61
- return {"result": result}
62
-
63
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  @app.post("/analyze")
2
  async def analyze(request: Request):
3
  data = await request.json()
4
+ text = preprocess(data.get("text", ""))
5
+
6
+ if not text.strip():
7
+ return {"error": "Empty input"}
8
+
9
+ # Tokenize to check length without truncating
10
+ tokenized = tokenizer(text, return_tensors='pt', add_special_tokens=True)
11
+ num_tokens = tokenized.input_ids.shape[1]
12
+
13
+ if num_tokens <= 512:
14
+ # ✅ Use direct inference for short inputs
15
+ encoded_input = tokenizer(text, return_tensors='pt', truncation=True, padding=True)
16
+ output = model(**encoded_input)
17
+ scores = output[0][0].detach().numpy()
18
+ probs = softmax(scores)
19
+
20
+ result = [
21
+ {"label": config.id2label[i], "score": round(float(probs[i]), 4)}
22
+ for i in probs.argsort()[::-1]
23
+ ]
24
+
25
+ return {"result": result}
26
+
27
+ else:
28
+ # Long input: Split into chunks of ~500 words
29
+ max_words = 500
30
+ words = text.split()
31
+ chunks = [" ".join(words[i:i + max_words]) for i in range(0, len(words), max_words)]
32
+
33
+ all_scores = []
34
+ for chunk in chunks:
35
+ encoded_input = tokenizer(chunk, return_tensors='pt', truncation=True, padding=True, max_length=512)
36
+ output = model(**encoded_input)
37
+ scores = output[0][0].detach().numpy()
38
+ probs = softmax(scores)
39
+ all_scores.append(probs)
40
+
41
+ # Average softmax scores
42
+ avg_scores = np.mean(all_scores, axis=0)
43
+
44
+ result = [
45
+ {"label": config.id2label[i], "score": round(float(avg_scores[i]), 4)}
46
+ for i in avg_scores.argsort()[::-1]
47
+ ]
48
+
49
+ return {"result": result}