service-internal commited on
Commit
3683686
·
verified ·
1 Parent(s): 0191339

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +39 -29
main.py CHANGED
@@ -1,3 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  @app.post("/analyze")
2
  async def analyze(request: Request):
3
  data = await request.json()
@@ -6,44 +33,27 @@ async def analyze(request: Request):
6
  if not text.strip():
7
  return {"error": "Empty input"}
8
 
9
- # Tokenize to check length without truncating
10
  tokenized = tokenizer(text, return_tensors='pt', add_special_tokens=True)
11
- num_tokens = tokenized.input_ids.shape[1]
12
-
13
- if num_tokens <= 512:
14
- # ✅ Use direct inference for short inputs
15
  encoded_input = tokenizer(text, return_tensors='pt', truncation=True, padding=True)
16
  output = model(**encoded_input)
17
- scores = output[0][0].detach().numpy()
18
- probs = softmax(scores)
19
-
20
- result = [
21
- {"label": config.id2label[i], "score": round(float(probs[i]), 4)}
22
- for i in probs.argsort()[::-1]
23
- ]
24
-
25
- return {"result": result}
26
-
27
  else:
28
- # ✅ Long input: Split into chunks of ~500 words
29
  max_words = 500
30
  words = text.split()
31
  chunks = [" ".join(words[i:i + max_words]) for i in range(0, len(words), max_words)]
32
-
33
- all_scores = []
34
  for chunk in chunks:
35
  encoded_input = tokenizer(chunk, return_tensors='pt', truncation=True, padding=True, max_length=512)
36
  output = model(**encoded_input)
37
- scores = output[0][0].detach().numpy()
38
- probs = softmax(scores)
39
- all_scores.append(probs)
40
-
41
- # Average softmax scores
42
- avg_scores = np.mean(all_scores, axis=0)
43
 
44
- result = [
45
- {"label": config.id2label[i], "score": round(float(avg_scores[i]), 4)}
46
- for i in avg_scores.argsort()[::-1]
47
- ]
 
48
 
49
- return {"result": result}
 
1
+ import os
2
+ os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf-cache"
3
+ os.environ["HF_HOME"] = "/tmp/hf-home"
4
+
5
+ from fastapi import FastAPI, Request
6
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig
7
+ from scipy.special import softmax
8
+ import numpy as np
9
+
10
+ # ✅ Define app BEFORE any @app.route
11
+ app = FastAPI()
12
+
13
+ MODEL = "cardiffnlp/twitter-roberta-base-sentiment-latest"
14
+ tokenizer = AutoTokenizer.from_pretrained(MODEL)
15
+ config = AutoConfig.from_pretrained(MODEL)
16
+ model = AutoModelForSequenceClassification.from_pretrained(MODEL)
17
+
18
+ def preprocess(text):
19
+ tokens = []
20
+ for t in text.split():
21
+ if t.startswith("@") and len(t) > 1:
22
+ t = "@user"
23
+ elif t.startswith("http"):
24
+ t = "http"
25
+ tokens.append(t)
26
+ return " ".join(tokens)
27
+
28
  @app.post("/analyze")
29
  async def analyze(request: Request):
30
  data = await request.json()
 
33
  if not text.strip():
34
  return {"error": "Empty input"}
35
 
36
+ # Token length check
37
  tokenized = tokenizer(text, return_tensors='pt', add_special_tokens=True)
38
+ if tokenized.input_ids.shape[1] <= 512:
 
 
 
39
  encoded_input = tokenizer(text, return_tensors='pt', truncation=True, padding=True)
40
  output = model(**encoded_input)
41
+ probs = softmax(output[0][0].detach().numpy())
 
 
 
 
 
 
 
 
 
42
  else:
 
43
  max_words = 500
44
  words = text.split()
45
  chunks = [" ".join(words[i:i + max_words]) for i in range(0, len(words), max_words)]
46
+ all_probs = []
 
47
  for chunk in chunks:
48
  encoded_input = tokenizer(chunk, return_tensors='pt', truncation=True, padding=True, max_length=512)
49
  output = model(**encoded_input)
50
+ probs_chunk = softmax(output[0][0].detach().numpy())
51
+ all_probs.append(probs_chunk)
52
+ probs = np.mean(all_probs, axis=0)
 
 
 
53
 
54
+ result = [
55
+ {"label": config.id2label[i], "score": round(float(probs[i]), 4)}
56
+ for i in probs.argsort()[::-1]
57
+ ]
58
+ return {"result": result}
59