Update fine_tune_inference_test.py
Browse files- fine_tune_inference_test.py +15 -3
fine_tune_inference_test.py
CHANGED
@@ -7,6 +7,7 @@ from pydantic import BaseModel
|
|
7 |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
|
8 |
from datasets import load_dataset
|
9 |
from peft import PeftModel
|
|
|
10 |
|
11 |
# ✅ Sabitler
|
12 |
HF_TOKEN = os.environ.get("HF_TOKEN")
|
@@ -19,6 +20,7 @@ RAG_DATA_REPO = "UcsTurkey/turkish-general-culture-tokenized"
|
|
19 |
# ✅ FastAPI app
|
20 |
app = FastAPI()
|
21 |
chat_history = []
|
|
|
22 |
|
23 |
class Message(BaseModel):
|
24 |
user_input: str
|
@@ -51,6 +53,10 @@ def root():
|
|
51 |
|
52 |
@app.post("/chat")
|
53 |
def chat(msg: Message):
|
|
|
|
|
|
|
|
|
54 |
user_input = msg.user_input.strip()
|
55 |
if not user_input:
|
56 |
return {"error": "Boş giriş"}
|
@@ -66,7 +72,6 @@ def chat(msg: Message):
|
|
66 |
chat_history.append({"user": user_input, "bot": answer})
|
67 |
return {"answer": answer, "chat_history": chat_history}
|
68 |
|
69 |
-
|
70 |
# ✅ Model ve RAG yükleme
|
71 |
def setup_model():
|
72 |
global pipe
|
@@ -89,7 +94,7 @@ def setup_model():
|
|
89 |
tokenizer = AutoTokenizer.from_pretrained(os.path.join(extract_dir, "output"))
|
90 |
|
91 |
print("🧠 Base model indiriliyor...")
|
92 |
-
base_model = AutoModelForCausalLM.from_pretrained(MODEL_BASE, torch_dtype=
|
93 |
|
94 |
print("➕ LoRA adapter uygulanıyor...")
|
95 |
model = PeftModel.from_pretrained(base_model, os.path.join(extract_dir, "output"))
|
@@ -98,7 +103,14 @@ def setup_model():
|
|
98 |
rag = load_dataset(RAG_DATA_REPO, data_files=RAG_DATA_FILE, split="train", token=HF_TOKEN)
|
99 |
print(f"🔍 RAG boyutu: {len(rag)}")
|
100 |
|
101 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
|
103 |
# ✅ Uygulama başladığında modeli yükle
|
104 |
threading.Thread(target=setup_model, daemon=True).start()
|
|
|
7 |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
|
8 |
from datasets import load_dataset
|
9 |
from peft import PeftModel
|
10 |
+
import torch # eksikse gerekli
|
11 |
|
12 |
# ✅ Sabitler
|
13 |
HF_TOKEN = os.environ.get("HF_TOKEN")
|
|
|
20 |
# ✅ FastAPI app
|
21 |
app = FastAPI()
|
22 |
chat_history = []
|
23 |
+
pipe = None # ❗ Global olarak tanımlıyoruz
|
24 |
|
25 |
class Message(BaseModel):
|
26 |
user_input: str
|
|
|
53 |
|
54 |
@app.post("/chat")
|
55 |
def chat(msg: Message):
|
56 |
+
global pipe
|
57 |
+
if pipe is None:
|
58 |
+
return {"error": "Model henüz yüklenmedi, lütfen birkaç saniye sonra tekrar deneyin."}
|
59 |
+
|
60 |
user_input = msg.user_input.strip()
|
61 |
if not user_input:
|
62 |
return {"error": "Boş giriş"}
|
|
|
72 |
chat_history.append({"user": user_input, "bot": answer})
|
73 |
return {"answer": answer, "chat_history": chat_history}
|
74 |
|
|
|
75 |
# ✅ Model ve RAG yükleme
|
76 |
def setup_model():
|
77 |
global pipe
|
|
|
94 |
tokenizer = AutoTokenizer.from_pretrained(os.path.join(extract_dir, "output"))
|
95 |
|
96 |
print("🧠 Base model indiriliyor...")
|
97 |
+
base_model = AutoModelForCausalLM.from_pretrained(MODEL_BASE, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32)
|
98 |
|
99 |
print("➕ LoRA adapter uygulanıyor...")
|
100 |
model = PeftModel.from_pretrained(base_model, os.path.join(extract_dir, "output"))
|
|
|
103 |
rag = load_dataset(RAG_DATA_REPO, data_files=RAG_DATA_FILE, split="train", token=HF_TOKEN)
|
104 |
print(f"🔍 RAG boyutu: {len(rag)}")
|
105 |
|
106 |
+
# ✅ pipeline oluşturuluyor
|
107 |
+
pipe = pipeline(
|
108 |
+
"text-generation",
|
109 |
+
model=model,
|
110 |
+
tokenizer=tokenizer,
|
111 |
+
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
|
112 |
+
device=0 if torch.cuda.is_available() else -1
|
113 |
+
)
|
114 |
|
115 |
# ✅ Uygulama başladığında modeli yükle
|
116 |
threading.Thread(target=setup_model, daemon=True).start()
|