ciyidogan commited on
Commit
d0e6741
·
verified ·
1 Parent(s): 9cb7961

Update fine_tune_inference_test.py

Browse files
Files changed (1) hide show
  1. fine_tune_inference_test.py +26 -70
fine_tune_inference_test.py CHANGED
@@ -1,13 +1,22 @@
1
  import os
2
  import threading
3
  import uvicorn
4
- from fastapi import FastAPI, Request
5
- from fastapi.responses import HTMLResponse, JSONResponse
6
  from pydantic import BaseModel
7
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
8
  from datasets import load_dataset
9
  from peft import PeftModel
10
- import torch # eksikse gerekli
 
 
 
 
 
 
 
 
 
11
 
12
  # ✅ Sabitler
13
  HF_TOKEN = os.environ.get("HF_TOKEN")
@@ -17,10 +26,9 @@ FINE_TUNE_REPO = "UcsTurkey/trained-zips"
17
  RAG_DATA_FILE = "merged_dataset_000_100.parquet"
18
  RAG_DATA_REPO = "UcsTurkey/turkish-general-culture-tokenized"
19
 
20
- # ✅ FastAPI app
21
  app = FastAPI()
22
  chat_history = []
23
- pipe = None # Global olarak tanımlıyoruz
24
 
25
  class Message(BaseModel):
26
  user_input: str
@@ -53,68 +61,16 @@ def root():
53
 
54
  @app.post("/chat")
55
  def chat(msg: Message):
56
- global pipe
57
- if pipe is None:
58
- return {"error": "Model henüz yüklenmedi, lütfen birkaç saniye sonra tekrar deneyin."}
59
-
60
- user_input = msg.user_input.strip()
61
- if not user_input:
62
- return {"error": "Boş giriş"}
63
-
64
- full_prompt = ""
65
- for turn in chat_history:
66
- full_prompt += f"Kullanıcı: {turn['user']}\nAsistan: {turn['bot']}\n"
67
- full_prompt += f"Kullanıcı: {user_input}\nAsistan:"
68
-
69
- result = pipe(full_prompt, max_new_tokens=200, do_sample=True, temperature=0.7)
70
- answer = result[0]["generated_text"][len(full_prompt):].strip()
71
-
72
- chat_history.append({"user": user_input, "bot": answer})
73
- return {"answer": answer, "chat_history": chat_history}
74
-
75
- # ✅ Model ve RAG yükleme
76
- def setup_model():
77
- global pipe
78
- from huggingface_hub import hf_hub_download
79
- import zipfile
80
-
81
- print("📦 Fine-tune zip indiriliyor...")
82
- zip_path = hf_hub_download(
83
- repo_id=FINE_TUNE_REPO,
84
- filename=FINE_TUNE_ZIP,
85
- repo_type="model",
86
- token=HF_TOKEN
87
- )
88
- extract_dir = "/app/extracted"
89
- os.makedirs(extract_dir, exist_ok=True)
90
- with zipfile.ZipFile(zip_path, "r") as zip_ref:
91
- zip_ref.extractall(extract_dir)
92
-
93
- print("🔁 Tokenizer yükleniyor...")
94
- tokenizer = AutoTokenizer.from_pretrained(os.path.join(extract_dir, "output"))
95
-
96
- print("🧠 Base model indiriliyor...")
97
- base_model = AutoModelForCausalLM.from_pretrained(MODEL_BASE, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32)
98
-
99
- print("➕ LoRA adapter uygulanıyor...")
100
- model = PeftModel.from_pretrained(base_model, os.path.join(extract_dir, "output"))
101
-
102
- print("📚 RAG dataseti yükleniyor...")
103
- rag = load_dataset(RAG_DATA_REPO, data_files=RAG_DATA_FILE, split="train", token=HF_TOKEN)
104
- print(f"🔍 RAG boyutu: {len(rag)}")
105
-
106
- # ✅ pipeline oluşturuluyor
107
- pipe = pipeline(
108
- "text-generation",
109
- model=model,
110
- tokenizer=tokenizer,
111
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
112
- device=0 if torch.cuda.is_available() else -1
113
- )
114
-
115
- # ✅ Uygulama başladığında modeli yükle
116
- threading.Thread(target=setup_model, daemon=True).start()
117
-
118
- # 🧘 Eğitim sonrası uygulama restart olmasın diye bekleme
119
- if __name__ == "__main__":
120
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
  import os
2
  import threading
3
  import uvicorn
4
+ from fastapi import FastAPI
5
+ from fastapi.responses import HTMLResponse
6
  from pydantic import BaseModel
7
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
8
  from datasets import load_dataset
9
  from peft import PeftModel
10
+ import torch
11
+ from huggingface_hub import hf_hub_download
12
+ import zipfile
13
+ from datetime import datetime
14
+
15
+ # ✅ Zamanlı log fonksiyonu (flush destekli)
16
+ def log(message):
17
+ timestamp = datetime.now().strftime("%H:%M:%S")
18
+ print(f"[{timestamp}] {message}")
19
+ os.sys.stdout.flush()
20
 
21
  # ✅ Sabitler
22
  HF_TOKEN = os.environ.get("HF_TOKEN")
 
26
  RAG_DATA_FILE = "merged_dataset_000_100.parquet"
27
  RAG_DATA_REPO = "UcsTurkey/turkish-general-culture-tokenized"
28
 
 
29
  app = FastAPI()
30
  chat_history = []
31
+ pipe = None # global text-generation pipeline
32
 
33
  class Message(BaseModel):
34
  user_input: str
 
61
 
62
  @app.post("/chat")
63
  def chat(msg: Message):
64
+ try:
65
+ global pipe
66
+ if pipe is None:
67
+ log("🚫 Hata: Model henüz yüklenmedi.")
68
+ return {"error": "Model yüklenmedi. Lütfen birkaç saniye sonra tekrar deneyin."}
69
+
70
+ user_input = msg.user_input.strip()
71
+ if not user_input:
72
+ return {"error": "Boş giriş"}
73
+
74
+ full_prompt = ""
75
+ for turn in chat_history:
76
+ full_prompt += f"Kullanıcı: {turn['user']}\nAsistan: {turn['bot]()_