fine-tune-inference-test / fine_tune_inference_test.py
ciyidogan's picture
Update fine_tune_inference_test.py
9cb7961 verified
raw
history blame
4.12 kB
import os
import threading
import uvicorn
from fastapi import FastAPI, Request
from fastapi.responses import HTMLResponse, JSONResponse
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from datasets import load_dataset
from peft import PeftModel
import torch # eksikse gerekli
# ✅ Sabitler
HF_TOKEN = os.environ.get("HF_TOKEN")
MODEL_BASE = "UcsTurkey/kanarya-750m-fixed"
FINE_TUNE_ZIP = "trained_model_000_100.zip"
FINE_TUNE_REPO = "UcsTurkey/trained-zips"
RAG_DATA_FILE = "merged_dataset_000_100.parquet"
RAG_DATA_REPO = "UcsTurkey/turkish-general-culture-tokenized"
# ✅ FastAPI app
app = FastAPI()
chat_history = []
pipe = None # ❗ Global olarak tanımlıyoruz
class Message(BaseModel):
user_input: str
@app.get("/", response_class=HTMLResponse)
def root():
return """
<html>
<head><title>Fine-Tune Chat</title></head>
<body>
<h2>📘 Fine-tune Chat Test</h2>
<textarea id="input" rows="4" cols="60" placeholder="Bir şeyler yaz..."></textarea><br><br>
<button onclick="send()">Gönder</button>
<pre id="output"></pre>
<script>
async function send() {
const input = document.getElementById("input").value;
const res = await fetch("/chat", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({ user_input: input })
});
const data = await res.json();
document.getElementById("output").innerText = data.answer || data.error || "Hata oluştu.";
}
</script>
</body>
</html>
"""
@app.post("/chat")
def chat(msg: Message):
global pipe
if pipe is None:
return {"error": "Model henüz yüklenmedi, lütfen birkaç saniye sonra tekrar deneyin."}
user_input = msg.user_input.strip()
if not user_input:
return {"error": "Boş giriş"}
full_prompt = ""
for turn in chat_history:
full_prompt += f"Kullanıcı: {turn['user']}\nAsistan: {turn['bot']}\n"
full_prompt += f"Kullanıcı: {user_input}\nAsistan:"
result = pipe(full_prompt, max_new_tokens=200, do_sample=True, temperature=0.7)
answer = result[0]["generated_text"][len(full_prompt):].strip()
chat_history.append({"user": user_input, "bot": answer})
return {"answer": answer, "chat_history": chat_history}
# ✅ Model ve RAG yükleme
def setup_model():
global pipe
from huggingface_hub import hf_hub_download
import zipfile
print("📦 Fine-tune zip indiriliyor...")
zip_path = hf_hub_download(
repo_id=FINE_TUNE_REPO,
filename=FINE_TUNE_ZIP,
repo_type="model",
token=HF_TOKEN
)
extract_dir = "/app/extracted"
os.makedirs(extract_dir, exist_ok=True)
with zipfile.ZipFile(zip_path, "r") as zip_ref:
zip_ref.extractall(extract_dir)
print("🔁 Tokenizer yükleniyor...")
tokenizer = AutoTokenizer.from_pretrained(os.path.join(extract_dir, "output"))
print("🧠 Base model indiriliyor...")
base_model = AutoModelForCausalLM.from_pretrained(MODEL_BASE, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32)
print("➕ LoRA adapter uygulanıyor...")
model = PeftModel.from_pretrained(base_model, os.path.join(extract_dir, "output"))
print("📚 RAG dataseti yükleniyor...")
rag = load_dataset(RAG_DATA_REPO, data_files=RAG_DATA_FILE, split="train", token=HF_TOKEN)
print(f"🔍 RAG boyutu: {len(rag)}")
# ✅ pipeline oluşturuluyor
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device=0 if torch.cuda.is_available() else -1
)
# ✅ Uygulama başladığında modeli yükle
threading.Thread(target=setup_model, daemon=True).start()
# 🧘 Eğitim sonrası uygulama restart olmasın diye bekleme
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)