ciyidogan commited on
Commit
3108e73
·
verified ·
1 Parent(s): 6d135b8

Create interence_test_with_intent_detection.py

Browse files
interence_test_with_intent_detection.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # intent_detection_service.py (Geliştirilmiş: Fine-tune + Intent + LLM)
2
+ import os
3
+ import json
4
+ import re
5
+ import torch
6
+ import asyncio
7
+ import shutil
8
+ import zipfile
9
+ import threading
10
+ import uvicorn
11
+ import time
12
+ import traceback
13
+ import random
14
+ from fastapi import FastAPI, Request
15
+ from fastapi.responses import JSONResponse, HTMLResponse
16
+ from pydantic import BaseModel
17
+ from datetime import datetime
18
+ from datasets import Dataset
19
+ from huggingface_hub import hf_hub_download
20
+ from transformers import (
21
+ AutoTokenizer,
22
+ AutoModelForSequenceClassification,
23
+ AutoModelForCausalLM,
24
+ Trainer,
25
+ TrainingArguments,
26
+ pipeline
27
+ )
28
+ from peft import PeftModel
29
+
30
+ # === Ayarlar ===
31
+ HF_TOKEN = os.getenv("HF_TOKEN")
32
+ MODEL_BASE = "malhajar/Mistral-7B-Instruct-v0.2-turkish"
33
+ USE_FINE_TUNE = False
34
+ FINE_TUNE_REPO = "UcsTurkey/trained-zips"
35
+ FINE_TUNE_ZIP = "trained_model_000_009.zip"
36
+ USE_SAMPLING = False
37
+ CONFIDENCE_THRESHOLD = -1.5
38
+ FALLBACK_ANSWERS = [
39
+ "Bu konuda maalesef bilgim yok.",
40
+ "Ne demek istediğinizi tam anlayamadım.",
41
+ "Bu soruya şu an yanıt veremiyorum."
42
+ ]
43
+
44
+ INTENT_MODEL_PATH = "intent_model"
45
+ INTENT_MODEL_ID = "dbmdz/bert-base-turkish-cased"
46
+ USE_CUDA = torch.cuda.is_available()
47
+ INTENT_MODEL = None
48
+ INTENT_TOKENIZER = None
49
+ LABEL2ID = {}
50
+ model = None
51
+ tokenizer = None
52
+ chat_history = []
53
+
54
+ # === FastAPI Uygulaması ===
55
+ app = FastAPI()
56
+
57
+ # === Yardımcı Fonksiyonlar ===
58
+ def log(msg):
59
+ print(f"[{datetime.now().strftime('%H:%M:%S')}] {msg}", flush=True)
60
+
61
+ def pattern_to_regex(pattern):
62
+ return re.sub(r"\{(\w+?)\}", r"(?P<\1>.+?)", pattern)
63
+
64
+ class ChatInput(BaseModel):
65
+ user_input: str
66
+
67
+ class TrainInput(BaseModel):
68
+ intents: list
69
+
70
+ @app.get("/")
71
+ def health():
72
+ return {"status": "ok"}
73
+
74
+ @app.get("/start", response_class=HTMLResponse)
75
+ def root():
76
+ return """
77
+ <html>
78
+ <body>
79
+ <h2>Mistral 7B Instruct Chat</h2>
80
+ <textarea id="input" rows="4" cols="60" placeholder="Write your instruction..."></textarea><br>
81
+ <button onclick="send()">Gönder</button><br><br>
82
+ <label>Model Cevabı:</label><br>
83
+ <textarea id="output" rows="10" cols="80" readonly style="white-space: pre-wrap;"></textarea>
84
+ <script>
85
+ async function send() {
86
+ const input = document.getElementById("input").value;
87
+ const res = await fetch('/chat', {
88
+ method: 'POST',
89
+ headers: { 'Content-Type': 'application/json' },
90
+ body: JSON.stringify({ user_input: input })
91
+ });
92
+ const data = await res.json();
93
+ document.getElementById('output').value = data.answer || data.response || data.error || 'Hata oluştu.';
94
+ }
95
+ </script>
96
+ </body>
97
+ </html>
98
+ """
99
+
100
+ @app.post("/train_intents")
101
+ def train_intents(train_input: TrainInput):
102
+ try:
103
+ intents = train_input.intents
104
+ log(f"🎯 Intent eğitimi başlatıldı. Intent sayısı: {len(intents)}")
105
+
106
+ texts, labels = [], []
107
+ label2id = {}
108
+ for idx, intent in enumerate(intents):
109
+ label2id[intent["name"]] = idx
110
+ for ex in intent["examples"]:
111
+ if "{" not in ex:
112
+ texts.append(ex)
113
+ labels.append(idx)
114
+
115
+ dataset = Dataset.from_dict({"text": texts, "label": labels})
116
+
117
+ tokenizer = AutoTokenizer.from_pretrained(INTENT_MODEL_ID)
118
+ model = AutoModelForSequenceClassification.from_pretrained(INTENT_MODEL_ID, num_labels=len(label2id))
119
+
120
+ def tokenize(batch):
121
+ return tokenizer(batch["text"], truncation=True, padding=True)
122
+
123
+ tokenized = dataset.map(tokenize, batched=True)
124
+ args = TrainingArguments("./intent_train_output", per_device_train_batch_size=4, num_train_epochs=3, logging_steps=10, save_strategy="no", report_to=[])
125
+ trainer = Trainer(model=model, args=args, train_dataset=tokenized)
126
+ trainer.train()
127
+
128
+ if os.path.exists(INTENT_MODEL_PATH):
129
+ shutil.rmtree(INTENT_MODEL_PATH)
130
+ model.save_pretrained(INTENT_MODEL_PATH)
131
+ tokenizer.save_pretrained(INTENT_MODEL_PATH)
132
+ with open(os.path.join(INTENT_MODEL_PATH, "label2id.json"), "w") as f:
133
+ json.dump(label2id, f)
134
+
135
+ log("✅ Intent modeli kaydedildi.")
136
+ return {"status": "ok", "message": "Intent modeli eğitildi ve kaydedildi."}
137
+
138
+ except Exception as e:
139
+ log(f"❌ Intent eğitimi hatası: {e}")
140
+ return JSONResponse(content={"error": str(e)}, status_code=500)
141
+
142
+ @app.post("/load_intent_model")
143
+ def load_intent_model():
144
+ global INTENT_MODEL, INTENT_TOKENIZER, LABEL2ID
145
+ try:
146
+ if not os.path.exists(INTENT_MODEL_PATH):
147
+ return JSONResponse(content={"error": "intent_model klasörü bulunamadı."}, status_code=400)
148
+
149
+ INTENT_TOKENIZER = AutoTokenizer.from_pretrained(INTENT_MODEL_PATH)
150
+ INTENT_MODEL = AutoModelForSequenceClassification.from_pretrained(INTENT_MODEL_PATH)
151
+ with open(os.path.join(INTENT_MODEL_PATH, "label2id.json")) as f:
152
+ LABEL2ID = json.load(f)
153
+ log("✅ Intent modeli belleğe yüklendi.")
154
+ return {"status": "ok", "message": "Intent modeli yüklendi."}
155
+
156
+ except Exception as e:
157
+ log(f"❌ Intent modeli yükleme hatası: {e}")
158
+ return JSONResponse(content={"error": str(e)}, status_code=500)
159
+
160
+ async def detect_intent(text):
161
+ inputs = INTENT_TOKENIZER(text, return_tensors="pt")
162
+ outputs = INTENT_MODEL(**inputs)
163
+ pred_id = outputs.logits.argmax().item()
164
+ id2label = {v: k for k, v in LABEL2ID.items()}
165
+ return id2label[pred_id]
166
+
167
+ async def generate_response(text):
168
+ messages = [{"role": "user", "content": text}]
169
+ inputs = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True)
170
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
171
+ generate_args = {
172
+ "max_new_tokens": 512,
173
+ "return_dict_in_generate": True,
174
+ "output_scores": True,
175
+ "do_sample": USE_SAMPLING
176
+ }
177
+ if USE_SAMPLING:
178
+ generate_args.update({"temperature": 0.7, "top_p": 0.9, "top_k": 50})
179
+
180
+ with torch.no_grad():
181
+ output = model.generate(**inputs, **generate_args)
182
+
183
+ prompt_text = tokenizer.decode(inputs["input_ids"][0], skip_special_tokens=True)
184
+ decoded = tokenizer.decode(output.sequences[0], skip_special_tokens=True)
185
+ answer = decoded.replace(prompt_text, "").strip()
186
+
187
+ if output.scores and len(output.scores) > 0:
188
+ first_token_score = output.scores[0][0]
189
+ if torch.isnan(first_token_score).any() or torch.isinf(first_token_score).any():
190
+ log("⚠️ Geçersiz logit (NaN/Inf) tespit edildi.")
191
+ return random.choice(FALLBACK_ANSWERS)
192
+ max_score = torch.max(first_token_score).item()
193
+ log(f"🔍 İlk token skoru: {max_score:.4f}")
194
+ if max_score < CONFIDENCE_THRESHOLD:
195
+ return random.choice(FALLBACK_ANSWERS)
196
+
197
+ return answer
198
+
199
+ @app.post("/chat")
200
+ async def chat(input: ChatInput):
201
+ user_input = input.user_input.strip()
202
+ try:
203
+ if model is None or tokenizer is None:
204
+ return {"error": "Model veya tokenizer henüz yüklenmedi."}
205
+
206
+ if INTENT_MODEL:
207
+ intent_task = asyncio.create_task(detect_intent(user_input))
208
+ response_task = asyncio.create_task(generate_response(user_input))
209
+ intent = await intent_task
210
+ response = await response_task
211
+ log(f"✅ Intent: {intent}")
212
+ return {"intent": intent, "response": response}
213
+ else:
214
+ response = await generate_response(user_input)
215
+ log("💬 Intent modeli yok, yalnızca LLM cevabı verildi.")
216
+ return {"response": response}
217
+
218
+ except Exception as e:
219
+ log(f"❌ /chat hatası: {e}")
220
+ traceback.print_exc()
221
+ return JSONResponse(content={"error": str(e)}, status_code=500)
222
+
223
+ # === Model setup ===
224
+ def setup_model():
225
+ global model, tokenizer
226
+ try:
227
+ device = "cuda" if torch.cuda.is_available() else "cpu"
228
+ dtype = torch.float32
229
+
230
+ if USE_FINE_TUNE:
231
+ log("📦 Fine-tune zip indiriliyor...")
232
+ zip_path = hf_hub_download(repo_id=FINE_TUNE_REPO, filename=FINE_TUNE_ZIP, repo_type="model", token=HF_TOKEN)
233
+ extract_dir = "/app/extracted"
234
+ os.makedirs(extract_dir, exist_ok=True)
235
+ with zipfile.ZipFile(zip_path, "r") as zip_ref:
236
+ zip_ref.extractall(extract_dir)
237
+
238
+ tokenizer = AutoTokenizer.from_pretrained(os.path.join(extract_dir, "output"), use_fast=False)
239
+ base_model = AutoModelForCausalLM.from_pretrained(MODEL_BASE, torch_dtype=dtype).to(device)
240
+ model = PeftModel.from_pretrained(base_model, os.path.join(extract_dir, "output")).to(device)
241
+ else:
242
+ log("🧠 Ana model indiriliyor...")
243
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_BASE, use_fast=False)
244
+ model = AutoModelForCausalLM.from_pretrained(MODEL_BASE, torch_dtype=dtype).to(device)
245
+
246
+ tokenizer.pad_token = tokenizer.pad_token or tokenizer.eos_token
247
+ model.eval()
248
+ log("✅ LLM model başarıyla yüklendi.")
249
+ except Exception as e:
250
+ log(f"❌ LLM model yükleme hatası: {e}")
251
+ traceback.print_exc()
252
+
253
+ # === Sunucu başlat ===
254
+ def run():
255
+ log("===== Application Startup =====")
256
+ threading.Thread(target=setup_model, daemon=True).start()
257
+ threading.Thread(target=lambda: uvicorn.run(app, host="0.0.0.0", port=7860), daemon=True).start()
258
+ while True:
259
+ time.sleep(60)
260
+
261
+ run()