File size: 5,163 Bytes
1191472
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7b3102a
 
 
1191472
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import torch
import os
import json
from huggingface_hub import hf_hub_download
import gradio as gr

repo_id = "iimran/AnalyserV2"
def download_model_files(repo_id):
    model_path = hf_hub_download(repo_id=repo_id, filename="model_weights.pth")
    vocab_path = hf_hub_download(repo_id=repo_id, filename="vocab.json")
    label_encoder_path = hf_hub_download(repo_id=repo_id, filename="label_encoder.json")
    config_path = hf_hub_download(repo_id=repo_id, filename="config.json")
    return model_path, vocab_path, label_encoder_path, config_path
def get_transformer_model_class():
    model_code = os.getenv("MODEL_VAR")

    if model_code is None:
        raise ValueError("Environment variable 'MODEL_VAR' is not set.")
    exec(model_code, globals())
    if "TransformerModel" not in globals():
        raise NameError("The TransformerModel class was not defined after executing MODEL_VAR.")
    TransformerModel = globals()["TransformerModel"]
    #print("TransformerModel Class:", TransformerModel)

    return TransformerModel
def get_preprocess_function():
    # Retrieve the preprocess_text code from the environment variable
    preprocess_code = os.getenv("MODEL_PROCESS")

    if preprocess_code is None:
        raise ValueError("Environment variable 'MODEL_PROCESS' is not set.")
    exec(preprocess_code, globals())
    if "preprocess_text" not in globals():
        raise NameError("The preprocess_text function was not defined after executing MODEL_PROCESS.")
    #print("Preprocess Function Loaded:", globals()["preprocess_text"])

    return globals()["preprocess_text"]
def load_model_and_resources(repo_id):
    model_path, vocab_path, label_encoder_path, config_path = download_model_files(repo_id)
    try:
        with open(vocab_path, "r") as f:
            vocab = json.load(f)
    except FileNotFoundError:
        raise FileNotFoundError(f"Vocabulary file not found at {vocab_path}. Please check the repository.")
    except json.JSONDecodeError:
        raise ValueError(f"Invalid JSON format in vocabulary file at {vocab_path}.")
    try:
        with open(label_encoder_path, "r") as f:
            label_encoder_classes = json.load(f)
    except FileNotFoundError:
        raise FileNotFoundError(f"Label encoder file not found at {label_encoder_path}. Please check the repository.")
    except json.JSONDecodeError:
        raise ValueError(f"Invalid JSON format in label encoder file at {label_encoder_path}.")
    TransformerModel = get_transformer_model_class()
    model = TransformerModel(vocab_size=len(vocab), num_classes=len(label_encoder_classes))
    model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))  # Use "cuda" if GPU is available
    model.eval()
    #print("Model Architecture:")
    #print(model)

    return model, vocab, label_encoder_classes
preprocess_text = get_preprocess_function()
def predict(text, model, vocab, label_encoder_classes):
    input_ids, attention_mask = preprocess_text(text, vocab)
    print("Input IDs:", input_ids)
    print("Attention Mask:", attention_mask)
    with torch.no_grad():
        outputs = model(input_ids, attention_mask)
        print("Model Outputs:", outputs)  # Debug: Inspect model outputs
    
    if outputs is None:
        raise ValueError("Model returned None. Check the forward method and input data.")
    predicted_class_idx = outputs.argmax(1).item()
    predicted_label = label_encoder_classes[predicted_class_idx]
    return predicted_label
def create_gradio_interface():
    model, vocab, label_encoder_classes = load_model_and_resources(repo_id)
    def predict_wrapper(text):
        return predict(text, model, vocab, label_encoder_classes)
    interface = gr.Interface(
        fn=predict_wrapper,  # Use the wrapper function
        inputs=gr.Textbox(lines=2, placeholder="Enter text here..."),
        outputs=gr.Textbox(label="Predicted Label"),
        title="Text Classification Model",
        description="Enter text to classify it using the model.",
        examples=[
            ["I would like to bring to your attention a pothole on Main Street that has become a safety hazard. The pothole is quite deep and poses a risk to both drivers and pedestrians. I kindly request the council to inspect and repair it at the earliest to prevent any potential accidents or vehicle damage. Please let me know if any further information is required."],
            ["I am writing to report a clogged drainage system in 1 tonsley. The blockage is causing water to accumulate, leading to potential flooding and sanitation issues. This situation poses a risk to public health and safety, especially during rainfall. I kindly request the council to inspect and resolve this issue at the earliest convenience."],
            ["I am writing to report a persistent issue of loud noise coming from my neighbors at 1 tonsley. The noise, which occurs through out the day, has been causing significant disturbance to me and other residents in the area."]
        ],
        cache_examples=False  # Disable caching
    )
    return interface
if __name__ == "__main__":
    interface = create_gradio_interface()
    interface.launch()