File size: 4,123 Bytes
6734e84 6398aea 0a5b12b 6734e84 0a5b12b 9cb7961 6734e84 0a5b12b 6734e84 0a5b12b 6734e84 9cb7961 6734e84 cc493ed 6734e84 9cb7961 6734e84 0a5b12b 6734e84 0a5b12b 9cb7961 0a5b12b 6734e84 9cb7961 6734e84 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
import os
import threading
import uvicorn
from fastapi import FastAPI, Request
from fastapi.responses import HTMLResponse, JSONResponse
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from datasets import load_dataset
from peft import PeftModel
import torch # eksikse gerekli
# ✅ Sabitler
HF_TOKEN = os.environ.get("HF_TOKEN")
MODEL_BASE = "UcsTurkey/kanarya-750m-fixed"
FINE_TUNE_ZIP = "trained_model_000_100.zip"
FINE_TUNE_REPO = "UcsTurkey/trained-zips"
RAG_DATA_FILE = "merged_dataset_000_100.parquet"
RAG_DATA_REPO = "UcsTurkey/turkish-general-culture-tokenized"
# ✅ FastAPI app
app = FastAPI()
chat_history = []
pipe = None # ❗ Global olarak tanımlıyoruz
class Message(BaseModel):
user_input: str
@app.get("/", response_class=HTMLResponse)
def root():
return """
<html>
<head><title>Fine-Tune Chat</title></head>
<body>
<h2>📘 Fine-tune Chat Test</h2>
<textarea id="input" rows="4" cols="60" placeholder="Bir şeyler yaz..."></textarea><br><br>
<button onclick="send()">Gönder</button>
<pre id="output"></pre>
<script>
async function send() {
const input = document.getElementById("input").value;
const res = await fetch("/chat", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({ user_input: input })
});
const data = await res.json();
document.getElementById("output").innerText = data.answer || data.error || "Hata oluştu.";
}
</script>
</body>
</html>
"""
@app.post("/chat")
def chat(msg: Message):
global pipe
if pipe is None:
return {"error": "Model henüz yüklenmedi, lütfen birkaç saniye sonra tekrar deneyin."}
user_input = msg.user_input.strip()
if not user_input:
return {"error": "Boş giriş"}
full_prompt = ""
for turn in chat_history:
full_prompt += f"Kullanıcı: {turn['user']}\nAsistan: {turn['bot']}\n"
full_prompt += f"Kullanıcı: {user_input}\nAsistan:"
result = pipe(full_prompt, max_new_tokens=200, do_sample=True, temperature=0.7)
answer = result[0]["generated_text"][len(full_prompt):].strip()
chat_history.append({"user": user_input, "bot": answer})
return {"answer": answer, "chat_history": chat_history}
# ✅ Model ve RAG yükleme
def setup_model():
global pipe
from huggingface_hub import hf_hub_download
import zipfile
print("📦 Fine-tune zip indiriliyor...")
zip_path = hf_hub_download(
repo_id=FINE_TUNE_REPO,
filename=FINE_TUNE_ZIP,
repo_type="model",
token=HF_TOKEN
)
extract_dir = "/app/extracted"
os.makedirs(extract_dir, exist_ok=True)
with zipfile.ZipFile(zip_path, "r") as zip_ref:
zip_ref.extractall(extract_dir)
print("🔁 Tokenizer yükleniyor...")
tokenizer = AutoTokenizer.from_pretrained(os.path.join(extract_dir, "output"))
print("🧠 Base model indiriliyor...")
base_model = AutoModelForCausalLM.from_pretrained(MODEL_BASE, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32)
print("➕ LoRA adapter uygulanıyor...")
model = PeftModel.from_pretrained(base_model, os.path.join(extract_dir, "output"))
print("📚 RAG dataseti yükleniyor...")
rag = load_dataset(RAG_DATA_REPO, data_files=RAG_DATA_FILE, split="train", token=HF_TOKEN)
print(f"🔍 RAG boyutu: {len(rag)}")
# ✅ pipeline oluşturuluyor
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device=0 if torch.cuda.is_available() else -1
)
# ✅ Uygulama başladığında modeli yükle
threading.Thread(target=setup_model, daemon=True).start()
# 🧘 Eğitim sonrası uygulama restart olmasın diye bekleme
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)
|