koyu008 commited on
Commit
16b2ba7
·
verified ·
1 Parent(s): aaf556c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -47
app.py CHANGED
@@ -1,33 +1,29 @@
1
- from fastapi import FastAPI
2
  from pydantic import BaseModel
3
- from langdetect import detect
4
  import torch
5
  import torch.nn as nn
6
- from transformers import DistilBertModel, AutoModel, AutoTokenizer, DistilBertTokenizer
 
7
  from huggingface_hub import snapshot_download
8
  import os
9
 
10
- # App and device
11
- app = FastAPI()
12
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
 
14
- # Create safe local cache directory
15
- hf_cache_dir = "./hf_cache"
16
- os.makedirs(hf_cache_dir, exist_ok=True)
17
- os.environ["TRANSFORMERS_CACHE"] = hf_cache_dir
18
 
19
- # Download model repositories to local path
20
- english_path = snapshot_download("koyu008/English_Toxic_Classifier", cache_dir=hf_cache_dir)
21
- hinglish_path = snapshot_download("koyu008/Hinglish_comment_classifier", cache_dir=hf_cache_dir)
22
 
23
- # ----------------------------
24
- # Model classes
25
- # ----------------------------
26
 
 
27
  class ToxicBERT(nn.Module):
28
  def __init__(self):
29
  super().__init__()
30
- self.bert = DistilBertModel.from_pretrained(english_path)
31
  self.dropout = nn.Dropout(0.3)
32
  self.classifier = nn.Linear(self.bert.config.hidden_size, 6)
33
 
@@ -36,10 +32,11 @@ class ToxicBERT(nn.Module):
36
  return self.classifier(self.dropout(output))
37
 
38
 
 
39
  class HinglishToxicClassifier(nn.Module):
40
  def __init__(self):
41
  super().__init__()
42
- self.bert = AutoModel.from_pretrained(hinglish_path)
43
  hidden_size = self.bert.config.hidden_size
44
  self.pool = lambda hidden: torch.cat([
45
  hidden.mean(dim=1),
@@ -58,49 +55,50 @@ class HinglishToxicClassifier(nn.Module):
58
  x = self.bottleneck(pooled)
59
  return self.classifier(x)
60
 
61
- # ----------------------------
62
- # Load Models & Tokenizers
63
- # ----------------------------
64
 
 
65
  english_model = ToxicBERT().to(device)
66
- english_model.load_state_dict(torch.load("bert_toxic_classifier.pt", map_location=device))
67
  english_model.eval()
68
- english_tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
69
 
70
  hinglish_model = HinglishToxicClassifier().to(device)
71
- hinglish_model.load_state_dict(torch.load("best_hinglish_model.pt", map_location=device))
72
  hinglish_model.eval()
73
- hinglish_tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-base")
74
 
75
- # ----------------------------
76
- # API
77
- # ----------------------------
 
 
 
 
78
 
79
- class InputText(BaseModel):
80
  text: str
81
 
 
82
  @app.post("/predict")
83
- async def predict(input: InputText):
84
- text = input.text
85
- lang = detect(text)
 
 
 
86
 
87
  if lang == "en":
88
- inputs = english_tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(device)
 
 
89
  with torch.no_grad():
90
- logits = english_model(**inputs)
91
- probs = torch.softmax(logits, dim=1).cpu().numpy().tolist()[0]
92
- return {
93
- "language": "english",
94
- "classes": ["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"],
95
- "probabilities": probs
96
- }
97
  else:
98
- inputs = hinglish_tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(device)
 
 
99
  with torch.no_grad():
100
- logits = hinglish_model(**inputs)
101
- probs = torch.softmax(logits, dim=1).cpu().numpy().tolist()[0]
102
- return {
103
- "language": "hinglish",
104
- "classes": ["toxic", "non-toxic"],
105
- "probabilities": probs
106
- }
 
1
+ from fastapi import FastAPI, Request
2
  from pydantic import BaseModel
 
3
  import torch
4
  import torch.nn as nn
5
+ from transformers import DistilBertTokenizer, DistilBertModel, AutoModel, AutoTokenizer
6
+ from langdetect import detect
7
  from huggingface_hub import snapshot_download
8
  import os
9
 
10
+ # Device
 
11
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
 
13
+ # Download model repos from HF Hub
14
+ english_repo = snapshot_download("koyu008/English_Toxic_Classifier")
15
+ hinglish_repo = snapshot_download("koyu008/HInglish_comment_classifier")
 
16
 
17
+ # Tokenizers
18
+ english_tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
19
+ hinglish_tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-base")
20
 
 
 
 
21
 
22
+ # English Model
23
  class ToxicBERT(nn.Module):
24
  def __init__(self):
25
  super().__init__()
26
+ self.bert = DistilBertModel.from_pretrained("distilbert-base-uncased")
27
  self.dropout = nn.Dropout(0.3)
28
  self.classifier = nn.Linear(self.bert.config.hidden_size, 6)
29
 
 
32
  return self.classifier(self.dropout(output))
33
 
34
 
35
+ # Hinglish Model
36
  class HinglishToxicClassifier(nn.Module):
37
  def __init__(self):
38
  super().__init__()
39
+ self.bert = AutoModel.from_pretrained("xlm-roberta-base")
40
  hidden_size = self.bert.config.hidden_size
41
  self.pool = lambda hidden: torch.cat([
42
  hidden.mean(dim=1),
 
55
  x = self.bottleneck(pooled)
56
  return self.classifier(x)
57
 
 
 
 
58
 
59
+ # Instantiate and load models
60
  english_model = ToxicBERT().to(device)
61
+ english_model.load_state_dict(torch.load(os.path.join(english_repo, "bert_toxic_classifier.pt"), map_location=device))
62
  english_model.eval()
 
63
 
64
  hinglish_model = HinglishToxicClassifier().to(device)
65
+ hinglish_model.load_state_dict(torch.load(os.path.join(hinglish_repo, "best_hinglish_model.pt"), map_location=device))
66
  hinglish_model.eval()
 
67
 
68
+ # Labels
69
+ english_labels = ['toxic', 'severe toxic', 'obscene', 'threat', 'insult', 'identity hate']
70
+ hinglish_labels = ['not toxic', 'toxic']
71
+
72
+ # FastAPI
73
+ app = FastAPI()
74
+
75
 
76
+ class TextIn(BaseModel):
77
  text: str
78
 
79
+
80
  @app.post("/predict")
81
+ def predict(data: TextIn):
82
+ text = data.text
83
+ try:
84
+ lang = detect(text)
85
+ except:
86
+ lang = "unknown"
87
 
88
  if lang == "en":
89
+ tokenizer = english_tokenizer
90
+ model = english_model
91
+ inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(device)
92
  with torch.no_grad():
93
+ outputs = model(**inputs)
94
+ probs = torch.sigmoid(outputs).squeeze().cpu().tolist()
95
+ return {"language": "English", "predictions": dict(zip(english_labels, probs))}
96
+
 
 
 
97
  else:
98
+ tokenizer = hinglish_tokenizer
99
+ model = hinglish_model
100
+ inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(device)
101
  with torch.no_grad():
102
+ outputs = model(**inputs)
103
+ probs = torch.softmax(outputs, dim=1).squeeze().cpu().tolist()
104
+ return {"language": "Hinglish", "predictions": dict(zip(hinglish_labels, probs))}