import torch import os import json from huggingface_hub import hf_hub_download import gradio as gr repo_id = "iimran/AnalyserV2" def download_model_files(repo_id): model_path = hf_hub_download(repo_id=repo_id, filename="model_weights.pth") vocab_path = hf_hub_download(repo_id=repo_id, filename="vocab.json") label_encoder_path = hf_hub_download(repo_id=repo_id, filename="label_encoder.json") config_path = hf_hub_download(repo_id=repo_id, filename="config.json") return model_path, vocab_path, label_encoder_path, config_path def get_transformer_model_class(): model_code = os.getenv("MODEL_VAR") if model_code is None: raise ValueError("Environment variable 'MODEL_VAR' is not set.") exec(model_code, globals()) if "TransformerModel" not in globals(): raise NameError("The TransformerModel class was not defined after executing MODEL_VAR.") TransformerModel = globals()["TransformerModel"] #print("TransformerModel Class:", TransformerModel) return TransformerModel def get_preprocess_function(): # Retrieve the preprocess_text code from the environment variable preprocess_code = os.getenv("MODEL_PROCESS") if preprocess_code is None: raise ValueError("Environment variable 'MODEL_PROCESS' is not set.") exec(preprocess_code, globals()) if "preprocess_text" not in globals(): raise NameError("The preprocess_text function was not defined after executing MODEL_PROCESS.") #print("Preprocess Function Loaded:", globals()["preprocess_text"]) return globals()["preprocess_text"] def load_model_and_resources(repo_id): model_path, vocab_path, label_encoder_path, config_path = download_model_files(repo_id) try: with open(vocab_path, "r") as f: vocab = json.load(f) except FileNotFoundError: raise FileNotFoundError(f"Vocabulary file not found at {vocab_path}. Please check the repository.") except json.JSONDecodeError: raise ValueError(f"Invalid JSON format in vocabulary file at {vocab_path}.") try: with open(label_encoder_path, "r") as f: label_encoder_classes = json.load(f) except FileNotFoundError: raise FileNotFoundError(f"Label encoder file not found at {label_encoder_path}. Please check the repository.") except json.JSONDecodeError: raise ValueError(f"Invalid JSON format in label encoder file at {label_encoder_path}.") TransformerModel = get_transformer_model_class() model = TransformerModel(vocab_size=len(vocab), num_classes=len(label_encoder_classes)) model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu"))) # Use "cuda" if GPU is available model.eval() #print("Model Architecture:") #print(model) return model, vocab, label_encoder_classes preprocess_text = get_preprocess_function() def predict(text, model, vocab, label_encoder_classes): input_ids, attention_mask = preprocess_text(text, vocab) print("Input IDs:", input_ids) print("Attention Mask:", attention_mask) with torch.no_grad(): outputs = model(input_ids, attention_mask) print("Model Outputs:", outputs) # Debug: Inspect model outputs if outputs is None: raise ValueError("Model returned None. Check the forward method and input data.") predicted_class_idx = outputs.argmax(1).item() predicted_label = label_encoder_classes[predicted_class_idx] return predicted_label def create_gradio_interface(): model, vocab, label_encoder_classes = load_model_and_resources(repo_id) def predict_wrapper(text): return predict(text, model, vocab, label_encoder_classes) interface = gr.Interface( fn=predict_wrapper, # Use the wrapper function inputs=gr.Textbox(lines=2, placeholder="Enter text here..."), outputs=gr.Textbox(label="Predicted Label"), title="Text Classification Model", description="Enter text to classify it using the model.", examples=[ ["I would like to bring to your attention a pothole on Main Street that has become a safety hazard. The pothole is quite deep and poses a risk to both drivers and pedestrians. I kindly request the council to inspect and repair it at the earliest to prevent any potential accidents or vehicle damage. Please let me know if any further information is required."], ["I am writing to report a clogged drainage system in 1 tonsley. The blockage is causing water to accumulate, leading to potential flooding and sanitation issues. This situation poses a risk to public health and safety, especially during rainfall. I kindly request the council to inspect and resolve this issue at the earliest convenience."], ["I am writing to report a persistent issue of loud noise coming from my neighbors at 1 tonsley. The noise, which occurs through out the day, has been causing significant disturbance to me and other residents in the area."] ], cache_examples=False # Disable caching ) return interface if __name__ == "__main__": interface = create_gradio_interface() interface.launch()