ciyidogan commited on
Commit
4c1fee6
·
verified ·
1 Parent(s): 0a70c08

Update fine_tune_inference_test_mistral.py

Browse files
Files changed (1) hide show
  1. fine_tune_inference_test_mistral.py +153 -0
fine_tune_inference_test_mistral.py CHANGED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, torch, zipfile, threading, uvicorn
2
+ from fastapi import FastAPI
3
+ from fastapi.responses import HTMLResponse, JSONResponse
4
+ from pydantic import BaseModel
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM
6
+ from peft import PeftModel
7
+ from huggingface_hub import hf_hub_download
8
+ from datetime import datetime
9
+ import random
10
+
11
+ # === Sabitler ===
12
+ HF_TOKEN = os.environ.get("HF_TOKEN")
13
+ MODEL_BASE = "mistralai/Mistral-7B-Instruct-v0.2"
14
+ FINE_TUNE_ZIP = "trained_model_000_009.zip"
15
+ FINE_TUNE_REPO = "UcsTurkey/trained-zips"
16
+ USE_SAMPLING = False
17
+ CONFIDENCE_THRESHOLD = -1.5
18
+ FALLBACK_ANSWERS = [
19
+ "Bu konuda maalesef bilgim yok.",
20
+ "Ne demek istediğinizi tam anlayamadım.",
21
+ "Bu soruya şu an yanıt veremiyorum."
22
+ ]
23
+
24
+ # === Log
25
+ def log(message):
26
+ timestamp = datetime.now().strftime("%H:%M:%S")
27
+ print(f"[{timestamp}] {message}")
28
+ os.sys.stdout.flush()
29
+
30
+ # === FastAPI
31
+ app = FastAPI()
32
+ chat_history = []
33
+ model = None
34
+ tokenizer = None
35
+
36
+ class Message(BaseModel):
37
+ user_input: str
38
+
39
+ @app.get("/")
40
+ def health():
41
+ return {"status": "ok"}
42
+
43
+ @app.get("/start", response_class=HTMLResponse)
44
+ def root():
45
+ return """
46
+ <html>
47
+ <body>
48
+ <h2>Mistral 7B Chat</h2>
49
+ <textarea id=\"input\" rows=\"4\" cols=\"60\" placeholder=\"SORU: ...\"></textarea><br>
50
+ <button onclick=\"send()\">Gönder</button>
51
+ <pre id=\"output\"></pre>
52
+ <script>
53
+ async function send() {
54
+ const input = document.getElementById(\"input\").value;
55
+ const res = await fetch('/chat', {
56
+ method: 'POST',
57
+ headers: { 'Content-Type': 'application/json' },
58
+ body: JSON.stringify({ user_input: input })
59
+ });
60
+ const data = await res.json();
61
+ document.getElementById('output').innerText = data.answer || data.error || 'Hata oluştu.';
62
+ }
63
+ </script>
64
+ </body>
65
+ </html>
66
+ """
67
+
68
+ @app.post("/chat")
69
+ def chat(msg: Message):
70
+ global model, tokenizer
71
+ try:
72
+ if model is None:
73
+ return {"error": "Model yüklenmedi"}
74
+ user_input = msg.user_input.strip()
75
+ if not user_input:
76
+ return {"error": "Boş giriş"}
77
+ prompt = f"SORU: {user_input}\nCEVAP:"
78
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
79
+ with torch.no_grad():
80
+ output = model.generate(
81
+ **inputs,
82
+ max_new_tokens=128,
83
+ do_sample=USE_SAMPLING,
84
+ temperature=0.7 if USE_SAMPLING else None,
85
+ top_p=0.9 if USE_SAMPLING else None,
86
+ top_k=50 if USE_SAMPLING else None,
87
+ return_dict_in_generate=True,
88
+ output_scores=True,
89
+ suppress_tokens=[tokenizer.pad_token_id]
90
+ )
91
+ decoded = tokenizer.decode(output.sequences[0], skip_special_tokens=True)
92
+ answer = decoded[len(prompt):].strip()
93
+
94
+ if output.scores and len(output.scores) > 0:
95
+ first_token_score = output.scores[0][0]
96
+ if torch.isnan(first_token_score).any() or torch.isinf(first_token_score).any():
97
+ log("⚠️ Geçersiz logit (NaN/Inf) tespit edildi.")
98
+ return {"answer": random.choice(FALLBACK_ANSWERS)}
99
+ max_score = torch.max(first_token_score).item()
100
+ log(f"🔍 İlk token skoru: {max_score:.4f}")
101
+ if max_score < CONFIDENCE_THRESHOLD:
102
+ answer = random.choice(FALLBACK_ANSWERS)
103
+
104
+ chat_history.append({"user": user_input, "bot": answer})
105
+ log(f"Soru: {user_input} → Cevap: {answer[:60]}...")
106
+ return {"answer": answer, "chat_history": chat_history}
107
+ except Exception as e:
108
+ log(f"❌ /chat hatası: {e}")
109
+ return {"error": str(e)}
110
+
111
+ def detect_env():
112
+ device = "cuda" if torch.cuda.is_available() else "cpu"
113
+ supports_bf16 = torch.cuda.is_available() and torch.cuda.get_device_capability(0)[0] >= 8
114
+ return device, supports_bf16
115
+
116
+ def setup_model():
117
+ global model, tokenizer
118
+ try:
119
+ log("📦 Zip indiriliyor...")
120
+ zip_path = hf_hub_download(
121
+ repo_id=FINE_TUNE_REPO,
122
+ filename=FINE_TUNE_ZIP,
123
+ repo_type="model",
124
+ token=HF_TOKEN
125
+ )
126
+ extract_path = "/app/extracted"
127
+ os.makedirs(extract_path, exist_ok=True)
128
+ with zipfile.ZipFile(zip_path, "r") as zip_ref:
129
+ zip_ref.extractall(extract_path)
130
+ tokenizer = AutoTokenizer.from_pretrained(os.path.join(extract_path, "output"))
131
+ if tokenizer.pad_token is None:
132
+ tokenizer.pad_token = tokenizer.eos_token
133
+
134
+ device, supports_bf16 = detect_env()
135
+ dtype = torch.bfloat16 if supports_bf16 else torch.float32
136
+ log(f"🧠 Ortam: {device.upper()}, dtype: {dtype}")
137
+ base = AutoModelForCausalLM.from_pretrained(MODEL_BASE, torch_dtype=dtype).to(device)
138
+ peft = PeftModel.from_pretrained(base, os.path.join(extract_path, "output"))
139
+ model = peft.model.to(device)
140
+ model.eval()
141
+ log("✅ Model yüklendi.")
142
+ except Exception as e:
143
+ log(f"❌ Model setup hatası: {e}")
144
+
145
+ def run_server():
146
+ log("🌐 Uvicorn başlatılıyor...")
147
+ uvicorn.run(app, host="0.0.0.0", port=7860)
148
+
149
+ log("🚀 Başlatılıyor...")
150
+ threading.Thread(target=setup_model, daemon=True).start()
151
+ threading.Thread(target=run_server, daemon=True).start()
152
+ while True:
153
+ time.sleep(60)