ciyidogan commited on
Commit
06fedb7
·
verified ·
1 Parent(s): 1f2f347

Update fine_tune_inference_test.py

Browse files
Files changed (1) hide show
  1. fine_tune_inference_test.py +44 -28
fine_tune_inference_test.py CHANGED
@@ -4,13 +4,14 @@ import uvicorn
4
  from fastapi import FastAPI
5
  from fastapi.responses import HTMLResponse, JSONResponse
6
  from pydantic import BaseModel
7
- from transformers import AutoTokenizer, AutoModelForCausalLM, TextGenerationPipeline
8
  from datasets import load_dataset
9
  from peft import PeftModel
10
  import torch
11
  from huggingface_hub import hf_hub_download
12
  import zipfile
13
  from datetime import datetime
 
14
 
15
  # ✅ Zamanlı log fonksiyonu (flush destekli)
16
  def log(message):
@@ -26,10 +27,17 @@ FINE_TUNE_REPO = "UcsTurkey/trained-zips"
26
  RAG_DATA_FILE = "merged_dataset_000_100.parquet"
27
  RAG_DATA_REPO = "UcsTurkey/turkish-general-culture-tokenized"
28
  USE_RAG = False # ✅ RAG kullanımını opsiyonel hale getiren sabit
 
 
 
 
 
 
29
 
30
  app = FastAPI()
31
  chat_history = []
32
- pipe = None # global text-generation pipeline
 
33
 
34
  class Message(BaseModel):
35
  user_input: str
@@ -68,8 +76,8 @@ def root():
68
  def chat(msg: Message):
69
  try:
70
  log(f"📦 Kullanıcı mesajı alındı: {msg}")
71
- global pipe
72
- if pipe is None:
73
  log("🚫 Hata: Model henüz yüklenmedi.")
74
  return {"error": "Model yüklenmedi. Lütfen birkaç saniye sonra tekrar deneyin."}
75
 
@@ -77,13 +85,36 @@ def chat(msg: Message):
77
  if not user_input:
78
  return {"error": "Boş giriş"}
79
 
80
- # ✅ Eğitimdeki formatla uyumlu prompt
81
  full_prompt = f"SORU: {user_input}\nCEVAP:"
82
  log(f"📨 Prompt: {full_prompt}")
83
 
84
- log("📦 Cevap hazırlanıyor...")
85
- result = pipe(full_prompt, max_new_tokens=200, do_sample=True, temperature=0.7)
86
- answer = result[0]["generated_text"][len(full_prompt):].strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  chat_history.append({"user": user_input, "bot": answer})
88
  log(f"🗨️ Soru: {user_input} → Yanıt: {answer[:60]}...")
89
  return {"answer": answer, "chat_history": chat_history}
@@ -93,7 +124,7 @@ def chat(msg: Message):
93
 
94
  def setup_model():
95
  try:
96
- global pipe
97
 
98
  log("📦 Fine-tune zip indiriliyor...")
99
  zip_path = hf_hub_download(
@@ -122,25 +153,10 @@ def setup_model():
122
  log("➕ LoRA adapter uygulanıyor...")
123
  peft_model = PeftModel.from_pretrained(base_model, os.path.join(extract_dir, "output"))
124
 
125
- if USE_RAG:
126
- log("📚 RAG dataseti yükleniyor...")
127
- rag = load_dataset(
128
- RAG_DATA_REPO,
129
- data_files=RAG_DATA_FILE,
130
- split="train",
131
- token=HF_TOKEN
132
- )
133
- log(f"🔍 RAG boyutu: {len(rag)}")
134
-
135
- log("🚀 Pipeline oluşturuluyor...")
136
- pipe = TextGenerationPipeline(
137
- model=peft_model.model,
138
- tokenizer=tokenizer,
139
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
140
- device=0 if torch.cuda.is_available() else -1
141
- )
142
 
143
- log("✅ Model ve pipeline başarıyla yüklendi.")
144
  except Exception as e:
145
  log(f"❌ setup_model() sırasında hata oluştu: {e}")
146
 
@@ -158,4 +174,4 @@ while True:
158
  import time
159
  time.sleep(60)
160
  except Exception as e:
161
- log(f"❌ Ana bekleme döngüsünde hata: {e}")
 
4
  from fastapi import FastAPI
5
  from fastapi.responses import HTMLResponse, JSONResponse
6
  from pydantic import BaseModel
7
+ from transformers import AutoTokenizer, AutoModelForCausalLM
8
  from datasets import load_dataset
9
  from peft import PeftModel
10
  import torch
11
  from huggingface_hub import hf_hub_download
12
  import zipfile
13
  from datetime import datetime
14
+ import random
15
 
16
  # ✅ Zamanlı log fonksiyonu (flush destekli)
17
  def log(message):
 
27
  RAG_DATA_FILE = "merged_dataset_000_100.parquet"
28
  RAG_DATA_REPO = "UcsTurkey/turkish-general-culture-tokenized"
29
  USE_RAG = False # ✅ RAG kullanımını opsiyonel hale getiren sabit
30
+ CONFIDENCE_THRESHOLD = -1.5 # ✅ Logit skorlarına göre eşik değeri
31
+ FALLBACK_ANSWERS = [
32
+ "Bu konuda maalesef bilgim yok.",
33
+ "Ne demek istediğinizi tam anlayamadım.",
34
+ "Bu soruya şu an yanıt veremiyorum."
35
+ ]
36
 
37
  app = FastAPI()
38
  chat_history = []
39
+ model = None
40
+ tokenizer = None
41
 
42
  class Message(BaseModel):
43
  user_input: str
 
76
  def chat(msg: Message):
77
  try:
78
  log(f"📦 Kullanıcı mesajı alındı: {msg}")
79
+ global model, tokenizer
80
+ if model is None or tokenizer is None:
81
  log("🚫 Hata: Model henüz yüklenmedi.")
82
  return {"error": "Model yüklenmedi. Lütfen birkaç saniye sonra tekrar deneyin."}
83
 
 
85
  if not user_input:
86
  return {"error": "Boş giriş"}
87
 
 
88
  full_prompt = f"SORU: {user_input}\nCEVAP:"
89
  log(f"📨 Prompt: {full_prompt}")
90
 
91
+ inputs = tokenizer(full_prompt, return_tensors="pt")
92
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
93
+
94
+ with torch.no_grad():
95
+ output = model.generate(
96
+ **inputs,
97
+ max_new_tokens=200,
98
+ do_sample=True,
99
+ temperature=0.7,
100
+ return_dict_in_generate=True,
101
+ output_scores=True
102
+ )
103
+
104
+ generated_ids = output.sequences[0]
105
+ generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
106
+ answer = generated_text[len(full_prompt):].strip()
107
+
108
+ if output.scores and len(output.scores) > 0:
109
+ first_token_logit = output.scores[0][0] # ilk tokenin logits
110
+ top_logit_score = torch.max(first_token_logit).item()
111
+ log(f"🔎 İlk token logit skoru: {top_logit_score:.4f}")
112
+
113
+ if top_logit_score < CONFIDENCE_THRESHOLD:
114
+ fallback = random.choice(FALLBACK_ANSWERS)
115
+ log(f"⚠️ Düşük güven: fallback cevabı gönderiliyor: {fallback}")
116
+ answer = fallback
117
+
118
  chat_history.append({"user": user_input, "bot": answer})
119
  log(f"🗨️ Soru: {user_input} → Yanıt: {answer[:60]}...")
120
  return {"answer": answer, "chat_history": chat_history}
 
124
 
125
  def setup_model():
126
  try:
127
+ global model, tokenizer
128
 
129
  log("📦 Fine-tune zip indiriliyor...")
130
  zip_path = hf_hub_download(
 
153
  log("➕ LoRA adapter uygulanıyor...")
154
  peft_model = PeftModel.from_pretrained(base_model, os.path.join(extract_dir, "output"))
155
 
156
+ model = peft_model.model
157
+ model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
 
159
+ log("✅ Model başarıyla yüklendi.")
160
  except Exception as e:
161
  log(f"❌ setup_model() sırasında hata oluştu: {e}")
162
 
 
174
  import time
175
  time.sleep(60)
176
  except Exception as e:
177
+ log(f"❌ Ana bekleme döngüsünde hata: {e}")