ciyidogan commited on
Commit
fe4457a
·
verified ·
1 Parent(s): ab83111

Update fine_tune_inference_test_mistral.py

Browse files
Files changed (1) hide show
  1. fine_tune_inference_test_mistral.py +28 -12
fine_tune_inference_test_mistral.py CHANGED
@@ -5,9 +5,15 @@ from pydantic import BaseModel
5
  from transformers import AutoTokenizer, AutoModelForCausalLM
6
  from peft import PeftModel
7
  from huggingface_hub import hf_hub_download
 
8
 
9
- # === Ayarlar
10
  HF_TOKEN = os.getenv("HF_TOKEN")
 
 
 
 
 
11
  MODEL_BASE = "mistralai/Mistral-7B-Instruct-v0.2"
12
  USE_FINE_TUNE = False
13
  FINE_TUNE_REPO = "UcsTurkey/trained-zips"
@@ -23,8 +29,7 @@ FALLBACK_ANSWERS = [
23
  # === Log
24
  def log(message):
25
  timestamp = time.strftime("%H:%M:%S")
26
- print(f"[{timestamp}] {message}")
27
- os.sys.stdout.flush()
28
 
29
  # === FastAPI
30
  app = FastAPI()
@@ -58,7 +63,7 @@ def root():
58
  body: JSON.stringify({ user_input: input })
59
  });
60
  const data = await res.json();
61
- document.getElementById('output').value = data.answer || data.response || data.error || 'Hata oluştu.';
62
  }
63
  </script>
64
  </body>
@@ -77,13 +82,24 @@ def chat(msg: Message):
77
  return {"error": "Boş giriş"}
78
 
79
  messages = [{"role": "user", "content": user_input}]
80
- input_ids = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True).to(model.device)
 
 
 
 
 
 
 
 
81
 
82
  generate_args = {
83
  "max_new_tokens": 128,
84
  "return_dict_in_generate": True,
85
  "output_scores": True,
86
- "do_sample": USE_SAMPLING
 
 
 
87
  }
88
 
89
  if USE_SAMPLING:
@@ -94,10 +110,11 @@ def chat(msg: Message):
94
  })
95
 
96
  with torch.no_grad():
97
- output = model.generate(input_ids=input_ids, **generate_args)
98
 
99
  decoded = tokenizer.decode(output.sequences[0], skip_special_tokens=True)
100
- answer = decoded.split("</s>")[-1].strip()
 
101
 
102
  if output.scores and len(output.scores) > 0:
103
  first_token_score = output.scores[0][0]
@@ -119,14 +136,13 @@ def chat(msg: Message):
119
  return {"error": str(e)}
120
 
121
  def detect_env():
122
- device = "cuda" if torch.cuda.is_available() else "cpu"
123
- return device
124
 
125
  def setup_model():
126
  global model, tokenizer
127
  try:
128
  device = detect_env()
129
- dtype = torch.float32
130
 
131
  if USE_FINE_TUNE:
132
  log("📦 Fine-tune zip indiriliyor...")
@@ -144,13 +160,13 @@ def setup_model():
144
  tokenizer = AutoTokenizer.from_pretrained(os.path.join(extract_dir, "output"), use_fast=False)
145
  base_model = AutoModelForCausalLM.from_pretrained(MODEL_BASE, torch_dtype=dtype).to(device)
146
  model = PeftModel.from_pretrained(base_model, os.path.join(extract_dir, "output")).to(device)
147
-
148
  else:
149
  log("🧠 Ana model indiriliyor...")
150
  tokenizer = AutoTokenizer.from_pretrained(MODEL_BASE, use_fast=False)
151
  model = AutoModelForCausalLM.from_pretrained(MODEL_BASE, torch_dtype=dtype).to(device)
152
 
153
  tokenizer.pad_token = tokenizer.pad_token or tokenizer.eos_token
 
154
  model.eval()
155
  log("✅ Model başarıyla yüklendi.")
156
  except Exception as e:
 
5
  from transformers import AutoTokenizer, AutoModelForCausalLM
6
  from peft import PeftModel
7
  from huggingface_hub import hf_hub_download
8
+ from datetime import datetime
9
 
10
+ # === Ortam
11
  HF_TOKEN = os.getenv("HF_TOKEN")
12
+ os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
13
+ os.environ["TORCH_HOME"] = "/app/.torch_cache"
14
+ os.makedirs("/app/.torch_cache", exist_ok=True)
15
+
16
+ # === Ayarlar
17
  MODEL_BASE = "mistralai/Mistral-7B-Instruct-v0.2"
18
  USE_FINE_TUNE = False
19
  FINE_TUNE_REPO = "UcsTurkey/trained-zips"
 
29
  # === Log
30
  def log(message):
31
  timestamp = time.strftime("%H:%M:%S")
32
+ print(f"[{timestamp}] {message}", flush=True)
 
33
 
34
  # === FastAPI
35
  app = FastAPI()
 
63
  body: JSON.stringify({ user_input: input })
64
  });
65
  const data = await res.json();
66
+ document.getElementById('output').value = data.answer || data.error || 'Hata oluştu.';
67
  }
68
  </script>
69
  </body>
 
82
  return {"error": "Boş giriş"}
83
 
84
  messages = [{"role": "user", "content": user_input}]
85
+ input_ids = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True)
86
+ if isinstance(input_ids, torch.Tensor):
87
+ input_ids = input_ids.to(model.device)
88
+ attention_mask = (input_ids != tokenizer.pad_token_id).long()
89
+ inputs = {"input_ids": input_ids, "attention_mask": attention_mask}
90
+ else:
91
+ inputs = {k: v.to(model.device) for k, v in input_ids.items()}
92
+ if "attention_mask" not in inputs:
93
+ inputs["attention_mask"] = (inputs["input_ids"] != tokenizer.pad_token_id).long()
94
 
95
  generate_args = {
96
  "max_new_tokens": 128,
97
  "return_dict_in_generate": True,
98
  "output_scores": True,
99
+ "do_sample": USE_SAMPLING,
100
+ "pad_token_id": tokenizer.pad_token_id,
101
+ "eos_token_id": tokenizer.eos_token_id,
102
+ "renormalize_logits": True
103
  }
104
 
105
  if USE_SAMPLING:
 
110
  })
111
 
112
  with torch.no_grad():
113
+ output = model.generate(**inputs, **generate_args)
114
 
115
  decoded = tokenizer.decode(output.sequences[0], skip_special_tokens=True)
116
+ input_text = tokenizer.decode(inputs["input_ids"][0], skip_special_tokens=True)
117
+ answer = decoded.replace(input_text, "").strip()
118
 
119
  if output.scores and len(output.scores) > 0:
120
  first_token_score = output.scores[0][0]
 
136
  return {"error": str(e)}
137
 
138
  def detect_env():
139
+ return "cuda" if torch.cuda.is_available() else "cpu"
 
140
 
141
  def setup_model():
142
  global model, tokenizer
143
  try:
144
  device = detect_env()
145
+ dtype = torch.float32 # Dilersen torch.bfloat16 yapabilirsin
146
 
147
  if USE_FINE_TUNE:
148
  log("📦 Fine-tune zip indiriliyor...")
 
160
  tokenizer = AutoTokenizer.from_pretrained(os.path.join(extract_dir, "output"), use_fast=False)
161
  base_model = AutoModelForCausalLM.from_pretrained(MODEL_BASE, torch_dtype=dtype).to(device)
162
  model = PeftModel.from_pretrained(base_model, os.path.join(extract_dir, "output")).to(device)
 
163
  else:
164
  log("🧠 Ana model indiriliyor...")
165
  tokenizer = AutoTokenizer.from_pretrained(MODEL_BASE, use_fast=False)
166
  model = AutoModelForCausalLM.from_pretrained(MODEL_BASE, torch_dtype=dtype).to(device)
167
 
168
  tokenizer.pad_token = tokenizer.pad_token or tokenizer.eos_token
169
+ model.config.pad_token_id = tokenizer.pad_token_id
170
  model.eval()
171
  log("✅ Model başarıyla yüklendi.")
172
  except Exception as e: