cjell commited on
Commit
89191ca
·
1 Parent(s): 10f7d04
Files changed (3) hide show
  1. app.py +24 -67
  2. test_spam.py +2 -1
  3. test_toxic.py +1 -0
app.py CHANGED
@@ -1,100 +1,57 @@
1
  from fastapi import FastAPI
2
  from pydantic import BaseModel
3
  from transformers import pipeline
4
- from datetime import datetime
5
  import os
6
 
 
7
  os.environ["HF_HOME"] = "/tmp"
8
 
9
  SPAM_MODEL = "valurank/distilroberta-spam-comments-detection"
10
  TOXIC_MODEL = "s-nlp/roberta_toxicity_classifier"
11
- SENTIMENT_MODEL = "nlptown/bert-base-multilingual-uncased-sentiment"
12
  NSFW_MODEL = "michellejieli/NSFW_text_classifier"
13
 
14
- # Load models
15
  spam = pipeline("text-classification", model=SPAM_MODEL)
 
16
  toxic = pipeline("text-classification", model=TOXIC_MODEL)
17
- sentiment = pipeline("text-classification", model=SENTIMENT_MODEL)
18
- nsfw = pipeline("text-classification", model=NSFW_MODEL)
19
 
20
- app = FastAPI(title="Plebzs AI Models API")
21
 
22
- class Query(BaseModel):
23
- text: str
 
 
24
 
25
  @app.get("/")
26
  def root():
27
- return {"status": "ok", "message": "Plebzs AI Models API"}
28
-
29
- # Required by Plebzs boss
30
- @app.get("/moderation/ping")
31
- def moderation_ping():
32
- return {
33
- "status": "healthy",
34
- "models": ["spam", "toxic", "sentiment", "nsfw"],
35
- "timestamp": datetime.now().isoformat(),
36
- "version": "1.0.0"
37
- }
38
 
39
- # Main endpoints - formatted for Plebzs
40
- @app.post("/toxicity") # Changed name to match Plebzs expectation
41
- def predict_toxicity(query: Query):
 
 
 
 
42
  result = toxic(query.text)[0]
43
-
44
- # Convert to 0-1 toxicity scale
45
- toxicity_score = result["score"] if result["label"] == "TOXIC" else 1 - result["score"]
46
-
47
- return {
48
- "toxicity_score": round(toxicity_score, 3),
49
- "confidence": round(result["score"], 3),
50
- "raw_output": result
51
- }
52
 
53
  @app.post("/sentiment")
54
  def predict_sentiment(query: Query):
55
  result = sentiment(query.text)[0]
56
-
57
- # Convert star rating to -1 to 1 scale
58
- label = result["label"]
59
- if "1" in label or "2" in label: # 1-2 stars = negative
60
- sentiment_score = -0.7
61
- elif "3" in label: # 3 stars = neutral
62
- sentiment_score = 0.0
63
- else: # 4-5 stars = positive
64
- sentiment_score = 0.7
65
-
66
- return {
67
- "sentiment_score": round(sentiment_score, 3),
68
- "confidence": round(result["score"], 3),
69
- "raw_output": result
70
- }
71
-
72
- # Bonus endpoints (not used by Plebzs yet, but good to have)
73
- @app.post("/spam")
74
- def predict_spam(query: Query):
75
- result = spam(query.text)[0]
76
- spam_score = result["score"] if result["label"] == "SPAM" else 1 - result["score"]
77
-
78
- return {
79
- "spam_score": round(spam_score, 3),
80
- "confidence": round(result["score"], 3),
81
- "raw_output": result
82
- }
83
 
84
  @app.post("/nsfw")
85
  def predict_nsfw(query: Query):
86
  result = nsfw(query.text)[0]
87
- nsfw_score = result["score"] if result["label"] == "NSFW" else 1 - result["score"]
88
-
89
- return {
90
- "nsfw_score": round(nsfw_score, 3),
91
- "confidence": round(result["score"], 3),
92
- "raw_output": result
93
- }
94
 
95
- # Keep your detailed health check
96
  @app.get("/health")
97
  def health_check():
 
98
  status = {
99
  "server": "running",
100
  "models": {}
@@ -120,4 +77,4 @@ def health_check():
120
  "status": f"error: {str(e)}"
121
  }
122
 
123
- return status
 
1
  from fastapi import FastAPI
2
  from pydantic import BaseModel
3
  from transformers import pipeline
 
4
  import os
5
 
6
+
7
  os.environ["HF_HOME"] = "/tmp"
8
 
9
  SPAM_MODEL = "valurank/distilroberta-spam-comments-detection"
10
  TOXIC_MODEL = "s-nlp/roberta_toxicity_classifier"
11
+ SENTIMENT_MODEL = "nlptown/bert-base-multilingual-uncased-sentiment"
12
  NSFW_MODEL = "michellejieli/NSFW_text_classifier"
13
 
 
14
  spam = pipeline("text-classification", model=SPAM_MODEL)
15
+
16
  toxic = pipeline("text-classification", model=TOXIC_MODEL)
 
 
17
 
18
+ sentiment = pipeline("text-classification", model = SENTIMENT_MODEL)
19
 
20
+ nsfw = pipeline("text-classification", model = NSFW_MODEL)
21
+
22
+
23
+ app = FastAPI()
24
 
25
  @app.get("/")
26
  def root():
27
+ return {"status": "ok"}
28
+
29
+ class Query(BaseModel):
30
+ text: str
 
 
 
 
 
 
 
31
 
32
+ @app.post("/spam")
33
+ def predict_spam(query: Query):
34
+ result = spam(query.text)[0]
35
+ return {"label": result["label"], "score": result["score"]}
36
+
37
+ @app.post("/toxic")
38
+ def predict_toxic(query: Query):
39
  result = toxic(query.text)[0]
40
+ return {"label": result["label"], "score": result["score"]}
 
 
 
 
 
 
 
 
41
 
42
  @app.post("/sentiment")
43
  def predict_sentiment(query: Query):
44
  result = sentiment(query.text)[0]
45
+ return {"label": result["label"], "score": result["score"]}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
  @app.post("/nsfw")
48
  def predict_nsfw(query: Query):
49
  result = nsfw(query.text)[0]
50
+ return {"label": result["label"], "score": result["score"]}
 
 
 
 
 
 
51
 
 
52
  @app.get("/health")
53
  def health_check():
54
+
55
  status = {
56
  "server": "running",
57
  "models": {}
 
77
  "status": f"error: {str(e)}"
78
  }
79
 
80
+ return status
test_spam.py CHANGED
@@ -5,6 +5,7 @@ url = "https://cjell-Demo.hf.space/spam"
5
  payload = {"text": "Congratulations! You won $1000! Click this link to claim your prize! htts://fakesite.com."}
6
 
7
  response = requests.post(url, json=payload)
 
8
 
9
  print("Status:", response.status_code)
10
  try:
@@ -14,4 +15,4 @@ except Exception:
14
 
15
  print("")
16
 
17
- print(response.text)
 
5
  payload = {"text": "Congratulations! You won $1000! Click this link to claim your prize! htts://fakesite.com."}
6
 
7
  response = requests.post(url, json=payload)
8
+ data = response.json()
9
 
10
  print("Status:", response.status_code)
11
  try:
 
15
 
16
  print("")
17
 
18
+ print(data['raw_output']['label'])
test_toxic.py CHANGED
@@ -5,6 +5,7 @@ url = "https://cjell-Demo.hf.space/toxic"
5
  payload = {"text": "I hate you!"}
6
 
7
  response = requests.post(url, json=payload)
 
8
 
9
  print("Status:", response.status_code)
10
  try:
 
5
  payload = {"text": "I hate you!"}
6
 
7
  response = requests.post(url, json=payload)
8
+ data = response.json()
9
 
10
  print("Status:", response.status_code)
11
  try: