ciyidogan commited on
Commit
3d30435
·
verified ·
1 Parent(s): 9870b98

Create fine_tune_inference_test_with_intents.py

Browse files
fine_tune_inference_test_with_intents.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # fine_tune_inference_with_intent.py
2
+ import os, torch, threading, uvicorn, time, traceback, zipfile, random, json, shutil, asyncio, re
3
+ from fastapi import FastAPI
4
+ from fastapi.responses import HTMLResponse, JSONResponse
5
+ from pydantic import BaseModel
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification, Trainer, TrainingArguments
7
+ from peft import PeftModel
8
+ from datasets import Dataset
9
+ from datetime import datetime
10
+
11
+ # === Ortam
12
+ HF_TOKEN = os.getenv("HF_TOKEN")
13
+ os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
14
+ os.environ["TORCH_HOME"] = "/app/.torch_cache"
15
+ os.makedirs("/app/.torch_cache", exist_ok=True)
16
+
17
+ # === Ayarlar
18
+ MODEL_BASE = "TURKCELL/Turkcell-LLM-7b-v1"
19
+ USE_FINE_TUNE = False
20
+ FINE_TUNE_REPO = "UcsTurkey/trained-zips"
21
+ FINE_TUNE_ZIP = "trained_model_000_009.zip"
22
+ USE_SAMPLING = False
23
+ CONFIDENCE_THRESHOLD = -1.5
24
+ FALLBACK_ANSWERS = [
25
+ "Bu konuda maalesef bilgim yok.",
26
+ "Ne demek istediğinizi tam anlayamadım.",
27
+ "Bu soruya şu an yanıt veremiyorum."
28
+ ]
29
+
30
+ INTENT_MODEL_PATH = "intent_model"
31
+ INTENT_MODEL_ID = "dbmdz/bert-base-turkish-cased"
32
+ INTENT_MODEL = None
33
+ INTENT_TOKENIZER = None
34
+ LABEL2ID = {}
35
+ INTENT_DEFINITIONS = {}
36
+
37
+ # === FastAPI
38
+ app = FastAPI()
39
+ chat_history = []
40
+ model = None
41
+ tokenizer = None
42
+ eos_token_id = None
43
+
44
+ class Message(BaseModel):
45
+ user_input: str
46
+
47
+ class TrainInput(BaseModel):
48
+ intents: list
49
+
50
+ @app.get("/")
51
+ def health():
52
+ return {"status": "ok"}
53
+
54
+ @app.get("/start", response_class=HTMLResponse)
55
+ def root():
56
+ return """
57
+ <html><body>
58
+ <h2>Turkcell LLM Chat</h2>
59
+ <textarea id='input' rows='4' cols='60'></textarea><br>
60
+ <button onclick='send()'>Gönder</button><br><br>
61
+ <label>Model Cevabı:</label><br>
62
+ <textarea id='output' rows='10' cols='80' readonly style='white-space: pre-wrap;'></textarea>
63
+ <script>
64
+ async function send() {
65
+ const input = document.getElementById("input").value;
66
+ const res = await fetch('/chat', {
67
+ method: 'POST',
68
+ headers: { 'Content-Type': 'application/json' },
69
+ body: JSON.stringify({ user_input: input })
70
+ });
71
+ const data = await res.json();
72
+ document.getElementById('output').value = data.answer || data.response || data.error || 'Hata oluştu.';
73
+ }
74
+ </script>
75
+ </body></html>
76
+ """
77
+
78
+ @app.post("/train_intents")
79
+ def train_intents(train_input: TrainInput):
80
+ global INTENT_DEFINITIONS
81
+ try:
82
+ intents = train_input.intents
83
+ INTENT_DEFINITIONS = {intent["name"]: intent for intent in intents}
84
+ texts, labels, label2id = [], [], {}
85
+ for idx, intent in enumerate(intents):
86
+ label2id[intent["name"]] = idx
87
+ for ex in intent["examples"]:
88
+ texts.append(ex)
89
+ labels.append(idx)
90
+
91
+ dataset = Dataset.from_dict({"text": texts, "label": labels})
92
+ tokenizer = AutoTokenizer.from_pretrained(INTENT_MODEL_ID)
93
+ model = AutoModelForSequenceClassification.from_pretrained(INTENT_MODEL_ID, num_labels=len(label2id))
94
+
95
+ def tokenize(batch):
96
+ return tokenizer(batch["text"], truncation=True, padding=True)
97
+
98
+ tokenized = dataset.map(tokenize, batched=True)
99
+ args = TrainingArguments("./intent_train_output", per_device_train_batch_size=4, num_train_epochs=3, logging_steps=10, save_strategy="no", report_to=[])
100
+ trainer = Trainer(model=model, args=args, train_dataset=tokenized)
101
+ trainer.train()
102
+
103
+ if os.path.exists(INTENT_MODEL_PATH): shutil.rmtree(INTENT_MODEL_PATH)
104
+ model.save_pretrained(INTENT_MODEL_PATH)
105
+ tokenizer.save_pretrained(INTENT_MODEL_PATH)
106
+ with open(os.path.join(INTENT_MODEL_PATH, "label2id.json"), "w") as f:
107
+ json.dump(label2id, f)
108
+
109
+ return {"status": "ok", "message": "Intent modeli eğitildi."}
110
+ except Exception as e:
111
+ return JSONResponse(content={"error": str(e)}, status_code=500)
112
+
113
+ @app.post("/load_intent_model")
114
+ def load_intent_model():
115
+ global INTENT_MODEL, INTENT_TOKENIZER, LABEL2ID
116
+ try:
117
+ INTENT_TOKENIZER = AutoTokenizer.from_pretrained(INTENT_MODEL_PATH)
118
+ INTENT_MODEL = AutoModelForSequenceClassification.from_pretrained(INTENT_MODEL_PATH)
119
+ with open(os.path.join(INTENT_MODEL_PATH, "label2id.json")) as f:
120
+ LABEL2ID = json.load(f)
121
+ return {"status": "ok", "message": "Intent modeli yüklendi."}
122
+ except Exception as e:
123
+ return JSONResponse(content={"error": str(e)}, status_code=500)
124
+
125
+ async def detect_intent(text):
126
+ inputs = INTENT_TOKENIZER(text, return_tensors="pt")
127
+ outputs = INTENT_MODEL(**inputs)
128
+ pred_id = outputs.logits.argmax().item()
129
+ id2label = {v: k for k, v in LABEL2ID.items()}
130
+ return id2label[pred_id]
131
+
132
+ def extract_parameters(variables_list, user_input):
133
+ for pattern in variables_list:
134
+ regex = re.sub(r"(\w+):\{(.+?)\}", r"(?P<\1>.+?)", pattern)
135
+ match = re.match(regex, user_input)
136
+ if match:
137
+ return [{"key": k, "value": v} for k, v in match.groupdict().items()]
138
+ return []
139
+
140
+ def execute_intent(intent_name, user_input):
141
+ if intent_name in INTENT_DEFINITIONS:
142
+ definition = INTENT_DEFINITIONS[intent_name]
143
+ variables = extract_parameters(definition.get("variables", []), user_input)
144
+ log(f"🚀 execute_intent('{intent_name}', {variables})")
145
+ return {"intent": intent_name, "parameters": variables}
146
+ return {"intent": intent_name, "parameters": []}
147
+
148
+ async def generate_response(text):
149
+ messages = [{"role": "user", "content": text}]
150
+ encodeds = tokenizer.apply_chat_template(messages, return_tensors="pt", add_special_tokens=True)
151
+ eos_token = tokenizer("<|im_end|>", add_special_tokens=False)["input_ids"][0]
152
+
153
+ input_ids = encodeds.to(model.device)
154
+ attention_mask = (input_ids != tokenizer.pad_token_id).long()
155
+
156
+ with torch.no_grad():
157
+ output = model.generate(
158
+ input_ids=input_ids,
159
+ attention_mask=attention_mask,
160
+ max_new_tokens=128,
161
+ do_sample=USE_SAMPLING,
162
+ eos_token_id=eos_token,
163
+ pad_token_id=tokenizer.pad_token_id,
164
+ return_dict_in_generate=True,
165
+ output_scores=True
166
+ )
167
+
168
+ return tokenizer.decode(output.sequences[0], skip_special_tokens=True).strip()
169
+
170
+ @app.post("/chat")
171
+ async def chat(msg: Message):
172
+ user_input = msg.user_input.strip()
173
+ try:
174
+ if model is None or tokenizer is None:
175
+ return {"error": "Model yüklenmedi."}
176
+
177
+ if INTENT_MODEL:
178
+ intent_task = asyncio.create_task(detect_intent(user_input))
179
+ response_task = asyncio.create_task(generate_response(user_input))
180
+ intent = await intent_task
181
+ if intent in INTENT_DEFINITIONS:
182
+ result = execute_intent(intent, user_input)
183
+ return result
184
+ else:
185
+ response = await response_task
186
+ return {"response": response}
187
+ else:
188
+ response = await generate_response(user_input)
189
+ return {"response": response}
190
+
191
+ except Exception as e:
192
+ traceback.print_exc()
193
+ return JSONResponse(content={"error": str(e)}, status_code=500)
194
+
195
+ def setup_model():
196
+ global model, tokenizer, eos_token_id
197
+ try:
198
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
199
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_BASE, use_fast=False)
200
+ model = AutoModelForCausalLM.from_pretrained(MODEL_BASE, torch_dtype=torch.float32).to(device)
201
+ tokenizer.pad_token = tokenizer.pad_token or tokenizer.eos_token
202
+ model.config.pad_token_id = tokenizer.pad_token_id
203
+ eos_token_id = tokenizer("<|im_end|>", add_special_tokens=False)["input_ids"][0]
204
+ model.eval()
205
+ except Exception as e:
206
+ traceback.print_exc()
207
+
208
+ threading.Thread(target=setup_model, daemon=True).start()
209
+ threading.Thread(target=lambda: uvicorn.run(app, host="0.0.0.0", port=7860), daemon=True).start()
210
+ while True:
211
+ time.sleep(60)