iimran commited on
Commit
1191472
·
verified ·
1 Parent(s): 869e048

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +96 -0
app.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ import json
4
+ from huggingface_hub import hf_hub_download
5
+ import gradio as gr
6
+
7
+ repo_id = "iimran/AnalyserV2"
8
+ def download_model_files(repo_id):
9
+ model_path = hf_hub_download(repo_id=repo_id, filename="model_weights.pth")
10
+ vocab_path = hf_hub_download(repo_id=repo_id, filename="vocab.json")
11
+ label_encoder_path = hf_hub_download(repo_id=repo_id, filename="label_encoder.json")
12
+ config_path = hf_hub_download(repo_id=repo_id, filename="config.json")
13
+ return model_path, vocab_path, label_encoder_path, config_path
14
+ def get_transformer_model_class():
15
+ model_code = os.getenv("MODEL_VAR")
16
+
17
+ if model_code is None:
18
+ raise ValueError("Environment variable 'MODEL_VAR' is not set.")
19
+ exec(model_code, globals())
20
+ if "TransformerModel" not in globals():
21
+ raise NameError("The TransformerModel class was not defined after executing MODEL_VAR.")
22
+ TransformerModel = globals()["TransformerModel"]
23
+ #print("TransformerModel Class:", TransformerModel)
24
+
25
+ return TransformerModel
26
+ def get_preprocess_function():
27
+ # Retrieve the preprocess_text code from the environment variable
28
+ preprocess_code = os.getenv("MODEL_PROCESS")
29
+
30
+ if preprocess_code is None:
31
+ raise ValueError("Environment variable 'MODEL_PROCESS' is not set.")
32
+ exec(preprocess_code, globals())
33
+ if "preprocess_text" not in globals():
34
+ raise NameError("The preprocess_text function was not defined after executing MODEL_PROCESS.")
35
+ #print("Preprocess Function Loaded:", globals()["preprocess_text"])
36
+
37
+ return globals()["preprocess_text"]
38
+ def load_model_and_resources(repo_id):
39
+ model_path, vocab_path, label_encoder_path, config_path = download_model_files(repo_id)
40
+ try:
41
+ with open(vocab_path, "r") as f:
42
+ vocab = json.load(f)
43
+ except FileNotFoundError:
44
+ raise FileNotFoundError(f"Vocabulary file not found at {vocab_path}. Please check the repository.")
45
+ except json.JSONDecodeError:
46
+ raise ValueError(f"Invalid JSON format in vocabulary file at {vocab_path}.")
47
+ try:
48
+ with open(label_encoder_path, "r") as f:
49
+ label_encoder_classes = json.load(f)
50
+ except FileNotFoundError:
51
+ raise FileNotFoundError(f"Label encoder file not found at {label_encoder_path}. Please check the repository.")
52
+ except json.JSONDecodeError:
53
+ raise ValueError(f"Invalid JSON format in label encoder file at {label_encoder_path}.")
54
+ TransformerModel = get_transformer_model_class()
55
+ model = TransformerModel(vocab_size=len(vocab), num_classes=len(label_encoder_classes))
56
+ model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu"))) # Use "cuda" if GPU is available
57
+ model.eval()
58
+ #print("Model Architecture:")
59
+ #print(model)
60
+
61
+ return model, vocab, label_encoder_classes
62
+ preprocess_text = get_preprocess_function()
63
+ def predict(text, model, vocab, label_encoder_classes):
64
+ input_ids, attention_mask = preprocess_text(text, vocab)
65
+ print("Input IDs:", input_ids)
66
+ print("Attention Mask:", attention_mask)
67
+ with torch.no_grad():
68
+ outputs = model(input_ids, attention_mask)
69
+ print("Model Outputs:", outputs) # Debug: Inspect model outputs
70
+
71
+ if outputs is None:
72
+ raise ValueError("Model returned None. Check the forward method and input data.")
73
+ predicted_class_idx = outputs.argmax(1).item()
74
+ predicted_label = label_encoder_classes[predicted_class_idx]
75
+ return predicted_label
76
+ def create_gradio_interface():
77
+ model, vocab, label_encoder_classes = load_model_and_resources(repo_id)
78
+ def predict_wrapper(text):
79
+ return predict(text, model, vocab, label_encoder_classes)
80
+ interface = gr.Interface(
81
+ fn=predict_wrapper, # Use the wrapper function
82
+ inputs=gr.Textbox(lines=2, placeholder="Enter text here..."),
83
+ outputs=gr.Textbox(label="Predicted Label"),
84
+ title="Text Classification Model",
85
+ description="Enter text to classify it using the model.",
86
+ examples=[
87
+ ["There is a pothole on Main Street."],
88
+ ["The drainage system is clogged."],
89
+ ["The streetlights are not working."]
90
+ ],
91
+ cache_examples=False # Disable caching
92
+ )
93
+ return interface
94
+ if __name__ == "__main__":
95
+ interface = create_gradio_interface()
96
+ interface.launch()