fine-tune-inference-test / fine_tune_inference_test.py
ciyidogan's picture
Update fine_tune_inference_test.py
53de39f verified
raw
history blame
2.6 kB
import os
import threading
import uvicorn
from fastapi import FastAPI, Request
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from datasets import load_dataset
from fastapi.responses import JSONResponse
# ✅ Sabitler
HF_TOKEN = os.environ.get("HF_TOKEN")
MODEL_BASE = "UcsTurkey/kanarya-750m-fixed"
FINE_TUNE_ZIP = "trained_model_000_100.zip" # 👈 Değiştirilebilir
FINE_TUNE_REPO = "UcsTurkey/trained-zips"
RAG_DATA_FILE = "merged_dataset_000_100.parquet" # 👈 Değiştirilebilir
RAG_DATA_REPO = "UcsTurkey/turkish-general-culture-tokenized"
# ✅ FastAPI app
app = FastAPI()
chat_history = []
class Message(BaseModel):
user_input: str
@app.get("/")
def health():
return {"status": "ok"}
@app.post("/chat")
def chat(msg: Message):
user_input = msg.user_input.strip()
if not user_input:
return {"error": "Boş giriş"}
full_prompt = ""
for turn in chat_history:
full_prompt += f"Kullanıcı: {turn['user']}\nAsistan: {turn['bot']}\n"
full_prompt += f"Kullanıcı: {user_input}\nAsistan:"
result = pipe(full_prompt, max_new_tokens=200, do_sample=True, temperature=0.7)
answer = result[0]["generated_text"][len(full_prompt):].strip()
chat_history.append({"user": user_input, "bot": answer})
return {"answer": answer, "chat_history": chat_history}
# ✅ Model ve RAG yükleme
def setup_model():
global pipe
from huggingface_hub import hf_hub_download
import zipfile
print("📦 Fine-tune zip indiriliyor...")
zip_path = hf_hub_download(
repo_id=FINE_TUNE_REPO,
filename=FINE_TUNE_ZIP,
repo_type="model",
token=HF_TOKEN
)
extract_dir = "/app/extracted"
os.makedirs(extract_dir, exist_ok=True)
with zipfile.ZipFile(zip_path, "r") as zip_ref:
zip_ref.extractall(extract_dir)
print("🔁 Tokenizer ve model yükleniyor...")
tokenizer = AutoTokenizer.from_pretrained(os.path.join(extract_dir, "output"))
model = AutoModelForCausalLM.from_pretrained(os.path.join(extract_dir, "output"))
print("📚 RAG dataseti yükleniyor...")
rag = load_dataset(RAG_DATA_REPO, data_files=RAG_DATA_FILE, split="train", token=HF_TOKEN)
print(f"🔍 RAG boyutu: {len(rag)}")
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
# ✅ Uygulama başladığında modeli yükle
threading.Thread(target=setup_model, daemon=True).start()
# 🧘 Eğitim sonrası uygulama restart olmasın diye bekleme
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)