flare / intent.py
ciyidogan's picture
Update intent.py
83de6d6 verified
raw
history blame
6.76 kB
import os, torch, traceback, json, shutil, re
from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments, default_data_collator, AutoConfig
from log import log
from pydantic import BaseModel
INTENT_MODEL = None
INTENT_TOKENIZER = None
LABEL2ID = None
class TrainInput(BaseModel):
intents: list
def background_training(intents, s_config):
global INTENT_MODEL, INTENT_TOKENIZER, LABEL2ID
try:
log("🔧 Intent eğitimi başlatıldı...")
texts, labels, label2id = [], [], {}
for idx, intent in enumerate(intents):
label2id[intent["name"]] = idx
for ex in intent["examples"]:
texts.append(ex)
labels.append(idx)
dataset = Dataset.from_dict({"text": texts, "label": labels})
tokenizer = AutoTokenizer.from_pretrained(s_config.INTENT_MODEL_ID)
config = AutoConfig.from_pretrained(s_config.INTENT_MODEL_ID)
config.problem_type = "single_label_classification"
config.num_labels = len(label2id)
model = AutoModelForSequenceClassification.from_pretrained(s_config.INTENT_MODEL_ID, config=config)
tokenized_data = {"input_ids": [], "attention_mask": [], "label": []}
for row in dataset:
out = tokenizer(row["text"], truncation=True, padding="max_length", max_length=128)
tokenized_data["input_ids"].append(out["input_ids"])
tokenized_data["attention_mask"].append(out["attention_mask"])
tokenized_data["label"].append(row["label"])
tokenized = Dataset.from_dict(tokenized_data)
tokenized.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
output_dir = "/app/intent_train_output"
os.makedirs(output_dir, exist_ok=True)
trainer = Trainer(
model=model,
args=TrainingArguments(output_dir, per_device_train_batch_size=4, num_train_epochs=3, logging_steps=10, save_strategy="no", report_to=[]),
train_dataset=tokenized,
data_collator=default_data_collator
)
trainer.train()
# ✅ Başarı raporu üret
log("🔧 Başarı raporu üretiliyor...")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
input_ids_tensor = tokenized["input_ids"].to(device)
attention_mask_tensor = tokenized["attention_mask"].to(device)
with torch.no_grad():
outputs = model(input_ids=input_ids_tensor, attention_mask=attention_mask_tensor)
predictions = outputs.logits.argmax(dim=-1).tolist()
actuals = tokenized["label"]
counts = {}
correct = {}
for pred, actual in zip(predictions, actuals):
intent = list(label2id.keys())[list(label2id.values()).index(actual)]
counts[intent] = counts.get(intent, 0) + 1
if pred == actual:
correct[intent] = correct.get(intent, 0) + 1
for intent, total in counts.items():
accuracy = correct.get(intent, 0) / total
log(f"📊 Intent '{intent}' doğruluk: {accuracy:.2f}{total} örnek")
if accuracy < s_config.TRAIN_CONFIDENCE_THRESHOLD or total < 5:
log(f"⚠️ Yetersiz performanslı intent: '{intent}' — Doğruluk: {accuracy:.2f}, Örnek: {total}")
log("📦 Intent modeli eğitimi kaydediliyor...")
if os.path.exists(s_config.INTENT_MODEL_PATH):
shutil.rmtree(s_config.INTENT_MODEL_PATH)
model.save_pretrained(s_config.INTENT_MODEL_PATH)
tokenizer.save_pretrained(s_config.INTENT_MODEL_PATH)
with open(os.path.join(s_config.INTENT_MODEL_PATH, "label2id.json"), "w") as f:
json.dump(label2id, f)
log("✅ Intent eğitimi tamamlandı ve model kaydedildi.")
except Exception as e:
log(f"❌ Intent eğitimi hatası: {e}")
traceback.print_exc()
async def detect_intent(text):
inputs = INTENT_TOKENIZER(text, return_tensors="pt")
outputs = INTENT_MODEL(**inputs)
probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
confidence, pred_id = torch.max(probs, dim=-1)
id2label = {v: k for k, v in LABEL2ID.items()}
return id2label[pred_id.item()], confidence.item()
def extract_parameters(variables_list, user_input):
for pattern in variables_list:
regex = re.sub(r"(\w+):\{(.+?)\}", r"(?P<\1>.+?)", pattern)
match = re.match(regex, user_input)
if match:
return [{"key": k, "value": v} for k, v in match.groupdict().items()]
return []
def resolve_placeholders(text: str, session: dict, variables: dict) -> str:
def replacer(match):
full = match.group(1)
try:
if full.startswith("variables."):
key = full.split(".", 1)[1]
return str(variables.get(key, f"{{{full}}}"))
elif full.startswith("session."):
key = full.split(".", 1)[1]
return str(session.get("variables", {}).get(key, f"{{{full}}}")) # session.variables içinden
elif full.startswith("auth_tokens."):
# auth_tokens.intent.token veya refresh_token
parts = full.split(".")
if len(parts) == 3:
intent, token_type = parts[1], parts[2]
return str(session.get("auth_tokens", {}).get(intent, {}).get(token_type, f"{{{full}}}"))
else:
return f"{{{full}}}"
else:
return f"{{{full}}}" # bilinmeyen yapı
except Exception as e:
return f"{{{full}}}"
return re.sub(r"\{([^{}]+)\}", replacer, text)
def validate_variable_formats(variables, variable_format_map, data_formats):
errors = {}
for var_name, format_name in variable_format_map.items():
value = variables.get(var_name)
if value is None:
continue # eksik parametre kontrolü zaten başka yerde yapılacak
format_def = next((fmt for fmt in data_formats if fmt["name"] == format_name), None)
if not format_def:
continue # tanımsız format
# valid_options kontrolü
if "valid_options" in format_def:
if value not in format_def["valid_options"]:
errors[var_name] = format_def.get("error_message", f"{var_name} değeri geçersiz.")
# pattern kontrolü
elif "pattern" in format_def:
if not re.fullmatch(format_def["pattern"], value):
errors[var_name] = format_def.get("error_message", f"{var_name} formatı geçersiz.")
return len(errors) == 0, errors