ciyidogan commited on
Commit
9cb7961
·
verified ·
1 Parent(s): 1b4c068

Update fine_tune_inference_test.py

Browse files
Files changed (1) hide show
  1. fine_tune_inference_test.py +15 -3
fine_tune_inference_test.py CHANGED
@@ -7,6 +7,7 @@ from pydantic import BaseModel
7
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
8
  from datasets import load_dataset
9
  from peft import PeftModel
 
10
 
11
  # ✅ Sabitler
12
  HF_TOKEN = os.environ.get("HF_TOKEN")
@@ -19,6 +20,7 @@ RAG_DATA_REPO = "UcsTurkey/turkish-general-culture-tokenized"
19
  # ✅ FastAPI app
20
  app = FastAPI()
21
  chat_history = []
 
22
 
23
  class Message(BaseModel):
24
  user_input: str
@@ -51,6 +53,10 @@ def root():
51
 
52
  @app.post("/chat")
53
  def chat(msg: Message):
 
 
 
 
54
  user_input = msg.user_input.strip()
55
  if not user_input:
56
  return {"error": "Boş giriş"}
@@ -66,7 +72,6 @@ def chat(msg: Message):
66
  chat_history.append({"user": user_input, "bot": answer})
67
  return {"answer": answer, "chat_history": chat_history}
68
 
69
-
70
  # ✅ Model ve RAG yükleme
71
  def setup_model():
72
  global pipe
@@ -89,7 +94,7 @@ def setup_model():
89
  tokenizer = AutoTokenizer.from_pretrained(os.path.join(extract_dir, "output"))
90
 
91
  print("🧠 Base model indiriliyor...")
92
- base_model = AutoModelForCausalLM.from_pretrained(MODEL_BASE, torch_dtype="auto")
93
 
94
  print("➕ LoRA adapter uygulanıyor...")
95
  model = PeftModel.from_pretrained(base_model, os.path.join(extract_dir, "output"))
@@ -98,7 +103,14 @@ def setup_model():
98
  rag = load_dataset(RAG_DATA_REPO, data_files=RAG_DATA_FILE, split="train", token=HF_TOKEN)
99
  print(f"🔍 RAG boyutu: {len(rag)}")
100
 
101
- pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
 
 
 
 
 
 
 
102
 
103
  # ✅ Uygulama başladığında modeli yükle
104
  threading.Thread(target=setup_model, daemon=True).start()
 
7
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
8
  from datasets import load_dataset
9
  from peft import PeftModel
10
+ import torch # eksikse gerekli
11
 
12
  # ✅ Sabitler
13
  HF_TOKEN = os.environ.get("HF_TOKEN")
 
20
  # ✅ FastAPI app
21
  app = FastAPI()
22
  chat_history = []
23
+ pipe = None # ❗ Global olarak tanımlıyoruz
24
 
25
  class Message(BaseModel):
26
  user_input: str
 
53
 
54
  @app.post("/chat")
55
  def chat(msg: Message):
56
+ global pipe
57
+ if pipe is None:
58
+ return {"error": "Model henüz yüklenmedi, lütfen birkaç saniye sonra tekrar deneyin."}
59
+
60
  user_input = msg.user_input.strip()
61
  if not user_input:
62
  return {"error": "Boş giriş"}
 
72
  chat_history.append({"user": user_input, "bot": answer})
73
  return {"answer": answer, "chat_history": chat_history}
74
 
 
75
  # ✅ Model ve RAG yükleme
76
  def setup_model():
77
  global pipe
 
94
  tokenizer = AutoTokenizer.from_pretrained(os.path.join(extract_dir, "output"))
95
 
96
  print("🧠 Base model indiriliyor...")
97
+ base_model = AutoModelForCausalLM.from_pretrained(MODEL_BASE, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32)
98
 
99
  print("➕ LoRA adapter uygulanıyor...")
100
  model = PeftModel.from_pretrained(base_model, os.path.join(extract_dir, "output"))
 
103
  rag = load_dataset(RAG_DATA_REPO, data_files=RAG_DATA_FILE, split="train", token=HF_TOKEN)
104
  print(f"🔍 RAG boyutu: {len(rag)}")
105
 
106
+ # pipeline oluşturuluyor
107
+ pipe = pipeline(
108
+ "text-generation",
109
+ model=model,
110
+ tokenizer=tokenizer,
111
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
112
+ device=0 if torch.cuda.is_available() else -1
113
+ )
114
 
115
  # ✅ Uygulama başladığında modeli yükle
116
  threading.Thread(target=setup_model, daemon=True).start()