ciyidogan commited on
Commit
01fdd15
·
verified ·
1 Parent(s): 65269fa

Update fine_tune_inference_test_mistral.py

Browse files
Files changed (1) hide show
  1. fine_tune_inference_test_mistral.py +28 -47
fine_tune_inference_test_mistral.py CHANGED
@@ -1,19 +1,15 @@
1
- import os, torch, zipfile, threading, uvicorn, time, traceback
2
  from fastapi import FastAPI
3
  from fastapi.responses import HTMLResponse, JSONResponse
4
  from pydantic import BaseModel
5
  from transformers import AutoTokenizer, AutoModelForCausalLM
6
- from peft import PeftModel
7
- from huggingface_hub import hf_hub_download
8
  from datetime import datetime
9
  import random
10
 
11
  # === Sabitler ===
12
  HF_TOKEN = os.environ.get("HF_TOKEN")
13
  MODEL_BASE = "mistralai/Mistral-7B-Instruct-v0.2"
14
- FINE_TUNE_ZIP = "trained_model_000_009.zip"
15
- FINE_TUNE_REPO = "UcsTurkey/trained-zips"
16
- USE_FINE_TUNE = False # ✅ Ana modeli test etmek için False yap
17
  USE_SAMPLING = False
18
  CONFIDENCE_THRESHOLD = -1.5
19
  FALLBACK_ANSWERS = [
@@ -47,7 +43,7 @@ def root():
47
  <html>
48
  <body>
49
  <h2>Mistral 7B Chat</h2>
50
- <textarea id=\"input\" rows=\"4\" cols=\"60\" placeholder=\"SORU: ...\"></textarea><br>
51
  <button onclick=\"send()\">Gönder</button>
52
  <pre id=\"output\"></pre>
53
  <script>
@@ -77,7 +73,8 @@ def chat(msg: Message):
77
  if not user_input:
78
  return {"error": "Boş giriş"}
79
 
80
- prompt = f"SORU: {user_input}\nCEVAP:"
 
81
  inputs = tokenizer(prompt, return_tensors="pt")
82
 
83
  if not inputs or "input_ids" not in inputs:
@@ -86,18 +83,23 @@ def chat(msg: Message):
86
 
87
  inputs = inputs.to(model.device)
88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  with torch.no_grad():
90
- output = model.generate(
91
- **inputs,
92
- max_new_tokens=128,
93
- do_sample=USE_SAMPLING,
94
- temperature=0.7 if USE_SAMPLING else None,
95
- top_p=0.9 if USE_SAMPLING else None,
96
- top_k=50 if USE_SAMPLING else None,
97
- return_dict_in_generate=True,
98
- output_scores=True,
99
- suppress_tokens=[tokenizer.pad_token_id]
100
- )
101
 
102
  decoded = tokenizer.decode(output.sequences[0], skip_special_tokens=True)
103
  answer = decoded[len(prompt):].strip()
@@ -130,37 +132,16 @@ def setup_model():
130
  global model, tokenizer
131
  try:
132
  device, supports_bf16 = detect_env()
133
- dtype = torch.bfloat16 if supports_bf16 else torch.float32
 
134
  log(f"🧠 Ortam: {device.upper()}, dtype: {dtype}")
 
135
 
136
- if USE_FINE_TUNE:
137
- log("📦 Fine-tune zip indiriliyor...")
138
- zip_path = hf_hub_download(
139
- repo_id=FINE_TUNE_REPO,
140
- filename=FINE_TUNE_ZIP,
141
- repo_type="model",
142
- token=HF_TOKEN
143
- )
144
- extract_path = "/app/extracted"
145
- os.makedirs(extract_path, exist_ok=True)
146
- with zipfile.ZipFile(zip_path, "r") as zip_ref:
147
- zip_ref.extractall(extract_path)
148
-
149
- tokenizer = AutoTokenizer.from_pretrained(os.path.join(extract_path, "output"))
150
- if tokenizer.pad_token is None:
151
- tokenizer.pad_token = tokenizer.eos_token
152
-
153
- base = AutoModelForCausalLM.from_pretrained(MODEL_BASE, torch_dtype=dtype).to(device)
154
- peft = PeftModel.from_pretrained(base, os.path.join(extract_path, "output"))
155
- model = peft.model.to(device)
156
-
157
- else:
158
- log("🧪 Sadece ana model yüklenecek...")
159
- tokenizer = AutoTokenizer.from_pretrained(MODEL_BASE, use_fast=False)
160
- if tokenizer.pad_token is None:
161
- tokenizer.pad_token = tokenizer.eos_token
162
- model = AutoModelForCausalLM.from_pretrained(MODEL_BASE, torch_dtype=dtype).to(device)
163
 
 
164
  model.eval()
165
  log("✅ Model başarıyla yüklendi.")
166
 
 
1
+ import os, torch, threading, uvicorn, time, traceback
2
  from fastapi import FastAPI
3
  from fastapi.responses import HTMLResponse, JSONResponse
4
  from pydantic import BaseModel
5
  from transformers import AutoTokenizer, AutoModelForCausalLM
 
 
6
  from datetime import datetime
7
  import random
8
 
9
  # === Sabitler ===
10
  HF_TOKEN = os.environ.get("HF_TOKEN")
11
  MODEL_BASE = "mistralai/Mistral-7B-Instruct-v0.2"
12
+ USE_FINE_TUNE = False
 
 
13
  USE_SAMPLING = False
14
  CONFIDENCE_THRESHOLD = -1.5
15
  FALLBACK_ANSWERS = [
 
43
  <html>
44
  <body>
45
  <h2>Mistral 7B Chat</h2>
46
+ <textarea id=\"input\" rows=\"4\" cols=\"60\" placeholder=\"Write your instruction...\"></textarea><br>
47
  <button onclick=\"send()\">Gönder</button>
48
  <pre id=\"output\"></pre>
49
  <script>
 
73
  if not user_input:
74
  return {"error": "Boş giriş"}
75
 
76
+ # Ana modelin beklediği instruct formatı
77
+ prompt = f"### Instruction:\n{user_input}\n\n### Response:"
78
  inputs = tokenizer(prompt, return_tensors="pt")
79
 
80
  if not inputs or "input_ids" not in inputs:
 
83
 
84
  inputs = inputs.to(model.device)
85
 
86
+ generate_args = {
87
+ "max_new_tokens": 128,
88
+ "return_dict_in_generate": True,
89
+ "output_scores": True,
90
+ "suppress_tokens": [tokenizer.pad_token_id],
91
+ "do_sample": USE_SAMPLING
92
+ }
93
+
94
+ if USE_SAMPLING:
95
+ generate_args.update({
96
+ "temperature": 0.7,
97
+ "top_p": 0.9,
98
+ "top_k": 50
99
+ })
100
+
101
  with torch.no_grad():
102
+ output = model.generate(**inputs, **generate_args)
 
 
 
 
 
 
 
 
 
 
103
 
104
  decoded = tokenizer.decode(output.sequences[0], skip_special_tokens=True)
105
  answer = decoded[len(prompt):].strip()
 
132
  global model, tokenizer
133
  try:
134
  device, supports_bf16 = detect_env()
135
+ dtype = torch.float32 # daha kararlı
136
+
137
  log(f"🧠 Ortam: {device.upper()}, dtype: {dtype}")
138
+ log("🧪 Sadece ana model yüklenecek...")
139
 
140
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_BASE, use_fast=False)
141
+ if tokenizer.pad_token is None:
142
+ tokenizer.pad_token = tokenizer.eos_token
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
 
144
+ model = AutoModelForCausalLM.from_pretrained(MODEL_BASE, torch_dtype=dtype).to(device)
145
  model.eval()
146
  log("✅ Model başarıyla yüklendi.")
147