Rerandaka commited on
Commit
130aa0d
·
verified ·
1 Parent(s): 1555460

fix the error

Browse files
Files changed (1) hide show
  1. app.py +7 -17
app.py CHANGED
@@ -2,18 +2,16 @@ import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
  import torch
4
 
5
- # Load model and tokenizer
6
- model_id = "Rerandaka/Cild_safety_bigbird"
7
  tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
8
  model = AutoModelForSequenceClassification.from_pretrained(model_id)
9
 
10
- # Class mapping (optional — edit as needed)
11
  label_map = {
12
  0: "Safe / Normal",
13
  1: "Inappropriate / Unsafe"
14
  }
15
 
16
- # Inference function
17
  def classify_text(text: str):
18
  inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=256)
19
  with torch.no_grad():
@@ -21,23 +19,15 @@ def classify_text(text: str):
21
  probs = torch.nn.functional.softmax(outputs.logits, dim=1)
22
  predicted = torch.argmax(probs, dim=1).item()
23
  confidence = probs[0][predicted].item()
24
- return {
25
- "label": label_map.get(predicted, str(predicted)),
26
- "confidence": round(confidence, 4)
27
- }
28
 
29
- # Define Gradio Interface
30
  demo = gr.Interface(
31
  fn=classify_text,
32
  inputs=gr.Textbox(label="Enter text to classify"),
33
- outputs=[
34
- gr.Textbox(label="Predicted Label"),
35
- gr.Textbox(label="Confidence")
36
- ],
37
  title="Child-Safety Text Classifier",
38
- description="This model detects if text content is unsafe or inappropriate for children.",
39
- allow_flagging="never"
40
  )
41
 
42
- # Expose API endpoint explicitly
43
- demo.launch(api_name="predict")
 
2
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
  import torch
4
 
5
+ model_id = "Rerandaka/child-safety-01"
6
+
7
  tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
8
  model = AutoModelForSequenceClassification.from_pretrained(model_id)
9
 
 
10
  label_map = {
11
  0: "Safe / Normal",
12
  1: "Inappropriate / Unsafe"
13
  }
14
 
 
15
  def classify_text(text: str):
16
  inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=256)
17
  with torch.no_grad():
 
19
  probs = torch.nn.functional.softmax(outputs.logits, dim=1)
20
  predicted = torch.argmax(probs, dim=1).item()
21
  confidence = probs[0][predicted].item()
22
+ return f"{label_map.get(predicted, predicted)} (Confidence: {confidence:.2f})"
 
 
 
23
 
 
24
  demo = gr.Interface(
25
  fn=classify_text,
26
  inputs=gr.Textbox(label="Enter text to classify"),
27
+ outputs=gr.Textbox(label="Prediction"),
 
 
 
28
  title="Child-Safety Text Classifier",
29
+ description="This model detects unsafe or inappropriate text for children.",
30
+ flagging_mode="never"
31
  )
32
 
33
+ demo.launch() # 🚫 DO NOT include api_name