ciyidogan commited on
Commit
39728bb
·
verified ·
1 Parent(s): c0112d6

Update inference_test_turkcell_with_intents.py

Browse files
inference_test_turkcell_with_intents.py CHANGED
@@ -7,25 +7,17 @@ from peft import PeftModel
7
  from datasets import Dataset
8
  from datetime import datetime
9
 
10
- # === Ortam
11
  HF_TOKEN = os.getenv("HF_TOKEN")
12
  os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
13
  os.environ["TORCH_HOME"] = "/app/.torch_cache"
14
  os.makedirs("/app/.torch_cache", exist_ok=True)
15
 
16
- # === Ayarlar
17
  MODEL_BASE = "TURKCELL/Turkcell-LLM-7b-v1"
18
  USE_FINE_TUNE = False
19
  FINE_TUNE_REPO = "UcsTurkey/trained-zips"
20
  FINE_TUNE_ZIP = "trained_model_000_009.zip"
21
  USE_SAMPLING = False
22
- GENERATION_CONFIDENCE_THRESHOLD = -1.5
23
- INTENT_CONFIDENCE_THRESHOLD = 0.5
24
- FALLBACK_ANSWERS = [
25
- "Bu konuda maalesef bilgim yok.",
26
- "Ne demek istediğinizi tam anlayamadım.",
27
- "Bu soruya şu an yanıt veremiyorum."
28
- ]
29
 
30
  INTENT_MODEL_PATH = "intent_model"
31
  INTENT_MODEL_ID = "dbmdz/bert-base-turkish-cased"
@@ -34,7 +26,16 @@ INTENT_TOKENIZER = None
34
  LABEL2ID = {}
35
  INTENT_DEFINITIONS = {}
36
 
37
- # === FastAPI
 
 
 
 
 
 
 
 
 
38
  app = FastAPI()
39
  chat_history = []
40
  model = None
@@ -75,6 +76,153 @@ def root():
75
  </body></html>
76
  """
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  @app.post("/chat")
79
  async def chat(msg: Message):
80
  user_input = msg.user_input.strip()
@@ -85,44 +233,47 @@ async def chat(msg: Message):
85
  if INTENT_MODEL:
86
  intent_task = asyncio.create_task(detect_intent(user_input))
87
  response_task = asyncio.create_task(generate_response(user_input))
88
- intent = await intent_task
89
-
90
- if intent is None:
91
- log("🟡 Intent confidence düşük. Ana modele yönlendiriliyor.")
92
- response = await response_task
93
- if isinstance(response, dict) and response.get("score", 0) < GENERATION_CONFIDENCE_THRESHOLD:
94
- return {"response": random.choice(FALLBACK_ANSWERS)}
95
- return {"response": response if isinstance(response, str) else response.get("text", "")}
96
-
97
- if intent in INTENT_DEFINITIONS:
98
  result = execute_intent(intent, user_input)
99
  return result
100
  else:
101
- response = await response_task
102
- return {"response": response if isinstance(response, str) else response.get("text", "")}
 
 
103
  else:
104
- response = await generate_response(user_input)
105
- if isinstance(response, dict) and response.get("score", 0) < GENERATION_CONFIDENCE_THRESHOLD:
106
  return {"response": random.choice(FALLBACK_ANSWERS)}
107
- return {"response": response if isinstance(response, str) else response.get("text", "")}
108
 
109
  except Exception as e:
110
  traceback.print_exc()
111
  return JSONResponse(content={"error": str(e)}, status_code=500)
112
 
113
- async def detect_intent(text):
114
- inputs = INTENT_TOKENIZER(text, return_tensors="pt")
115
- outputs = INTENT_MODEL(**inputs)
116
- logits = outputs.logits
117
- probs = torch.nn.functional.softmax(logits, dim=1)
118
- pred_id = logits.argmax().item()
119
- confidence = probs[0][pred_id].item()
120
 
121
- id2label = {v: k for k, v in LABEL2ID.items()}
122
- intent_name = id2label[pred_id]
123
- log(f"🔍 Intent tahmini: {intent_name} (confidence: {confidence:.2f})")
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
- if confidence < INTENT_CONFIDENCE_THRESHOLD:
126
- log(f"⚠️ Düşük confidence ({confidence:.2f}) nedeniyle intent boş döndü.")
127
- return None
128
- return intent_name
 
7
  from datasets import Dataset
8
  from datetime import datetime
9
 
10
+ # === Ortam ve Ayarlar ===
11
  HF_TOKEN = os.getenv("HF_TOKEN")
12
  os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
13
  os.environ["TORCH_HOME"] = "/app/.torch_cache"
14
  os.makedirs("/app/.torch_cache", exist_ok=True)
15
 
 
16
  MODEL_BASE = "TURKCELL/Turkcell-LLM-7b-v1"
17
  USE_FINE_TUNE = False
18
  FINE_TUNE_REPO = "UcsTurkey/trained-zips"
19
  FINE_TUNE_ZIP = "trained_model_000_009.zip"
20
  USE_SAMPLING = False
 
 
 
 
 
 
 
21
 
22
  INTENT_MODEL_PATH = "intent_model"
23
  INTENT_MODEL_ID = "dbmdz/bert-base-turkish-cased"
 
26
  LABEL2ID = {}
27
  INTENT_DEFINITIONS = {}
28
 
29
+ INTENT_CONFIDENCE_THRESHOLD = 0.5
30
+ LLM_CONFIDENCE_THRESHOLD = 0.2
31
+ TRAIN_CONFIDENCE_THRESHOLD = 0.7
32
+ FALLBACK_ANSWERS = [
33
+ "Bu konuda maalesef bilgim yok.",
34
+ "Ne demek istediğinizi tam anlayamadım.",
35
+ "Bu soruya şu an yanıt veremiyorum."
36
+ ]
37
+
38
+ # === FastAPI ===
39
  app = FastAPI()
40
  chat_history = []
41
  model = None
 
76
  </body></html>
77
  """
78
 
79
+ @app.post("/train_intents", status_code=202)
80
+ def train_intents(train_input: TrainInput):
81
+ global INTENT_DEFINITIONS
82
+ log("📥 POST /train_intents çağrıldı.")
83
+ intents = train_input.intents
84
+ INTENT_DEFINITIONS = {intent["name"]: intent for intent in intents}
85
+ threading.Thread(target=lambda: background_training(intents), daemon=True).start()
86
+ return {"status": "accepted", "message": "Intent eğitimi arka planda başlatıldı."}
87
+
88
+ def background_training(intents):
89
+ try:
90
+ log("🔧 Intent eğitimi başlatıldı...")
91
+ texts, labels, label2id = [], [], {}
92
+ for idx, intent in enumerate(intents):
93
+ label2id[intent["name"]] = idx
94
+ for ex in intent["examples"]:
95
+ texts.append(ex)
96
+ labels.append(idx)
97
+
98
+ dataset = Dataset.from_dict({"text": texts, "label": labels})
99
+ tokenizer = AutoTokenizer.from_pretrained(INTENT_MODEL_ID)
100
+ config = AutoConfig.from_pretrained(INTENT_MODEL_ID)
101
+ config.problem_type = "single_label_classification"
102
+ config.num_labels = len(label2id)
103
+ model = AutoModelForSequenceClassification.from_pretrained(INTENT_MODEL_ID, config=config)
104
+
105
+ tokenized_data = {"input_ids": [], "attention_mask": [], "label": []}
106
+ for row in dataset:
107
+ out = tokenizer(row["text"], truncation=True, padding="max_length", max_length=128)
108
+ tokenized_data["input_ids"].append(out["input_ids"])
109
+ tokenized_data["attention_mask"].append(out["attention_mask"])
110
+ tokenized_data["label"].append(row["label"])
111
+
112
+ tokenized = Dataset.from_dict(tokenized_data)
113
+ tokenized.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
114
+
115
+ output_dir = "/app/intent_train_output"
116
+ os.makedirs(output_dir, exist_ok=True)
117
+ trainer = Trainer(
118
+ model=model,
119
+ args=TrainingArguments(output_dir, per_device_train_batch_size=4, num_train_epochs=3, logging_steps=10, save_strategy="no", report_to=[]),
120
+ train_dataset=tokenized,
121
+ data_collator=default_data_collator
122
+ )
123
+ trainer.train()
124
+
125
+ # Raporlama
126
+ predictions = model(tokenized["input_ids"]).logits.argmax(dim=-1).tolist()
127
+ actuals = tokenized["label"]
128
+ counts = {}
129
+ correct = {}
130
+ for pred, actual in zip(predictions, actuals):
131
+ intent = list(label2id.keys())[list(label2id.values()).index(actual)]
132
+ counts[intent] = counts.get(intent, 0) + 1
133
+ if pred == actual:
134
+ correct[intent] = correct.get(intent, 0) + 1
135
+
136
+ for intent, total in counts.items():
137
+ accuracy = correct.get(intent, 0) / total
138
+ log(f"📊 Intent '{intent}' doğruluk: {accuracy:.2f} — {total} örnek")
139
+ if accuracy < TRAIN_CONFIDENCE_THRESHOLD or total < 5:
140
+ log(f"⚠️ Yetersiz performanslı intent: '{intent}' — Doğruluk: {accuracy:.2f}, Örnek: {total}")
141
+
142
+ if os.path.exists(INTENT_MODEL_PATH):
143
+ shutil.rmtree(INTENT_MODEL_PATH)
144
+ model.save_pretrained(INTENT_MODEL_PATH)
145
+ tokenizer.save_pretrained(INTENT_MODEL_PATH)
146
+ with open(os.path.join(INTENT_MODEL_PATH, "label2id.json"), "w") as f:
147
+ json.dump(label2id, f)
148
+
149
+ log("✅ Intent eğitimi tamamlandı ve model kaydedildi.")
150
+
151
+ except Exception as e:
152
+ log(f"❌ Intent eğitimi hatası: {e}")
153
+ traceback.print_exc()
154
+
155
+ @app.post("/load_intent_model")
156
+ def load_intent_model():
157
+ global INTENT_MODEL, INTENT_TOKENIZER, LABEL2ID
158
+ try:
159
+ INTENT_TOKENIZER = AutoTokenizer.from_pretrained(INTENT_MODEL_PATH)
160
+ INTENT_MODEL = AutoModelForSequenceClassification.from_pretrained(INTENT_MODEL_PATH)
161
+ with open(os.path.join(INTENT_MODEL_PATH, "label2id.json")) as f:
162
+ LABEL2ID = json.load(f)
163
+ return {"status": "ok", "message": "Intent modeli yüklendi."}
164
+ except Exception as e:
165
+ return JSONResponse(content={"error": str(e)}, status_code=500)
166
+
167
+ async def detect_intent(text):
168
+ inputs = INTENT_TOKENIZER(text, return_tensors="pt")
169
+ outputs = INTENT_MODEL(**inputs)
170
+ probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
171
+ confidence, pred_id = torch.max(probs, dim=-1)
172
+ id2label = {v: k for k, v in LABEL2ID.items()}
173
+ return id2label[pred_id.item()], confidence.item()
174
+
175
+ async def generate_response(text):
176
+ messages = [{"role": "user", "content": text}]
177
+ encodeds = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True)
178
+ eos_token = tokenizer("<|im_end|>", add_special_tokens=False)["input_ids"][0]
179
+
180
+ input_ids = encodeds.to(model.device)
181
+ attention_mask = (input_ids != tokenizer.pad_token_id).long()
182
+
183
+ with torch.no_grad():
184
+ output = model.generate(
185
+ input_ids=input_ids,
186
+ attention_mask=attention_mask,
187
+ max_new_tokens=128,
188
+ do_sample=USE_SAMPLING,
189
+ eos_token_id=eos_token,
190
+ pad_token_id=tokenizer.pad_token_id,
191
+ return_dict_in_generate=True,
192
+ output_scores=True
193
+ )
194
+
195
+ if not USE_SAMPLING:
196
+ scores = torch.stack(output.scores, dim=1)
197
+ probs = torch.nn.functional.softmax(scores[0], dim=-1)
198
+ top_conf = probs.max().item()
199
+ else:
200
+ top_conf = None
201
+
202
+ decoded = tokenizer.decode(output.sequences[0], skip_special_tokens=True).strip()
203
+ for tag in ["assistant", "<|im_start|>assistant"]:
204
+ start = decoded.find(tag)
205
+ if start != -1:
206
+ decoded = decoded[start + len(tag):].strip()
207
+ break
208
+ return decoded, top_conf
209
+
210
+ def extract_parameters(variables_list, user_input):
211
+ for pattern in variables_list:
212
+ regex = re.sub(r"(\w+):\{(.+?)\}", r"(?P<\1>.+?)", pattern)
213
+ match = re.match(regex, user_input)
214
+ if match:
215
+ return [{"key": k, "value": v} for k, v in match.groupdict().items()]
216
+ return []
217
+
218
+ def execute_intent(intent_name, user_input):
219
+ if intent_name in INTENT_DEFINITIONS:
220
+ definition = INTENT_DEFINITIONS[intent_name]
221
+ variables = extract_parameters(definition.get("variables", []), user_input)
222
+ log(f"🚀 execute_intent('{intent_name}', {variables})")
223
+ return {"intent": intent_name, "parameters": variables}
224
+ return {"intent": intent_name, "parameters": []}
225
+
226
  @app.post("/chat")
227
  async def chat(msg: Message):
228
  user_input = msg.user_input.strip()
 
233
  if INTENT_MODEL:
234
  intent_task = asyncio.create_task(detect_intent(user_input))
235
  response_task = asyncio.create_task(generate_response(user_input))
236
+ intent, intent_conf = await intent_task
237
+ log(f"🎯 Intent: {intent} (conf={intent_conf:.2f})")
238
+ if intent_conf > INTENT_CONFIDENCE_THRESHOLD and intent in INTENT_DEFINITIONS:
 
 
 
 
 
 
 
239
  result = execute_intent(intent, user_input)
240
  return result
241
  else:
242
+ response, response_conf = await response_task
243
+ if response_conf is not None and response_conf < LLM_CONFIDENCE_THRESHOLD:
244
+ return {"response": random.choice(FALLBACK_ANSWERS)}
245
+ return {"response": response}
246
  else:
247
+ response, response_conf = await generate_response(user_input)
248
+ if response_conf is not None and response_conf < LLM_CONFIDENCE_THRESHOLD:
249
  return {"response": random.choice(FALLBACK_ANSWERS)}
250
+ return {"response": response}
251
 
252
  except Exception as e:
253
  traceback.print_exc()
254
  return JSONResponse(content={"error": str(e)}, status_code=500)
255
 
256
+ def log(message):
257
+ timestamp = datetime.now().strftime("%H:%M:%S")
258
+ print(f"[{timestamp}] {message}", flush=True)
 
 
 
 
259
 
260
+ def setup_model():
261
+ global model, tokenizer, eos_token_id
262
+ try:
263
+ log("🧠 setup_model() başladı")
264
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
265
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_BASE, use_fast=False)
266
+ model = AutoModelForCausalLM.from_pretrained(MODEL_BASE, torch_dtype=torch.float32).to(device)
267
+ tokenizer.pad_token = tokenizer.pad_token or tokenizer.eos_token
268
+ model.config.pad_token_id = tokenizer.pad_token_id
269
+ eos_token_id = tokenizer("<|im_end|>", add_special_tokens=False)["input_ids"][0]
270
+ model.eval()
271
+ log("✅ Ana model yüklendi")
272
+ except Exception as e:
273
+ log(f"❌ setup_model() hatası: {e}")
274
+ traceback.print_exc()
275
 
276
+ threading.Thread(target=setup_model, daemon=True).start()
277
+ threading.Thread(target=lambda: uvicorn.run(app, host="0.0.0.0", port=7860), daemon=True).start()
278
+ while True:
279
+ time.sleep(60)