koyu008 commited on
Commit
3386fb5
·
verified ·
1 Parent(s): f3eb85c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -24
app.py CHANGED
@@ -6,6 +6,7 @@ from transformers import DistilBertTokenizer, DistilBertModel, AutoModel, AutoTo
6
  from langdetect import detect
7
  from huggingface_hub import snapshot_download
8
  import os
 
9
 
10
  # Device
11
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -84,34 +85,43 @@ app.add_middleware(
84
 
85
 
86
  class TextIn(BaseModel):
87
- text: str
88
 
89
 
90
  @app.post("/api/predict")
91
  def predict(data: TextIn):
92
- text = data.text
93
- try:
94
- lang = detect(text)
95
- except:
96
- lang = "unknown"
97
-
98
- if lang == "en":
99
- tokenizer = english_tokenizer
100
- model = english_model
101
- inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(device)
102
- with torch.no_grad():
103
- outputs = model(**inputs)
104
- probs = torch.sigmoid(outputs).squeeze().cpu().tolist()
105
- return {"language": "English", "predictions": dict(zip(english_labels, probs))}
106
-
107
- else:
108
- tokenizer = hinglish_tokenizer
109
- model = hinglish_model
110
- inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(device)
111
- with torch.no_grad():
112
- outputs = model(**inputs)
113
- probs = torch.softmax(outputs, dim=1).squeeze().cpu().tolist()
114
- return {"language": "Hinglish", "predictions": dict(zip(hinglish_labels, probs))}
 
 
 
 
 
 
 
 
 
115
 
116
  @app.get("/")
117
  def root():
 
6
  from langdetect import detect
7
  from huggingface_hub import snapshot_download
8
  import os
9
+ from typing import List
10
 
11
  # Device
12
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
85
 
86
 
87
  class TextIn(BaseModel):
88
+ texts: List[str]
89
 
90
 
91
  @app.post("/api/predict")
92
  def predict(data: TextIn):
93
+ results = []
94
+
95
+ for text in data.texts:
96
+ try:
97
+ lang = detect(text)
98
+ except:
99
+ lang = "unknown"
100
+
101
+ if lang == "en":
102
+ tokenizer = english_tokenizer
103
+ model = english_model
104
+ inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(device)
105
+ with torch.no_grad():
106
+ outputs = model(**inputs)
107
+ probs = torch.sigmoid(outputs).squeeze().cpu().tolist()
108
+ predictions = dict(zip(english_labels, probs))
109
+ else:
110
+ tokenizer = hinglish_tokenizer
111
+ model = hinglish_model
112
+ inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(device)
113
+ with torch.no_grad():
114
+ outputs = model(**inputs)
115
+ probs = torch.softmax(outputs, dim=1).squeeze().cpu().tolist()
116
+ predictions = dict(zip(hinglish_labels, probs))
117
+
118
+ results.append({
119
+ "text": text,
120
+ "language": lang if lang in ["en", "hi"] else "unknown",
121
+ "predictions": predictions
122
+ })
123
+
124
+ return {"results": results}
125
 
126
  @app.get("/")
127
  def root():