Spaces:
Runtime error
Runtime error
import torch | |
from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification | |
import os | |
import json | |
# Get project root directory | |
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) | |
def load_model(): | |
print("📦 Loading pre-trained intent classification model...") | |
model_name = "distilbert-base-uncased" | |
tokenizer = DistilBertTokenizerFast.from_pretrained(model_name) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# Define intent labels | |
intent_labels = [ | |
"summary", "comparison", "trend", "anomaly", "forecast" | |
] | |
num_labels = len(intent_labels) | |
# Create label mapping | |
label_mapping = {label: idx for idx, label in enumerate(intent_labels)} | |
# Load model with our number of labels | |
model = DistilBertForSequenceClassification.from_pretrained( | |
model_name, num_labels=num_labels | |
) | |
model = model.to(device) | |
model.eval() | |
return model, tokenizer, device, label_mapping | |
def classify_intent(question, model, tokenizer, device, label_mapping): | |
# Tokenize input | |
inputs = tokenizer( | |
question, | |
return_tensors="pt", | |
truncation=True, | |
padding=True, | |
max_length=128 | |
).to(device) | |
# Get prediction | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
predicted_class_id = outputs.logits.argmax().item() | |
# Convert back to label | |
id2label = {v: k for k, v in label_mapping.items()} | |
intent = id2label[predicted_class_id] | |
return intent | |
if __name__ == "__main__": | |
# Load the model | |
model, tokenizer, device, label_mapping = load_model() | |
# Save the model and label mapping | |
output_dir = os.path.join(PROJECT_ROOT, "model_intent_classifier") | |
print(f"💾 Saving model to {output_dir}") | |
os.makedirs(output_dir, exist_ok=True) | |
model.save_pretrained(output_dir) | |
tokenizer.save_pretrained(output_dir) | |
# Save label mapping | |
with open(os.path.join(output_dir, "label_mapping.json"), "w") as f: | |
json.dump(label_mapping, f) | |
print(f"✅ Model successfully saved to {output_dir}") | |
# Example usage | |
test_questions = [ | |
"What is the total sales amount for each product category?", | |
"Compare sales between March and April", | |
"Show me the sales trend over the last 6 months", | |
"Which products have unusual sales patterns?", | |
"What will be the sales forecast for next month?" | |
] | |
print("\nTesting intent classification:") | |
for question in test_questions: | |
intent = classify_intent(question, model, tokenizer, device, label_mapping) | |
print(f"Question: {question}") | |
print(f"Predicted intent: {intent}\n") | |