FlinShaHealth / app.py
mgbam's picture
Update app.py
6918648 verified
import gradio as gr
import tensorflow as tf
from transformers import TFAutoModel, AutoTokenizer
import numpy as np
# Load model and tokenizer
MODEL_NAME = "cardiffnlp/twitter-roberta-base-sentiment-latest"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
try:
model = tf.keras.models.load_model("model.h5")
except Exception as e:
print(f"Error loading model: {e}")
model = None
LABELS = [
"Cardiologist", "Dermatologist", "ENT Specialist", "Gastroenterologist",
"General Physicians", "Neurologist", "Ophthalmologist",
"Orthopedist", "Psychiatrist", "Respirologist", "Rheumatologist",
"Surgeon"
]
def preprocess_input(text):
tokens = tokenizer(text, max_length=128, truncation=True, padding="max_length", return_tensors="tf")
print(f"Tokens: {tokens}")
return {"input_ids": tokens["input_ids"], "attention_mask": tokens["attention_mask"]}
def predict_specialist(text):
if model is None:
return {"Error": "Model not loaded."}
try:
inputs = preprocess_input(text)
predictions = model.predict(inputs)
print(f"Predictions: {predictions}")
return {LABELS[i]: float(predictions[0][i]) for i in range(len(LABELS))}
except Exception as e:
print(f"Error during prediction: {e}")
return {"Error": str(e)}
def predict_specialist_ui(text):
predictions = predict_specialist(text)
if "Error" in predictions:
return "An error occurred. Check the logs for more details."
return predictions
# Gradio UI
def build_interface():
with gr.Blocks() as demo:
gr.Markdown("## Welcome to FlinShaHealth")
text_input = gr.Textbox(label="Describe your symptoms:")
output_label = gr.Label(label="Predicted Specialist")
submit_btn = gr.Button("Predict")
submit_btn.click(predict_specialist_ui, inputs=text_input, outputs=output_label)
return demo
if __name__ == "__main__":
app = build_interface()
app.launch()