Spaces:
Running
Running
File size: 6,755 Bytes
4b9f7d2 3ff1612 4b9f7d2 83de6d6 4b9f7d2 75915c2 67f5f23 4b9f7d2 75915c2 4b9f7d2 75915c2 4b9f7d2 75915c2 4b9f7d2 75915c2 4b9f7d2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
import os, torch, traceback, json, shutil, re
from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments, default_data_collator, AutoConfig
from log import log
from pydantic import BaseModel
INTENT_MODEL = None
INTENT_TOKENIZER = None
LABEL2ID = None
class TrainInput(BaseModel):
intents: list
def background_training(intents, s_config):
global INTENT_MODEL, INTENT_TOKENIZER, LABEL2ID
try:
log("🔧 Intent eğitimi başlatıldı...")
texts, labels, label2id = [], [], {}
for idx, intent in enumerate(intents):
label2id[intent["name"]] = idx
for ex in intent["examples"]:
texts.append(ex)
labels.append(idx)
dataset = Dataset.from_dict({"text": texts, "label": labels})
tokenizer = AutoTokenizer.from_pretrained(s_config.INTENT_MODEL_ID)
config = AutoConfig.from_pretrained(s_config.INTENT_MODEL_ID)
config.problem_type = "single_label_classification"
config.num_labels = len(label2id)
model = AutoModelForSequenceClassification.from_pretrained(s_config.INTENT_MODEL_ID, config=config)
tokenized_data = {"input_ids": [], "attention_mask": [], "label": []}
for row in dataset:
out = tokenizer(row["text"], truncation=True, padding="max_length", max_length=128)
tokenized_data["input_ids"].append(out["input_ids"])
tokenized_data["attention_mask"].append(out["attention_mask"])
tokenized_data["label"].append(row["label"])
tokenized = Dataset.from_dict(tokenized_data)
tokenized.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
output_dir = "/app/intent_train_output"
os.makedirs(output_dir, exist_ok=True)
trainer = Trainer(
model=model,
args=TrainingArguments(output_dir, per_device_train_batch_size=4, num_train_epochs=3, logging_steps=10, save_strategy="no", report_to=[]),
train_dataset=tokenized,
data_collator=default_data_collator
)
trainer.train()
# ✅ Başarı raporu üret
log("🔧 Başarı raporu üretiliyor...")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
input_ids_tensor = tokenized["input_ids"].to(device)
attention_mask_tensor = tokenized["attention_mask"].to(device)
with torch.no_grad():
outputs = model(input_ids=input_ids_tensor, attention_mask=attention_mask_tensor)
predictions = outputs.logits.argmax(dim=-1).tolist()
actuals = tokenized["label"]
counts = {}
correct = {}
for pred, actual in zip(predictions, actuals):
intent = list(label2id.keys())[list(label2id.values()).index(actual)]
counts[intent] = counts.get(intent, 0) + 1
if pred == actual:
correct[intent] = correct.get(intent, 0) + 1
for intent, total in counts.items():
accuracy = correct.get(intent, 0) / total
log(f"📊 Intent '{intent}' doğruluk: {accuracy:.2f} — {total} örnek")
if accuracy < s_config.TRAIN_CONFIDENCE_THRESHOLD or total < 5:
log(f"⚠️ Yetersiz performanslı intent: '{intent}' — Doğruluk: {accuracy:.2f}, Örnek: {total}")
log("📦 Intent modeli eğitimi kaydediliyor...")
if os.path.exists(s_config.INTENT_MODEL_PATH):
shutil.rmtree(s_config.INTENT_MODEL_PATH)
model.save_pretrained(s_config.INTENT_MODEL_PATH)
tokenizer.save_pretrained(s_config.INTENT_MODEL_PATH)
with open(os.path.join(s_config.INTENT_MODEL_PATH, "label2id.json"), "w") as f:
json.dump(label2id, f)
log("✅ Intent eğitimi tamamlandı ve model kaydedildi.")
except Exception as e:
log(f"❌ Intent eğitimi hatası: {e}")
traceback.print_exc()
async def detect_intent(text):
inputs = INTENT_TOKENIZER(text, return_tensors="pt")
outputs = INTENT_MODEL(**inputs)
probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
confidence, pred_id = torch.max(probs, dim=-1)
id2label = {v: k for k, v in LABEL2ID.items()}
return id2label[pred_id.item()], confidence.item()
def extract_parameters(variables_list, user_input):
for pattern in variables_list:
regex = re.sub(r"(\w+):\{(.+?)\}", r"(?P<\1>.+?)", pattern)
match = re.match(regex, user_input)
if match:
return [{"key": k, "value": v} for k, v in match.groupdict().items()]
return []
def resolve_placeholders(text: str, session: dict, variables: dict) -> str:
def replacer(match):
full = match.group(1)
try:
if full.startswith("variables."):
key = full.split(".", 1)[1]
return str(variables.get(key, f"{{{full}}}"))
elif full.startswith("session."):
key = full.split(".", 1)[1]
return str(session.get("variables", {}).get(key, f"{{{full}}}")) # session.variables içinden
elif full.startswith("auth_tokens."):
# auth_tokens.intent.token veya refresh_token
parts = full.split(".")
if len(parts) == 3:
intent, token_type = parts[1], parts[2]
return str(session.get("auth_tokens", {}).get(intent, {}).get(token_type, f"{{{full}}}"))
else:
return f"{{{full}}}"
else:
return f"{{{full}}}" # bilinmeyen yapı
except Exception as e:
return f"{{{full}}}"
return re.sub(r"\{([^{}]+)\}", replacer, text)
def validate_variable_formats(variables, variable_format_map, data_formats):
errors = {}
for var_name, format_name in variable_format_map.items():
value = variables.get(var_name)
if value is None:
continue # eksik parametre kontrolü zaten başka yerde yapılacak
format_def = next((fmt for fmt in data_formats if fmt["name"] == format_name), None)
if not format_def:
continue # tanımsız format
# valid_options kontrolü
if "valid_options" in format_def:
if value not in format_def["valid_options"]:
errors[var_name] = format_def.get("error_message", f"{var_name} değeri geçersiz.")
# pattern kontrolü
elif "pattern" in format_def:
if not re.fullmatch(format_def["pattern"], value):
errors[var_name] = format_def.get("error_message", f"{var_name} formatı geçersiz.")
return len(errors) == 0, errors |