Spaces:
Sleeping
Sleeping
import torch | |
import os | |
import json | |
from huggingface_hub import hf_hub_download | |
import gradio as gr | |
repo_id = "iimran/AnalyserV2" | |
def download_model_files(repo_id): | |
model_path = hf_hub_download(repo_id=repo_id, filename="model_weights.pth") | |
vocab_path = hf_hub_download(repo_id=repo_id, filename="vocab.json") | |
label_encoder_path = hf_hub_download(repo_id=repo_id, filename="label_encoder.json") | |
config_path = hf_hub_download(repo_id=repo_id, filename="config.json") | |
return model_path, vocab_path, label_encoder_path, config_path | |
def get_transformer_model_class(): | |
model_code = os.getenv("MODEL_VAR") | |
if model_code is None: | |
raise ValueError("Environment variable 'MODEL_VAR' is not set.") | |
exec(model_code, globals()) | |
if "TransformerModel" not in globals(): | |
raise NameError("The TransformerModel class was not defined after executing MODEL_VAR.") | |
TransformerModel = globals()["TransformerModel"] | |
#print("TransformerModel Class:", TransformerModel) | |
return TransformerModel | |
def get_preprocess_function(): | |
# Retrieve the preprocess_text code from the environment variable | |
preprocess_code = os.getenv("MODEL_PROCESS") | |
if preprocess_code is None: | |
raise ValueError("Environment variable 'MODEL_PROCESS' is not set.") | |
exec(preprocess_code, globals()) | |
if "preprocess_text" not in globals(): | |
raise NameError("The preprocess_text function was not defined after executing MODEL_PROCESS.") | |
#print("Preprocess Function Loaded:", globals()["preprocess_text"]) | |
return globals()["preprocess_text"] | |
def load_model_and_resources(repo_id): | |
model_path, vocab_path, label_encoder_path, config_path = download_model_files(repo_id) | |
try: | |
with open(vocab_path, "r") as f: | |
vocab = json.load(f) | |
except FileNotFoundError: | |
raise FileNotFoundError(f"Vocabulary file not found at {vocab_path}. Please check the repository.") | |
except json.JSONDecodeError: | |
raise ValueError(f"Invalid JSON format in vocabulary file at {vocab_path}.") | |
try: | |
with open(label_encoder_path, "r") as f: | |
label_encoder_classes = json.load(f) | |
except FileNotFoundError: | |
raise FileNotFoundError(f"Label encoder file not found at {label_encoder_path}. Please check the repository.") | |
except json.JSONDecodeError: | |
raise ValueError(f"Invalid JSON format in label encoder file at {label_encoder_path}.") | |
TransformerModel = get_transformer_model_class() | |
model = TransformerModel(vocab_size=len(vocab), num_classes=len(label_encoder_classes)) | |
model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu"))) # Use "cuda" if GPU is available | |
model.eval() | |
#print("Model Architecture:") | |
#print(model) | |
return model, vocab, label_encoder_classes | |
preprocess_text = get_preprocess_function() | |
def predict(text, model, vocab, label_encoder_classes): | |
input_ids, attention_mask = preprocess_text(text, vocab) | |
print("Input IDs:", input_ids) | |
print("Attention Mask:", attention_mask) | |
with torch.no_grad(): | |
outputs = model(input_ids, attention_mask) | |
print("Model Outputs:", outputs) # Debug: Inspect model outputs | |
if outputs is None: | |
raise ValueError("Model returned None. Check the forward method and input data.") | |
predicted_class_idx = outputs.argmax(1).item() | |
predicted_label = label_encoder_classes[predicted_class_idx] | |
return predicted_label | |
def create_gradio_interface(): | |
model, vocab, label_encoder_classes = load_model_and_resources(repo_id) | |
def predict_wrapper(text): | |
return predict(text, model, vocab, label_encoder_classes) | |
interface = gr.Interface( | |
fn=predict_wrapper, # Use the wrapper function | |
inputs=gr.Textbox(lines=2, placeholder="Enter text here..."), | |
outputs=gr.Textbox(label="Predicted Label"), | |
title="Text Classification Model", | |
description="Enter text to classify it using the model.", | |
examples=[ | |
["I would like to bring to your attention a pothole on Main Street that has become a safety hazard. The pothole is quite deep and poses a risk to both drivers and pedestrians. I kindly request the council to inspect and repair it at the earliest to prevent any potential accidents or vehicle damage. Please let me know if any further information is required."], | |
["I am writing to report a clogged drainage system in 1 tonsley. The blockage is causing water to accumulate, leading to potential flooding and sanitation issues. This situation poses a risk to public health and safety, especially during rainfall. I kindly request the council to inspect and resolve this issue at the earliest convenience."], | |
["I am writing to report a persistent issue of loud noise coming from my neighbors at 1 tonsley. The noise, which occurs through out the day, has been causing significant disturbance to me and other residents in the area."] | |
], | |
cache_examples=False # Disable caching | |
) | |
return interface | |
if __name__ == "__main__": | |
interface = create_gradio_interface() | |
interface.launch() |