Rerandaka commited on
Commit
ff14db7
·
verified ·
1 Parent(s): ebeee49

update for API

Browse files
Files changed (1) hide show
  1. app.py +13 -18
app.py CHANGED
@@ -2,32 +2,27 @@ import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
  import torch
4
 
 
5
  model_id = "Rerandaka/Cild_safety_bigbird"
6
-
7
  tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
8
  model = AutoModelForSequenceClassification.from_pretrained(model_id)
9
 
10
- label_map = {
11
- 0: "Safe / Normal",
12
- 1: "Inappropriate / Unsafe"
13
- }
14
-
15
- def classify_text(text: str):
16
- inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=256)
17
  with torch.no_grad():
18
- outputs = model(**inputs)
19
- probs = torch.nn.functional.softmax(outputs.logits, dim=1)
20
- predicted = torch.argmax(probs, dim=1).item()
21
- confidence = probs[0][predicted].item()
22
- return f"{label_map.get(predicted, predicted)} (Confidence: {confidence:.2f})"
23
 
 
24
  demo = gr.Interface(
25
- fn=classify_text,
26
- inputs=gr.Textbox(label="Enter text to classify"),
27
  outputs=gr.Textbox(label="Prediction"),
28
  title="Child-Safety Text Classifier",
29
- description="This model detects unsafe or inappropriate text for children.",
30
- flagging_mode="never"
31
  )
32
 
33
- demo.launch() # 🚫 DO NOT include api_name
 
 
2
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
  import torch
4
 
5
+ # Load model
6
  model_id = "Rerandaka/Cild_safety_bigbird"
 
7
  tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
8
  model = AutoModelForSequenceClassification.from_pretrained(model_id)
9
 
10
+ # Inference function
11
+ def classify(text):
12
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
 
 
 
 
13
  with torch.no_grad():
14
+ logits = model(**inputs).logits
15
+ predicted_class = torch.argmax(logits, dim=1).item()
16
+ return str(predicted_class)
 
 
17
 
18
+ # Gradio interface with a named API
19
  demo = gr.Interface(
20
+ fn=classify,
21
+ inputs=gr.Textbox(label="Text Input"),
22
  outputs=gr.Textbox(label="Prediction"),
23
  title="Child-Safety Text Classifier",
24
+ description="This model detects unsafe or inappropriate text for children."
 
25
  )
26
 
27
+ # Launch with API endpoint
28
+ demo.launch(api_name="predict")