RetailGenie / code /train_intent_classifier_local.py
shubh7's picture
Adding application file
5f946b0
import torch
from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification
import os
import json
# Get project root directory
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
def load_model():
print("📦 Loading pre-trained intent classification model...")
model_name = "distilbert-base-uncased"
tokenizer = DistilBertTokenizerFast.from_pretrained(model_name)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Define intent labels
intent_labels = [
"summary", "comparison", "trend", "anomaly", "forecast"
]
num_labels = len(intent_labels)
# Create label mapping
label_mapping = {label: idx for idx, label in enumerate(intent_labels)}
# Load model with our number of labels
model = DistilBertForSequenceClassification.from_pretrained(
model_name, num_labels=num_labels
)
model = model.to(device)
model.eval()
return model, tokenizer, device, label_mapping
def classify_intent(question, model, tokenizer, device, label_mapping):
# Tokenize input
inputs = tokenizer(
question,
return_tensors="pt",
truncation=True,
padding=True,
max_length=128
).to(device)
# Get prediction
with torch.no_grad():
outputs = model(**inputs)
predicted_class_id = outputs.logits.argmax().item()
# Convert back to label
id2label = {v: k for k, v in label_mapping.items()}
intent = id2label[predicted_class_id]
return intent
if __name__ == "__main__":
# Load the model
model, tokenizer, device, label_mapping = load_model()
# Save the model and label mapping
output_dir = os.path.join(PROJECT_ROOT, "model_intent_classifier")
print(f"💾 Saving model to {output_dir}")
os.makedirs(output_dir, exist_ok=True)
model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
# Save label mapping
with open(os.path.join(output_dir, "label_mapping.json"), "w") as f:
json.dump(label_mapping, f)
print(f"✅ Model successfully saved to {output_dir}")
# Example usage
test_questions = [
"What is the total sales amount for each product category?",
"Compare sales between March and April",
"Show me the sales trend over the last 6 months",
"Which products have unusual sales patterns?",
"What will be the sales forecast for next month?"
]
print("\nTesting intent classification:")
for question in test_questions:
intent = classify_intent(question, model, tokenizer, device, label_mapping)
print(f"Question: {question}")
print(f"Predicted intent: {intent}\n")