aaryan24 commited on
Commit
b18f4a0
·
verified ·
1 Parent(s): 0cc2ad1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +112 -103
app.py CHANGED
@@ -1,103 +1,112 @@
1
- from fastapi import FastAPI, Request, HTTPException
2
- from fastapi.responses import HTMLResponse
3
- from fastapi.staticfiles import StaticFiles
4
- from fastapi.middleware.cors import CORSMiddleware
5
- from pydantic import BaseModel
6
- import tensorflow as tf
7
- import pickle
8
- from tensorflow.keras.preprocessing.sequence import pad_sequences
9
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
10
- import torch
11
- from fastapi.responses import JSONResponse
12
- # Initialize FastAPI
13
- app = FastAPI()
14
-
15
- # Load GRU model and tokenizer
16
- gru_model = tf.keras.models.load_model('hs_gru.h5')
17
- with open('tokenizerpkl_gru.pkl', 'rb') as f:
18
- gru_tokenizer = pickle.load(f)
19
- gru_maxlen = 100
20
-
21
- # Load RoBERTa model
22
- # Load RoBERTa model
23
- roberta_model_name = "facebook/roberta-hate-speech-dynabench-r4-target"
24
- roberta_tokenizer = AutoTokenizer.from_pretrained(roberta_model_name)
25
- if roberta_tokenizer.pad_token is None:
26
- roberta_tokenizer.add_special_tokens({'pad_token': '[PAD]'})
27
- roberta_model = AutoModelForSequenceClassification.from_pretrained(roberta_model_name)
28
- roberta_model.resize_token_embeddings(len(roberta_tokenizer))
29
-
30
- #load toxigen-hatebert model
31
- toxigen_model_name = "tomh/toxigen_roberta"
32
- toxigen_tokenizer = AutoTokenizer.from_pretrained(toxigen_model_name)
33
- if toxigen_tokenizer.pad_token is None:
34
- toxigen_tokenizer.add_special_tokens({'pad_token': '[PAD]'})
35
- toxigen_model = AutoModelForSequenceClassification.from_pretrained(toxigen_model_name)
36
- toxigen_model.resize_token_embeddings(len(toxigen_tokenizer))
37
-
38
- # Enable CORS
39
- app.add_middleware(
40
- CORSMiddleware,
41
- allow_origins=["*"],
42
- allow_credentials=True,
43
- allow_methods=["*"],
44
- allow_headers=["*"],
45
- )
46
-
47
- # Mount static directory
48
- # app.mount("/static", StaticFiles(directory="static"), name="static")
49
-
50
- # Pydantic input model
51
- class TextInput(BaseModel):
52
- text: str
53
-
54
- @app.get("/", response_class=HTMLResponse)
55
- def read_root():
56
- with open("index.html", "r") as f:
57
- return f.read()
58
-
59
- @app.get("/health")
60
- def health_check():
61
- return {"message": "Hate Speech Detection API is running!"}
62
-
63
- @app.post("/predict")
64
- def predict_ensemble(input: TextInput):
65
- try:
66
- text = input.text
67
- # print(f"Received input: {input.text}")
68
-
69
- # ----- GRU Prediction -----
70
- seq = gru_tokenizer.texts_to_sequences([text])
71
- padded = pad_sequences(seq, maxlen=gru_maxlen, padding='post')
72
- gru_prob = float(gru_model.predict(padded)[0][0])
73
-
74
- # ----- RoBERTa Prediction -----
75
- inputs_roberta = roberta_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
76
- with torch.no_grad():
77
- logits_roberta = roberta_model(**inputs_roberta).logits
78
- probs_roberta = torch.nn.functional.softmax(logits_roberta, dim=1)
79
- roberta_prob = float(probs_roberta[0][1].item())
80
-
81
- # -----toxigen -hatebert Prediction -----
82
- inputs_toxigen = toxigen_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
83
- with torch.no_grad():
84
- logits_toxigen = toxigen_model(**inputs_toxigen).logits
85
- probs_toxigen = torch.nn.functional.softmax(logits_toxigen, dim=1)
86
- toxigen_prob = float(probs_toxigen[0][1].item())
87
-
88
- # ----- Weighted Ensemble -----
89
- final_score = (0.3 * gru_prob) + (0.4 * roberta_prob) + (0.3 * toxigen_prob)
90
- label = "Hate Speech" if final_score > 0.5 else "Not Hate Speech"
91
-
92
- return {
93
- # "text": text,
94
- "gru_prob": round(gru_prob, 4),
95
- "roberta_prob": round(roberta_prob, 4),
96
- "toxigen_prob": round(toxigen_prob, 4),
97
- "final_score": round(final_score, 4),
98
- "prediction": label
99
- }
100
-
101
- except Exception as e:
102
- print(f"Error during prediction: {str(e)}")
103
- return JSONResponse(status_code=500, content={"detail": str(e)})
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, Request, HTTPException
2
+ from fastapi.responses import HTMLResponse
3
+ from fastapi.staticfiles import StaticFiles
4
+ from fastapi.middleware.cors import CORSMiddleware
5
+ from pydantic import BaseModel
6
+ import tensorflow as tf
7
+ import pickle
8
+ from tensorflow.keras.preprocessing.sequence import pad_sequences
9
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
10
+ import torch
11
+ import os
12
+ from fastapi.responses import JSONResponse
13
+ # Initialize FastAPI
14
+ app = FastAPI()
15
+
16
+
17
+ cache = "/app/hf_cache"
18
+ os.makedirs(cache, exist_ok=True)
19
+ os.environ["HF_HOME"] = cache
20
+ os.environ["TRANSFORMERS_CACHE"] = cache
21
+ os.environ["XDG_CACHE_HOME"] = cache
22
+
23
+ from transformers import AutoTokenizer
24
+ # Load GRU model and tokenizer
25
+ gru_model = tf.keras.models.load_model('hs_gru.h5')
26
+ with open('tokenizerpkl_gru.pkl', 'rb') as f:
27
+ gru_tokenizer = pickle.load(f)
28
+ gru_maxlen = 100
29
+
30
+ # Load RoBERTa model
31
+ # Load RoBERTa model
32
+ roberta_model_name = "facebook/roberta-hate-speech-dynabench-r4-target"
33
+ roberta_tokenizer = AutoTokenizer.from_pretrained(roberta_model_name)
34
+ if roberta_tokenizer.pad_token is None:
35
+ roberta_tokenizer.add_special_tokens({'pad_token': '[PAD]'})
36
+ roberta_model = AutoModelForSequenceClassification.from_pretrained(roberta_model_name)
37
+ roberta_model.resize_token_embeddings(len(roberta_tokenizer))
38
+
39
+ #load toxigen-hatebert model
40
+ toxigen_model_name = "tomh/toxigen_roberta"
41
+ toxigen_tokenizer = AutoTokenizer.from_pretrained(toxigen_model_name)
42
+ if toxigen_tokenizer.pad_token is None:
43
+ toxigen_tokenizer.add_special_tokens({'pad_token': '[PAD]'})
44
+ toxigen_model = AutoModelForSequenceClassification.from_pretrained(toxigen_model_name)
45
+ toxigen_model.resize_token_embeddings(len(toxigen_tokenizer))
46
+
47
+ # Enable CORS
48
+ app.add_middleware(
49
+ CORSMiddleware,
50
+ allow_origins=["*"],
51
+ allow_credentials=True,
52
+ allow_methods=["*"],
53
+ allow_headers=["*"],
54
+ )
55
+
56
+ # Mount static directory
57
+ # app.mount("/static", StaticFiles(directory="static"), name="static")
58
+
59
+ # Pydantic input model
60
+ class TextInput(BaseModel):
61
+ text: str
62
+
63
+ @app.get("/", response_class=HTMLResponse)
64
+ def read_root():
65
+ with open("index.html", "r") as f:
66
+ return f.read()
67
+
68
+ @app.get("/health")
69
+ def health_check():
70
+ return {"message": "Hate Speech Detection API is running!"}
71
+
72
+ @app.post("/predict")
73
+ def predict_ensemble(input: TextInput):
74
+ try:
75
+ text = input.text
76
+ # print(f"Received input: {input.text}")
77
+
78
+ # ----- GRU Prediction -----
79
+ seq = gru_tokenizer.texts_to_sequences([text])
80
+ padded = pad_sequences(seq, maxlen=gru_maxlen, padding='post')
81
+ gru_prob = float(gru_model.predict(padded)[0][0])
82
+
83
+ # ----- RoBERTa Prediction -----
84
+ inputs_roberta = roberta_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
85
+ with torch.no_grad():
86
+ logits_roberta = roberta_model(**inputs_roberta).logits
87
+ probs_roberta = torch.nn.functional.softmax(logits_roberta, dim=1)
88
+ roberta_prob = float(probs_roberta[0][1].item())
89
+
90
+ # -----toxigen -hatebert Prediction -----
91
+ inputs_toxigen = toxigen_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
92
+ with torch.no_grad():
93
+ logits_toxigen = toxigen_model(**inputs_toxigen).logits
94
+ probs_toxigen = torch.nn.functional.softmax(logits_toxigen, dim=1)
95
+ toxigen_prob = float(probs_toxigen[0][1].item())
96
+
97
+ # ----- Weighted Ensemble -----
98
+ final_score = (0.3 * gru_prob) + (0.4 * roberta_prob) + (0.3 * toxigen_prob)
99
+ label = "Hate Speech" if final_score > 0.5 else "Not Hate Speech"
100
+
101
+ return {
102
+ # "text": text,
103
+ "gru_prob": round(gru_prob, 4),
104
+ "roberta_prob": round(roberta_prob, 4),
105
+ "toxigen_prob": round(toxigen_prob, 4),
106
+ "final_score": round(final_score, 4),
107
+ "prediction": label
108
+ }
109
+
110
+ except Exception as e:
111
+ print(f"Error during prediction: {str(e)}")
112
+ return JSONResponse(status_code=500, content={"detail": str(e)})