mgbam commited on
Commit
6918648
·
verified ·
1 Parent(s): e9d76a0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -12
app.py CHANGED
@@ -3,16 +3,16 @@ import tensorflow as tf
3
  from transformers import TFAutoModel, AutoTokenizer
4
  import numpy as np
5
 
6
- # Load pre-trained model
7
  MODEL_NAME = "cardiffnlp/twitter-roberta-base-sentiment-latest"
8
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
9
 
10
- # Ensure model loads with TensorFlow and compatibility fixes
11
- model = tf.keras.models.load_model("model.h5", custom_objects={
12
- "TFRobertaModel": TFAutoModel.from_pretrained(MODEL_NAME)
13
- })
 
14
 
15
- # Labels for predictions
16
  LABELS = [
17
  "Cardiologist", "Dermatologist", "ENT Specialist", "Gastroenterologist",
18
  "General Physicians", "Neurologist", "Ophthalmologist",
@@ -20,16 +20,28 @@ LABELS = [
20
  "Surgeon"
21
  ]
22
 
23
- # Preprocess input data
24
  def preprocess_input(text):
25
  tokens = tokenizer(text, max_length=128, truncation=True, padding="max_length", return_tensors="tf")
 
26
  return {"input_ids": tokens["input_ids"], "attention_mask": tokens["attention_mask"]}
27
 
28
- # Predict from input text
29
  def predict_specialist(text):
30
- inputs = preprocess_input(text)
31
- predictions = model.predict(inputs)
32
- return {LABELS[i]: float(predictions[0][i]) for i in range(len(LABELS))}
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
  # Gradio UI
35
  def build_interface():
@@ -38,7 +50,7 @@ def build_interface():
38
  text_input = gr.Textbox(label="Describe your symptoms:")
39
  output_label = gr.Label(label="Predicted Specialist")
40
  submit_btn = gr.Button("Predict")
41
- submit_btn.click(predict_specialist, inputs=text_input, outputs=output_label)
42
  return demo
43
 
44
  if __name__ == "__main__":
 
3
  from transformers import TFAutoModel, AutoTokenizer
4
  import numpy as np
5
 
6
+ # Load model and tokenizer
7
  MODEL_NAME = "cardiffnlp/twitter-roberta-base-sentiment-latest"
8
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
9
 
10
+ try:
11
+ model = tf.keras.models.load_model("model.h5")
12
+ except Exception as e:
13
+ print(f"Error loading model: {e}")
14
+ model = None
15
 
 
16
  LABELS = [
17
  "Cardiologist", "Dermatologist", "ENT Specialist", "Gastroenterologist",
18
  "General Physicians", "Neurologist", "Ophthalmologist",
 
20
  "Surgeon"
21
  ]
22
 
 
23
  def preprocess_input(text):
24
  tokens = tokenizer(text, max_length=128, truncation=True, padding="max_length", return_tensors="tf")
25
+ print(f"Tokens: {tokens}")
26
  return {"input_ids": tokens["input_ids"], "attention_mask": tokens["attention_mask"]}
27
 
 
28
  def predict_specialist(text):
29
+ if model is None:
30
+ return {"Error": "Model not loaded."}
31
+ try:
32
+ inputs = preprocess_input(text)
33
+ predictions = model.predict(inputs)
34
+ print(f"Predictions: {predictions}")
35
+ return {LABELS[i]: float(predictions[0][i]) for i in range(len(LABELS))}
36
+ except Exception as e:
37
+ print(f"Error during prediction: {e}")
38
+ return {"Error": str(e)}
39
+
40
+ def predict_specialist_ui(text):
41
+ predictions = predict_specialist(text)
42
+ if "Error" in predictions:
43
+ return "An error occurred. Check the logs for more details."
44
+ return predictions
45
 
46
  # Gradio UI
47
  def build_interface():
 
50
  text_input = gr.Textbox(label="Describe your symptoms:")
51
  output_label = gr.Label(label="Predicted Specialist")
52
  submit_btn = gr.Button("Predict")
53
+ submit_btn.click(predict_specialist_ui, inputs=text_input, outputs=output_label)
54
  return demo
55
 
56
  if __name__ == "__main__":