mistral7b / fine_tune_inference_test_mistral.py
ciyidogan's picture
Update fine_tune_inference_test_mistral.py
6d135b8 verified
raw
history blame
5.86 kB
import os, torch, threading, uvicorn, time, traceback, zipfile, random
from fastapi import FastAPI
from fastapi.responses import HTMLResponse, JSONResponse
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
from huggingface_hub import hf_hub_download
# === Ayarlar
HF_TOKEN = os.getenv("HF_TOKEN")
MODEL_BASE = "malhajar/Mistral-7B-Instruct-v0.2-turkish"
USE_FINE_TUNE = False
FINE_TUNE_REPO = "UcsTurkey/trained-zips"
FINE_TUNE_ZIP = "trained_model_000_009.zip"
USE_SAMPLING = False
CONFIDENCE_THRESHOLD = -1.5
FALLBACK_ANSWERS = [
"Bu konuda maalesef bilgim yok.",
"Ne demek istediğinizi tam anlayamadım.",
"Bu soruya şu an yanıt veremiyorum."
]
# === Log
def log(message):
timestamp = time.strftime("%H:%M:%S")
print(f"[{timestamp}] {message}")
os.sys.stdout.flush()
# === FastAPI
app = FastAPI()
chat_history = []
model = None
tokenizer = None
class Message(BaseModel):
user_input: str
@app.get("/")
def health():
return {"status": "ok"}
@app.get("/start", response_class=HTMLResponse)
def root():
return """
<html>
<body>
<h2>Mistral 7B Instruct Chat</h2>
<textarea id=\"input\" rows=\"4\" cols=\"60\" placeholder=\"Write your instruction...\"></textarea><br>
<button onclick=\"send()\">Gönder</button>
<pre id=\"output\"></pre>
<script>
async function send() {
const input = document.getElementById(\"input\").value;
const res = await fetch('/chat', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ user_input: input })
});
const data = await res.json();
document.getElementById('output').innerText = data.answer || data.error || 'Hata oluştu.';
}
</script>
</body>
</html>
"""
@app.post("/chat")
def chat(msg: Message):
global model, tokenizer
try:
if model is None or tokenizer is None:
return {"error": "Model veya tokenizer henüz yüklenmedi."}
user_input = msg.user_input.strip()
if not user_input:
return {"error": "Boş giriş"}
messages = [{"role": "user", "content": user_input}]
inputs = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True)
inputs = {k: v.to(model.device) for k, v in inputs.items()}
generate_args = {
"max_new_tokens": 512,
"return_dict_in_generate": True,
"output_scores": True,
"do_sample": USE_SAMPLING
}
if USE_SAMPLING:
generate_args.update({
"temperature": 0.7,
"top_p": 0.9,
"top_k": 50
})
with torch.no_grad():
output = model.generate(**inputs, **generate_args)
prompt_text = tokenizer.decode(inputs["input_ids"][0], skip_special_tokens=True)
decoded = tokenizer.decode(output.sequences[0], skip_special_tokens=True)
answer = decoded.replace(prompt_text, "").strip()
if output.scores and len(output.scores) > 0:
first_token_score = output.scores[0][0]
if torch.isnan(first_token_score).any() or torch.isinf(first_token_score).any():
log("⚠️ Geçersiz logit (NaN/Inf) tespit edildi.")
return {"answer": random.choice(FALLBACK_ANSWERS)}
max_score = torch.max(first_token_score).item()
log(f"🔍 İlk token skoru: {max_score:.4f}")
if max_score < CONFIDENCE_THRESHOLD:
answer = random.choice(FALLBACK_ANSWERS)
chat_history.append({"user": user_input, "bot": answer})
log(f"Soru: {user_input} → Cevap: {answer[:60]}...")
return {"answer": answer, "chat_history": chat_history}
except Exception as e:
log(f"❌ /chat hatası: {e}")
traceback.print_exc()
return {"error": str(e)}
def detect_env():
device = "cuda" if torch.cuda.is_available() else "cpu"
return device
def setup_model():
global model, tokenizer
try:
device = detect_env()
dtype = torch.float32
if USE_FINE_TUNE:
log("📦 Fine-tune zip indiriliyor...")
zip_path = hf_hub_download(
repo_id=FINE_TUNE_REPO,
filename=FINE_TUNE_ZIP,
repo_type="model",
token=HF_TOKEN
)
extract_dir = "/app/extracted"
os.makedirs(extract_dir, exist_ok=True)
with zipfile.ZipFile(zip_path, "r") as zip_ref:
zip_ref.extractall(extract_dir)
tokenizer = AutoTokenizer.from_pretrained(os.path.join(extract_dir, "output"), use_fast=False)
base_model = AutoModelForCausalLM.from_pretrained(MODEL_BASE, torch_dtype=dtype).to(device)
model = PeftModel.from_pretrained(base_model, os.path.join(extract_dir, "output")).to(device)
else:
log("🧠 Ana model indiriliyor...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_BASE, use_fast=False)
model = AutoModelForCausalLM.from_pretrained(MODEL_BASE, torch_dtype=dtype).to(device)
tokenizer.pad_token = tokenizer.pad_token or tokenizer.eos_token
model.eval()
log("✅ Model başarıyla yüklendi.")
except Exception as e:
log(f"❌ Model yüklenirken hata: {e}")
traceback.print_exc()
def run_server():
log("🌐 Uvicorn başlatılıyor...")
uvicorn.run(app, host="0.0.0.0", port=7860)
log("===== Application Startup =====")
threading.Thread(target=setup_model, daemon=True).start()
threading.Thread(target=run_server, daemon=True).start()
while True:
time.sleep(60)