Spaces:
Paused
Paused
| import os, torch, traceback, json, shutil, re | |
| from datasets import Dataset | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments, default_data_collator, AutoConfig | |
| from log import log | |
| from pydantic import BaseModel | |
| INTENT_MODEL = None | |
| INTENT_TOKENIZER = None | |
| LABEL2ID = None | |
| class TrainInput(BaseModel): | |
| intents: list | |
| def background_training(intents, s_config): | |
| global INTENT_MODEL, INTENT_TOKENIZER, LABEL2ID | |
| try: | |
| log("🔧 Intent eğitimi başlatıldı...") | |
| texts, labels, label2id = [], [], {} | |
| for idx, intent in enumerate(intents): | |
| label2id[intent["name"]] = idx | |
| for ex in intent["examples"]: | |
| texts.append(ex) | |
| labels.append(idx) | |
| dataset = Dataset.from_dict({"text": texts, "label": labels}) | |
| tokenizer = AutoTokenizer.from_pretrained(s_config.INTENT_MODEL_ID) | |
| config = AutoConfig.from_pretrained(s_config.INTENT_MODEL_ID) | |
| config.problem_type = "single_label_classification" | |
| config.num_labels = len(label2id) | |
| model = AutoModelForSequenceClassification.from_pretrained(s_config.INTENT_MODEL_ID, config=config) | |
| tokenized_data = {"input_ids": [], "attention_mask": [], "label": []} | |
| for row in dataset: | |
| out = tokenizer(row["text"], truncation=True, padding="max_length", max_length=128) | |
| tokenized_data["input_ids"].append(out["input_ids"]) | |
| tokenized_data["attention_mask"].append(out["attention_mask"]) | |
| tokenized_data["label"].append(row["label"]) | |
| tokenized = Dataset.from_dict(tokenized_data) | |
| tokenized.set_format(type="torch", columns=["input_ids", "attention_mask", "label"]) | |
| output_dir = "/app/intent_train_output" | |
| os.makedirs(output_dir, exist_ok=True) | |
| trainer = Trainer( | |
| model=model, | |
| args=TrainingArguments(output_dir, per_device_train_batch_size=4, num_train_epochs=3, logging_steps=10, save_strategy="no", report_to=[]), | |
| train_dataset=tokenized, | |
| data_collator=default_data_collator | |
| ) | |
| trainer.train() | |
| # ✅ Başarı raporu üret | |
| log("🔧 Başarı raporu üretiliyor...") | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model.to(device) | |
| input_ids_tensor = tokenized["input_ids"].to(device) | |
| attention_mask_tensor = tokenized["attention_mask"].to(device) | |
| with torch.no_grad(): | |
| outputs = model(input_ids=input_ids_tensor, attention_mask=attention_mask_tensor) | |
| predictions = outputs.logits.argmax(dim=-1).tolist() | |
| actuals = tokenized["label"] | |
| counts = {} | |
| correct = {} | |
| for pred, actual in zip(predictions, actuals): | |
| intent = list(label2id.keys())[list(label2id.values()).index(actual)] | |
| counts[intent] = counts.get(intent, 0) + 1 | |
| if pred == actual: | |
| correct[intent] = correct.get(intent, 0) + 1 | |
| for intent, total in counts.items(): | |
| accuracy = correct.get(intent, 0) / total | |
| log(f"📊 Intent '{intent}' doğruluk: {accuracy:.2f} — {total} örnek") | |
| if accuracy < s_config.TRAIN_CONFIDENCE_THRESHOLD or total < 5: | |
| log(f"⚠️ Yetersiz performanslı intent: '{intent}' — Doğruluk: {accuracy:.2f}, Örnek: {total}") | |
| log("📦 Intent modeli eğitimi kaydediliyor...") | |
| if os.path.exists(s_config.INTENT_MODEL_PATH): | |
| shutil.rmtree(s_config.INTENT_MODEL_PATH) | |
| model.save_pretrained(s_config.INTENT_MODEL_PATH) | |
| tokenizer.save_pretrained(s_config.INTENT_MODEL_PATH) | |
| with open(os.path.join(s_config.INTENT_MODEL_PATH, "label2id.json"), "w") as f: | |
| json.dump(label2id, f) | |
| log("✅ Intent eğitimi tamamlandı ve model kaydedildi.") | |
| except Exception as e: | |
| log(f"❌ Intent eğitimi hatası: {e}") | |
| traceback.print_exc() | |
| async def detect_intent(text): | |
| inputs = INTENT_TOKENIZER(text, return_tensors="pt") | |
| outputs = INTENT_MODEL(**inputs) | |
| probs = torch.nn.functional.softmax(outputs.logits, dim=-1) | |
| confidence, pred_id = torch.max(probs, dim=-1) | |
| id2label = {v: k for k, v in LABEL2ID.items()} | |
| return id2label[pred_id.item()], confidence.item() | |
| def extract_parameters(variables_list, user_input): | |
| for pattern in variables_list: | |
| regex = re.sub(r"(\w+):\{(.+?)\}", r"(?P<\1>.+?)", pattern) | |
| match = re.match(regex, user_input) | |
| if match: | |
| return [{"key": k, "value": v} for k, v in match.groupdict().items()] | |
| return [] | |
| def resolve_placeholders(text: str, session: dict, variables: dict) -> str: | |
| def replacer(match): | |
| full = match.group(1) | |
| try: | |
| if full.startswith("variables."): | |
| key = full.split(".", 1)[1] | |
| return str(variables.get(key, f"{{{full}}}")) | |
| elif full.startswith("session."): | |
| key = full.split(".", 1)[1] | |
| return str(session.get("variables", {}).get(key, f"{{{full}}}")) # session.variables içinden | |
| elif full.startswith("auth_tokens."): | |
| # auth_tokens.intent.token veya refresh_token | |
| parts = full.split(".") | |
| if len(parts) == 3: | |
| intent, token_type = parts[1], parts[2] | |
| return str(session.get("auth_tokens", {}).get(intent, {}).get(token_type, f"{{{full}}}")) | |
| else: | |
| return f"{{{full}}}" | |
| else: | |
| return f"{{{full}}}" # bilinmeyen yapı | |
| except Exception as e: | |
| return f"{{{full}}}" | |
| return re.sub(r"\{([^{}]+)\}", replacer, text) | |
| def validate_variable_formats(variables, variable_format_map, data_formats): | |
| errors = {} | |
| for var_name, format_name in variable_format_map.items(): | |
| value = variables.get(var_name) | |
| if value is None: | |
| continue # eksik parametre kontrolü zaten başka yerde yapılacak | |
| format_def = next((fmt for fmt in data_formats if fmt["name"] == format_name), None) | |
| if not format_def: | |
| continue # tanımsız format | |
| # valid_options kontrolü | |
| if "valid_options" in format_def: | |
| if value not in format_def["valid_options"]: | |
| errors[var_name] = format_def.get("error_message", f"{var_name} değeri geçersiz.") | |
| # pattern kontrolü | |
| elif "pattern" in format_def: | |
| if not re.fullmatch(format_def["pattern"], value): | |
| errors[var_name] = format_def.get("error_message", f"{var_name} formatı geçersiz.") | |
| return len(errors) == 0, errors |