ciyidogan commited on
Commit
4977ad5
·
verified ·
1 Parent(s): 0037e99

Update intent.py

Browse files
Files changed (1) hide show
  1. intent.py +163 -146
intent.py CHANGED
@@ -1,146 +1,163 @@
1
- import os
2
- import torch
3
- import json
4
- import shutil
5
- import re
6
- import traceback
7
- from datasets import Dataset
8
- from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments, default_data_collator, AutoConfig
9
- from log import log
10
-
11
- INTENT_MODELS = {} # project_name -> (model, tokenizer, label2id)
12
-
13
- async def detect_intent(text):
14
- # Bu fonksiyon bir örnek; çağırırken ilgili proje için model alınmalı
15
- raise NotImplementedError("detect_intent çağrısı, proje bazlı model ile yapılmalıdır.")
16
-
17
- def background_training(project_name, intents, model_id, output_path, confidence_threshold):
18
- try:
19
- log(f"🔧 Intent eğitimi başlatıldı (proje: {project_name})")
20
- texts, labels, label2id = [], [], {}
21
- for idx, intent in enumerate(intents):
22
- label2id[intent["name"]] = idx
23
- for ex in intent["examples"]:
24
- texts.append(ex)
25
- labels.append(idx)
26
-
27
- dataset = Dataset.from_dict({"text": texts, "label": labels})
28
- tokenizer = AutoTokenizer.from_pretrained(model_id)
29
- config = AutoConfig.from_pretrained(model_id)
30
- config.problem_type = "single_label_classification"
31
- config.num_labels = len(label2id)
32
- model = AutoModelForSequenceClassification.from_pretrained(model_id, config=config)
33
-
34
- tokenized_data = {"input_ids": [], "attention_mask": [], "label": []}
35
- for row in dataset:
36
- out = tokenizer(row["text"], truncation=True, padding="max_length", max_length=128)
37
- tokenized_data["input_ids"].append(out["input_ids"])
38
- tokenized_data["attention_mask"].append(out["attention_mask"])
39
- tokenized_data["label"].append(row["label"])
40
-
41
- tokenized = Dataset.from_dict(tokenized_data)
42
- tokenized.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
43
-
44
- if os.path.exists(output_path):
45
- shutil.rmtree(output_path)
46
- os.makedirs(output_path, exist_ok=True)
47
-
48
- trainer = Trainer(
49
- model=model,
50
- args=TrainingArguments(output_path, per_device_train_batch_size=4, num_train_epochs=3, logging_steps=10, save_strategy="no", report_to=[]),
51
- train_dataset=tokenized,
52
- data_collator=default_data_collator
53
- )
54
- trainer.train()
55
-
56
- # Başarı raporu
57
- log("🔧 Başarı raporu üretiliyor...")
58
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
59
- model.to(device)
60
- input_ids_tensor = torch.tensor(tokenized["input_ids"]).to(device)
61
- attention_mask_tensor = torch.tensor(tokenized["attention_mask"]).to(device)
62
-
63
- with torch.no_grad():
64
- outputs = model(input_ids=input_ids_tensor, attention_mask=attention_mask_tensor)
65
- predictions = outputs.logits.argmax(dim=-1).tolist()
66
-
67
- actuals = tokenized["label"]
68
- counts, correct = {}, {}
69
- for pred, actual in zip(predictions, actuals):
70
- intent_name = list(label2id.keys())[list(label2id.values()).index(actual)]
71
- counts[intent_name] = counts.get(intent_name, 0) + 1
72
- if pred == actual:
73
- correct[intent_name] = correct.get(intent_name, 0) + 1
74
- for intent_name, total in counts.items():
75
- accuracy = correct.get(intent_name, 0) / total
76
- log(f"📊 Intent '{intent_name}' doğruluk: {accuracy:.2f} — {total} örnek")
77
- if accuracy < confidence_threshold or total < 5:
78
- log(f"⚠️ Yetersiz performanslı intent: '{intent_name}' — Doğruluk: {accuracy:.2f}, Örnek: {total}")
79
-
80
- model.save_pretrained(output_path)
81
- tokenizer.save_pretrained(output_path)
82
- with open(os.path.join(output_path, "label2id.json"), "w") as f:
83
- json.dump(label2id, f)
84
-
85
- INTENT_MODELS[project_name] = {
86
- "model": model,
87
- "tokenizer": tokenizer,
88
- "label2id": label2id
89
- }
90
- log(f"✅ Intent eğitimi tamamlandı ve '{project_name}' modeli yüklendi.")
91
-
92
- except Exception as e:
93
- log(f" Intent eğitimi hatası: {e}")
94
- traceback.print_exc()
95
-
96
- def extract_parameters(variables_list, user_input):
97
- for pattern in variables_list:
98
- regex = re.sub(r"(\w+):\{(.+?)\}", r"(?P<\1>.+?)", pattern)
99
- match = re.match(regex, user_input)
100
- if match:
101
- return [{"key": k, "value": v} for k, v in match.groupdict().items()]
102
- return []
103
-
104
- def resolve_placeholders(text: str, session: dict, variables: dict) -> str:
105
- def replacer(match):
106
- full = match.group(1)
107
- try:
108
- if full.startswith("variables."):
109
- key = full.split(".", 1)[1]
110
- return str(variables.get(key, f"{{{full}}}"))
111
- elif full.startswith("session."):
112
- key = full.split(".", 1)[1]
113
- return str(session.get("variables", {}).get(key, f"{{{full}}}"))
114
- elif full.startswith("auth_tokens."):
115
- parts = full.split(".")
116
- if len(parts) == 3:
117
- intent, token_type = parts[1], parts[2]
118
- return str(session.get("auth_tokens", {}).get(intent, {}).get(token_type, f"{{{full}}}"))
119
- else:
120
- return f"{{{full}}}"
121
- else:
122
- return f"{{{full}}}"
123
- except Exception:
124
- return f"{{{full}}}"
125
-
126
- return re.sub(r"\{([^{}]+)\}", replacer, text)
127
-
128
- def validate_variable_formats(variables, variable_format_map, data_formats):
129
- errors = {}
130
- for var_name, format_name in variable_format_map.items():
131
- value = variables.get(var_name)
132
- if value is None:
133
- continue
134
-
135
- format_def = data_formats.get(format_name)
136
- if not format_def:
137
- continue
138
-
139
- if "valid_options" in format_def:
140
- if value not in format_def["valid_options"]:
141
- errors[var_name] = format_def.get("error_message", f"{var_name} değeri geçersiz.")
142
- elif "pattern" in format_def:
143
- if not re.fullmatch(format_def["pattern"], value):
144
- errors[var_name] = format_def.get("error_message", f"{var_name} formatı geçersiz.")
145
-
146
- return len(errors) == 0, errors
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import json
4
+ import shutil
5
+ import re
6
+ import traceback
7
+ from datasets import Dataset
8
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments, default_data_collator, AutoConfig
9
+ from log import log
10
+
11
+ INTENT_MODELS = {} # project_name -> (model, tokenizer, label2id)
12
+
13
+ from core import INTENT_MODELS
14
+
15
+ async def detect_intent(text, project_name):
16
+ project_model = INTENT_MODELS.get(project_name)
17
+ if not project_model:
18
+ raise Exception(f"'{project_name}' için intent modeli yüklenmemiş.")
19
+
20
+ tokenizer = project_model["tokenizer"]
21
+ model = project_model["model"]
22
+ label2id = project_model["label2id"]
23
+
24
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
25
+ outputs = model(**inputs)
26
+ predicted_id = outputs.logits.argmax(dim=-1).item()
27
+
28
+ # ID → intent adı
29
+ detected_intent = [k for k, v in label2id.items() if v == predicted_id][0]
30
+ confidence = outputs.logits.softmax(dim=-1).max().item()
31
+
32
+ return detected_intent, confidence
33
+
34
+ def background_training(project_name, intents, model_id, output_path, confidence_threshold):
35
+ try:
36
+ log(f"🔧 Intent eğitimi başlatıldı (proje: {project_name})")
37
+ texts, labels, label2id = [], [], {}
38
+ for idx, intent in enumerate(intents):
39
+ label2id[intent["name"]] = idx
40
+ for ex in intent["examples"]:
41
+ texts.append(ex)
42
+ labels.append(idx)
43
+
44
+ dataset = Dataset.from_dict({"text": texts, "label": labels})
45
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
46
+ config = AutoConfig.from_pretrained(model_id)
47
+ config.problem_type = "single_label_classification"
48
+ config.num_labels = len(label2id)
49
+ model = AutoModelForSequenceClassification.from_pretrained(model_id, config=config)
50
+
51
+ tokenized_data = {"input_ids": [], "attention_mask": [], "label": []}
52
+ for row in dataset:
53
+ out = tokenizer(row["text"], truncation=True, padding="max_length", max_length=128)
54
+ tokenized_data["input_ids"].append(out["input_ids"])
55
+ tokenized_data["attention_mask"].append(out["attention_mask"])
56
+ tokenized_data["label"].append(row["label"])
57
+
58
+ tokenized = Dataset.from_dict(tokenized_data)
59
+ tokenized.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
60
+
61
+ if os.path.exists(output_path):
62
+ shutil.rmtree(output_path)
63
+ os.makedirs(output_path, exist_ok=True)
64
+
65
+ trainer = Trainer(
66
+ model=model,
67
+ args=TrainingArguments(output_path, per_device_train_batch_size=4, num_train_epochs=3, logging_steps=10, save_strategy="no", report_to=[]),
68
+ train_dataset=tokenized,
69
+ data_collator=default_data_collator
70
+ )
71
+ trainer.train()
72
+
73
+ # Başarı raporu
74
+ log("🔧 Başarı raporu üretiliyor...")
75
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
76
+ model.to(device)
77
+ input_ids_tensor = torch.tensor(tokenized["input_ids"]).to(device)
78
+ attention_mask_tensor = torch.tensor(tokenized["attention_mask"]).to(device)
79
+
80
+ with torch.no_grad():
81
+ outputs = model(input_ids=input_ids_tensor, attention_mask=attention_mask_tensor)
82
+ predictions = outputs.logits.argmax(dim=-1).tolist()
83
+
84
+ actuals = tokenized["label"]
85
+ counts, correct = {}, {}
86
+ for pred, actual in zip(predictions, actuals):
87
+ intent_name = list(label2id.keys())[list(label2id.values()).index(actual)]
88
+ counts[intent_name] = counts.get(intent_name, 0) + 1
89
+ if pred == actual:
90
+ correct[intent_name] = correct.get(intent_name, 0) + 1
91
+ for intent_name, total in counts.items():
92
+ accuracy = correct.get(intent_name, 0) / total
93
+ log(f"📊 Intent '{intent_name}' doğruluk: {accuracy:.2f} — {total} örnek")
94
+ if accuracy < confidence_threshold or total < 5:
95
+ log(f"⚠️ Yetersiz performanslı intent: '{intent_name}' — Doğruluk: {accuracy:.2f}, Örnek: {total}")
96
+
97
+ model.save_pretrained(output_path)
98
+ tokenizer.save_pretrained(output_path)
99
+ with open(os.path.join(output_path, "label2id.json"), "w") as f:
100
+ json.dump(label2id, f)
101
+
102
+ INTENT_MODELS[project_name] = {
103
+ "model": model,
104
+ "tokenizer": tokenizer,
105
+ "label2id": label2id
106
+ }
107
+ log(f"✅ Intent eğitimi tamamlandı ve '{project_name}' modeli yüklendi.")
108
+
109
+ except Exception as e:
110
+ log(f"❌ Intent eğitimi hatası: {e}")
111
+ traceback.print_exc()
112
+
113
+ def extract_parameters(variables_list, user_input):
114
+ for pattern in variables_list:
115
+ regex = re.sub(r"(\w+):\{(.+?)\}", r"(?P<\1>.+?)", pattern)
116
+ match = re.match(regex, user_input)
117
+ if match:
118
+ return [{"key": k, "value": v} for k, v in match.groupdict().items()]
119
+ return []
120
+
121
+ def resolve_placeholders(text: str, session: dict, variables: dict) -> str:
122
+ def replacer(match):
123
+ full = match.group(1)
124
+ try:
125
+ if full.startswith("variables."):
126
+ key = full.split(".", 1)[1]
127
+ return str(variables.get(key, f"{{{full}}}"))
128
+ elif full.startswith("session."):
129
+ key = full.split(".", 1)[1]
130
+ return str(session.get("variables", {}).get(key, f"{{{full}}}"))
131
+ elif full.startswith("auth_tokens."):
132
+ parts = full.split(".")
133
+ if len(parts) == 3:
134
+ intent, token_type = parts[1], parts[2]
135
+ return str(session.get("auth_tokens", {}).get(intent, {}).get(token_type, f"{{{full}}}"))
136
+ else:
137
+ return f"{{{full}}}"
138
+ else:
139
+ return f"{{{full}}}"
140
+ except Exception:
141
+ return f"{{{full}}}"
142
+
143
+ return re.sub(r"\{([^{}]+)\}", replacer, text)
144
+
145
+ def validate_variable_formats(variables, variable_format_map, data_formats):
146
+ errors = {}
147
+ for var_name, format_name in variable_format_map.items():
148
+ value = variables.get(var_name)
149
+ if value is None:
150
+ continue
151
+
152
+ format_def = data_formats.get(format_name)
153
+ if not format_def:
154
+ continue
155
+
156
+ if "valid_options" in format_def:
157
+ if value not in format_def["valid_options"]:
158
+ errors[var_name] = format_def.get("error_message", f"{var_name} değeri geçersiz.")
159
+ elif "pattern" in format_def:
160
+ if not re.fullmatch(format_def["pattern"], value):
161
+ errors[var_name] = format_def.get("error_message", f"{var_name} formatı geçersiz.")
162
+
163
+ return len(errors) == 0, errors