Update fine_tune_inference_test_mistral.py
Browse files
fine_tune_inference_test_mistral.py
CHANGED
@@ -1,19 +1,15 @@
|
|
1 |
-
import os, torch,
|
2 |
from fastapi import FastAPI
|
3 |
from fastapi.responses import HTMLResponse, JSONResponse
|
4 |
from pydantic import BaseModel
|
5 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
6 |
-
from peft import PeftModel
|
7 |
-
from huggingface_hub import hf_hub_download
|
8 |
from datetime import datetime
|
9 |
import random
|
10 |
|
11 |
# === Sabitler ===
|
12 |
HF_TOKEN = os.environ.get("HF_TOKEN")
|
13 |
MODEL_BASE = "mistralai/Mistral-7B-Instruct-v0.2"
|
14 |
-
|
15 |
-
FINE_TUNE_REPO = "UcsTurkey/trained-zips"
|
16 |
-
USE_FINE_TUNE = False # ✅ Ana modeli test etmek için False yap
|
17 |
USE_SAMPLING = False
|
18 |
CONFIDENCE_THRESHOLD = -1.5
|
19 |
FALLBACK_ANSWERS = [
|
@@ -47,7 +43,7 @@ def root():
|
|
47 |
<html>
|
48 |
<body>
|
49 |
<h2>Mistral 7B Chat</h2>
|
50 |
-
<textarea id=\"input\" rows=\"4\" cols=\"60\" placeholder=\"
|
51 |
<button onclick=\"send()\">Gönder</button>
|
52 |
<pre id=\"output\"></pre>
|
53 |
<script>
|
@@ -77,7 +73,8 @@ def chat(msg: Message):
|
|
77 |
if not user_input:
|
78 |
return {"error": "Boş giriş"}
|
79 |
|
80 |
-
|
|
|
81 |
inputs = tokenizer(prompt, return_tensors="pt")
|
82 |
|
83 |
if not inputs or "input_ids" not in inputs:
|
@@ -86,18 +83,23 @@ def chat(msg: Message):
|
|
86 |
|
87 |
inputs = inputs.to(model.device)
|
88 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
with torch.no_grad():
|
90 |
-
output = model.generate(
|
91 |
-
**inputs,
|
92 |
-
max_new_tokens=128,
|
93 |
-
do_sample=USE_SAMPLING,
|
94 |
-
temperature=0.7 if USE_SAMPLING else None,
|
95 |
-
top_p=0.9 if USE_SAMPLING else None,
|
96 |
-
top_k=50 if USE_SAMPLING else None,
|
97 |
-
return_dict_in_generate=True,
|
98 |
-
output_scores=True,
|
99 |
-
suppress_tokens=[tokenizer.pad_token_id]
|
100 |
-
)
|
101 |
|
102 |
decoded = tokenizer.decode(output.sequences[0], skip_special_tokens=True)
|
103 |
answer = decoded[len(prompt):].strip()
|
@@ -130,37 +132,16 @@ def setup_model():
|
|
130 |
global model, tokenizer
|
131 |
try:
|
132 |
device, supports_bf16 = detect_env()
|
133 |
-
dtype = torch.
|
|
|
134 |
log(f"🧠 Ortam: {device.upper()}, dtype: {dtype}")
|
|
|
135 |
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
repo_id=FINE_TUNE_REPO,
|
140 |
-
filename=FINE_TUNE_ZIP,
|
141 |
-
repo_type="model",
|
142 |
-
token=HF_TOKEN
|
143 |
-
)
|
144 |
-
extract_path = "/app/extracted"
|
145 |
-
os.makedirs(extract_path, exist_ok=True)
|
146 |
-
with zipfile.ZipFile(zip_path, "r") as zip_ref:
|
147 |
-
zip_ref.extractall(extract_path)
|
148 |
-
|
149 |
-
tokenizer = AutoTokenizer.from_pretrained(os.path.join(extract_path, "output"))
|
150 |
-
if tokenizer.pad_token is None:
|
151 |
-
tokenizer.pad_token = tokenizer.eos_token
|
152 |
-
|
153 |
-
base = AutoModelForCausalLM.from_pretrained(MODEL_BASE, torch_dtype=dtype).to(device)
|
154 |
-
peft = PeftModel.from_pretrained(base, os.path.join(extract_path, "output"))
|
155 |
-
model = peft.model.to(device)
|
156 |
-
|
157 |
-
else:
|
158 |
-
log("🧪 Sadece ana model yüklenecek...")
|
159 |
-
tokenizer = AutoTokenizer.from_pretrained(MODEL_BASE, use_fast=False)
|
160 |
-
if tokenizer.pad_token is None:
|
161 |
-
tokenizer.pad_token = tokenizer.eos_token
|
162 |
-
model = AutoModelForCausalLM.from_pretrained(MODEL_BASE, torch_dtype=dtype).to(device)
|
163 |
|
|
|
164 |
model.eval()
|
165 |
log("✅ Model başarıyla yüklendi.")
|
166 |
|
|
|
1 |
+
import os, torch, threading, uvicorn, time, traceback
|
2 |
from fastapi import FastAPI
|
3 |
from fastapi.responses import HTMLResponse, JSONResponse
|
4 |
from pydantic import BaseModel
|
5 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
|
|
|
|
6 |
from datetime import datetime
|
7 |
import random
|
8 |
|
9 |
# === Sabitler ===
|
10 |
HF_TOKEN = os.environ.get("HF_TOKEN")
|
11 |
MODEL_BASE = "mistralai/Mistral-7B-Instruct-v0.2"
|
12 |
+
USE_FINE_TUNE = False
|
|
|
|
|
13 |
USE_SAMPLING = False
|
14 |
CONFIDENCE_THRESHOLD = -1.5
|
15 |
FALLBACK_ANSWERS = [
|
|
|
43 |
<html>
|
44 |
<body>
|
45 |
<h2>Mistral 7B Chat</h2>
|
46 |
+
<textarea id=\"input\" rows=\"4\" cols=\"60\" placeholder=\"Write your instruction...\"></textarea><br>
|
47 |
<button onclick=\"send()\">Gönder</button>
|
48 |
<pre id=\"output\"></pre>
|
49 |
<script>
|
|
|
73 |
if not user_input:
|
74 |
return {"error": "Boş giriş"}
|
75 |
|
76 |
+
# ✅ Ana modelin beklediği instruct formatı
|
77 |
+
prompt = f"### Instruction:\n{user_input}\n\n### Response:"
|
78 |
inputs = tokenizer(prompt, return_tensors="pt")
|
79 |
|
80 |
if not inputs or "input_ids" not in inputs:
|
|
|
83 |
|
84 |
inputs = inputs.to(model.device)
|
85 |
|
86 |
+
generate_args = {
|
87 |
+
"max_new_tokens": 128,
|
88 |
+
"return_dict_in_generate": True,
|
89 |
+
"output_scores": True,
|
90 |
+
"suppress_tokens": [tokenizer.pad_token_id],
|
91 |
+
"do_sample": USE_SAMPLING
|
92 |
+
}
|
93 |
+
|
94 |
+
if USE_SAMPLING:
|
95 |
+
generate_args.update({
|
96 |
+
"temperature": 0.7,
|
97 |
+
"top_p": 0.9,
|
98 |
+
"top_k": 50
|
99 |
+
})
|
100 |
+
|
101 |
with torch.no_grad():
|
102 |
+
output = model.generate(**inputs, **generate_args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
|
104 |
decoded = tokenizer.decode(output.sequences[0], skip_special_tokens=True)
|
105 |
answer = decoded[len(prompt):].strip()
|
|
|
132 |
global model, tokenizer
|
133 |
try:
|
134 |
device, supports_bf16 = detect_env()
|
135 |
+
dtype = torch.float32 # daha kararlı
|
136 |
+
|
137 |
log(f"🧠 Ortam: {device.upper()}, dtype: {dtype}")
|
138 |
+
log("🧪 Sadece ana model yüklenecek...")
|
139 |
|
140 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_BASE, use_fast=False)
|
141 |
+
if tokenizer.pad_token is None:
|
142 |
+
tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
143 |
|
144 |
+
model = AutoModelForCausalLM.from_pretrained(MODEL_BASE, torch_dtype=dtype).to(device)
|
145 |
model.eval()
|
146 |
log("✅ Model başarıyla yüklendi.")
|
147 |
|