import torch from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification import os import json # Get project root directory PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) def load_model(): print("📦 Loading pre-trained intent classification model...") model_name = "distilbert-base-uncased" tokenizer = DistilBertTokenizerFast.from_pretrained(model_name) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Define intent labels intent_labels = [ "summary", "comparison", "trend", "anomaly", "forecast" ] num_labels = len(intent_labels) # Create label mapping label_mapping = {label: idx for idx, label in enumerate(intent_labels)} # Load model with our number of labels model = DistilBertForSequenceClassification.from_pretrained( model_name, num_labels=num_labels ) model = model.to(device) model.eval() return model, tokenizer, device, label_mapping def classify_intent(question, model, tokenizer, device, label_mapping): # Tokenize input inputs = tokenizer( question, return_tensors="pt", truncation=True, padding=True, max_length=128 ).to(device) # Get prediction with torch.no_grad(): outputs = model(**inputs) predicted_class_id = outputs.logits.argmax().item() # Convert back to label id2label = {v: k for k, v in label_mapping.items()} intent = id2label[predicted_class_id] return intent if __name__ == "__main__": # Load the model model, tokenizer, device, label_mapping = load_model() # Save the model and label mapping output_dir = os.path.join(PROJECT_ROOT, "model_intent_classifier") print(f"💾 Saving model to {output_dir}") os.makedirs(output_dir, exist_ok=True) model.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir) # Save label mapping with open(os.path.join(output_dir, "label_mapping.json"), "w") as f: json.dump(label_mapping, f) print(f"✅ Model successfully saved to {output_dir}") # Example usage test_questions = [ "What is the total sales amount for each product category?", "Compare sales between March and April", "Show me the sales trend over the last 6 months", "Which products have unusual sales patterns?", "What will be the sales forecast for next month?" ] print("\nTesting intent classification:") for question in test_questions: intent = classify_intent(question, model, tokenizer, device, label_mapping) print(f"Question: {question}") print(f"Predicted intent: {intent}\n")