iimran's picture
Update app.py
7b3102a verified
import torch
import os
import json
from huggingface_hub import hf_hub_download
import gradio as gr
repo_id = "iimran/AnalyserV2"
def download_model_files(repo_id):
model_path = hf_hub_download(repo_id=repo_id, filename="model_weights.pth")
vocab_path = hf_hub_download(repo_id=repo_id, filename="vocab.json")
label_encoder_path = hf_hub_download(repo_id=repo_id, filename="label_encoder.json")
config_path = hf_hub_download(repo_id=repo_id, filename="config.json")
return model_path, vocab_path, label_encoder_path, config_path
def get_transformer_model_class():
model_code = os.getenv("MODEL_VAR")
if model_code is None:
raise ValueError("Environment variable 'MODEL_VAR' is not set.")
exec(model_code, globals())
if "TransformerModel" not in globals():
raise NameError("The TransformerModel class was not defined after executing MODEL_VAR.")
TransformerModel = globals()["TransformerModel"]
#print("TransformerModel Class:", TransformerModel)
return TransformerModel
def get_preprocess_function():
# Retrieve the preprocess_text code from the environment variable
preprocess_code = os.getenv("MODEL_PROCESS")
if preprocess_code is None:
raise ValueError("Environment variable 'MODEL_PROCESS' is not set.")
exec(preprocess_code, globals())
if "preprocess_text" not in globals():
raise NameError("The preprocess_text function was not defined after executing MODEL_PROCESS.")
#print("Preprocess Function Loaded:", globals()["preprocess_text"])
return globals()["preprocess_text"]
def load_model_and_resources(repo_id):
model_path, vocab_path, label_encoder_path, config_path = download_model_files(repo_id)
try:
with open(vocab_path, "r") as f:
vocab = json.load(f)
except FileNotFoundError:
raise FileNotFoundError(f"Vocabulary file not found at {vocab_path}. Please check the repository.")
except json.JSONDecodeError:
raise ValueError(f"Invalid JSON format in vocabulary file at {vocab_path}.")
try:
with open(label_encoder_path, "r") as f:
label_encoder_classes = json.load(f)
except FileNotFoundError:
raise FileNotFoundError(f"Label encoder file not found at {label_encoder_path}. Please check the repository.")
except json.JSONDecodeError:
raise ValueError(f"Invalid JSON format in label encoder file at {label_encoder_path}.")
TransformerModel = get_transformer_model_class()
model = TransformerModel(vocab_size=len(vocab), num_classes=len(label_encoder_classes))
model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu"))) # Use "cuda" if GPU is available
model.eval()
#print("Model Architecture:")
#print(model)
return model, vocab, label_encoder_classes
preprocess_text = get_preprocess_function()
def predict(text, model, vocab, label_encoder_classes):
input_ids, attention_mask = preprocess_text(text, vocab)
print("Input IDs:", input_ids)
print("Attention Mask:", attention_mask)
with torch.no_grad():
outputs = model(input_ids, attention_mask)
print("Model Outputs:", outputs) # Debug: Inspect model outputs
if outputs is None:
raise ValueError("Model returned None. Check the forward method and input data.")
predicted_class_idx = outputs.argmax(1).item()
predicted_label = label_encoder_classes[predicted_class_idx]
return predicted_label
def create_gradio_interface():
model, vocab, label_encoder_classes = load_model_and_resources(repo_id)
def predict_wrapper(text):
return predict(text, model, vocab, label_encoder_classes)
interface = gr.Interface(
fn=predict_wrapper, # Use the wrapper function
inputs=gr.Textbox(lines=2, placeholder="Enter text here..."),
outputs=gr.Textbox(label="Predicted Label"),
title="Text Classification Model",
description="Enter text to classify it using the model.",
examples=[
["I would like to bring to your attention a pothole on Main Street that has become a safety hazard. The pothole is quite deep and poses a risk to both drivers and pedestrians. I kindly request the council to inspect and repair it at the earliest to prevent any potential accidents or vehicle damage. Please let me know if any further information is required."],
["I am writing to report a clogged drainage system in 1 tonsley. The blockage is causing water to accumulate, leading to potential flooding and sanitation issues. This situation poses a risk to public health and safety, especially during rainfall. I kindly request the council to inspect and resolve this issue at the earliest convenience."],
["I am writing to report a persistent issue of loud noise coming from my neighbors at 1 tonsley. The noise, which occurs through out the day, has been causing significant disturbance to me and other residents in the area."]
],
cache_examples=False # Disable caching
)
return interface
if __name__ == "__main__":
interface = create_gradio_interface()
interface.launch()