Update fine_tune_inference_test_mistral.py
Browse files
fine_tune_inference_test_mistral.py
CHANGED
@@ -5,9 +5,15 @@ from pydantic import BaseModel
|
|
5 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
6 |
from peft import PeftModel
|
7 |
from huggingface_hub import hf_hub_download
|
|
|
8 |
|
9 |
-
# ===
|
10 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
|
|
|
|
|
|
|
|
|
|
11 |
MODEL_BASE = "mistralai/Mistral-7B-Instruct-v0.2"
|
12 |
USE_FINE_TUNE = False
|
13 |
FINE_TUNE_REPO = "UcsTurkey/trained-zips"
|
@@ -23,8 +29,7 @@ FALLBACK_ANSWERS = [
|
|
23 |
# === Log
|
24 |
def log(message):
|
25 |
timestamp = time.strftime("%H:%M:%S")
|
26 |
-
print(f"[{timestamp}] {message}")
|
27 |
-
os.sys.stdout.flush()
|
28 |
|
29 |
# === FastAPI
|
30 |
app = FastAPI()
|
@@ -58,7 +63,7 @@ def root():
|
|
58 |
body: JSON.stringify({ user_input: input })
|
59 |
});
|
60 |
const data = await res.json();
|
61 |
-
document.getElementById('output').value = data.answer || data.
|
62 |
}
|
63 |
</script>
|
64 |
</body>
|
@@ -77,13 +82,24 @@ def chat(msg: Message):
|
|
77 |
return {"error": "Boş giriş"}
|
78 |
|
79 |
messages = [{"role": "user", "content": user_input}]
|
80 |
-
input_ids = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
|
82 |
generate_args = {
|
83 |
"max_new_tokens": 128,
|
84 |
"return_dict_in_generate": True,
|
85 |
"output_scores": True,
|
86 |
-
"do_sample": USE_SAMPLING
|
|
|
|
|
|
|
87 |
}
|
88 |
|
89 |
if USE_SAMPLING:
|
@@ -94,10 +110,11 @@ def chat(msg: Message):
|
|
94 |
})
|
95 |
|
96 |
with torch.no_grad():
|
97 |
-
output = model.generate(
|
98 |
|
99 |
decoded = tokenizer.decode(output.sequences[0], skip_special_tokens=True)
|
100 |
-
|
|
|
101 |
|
102 |
if output.scores and len(output.scores) > 0:
|
103 |
first_token_score = output.scores[0][0]
|
@@ -119,14 +136,13 @@ def chat(msg: Message):
|
|
119 |
return {"error": str(e)}
|
120 |
|
121 |
def detect_env():
|
122 |
-
|
123 |
-
return device
|
124 |
|
125 |
def setup_model():
|
126 |
global model, tokenizer
|
127 |
try:
|
128 |
device = detect_env()
|
129 |
-
dtype = torch.float32
|
130 |
|
131 |
if USE_FINE_TUNE:
|
132 |
log("📦 Fine-tune zip indiriliyor...")
|
@@ -144,13 +160,13 @@ def setup_model():
|
|
144 |
tokenizer = AutoTokenizer.from_pretrained(os.path.join(extract_dir, "output"), use_fast=False)
|
145 |
base_model = AutoModelForCausalLM.from_pretrained(MODEL_BASE, torch_dtype=dtype).to(device)
|
146 |
model = PeftModel.from_pretrained(base_model, os.path.join(extract_dir, "output")).to(device)
|
147 |
-
|
148 |
else:
|
149 |
log("🧠 Ana model indiriliyor...")
|
150 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_BASE, use_fast=False)
|
151 |
model = AutoModelForCausalLM.from_pretrained(MODEL_BASE, torch_dtype=dtype).to(device)
|
152 |
|
153 |
tokenizer.pad_token = tokenizer.pad_token or tokenizer.eos_token
|
|
|
154 |
model.eval()
|
155 |
log("✅ Model başarıyla yüklendi.")
|
156 |
except Exception as e:
|
|
|
5 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
6 |
from peft import PeftModel
|
7 |
from huggingface_hub import hf_hub_download
|
8 |
+
from datetime import datetime
|
9 |
|
10 |
+
# === Ortam
|
11 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
12 |
+
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
|
13 |
+
os.environ["TORCH_HOME"] = "/app/.torch_cache"
|
14 |
+
os.makedirs("/app/.torch_cache", exist_ok=True)
|
15 |
+
|
16 |
+
# === Ayarlar
|
17 |
MODEL_BASE = "mistralai/Mistral-7B-Instruct-v0.2"
|
18 |
USE_FINE_TUNE = False
|
19 |
FINE_TUNE_REPO = "UcsTurkey/trained-zips"
|
|
|
29 |
# === Log
|
30 |
def log(message):
|
31 |
timestamp = time.strftime("%H:%M:%S")
|
32 |
+
print(f"[{timestamp}] {message}", flush=True)
|
|
|
33 |
|
34 |
# === FastAPI
|
35 |
app = FastAPI()
|
|
|
63 |
body: JSON.stringify({ user_input: input })
|
64 |
});
|
65 |
const data = await res.json();
|
66 |
+
document.getElementById('output').value = data.answer || data.error || 'Hata oluştu.';
|
67 |
}
|
68 |
</script>
|
69 |
</body>
|
|
|
82 |
return {"error": "Boş giriş"}
|
83 |
|
84 |
messages = [{"role": "user", "content": user_input}]
|
85 |
+
input_ids = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True)
|
86 |
+
if isinstance(input_ids, torch.Tensor):
|
87 |
+
input_ids = input_ids.to(model.device)
|
88 |
+
attention_mask = (input_ids != tokenizer.pad_token_id).long()
|
89 |
+
inputs = {"input_ids": input_ids, "attention_mask": attention_mask}
|
90 |
+
else:
|
91 |
+
inputs = {k: v.to(model.device) for k, v in input_ids.items()}
|
92 |
+
if "attention_mask" not in inputs:
|
93 |
+
inputs["attention_mask"] = (inputs["input_ids"] != tokenizer.pad_token_id).long()
|
94 |
|
95 |
generate_args = {
|
96 |
"max_new_tokens": 128,
|
97 |
"return_dict_in_generate": True,
|
98 |
"output_scores": True,
|
99 |
+
"do_sample": USE_SAMPLING,
|
100 |
+
"pad_token_id": tokenizer.pad_token_id,
|
101 |
+
"eos_token_id": tokenizer.eos_token_id,
|
102 |
+
"renormalize_logits": True
|
103 |
}
|
104 |
|
105 |
if USE_SAMPLING:
|
|
|
110 |
})
|
111 |
|
112 |
with torch.no_grad():
|
113 |
+
output = model.generate(**inputs, **generate_args)
|
114 |
|
115 |
decoded = tokenizer.decode(output.sequences[0], skip_special_tokens=True)
|
116 |
+
input_text = tokenizer.decode(inputs["input_ids"][0], skip_special_tokens=True)
|
117 |
+
answer = decoded.replace(input_text, "").strip()
|
118 |
|
119 |
if output.scores and len(output.scores) > 0:
|
120 |
first_token_score = output.scores[0][0]
|
|
|
136 |
return {"error": str(e)}
|
137 |
|
138 |
def detect_env():
|
139 |
+
return "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
140 |
|
141 |
def setup_model():
|
142 |
global model, tokenizer
|
143 |
try:
|
144 |
device = detect_env()
|
145 |
+
dtype = torch.float32 # Dilersen torch.bfloat16 yapabilirsin
|
146 |
|
147 |
if USE_FINE_TUNE:
|
148 |
log("📦 Fine-tune zip indiriliyor...")
|
|
|
160 |
tokenizer = AutoTokenizer.from_pretrained(os.path.join(extract_dir, "output"), use_fast=False)
|
161 |
base_model = AutoModelForCausalLM.from_pretrained(MODEL_BASE, torch_dtype=dtype).to(device)
|
162 |
model = PeftModel.from_pretrained(base_model, os.path.join(extract_dir, "output")).to(device)
|
|
|
163 |
else:
|
164 |
log("🧠 Ana model indiriliyor...")
|
165 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_BASE, use_fast=False)
|
166 |
model = AutoModelForCausalLM.from_pretrained(MODEL_BASE, torch_dtype=dtype).to(device)
|
167 |
|
168 |
tokenizer.pad_token = tokenizer.pad_token or tokenizer.eos_token
|
169 |
+
model.config.pad_token_id = tokenizer.pad_token_id
|
170 |
model.eval()
|
171 |
log("✅ Model başarıyla yüklendi.")
|
172 |
except Exception as e:
|