flare / intent.py
ciyidogan's picture
Upload 15 files
16134a9 verified
raw
history blame
6.42 kB
import os
import torch
import json
import shutil
import re
import traceback
from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments, default_data_collator, AutoConfig
from log import log
INTENT_MODELS = {} # project_name -> (model, tokenizer, label2id)
async def detect_intent(text):
# Bu fonksiyon bir örnek; çağırırken ilgili proje için model alınmalı
raise NotImplementedError("detect_intent çağrısı, proje bazlı model ile yapılmalıdır.")
def background_training(project_name, intents, model_id, output_path, confidence_threshold):
try:
log(f"🔧 Intent eğitimi başlatıldı (proje: {project_name})")
texts, labels, label2id = [], [], {}
for idx, intent in enumerate(intents):
label2id[intent["name"]] = idx
for ex in intent["examples"]:
texts.append(ex)
labels.append(idx)
dataset = Dataset.from_dict({"text": texts, "label": labels})
tokenizer = AutoTokenizer.from_pretrained(model_id)
config = AutoConfig.from_pretrained(model_id)
config.problem_type = "single_label_classification"
config.num_labels = len(label2id)
model = AutoModelForSequenceClassification.from_pretrained(model_id, config=config)
tokenized_data = {"input_ids": [], "attention_mask": [], "label": []}
for row in dataset:
out = tokenizer(row["text"], truncation=True, padding="max_length", max_length=128)
tokenized_data["input_ids"].append(out["input_ids"])
tokenized_data["attention_mask"].append(out["attention_mask"])
tokenized_data["label"].append(row["label"])
tokenized = Dataset.from_dict(tokenized_data)
tokenized.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
if os.path.exists(output_path):
shutil.rmtree(output_path)
os.makedirs(output_path, exist_ok=True)
trainer = Trainer(
model=model,
args=TrainingArguments(output_path, per_device_train_batch_size=4, num_train_epochs=3, logging_steps=10, save_strategy="no", report_to=[]),
train_dataset=tokenized,
data_collator=default_data_collator
)
trainer.train()
# Başarı raporu
log("🔧 Başarı raporu üretiliyor...")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
input_ids_tensor = torch.tensor(tokenized["input_ids"]).to(device)
attention_mask_tensor = torch.tensor(tokenized["attention_mask"]).to(device)
with torch.no_grad():
outputs = model(input_ids=input_ids_tensor, attention_mask=attention_mask_tensor)
predictions = outputs.logits.argmax(dim=-1).tolist()
actuals = tokenized["label"]
counts, correct = {}, {}
for pred, actual in zip(predictions, actuals):
intent_name = list(label2id.keys())[list(label2id.values()).index(actual)]
counts[intent_name] = counts.get(intent_name, 0) + 1
if pred == actual:
correct[intent_name] = correct.get(intent_name, 0) + 1
for intent_name, total in counts.items():
accuracy = correct.get(intent_name, 0) / total
log(f"📊 Intent '{intent_name}' doğruluk: {accuracy:.2f}{total} örnek")
if accuracy < confidence_threshold or total < 5:
log(f"⚠️ Yetersiz performanslı intent: '{intent_name}' — Doğruluk: {accuracy:.2f}, Örnek: {total}")
model.save_pretrained(output_path)
tokenizer.save_pretrained(output_path)
with open(os.path.join(output_path, "label2id.json"), "w") as f:
json.dump(label2id, f)
INTENT_MODELS[project_name] = {
"model": model,
"tokenizer": tokenizer,
"label2id": label2id
}
log(f"✅ Intent eğitimi tamamlandı ve '{project_name}' modeli yüklendi.")
except Exception as e:
log(f"❌ Intent eğitimi hatası: {e}")
traceback.print_exc()
def extract_parameters(variables_list, user_input):
for pattern in variables_list:
regex = re.sub(r"(\w+):\{(.+?)\}", r"(?P<\1>.+?)", pattern)
match = re.match(regex, user_input)
if match:
return [{"key": k, "value": v} for k, v in match.groupdict().items()]
return []
def resolve_placeholders(text: str, session: dict, variables: dict) -> str:
def replacer(match):
full = match.group(1)
try:
if full.startswith("variables."):
key = full.split(".", 1)[1]
return str(variables.get(key, f"{{{full}}}"))
elif full.startswith("session."):
key = full.split(".", 1)[1]
return str(session.get("variables", {}).get(key, f"{{{full}}}"))
elif full.startswith("auth_tokens."):
parts = full.split(".")
if len(parts) == 3:
intent, token_type = parts[1], parts[2]
return str(session.get("auth_tokens", {}).get(intent, {}).get(token_type, f"{{{full}}}"))
else:
return f"{{{full}}}"
else:
return f"{{{full}}}"
except Exception:
return f"{{{full}}}"
return re.sub(r"\{([^{}]+)\}", replacer, text)
def validate_variable_formats(variables, variable_format_map, data_formats):
errors = {}
for var_name, format_name in variable_format_map.items():
value = variables.get(var_name)
if value is None:
continue
format_def = data_formats.get(format_name)
if not format_def:
continue
if "valid_options" in format_def:
if value not in format_def["valid_options"]:
errors[var_name] = format_def.get("error_message", f"{var_name} değeri geçersiz.")
elif "pattern" in format_def:
if not re.fullmatch(format_def["pattern"], value):
errors[var_name] = format_def.get("error_message", f"{var_name} formatı geçersiz.")
return len(errors) == 0, errors