Update fine_tune_inference_test.py
Browse files- fine_tune_inference_test.py +44 -28
fine_tune_inference_test.py
CHANGED
@@ -4,13 +4,14 @@ import uvicorn
|
|
4 |
from fastapi import FastAPI
|
5 |
from fastapi.responses import HTMLResponse, JSONResponse
|
6 |
from pydantic import BaseModel
|
7 |
-
from transformers import AutoTokenizer, AutoModelForCausalLM
|
8 |
from datasets import load_dataset
|
9 |
from peft import PeftModel
|
10 |
import torch
|
11 |
from huggingface_hub import hf_hub_download
|
12 |
import zipfile
|
13 |
from datetime import datetime
|
|
|
14 |
|
15 |
# ✅ Zamanlı log fonksiyonu (flush destekli)
|
16 |
def log(message):
|
@@ -26,10 +27,17 @@ FINE_TUNE_REPO = "UcsTurkey/trained-zips"
|
|
26 |
RAG_DATA_FILE = "merged_dataset_000_100.parquet"
|
27 |
RAG_DATA_REPO = "UcsTurkey/turkish-general-culture-tokenized"
|
28 |
USE_RAG = False # ✅ RAG kullanımını opsiyonel hale getiren sabit
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
|
30 |
app = FastAPI()
|
31 |
chat_history = []
|
32 |
-
|
|
|
33 |
|
34 |
class Message(BaseModel):
|
35 |
user_input: str
|
@@ -68,8 +76,8 @@ def root():
|
|
68 |
def chat(msg: Message):
|
69 |
try:
|
70 |
log(f"📦 Kullanıcı mesajı alındı: {msg}")
|
71 |
-
global
|
72 |
-
if
|
73 |
log("🚫 Hata: Model henüz yüklenmedi.")
|
74 |
return {"error": "Model yüklenmedi. Lütfen birkaç saniye sonra tekrar deneyin."}
|
75 |
|
@@ -77,13 +85,36 @@ def chat(msg: Message):
|
|
77 |
if not user_input:
|
78 |
return {"error": "Boş giriş"}
|
79 |
|
80 |
-
# ✅ Eğitimdeki formatla uyumlu prompt
|
81 |
full_prompt = f"SORU: {user_input}\nCEVAP:"
|
82 |
log(f"📨 Prompt: {full_prompt}")
|
83 |
|
84 |
-
|
85 |
-
|
86 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
chat_history.append({"user": user_input, "bot": answer})
|
88 |
log(f"🗨️ Soru: {user_input} → Yanıt: {answer[:60]}...")
|
89 |
return {"answer": answer, "chat_history": chat_history}
|
@@ -93,7 +124,7 @@ def chat(msg: Message):
|
|
93 |
|
94 |
def setup_model():
|
95 |
try:
|
96 |
-
global
|
97 |
|
98 |
log("📦 Fine-tune zip indiriliyor...")
|
99 |
zip_path = hf_hub_download(
|
@@ -122,25 +153,10 @@ def setup_model():
|
|
122 |
log("➕ LoRA adapter uygulanıyor...")
|
123 |
peft_model = PeftModel.from_pretrained(base_model, os.path.join(extract_dir, "output"))
|
124 |
|
125 |
-
|
126 |
-
|
127 |
-
rag = load_dataset(
|
128 |
-
RAG_DATA_REPO,
|
129 |
-
data_files=RAG_DATA_FILE,
|
130 |
-
split="train",
|
131 |
-
token=HF_TOKEN
|
132 |
-
)
|
133 |
-
log(f"🔍 RAG boyutu: {len(rag)}")
|
134 |
-
|
135 |
-
log("🚀 Pipeline oluşturuluyor...")
|
136 |
-
pipe = TextGenerationPipeline(
|
137 |
-
model=peft_model.model,
|
138 |
-
tokenizer=tokenizer,
|
139 |
-
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
|
140 |
-
device=0 if torch.cuda.is_available() else -1
|
141 |
-
)
|
142 |
|
143 |
-
log("✅ Model
|
144 |
except Exception as e:
|
145 |
log(f"❌ setup_model() sırasında hata oluştu: {e}")
|
146 |
|
@@ -158,4 +174,4 @@ while True:
|
|
158 |
import time
|
159 |
time.sleep(60)
|
160 |
except Exception as e:
|
161 |
-
log(f"❌ Ana bekleme döngüsünde hata: {e}")
|
|
|
4 |
from fastapi import FastAPI
|
5 |
from fastapi.responses import HTMLResponse, JSONResponse
|
6 |
from pydantic import BaseModel
|
7 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
8 |
from datasets import load_dataset
|
9 |
from peft import PeftModel
|
10 |
import torch
|
11 |
from huggingface_hub import hf_hub_download
|
12 |
import zipfile
|
13 |
from datetime import datetime
|
14 |
+
import random
|
15 |
|
16 |
# ✅ Zamanlı log fonksiyonu (flush destekli)
|
17 |
def log(message):
|
|
|
27 |
RAG_DATA_FILE = "merged_dataset_000_100.parquet"
|
28 |
RAG_DATA_REPO = "UcsTurkey/turkish-general-culture-tokenized"
|
29 |
USE_RAG = False # ✅ RAG kullanımını opsiyonel hale getiren sabit
|
30 |
+
CONFIDENCE_THRESHOLD = -1.5 # ✅ Logit skorlarına göre eşik değeri
|
31 |
+
FALLBACK_ANSWERS = [
|
32 |
+
"Bu konuda maalesef bilgim yok.",
|
33 |
+
"Ne demek istediğinizi tam anlayamadım.",
|
34 |
+
"Bu soruya şu an yanıt veremiyorum."
|
35 |
+
]
|
36 |
|
37 |
app = FastAPI()
|
38 |
chat_history = []
|
39 |
+
model = None
|
40 |
+
tokenizer = None
|
41 |
|
42 |
class Message(BaseModel):
|
43 |
user_input: str
|
|
|
76 |
def chat(msg: Message):
|
77 |
try:
|
78 |
log(f"📦 Kullanıcı mesajı alındı: {msg}")
|
79 |
+
global model, tokenizer
|
80 |
+
if model is None or tokenizer is None:
|
81 |
log("🚫 Hata: Model henüz yüklenmedi.")
|
82 |
return {"error": "Model yüklenmedi. Lütfen birkaç saniye sonra tekrar deneyin."}
|
83 |
|
|
|
85 |
if not user_input:
|
86 |
return {"error": "Boş giriş"}
|
87 |
|
|
|
88 |
full_prompt = f"SORU: {user_input}\nCEVAP:"
|
89 |
log(f"📨 Prompt: {full_prompt}")
|
90 |
|
91 |
+
inputs = tokenizer(full_prompt, return_tensors="pt")
|
92 |
+
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
93 |
+
|
94 |
+
with torch.no_grad():
|
95 |
+
output = model.generate(
|
96 |
+
**inputs,
|
97 |
+
max_new_tokens=200,
|
98 |
+
do_sample=True,
|
99 |
+
temperature=0.7,
|
100 |
+
return_dict_in_generate=True,
|
101 |
+
output_scores=True
|
102 |
+
)
|
103 |
+
|
104 |
+
generated_ids = output.sequences[0]
|
105 |
+
generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
|
106 |
+
answer = generated_text[len(full_prompt):].strip()
|
107 |
+
|
108 |
+
if output.scores and len(output.scores) > 0:
|
109 |
+
first_token_logit = output.scores[0][0] # ilk tokenin logits
|
110 |
+
top_logit_score = torch.max(first_token_logit).item()
|
111 |
+
log(f"🔎 İlk token logit skoru: {top_logit_score:.4f}")
|
112 |
+
|
113 |
+
if top_logit_score < CONFIDENCE_THRESHOLD:
|
114 |
+
fallback = random.choice(FALLBACK_ANSWERS)
|
115 |
+
log(f"⚠️ Düşük güven: fallback cevabı gönderiliyor: {fallback}")
|
116 |
+
answer = fallback
|
117 |
+
|
118 |
chat_history.append({"user": user_input, "bot": answer})
|
119 |
log(f"🗨️ Soru: {user_input} → Yanıt: {answer[:60]}...")
|
120 |
return {"answer": answer, "chat_history": chat_history}
|
|
|
124 |
|
125 |
def setup_model():
|
126 |
try:
|
127 |
+
global model, tokenizer
|
128 |
|
129 |
log("📦 Fine-tune zip indiriliyor...")
|
130 |
zip_path = hf_hub_download(
|
|
|
153 |
log("➕ LoRA adapter uygulanıyor...")
|
154 |
peft_model = PeftModel.from_pretrained(base_model, os.path.join(extract_dir, "output"))
|
155 |
|
156 |
+
model = peft_model.model
|
157 |
+
model.eval()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
158 |
|
159 |
+
log("✅ Model başarıyla yüklendi.")
|
160 |
except Exception as e:
|
161 |
log(f"❌ setup_model() sırasında hata oluştu: {e}")
|
162 |
|
|
|
174 |
import time
|
175 |
time.sleep(60)
|
176 |
except Exception as e:
|
177 |
+
log(f"❌ Ana bekleme döngüsünde hata: {e}")
|