Sid26Roy commited on
Commit
94929e3
·
verified ·
1 Parent(s): 86a8a3f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -0
app.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import BertTokenizer, BertForSequenceClassification
3
+ import gradio as gr
4
+
5
+ # Load model and tokenizer from local directory (same folder as app.py)
6
+ model = BertForSequenceClassification.from_pretrained("bert")
7
+ tokenizer = BertTokenizer.from_pretrained(".")
8
+
9
+ # Ensure model runs on GPU if available
10
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
+ model.to(device)
12
+ model.eval()
13
+
14
+ # ID to label mapping
15
+ id2label = {0: "Select", 1: "Insert", 2: "Delete", 3: "Update", 4: "Analyse"}
16
+
17
+ def classify_query(text):
18
+ if not text.strip():
19
+ return "Please enter a valid query."
20
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128).to(device)
21
+ with torch.no_grad():
22
+ outputs = model(**inputs)
23
+ prediction = torch.argmax(outputs.logits, dim=1).item()
24
+ label = id2label.get(prediction, "Unknown")
25
+ return f"🧠 Predicted Query Type: **{label}**"
26
+
27
+ # Gradio UI
28
+ demo = gr.Interface(
29
+ fn=classify_query,
30
+ inputs=gr.Textbox(label="Enter your expense query", lines=2, placeholder="e.g., Show me all expenses from January."),
31
+ outputs=gr.Markdown(label="Query Type"),
32
+ title="💰 Expense Query Type Classifier",
33
+ description="This model classifies your natural language query into one of 5 SQL operation types: Select, Insert, Delete, Update, or Analyse.",
34
+ examples=[
35
+ ["Add an expense of 500 in groceries at Amazon"],
36
+ ["Remove last transaction from Starbucks"],
37
+ ["Update amount of food expense to 850"],
38
+ ["Kitna kharcha hua electronics par?"],
39
+ ["Give me analytics of travel spending"]
40
+ ],
41
+ theme="soft",
42
+ )
43
+
44
+ if __name__ == "__main__":
45
+ demo.launch()