ciyidogan commited on
Commit
699f9fe
·
verified ·
1 Parent(s): 1fd319a

Update intent_utils.py

Browse files
Files changed (1) hide show
  1. intent_utils.py +64 -174
intent_utils.py CHANGED
@@ -1,174 +1,64 @@
1
- import os
2
- import torch
3
- import json
4
- import shutil
5
- import re
6
- import traceback
7
- from datasets import Dataset
8
- from transformers import (
9
- AutoTokenizer,
10
- AutoModelForSequenceClassification,
11
- Trainer,
12
- TrainingArguments,
13
- default_data_collator,
14
- AutoConfig,
15
- )
16
- from log import log
17
- from core import llm_models
18
-
19
- async def detect_intent(text, project_name):
20
- llm_model_instance = llm_models.get(project_name)
21
- if not llm_model_instance or not llm_model_instance.intent_model:
22
- raise Exception(f"'{project_name}' için intent modeli yüklenmemiş.")
23
-
24
- tokenizer = llm_model_instance.intent_tokenizer
25
- model = llm_model_instance.intent_model
26
- label2id = llm_model_instance.intent_label2id
27
-
28
- inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
29
- outputs = model(**inputs)
30
- predicted_id = outputs.logits.argmax(dim=-1).item()
31
-
32
- detected_intent = [k for k, v in label2id.items() if v == predicted_id][0]
33
- confidence = outputs.logits.softmax(dim=-1).max().item()
34
-
35
- return detected_intent, confidence
36
-
37
- def background_training(project_name, intents, model_id, output_path, confidence_threshold):
38
- try:
39
- log(f"🔧 Intent eğitimi başlatıldı (proje: {project_name})")
40
- texts, labels, label2id = [], [], {}
41
- for idx, intent in enumerate(intents):
42
- label2id[intent["name"]] = idx
43
- for ex in intent["examples"]:
44
- texts.append(ex)
45
- labels.append(idx)
46
-
47
- dataset = Dataset.from_dict({"text": texts, "label": labels})
48
- tokenizer = AutoTokenizer.from_pretrained(model_id)
49
- config = AutoConfig.from_pretrained(model_id)
50
- config.problem_type = "single_label_classification"
51
- config.num_labels = len(label2id)
52
- model = AutoModelForSequenceClassification.from_pretrained(model_id, config=config)
53
-
54
- tokenized_data = {"input_ids": [], "attention_mask": [], "label": []}
55
- for row in dataset:
56
- out = tokenizer(row["text"], truncation=True, padding="max_length", max_length=128)
57
- tokenized_data["input_ids"].append(out["input_ids"])
58
- tokenized_data["attention_mask"].append(out["attention_mask"])
59
- tokenized_data["label"].append(row["label"])
60
-
61
- tokenized = Dataset.from_dict(tokenized_data)
62
- tokenized.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
63
-
64
- if os.path.exists(output_path):
65
- shutil.rmtree(output_path)
66
- os.makedirs(output_path, exist_ok=True)
67
-
68
- trainer = Trainer(
69
- model=model,
70
- args=TrainingArguments(output_path, per_device_train_batch_size=4, num_train_epochs=3, logging_steps=10, save_strategy="no", report_to=[]),
71
- train_dataset=tokenized,
72
- data_collator=default_data_collator,
73
- )
74
- trainer.train()
75
-
76
- log("🔧 Başarı raporu üretiliyor...")
77
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
78
- model.to(device)
79
- input_ids_tensor = torch.tensor(tokenized["input_ids"]).to(device)
80
- attention_mask_tensor = torch.tensor(tokenized["attention_mask"]).to(device)
81
-
82
- with torch.no_grad():
83
- outputs = model(input_ids=input_ids_tensor, attention_mask=attention_mask_tensor)
84
- predictions = outputs.logits.argmax(dim=-1).tolist()
85
-
86
- actuals = tokenized["label"]
87
- counts, correct = {}, {}
88
- for pred, actual in zip(predictions, actuals):
89
- intent_name = list(label2id.keys())[list(label2id.values()).index(actual)]
90
- counts[intent_name] = counts.get(intent_name, 0) + 1
91
- if pred == actual:
92
- correct[intent_name] = correct.get(intent_name, 0) + 1
93
- for intent_name, total in counts.items():
94
- accuracy = correct.get(intent_name, 0) / total
95
- log(f"📊 Intent '{intent_name}' doğruluk: {accuracy:.2f} — {total} örnek")
96
- if accuracy < confidence_threshold or total < 5:
97
- log(f"⚠️ Yetersiz performanslı intent: '{intent_name}' — Doğruluk: {accuracy:.2f}, Örnek: {total}")
98
-
99
- # Eğitim sonrası model ve tokenizer'ı diske kaydet
100
- model.save_pretrained(output_path)
101
- tokenizer.save_pretrained(output_path)
102
- with open(os.path.join(output_path, "label2id.json"), "w") as f:
103
- json.dump(label2id, f)
104
-
105
- log(f"✅ Intent eğitimi tamamlandı ve '{project_name}' için model disk üzerinde hazır.")
106
-
107
- except Exception as e:
108
- log(f"❌ Intent eğitimi hatası: {e}")
109
- traceback.print_exc()
110
-
111
- def extract_parameters(variables_list, user_input):
112
- extracted_params = []
113
- for pattern in variables_list:
114
- # Örneğin: from_location:{Ankara} to_location:{İstanbul}
115
- regex = re.sub(r"(\w+):\{(.+?)\}", r"(?P<\1>.+?)", pattern)
116
- match = re.match(regex, user_input)
117
- if match:
118
- extracted_params = [{"key": k, "value": v} for k, v in match.groupdict().items()]
119
- break
120
-
121
- # Ek özel basit yakalama: iki şehir birden yazılırsa → sırayla atama
122
- if not extracted_params:
123
- city_pattern = r"(\bAnkara\b|\bİstanbul\b|\bİzmir\b)"
124
- cities = re.findall(city_pattern, user_input)
125
- if len(cities) >= 2:
126
- extracted_params = [
127
- {"key": "from_location", "value": cities[0]},
128
- {"key": "to_location", "value": cities[1]}
129
- ]
130
- return extracted_params
131
-
132
- def resolve_placeholders(text: str, session: dict, variables: dict) -> str:
133
- def replacer(match):
134
- full = match.group(1)
135
- try:
136
- if full.startswith("variables."):
137
- key = full.split(".", 1)[1]
138
- return str(variables.get(key, f"{{{full}}}"))
139
- elif full.startswith("session."):
140
- key = full.split(".", 1)[1]
141
- return str(session.get("variables", {}).get(key, f"{{{full}}}"))
142
- elif full.startswith("auth_tokens."):
143
- parts = full.split(".")
144
- if len(parts) == 3:
145
- intent, token_type = parts[1], parts[2]
146
- return str(session.get("auth_tokens", {}).get(intent, {}).get(token_type, f"{{{full}}}"))
147
- else:
148
- return f"{{{full}}}"
149
- else:
150
- return f"{{{full}}}"
151
- except Exception:
152
- return f"{{{full}}}"
153
-
154
- return re.sub(r"\{([^{}]+)\}", replacer, text)
155
-
156
- def validate_variable_formats(variables, variable_format_map, data_formats):
157
- errors = {}
158
- for var_name, format_name in variable_format_map.items():
159
- value = variables.get(var_name)
160
- if value is None:
161
- continue
162
-
163
- format_def = data_formats.get(format_name)
164
- if not format_def:
165
- continue
166
-
167
- if "valid_options" in format_def:
168
- if value not in format_def["valid_options"]:
169
- errors[var_name] = format_def.get("error_message", f"{var_name} değeri geçersiz.")
170
- elif "pattern" in format_def:
171
- if not re.fullmatch(format_def["pattern"], value):
172
- errors[var_name] = format_def.get("error_message", f"{var_name} formatı geçersiz.")
173
-
174
- return len(errors) == 0, errors
 
1
+ import re
2
+
3
+ def extract_parameters(variables_list, user_input):
4
+ extracted_params = []
5
+ for pattern in variables_list:
6
+ regex = re.sub(r"(\w+):\{(.+?)\}", r"(?P<\1>.+?)", pattern)
7
+ match = re.match(regex, user_input)
8
+ if match:
9
+ extracted_params = [{"key": k, "value": v} for k, v in match.groupdict().items()]
10
+ break
11
+
12
+ if not extracted_params:
13
+ city_pattern = r"(\bAnkara\b|\bİstanbul\b|\bİzmir\b)"
14
+ cities = re.findall(city_pattern, user_input)
15
+ if len(cities) >= 2:
16
+ extracted_params = [
17
+ {"key": "from_location", "value": cities[0]},
18
+ {"key": "to_location", "value": cities[1]}
19
+ ]
20
+ return extracted_params
21
+
22
+ def resolve_placeholders(text: str, session: dict, variables: dict) -> str:
23
+ def replacer(match):
24
+ full = match.group(1)
25
+ try:
26
+ if full.startswith("variables."):
27
+ key = full.split(".", 1)[1]
28
+ return str(variables.get(key, f"{{{full}}}"))
29
+ elif full.startswith("session."):
30
+ key = full.split(".", 1)[1]
31
+ return str(session.get("variables", {}).get(key, f"{{{full}}}"))
32
+ elif full.startswith("auth_tokens."):
33
+ parts = full.split(".")
34
+ if len(parts) == 3:
35
+ intent, token_type = parts[1], parts[2]
36
+ return str(session.get("auth_tokens", {}).get(intent, {}).get(token_type, f"{{{full}}}"))
37
+ else:
38
+ return f"{{{full}}}"
39
+ else:
40
+ return f"{{{full}}}"
41
+ except Exception:
42
+ return f"{{{full}}}"
43
+
44
+ return re.sub(r"\{([^{}]+)\}", replacer, text)
45
+
46
+ def validate_variable_formats(variables, variable_format_map, data_formats):
47
+ errors = {}
48
+ for var_name, format_name in variable_format_map.items():
49
+ value = variables.get(var_name)
50
+ if value is None:
51
+ continue
52
+
53
+ format_def = data_formats.get(format_name)
54
+ if not format_def:
55
+ continue
56
+
57
+ if "valid_options" in format_def:
58
+ if value not in format_def["valid_options"]:
59
+ errors[var_name] = format_def.get("error_message", f"{var_name} değeri geçersiz.")
60
+ elif "pattern" in format_def:
61
+ if not re.fullmatch(format_def["pattern"], value):
62
+ errors[var_name] = format_def.get("error_message", f"{var_name} formatı geçersiz.")
63
+
64
+ return len(errors) == 0, errors