ciyidogan commited on
Commit
0a5b12b
·
verified ·
1 Parent(s): 6398aea

Update fine_tune_inference_test.py

Browse files
Files changed (1) hide show
  1. fine_tune_inference_test.py +11 -6
fine_tune_inference_test.py CHANGED
@@ -2,18 +2,18 @@ import os
2
  import threading
3
  import uvicorn
4
  from fastapi import FastAPI, Request
5
- from fastapi.responses import HTMLResponse
6
  from pydantic import BaseModel
7
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
8
  from datasets import load_dataset
9
- from fastapi.responses import JSONResponse
10
 
11
  # ✅ Sabitler
12
  HF_TOKEN = os.environ.get("HF_TOKEN")
13
  MODEL_BASE = "UcsTurkey/kanarya-750m-fixed"
14
- FINE_TUNE_ZIP = "trained_model_000_100.zip" # 👈 Değiştirilebilir
15
  FINE_TUNE_REPO = "UcsTurkey/trained-zips"
16
- RAG_DATA_FILE = "merged_dataset_000_100.parquet" # 👈 Değiştirilebilir
17
  RAG_DATA_REPO = "UcsTurkey/turkish-general-culture-tokenized"
18
 
19
  # ✅ FastAPI app
@@ -85,9 +85,14 @@ def setup_model():
85
  with zipfile.ZipFile(zip_path, "r") as zip_ref:
86
  zip_ref.extractall(extract_dir)
87
 
88
- print("🔁 Tokenizer ve model yükleniyor...")
89
  tokenizer = AutoTokenizer.from_pretrained(os.path.join(extract_dir, "output"))
90
- model = AutoModelForCausalLM.from_pretrained(os.path.join(extract_dir, "output"))
 
 
 
 
 
91
 
92
  print("📚 RAG dataseti yükleniyor...")
93
  rag = load_dataset(RAG_DATA_REPO, data_files=RAG_DATA_FILE, split="train", token=HF_TOKEN)
 
2
  import threading
3
  import uvicorn
4
  from fastapi import FastAPI, Request
5
+ from fastapi.responses import HTMLResponse, JSONResponse
6
  from pydantic import BaseModel
7
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
8
  from datasets import load_dataset
9
+ from peft import PeftModel
10
 
11
  # ✅ Sabitler
12
  HF_TOKEN = os.environ.get("HF_TOKEN")
13
  MODEL_BASE = "UcsTurkey/kanarya-750m-fixed"
14
+ FINE_TUNE_ZIP = "trained_model_000_100.zip"
15
  FINE_TUNE_REPO = "UcsTurkey/trained-zips"
16
+ RAG_DATA_FILE = "merged_dataset_000_100.parquet"
17
  RAG_DATA_REPO = "UcsTurkey/turkish-general-culture-tokenized"
18
 
19
  # ✅ FastAPI app
 
85
  with zipfile.ZipFile(zip_path, "r") as zip_ref:
86
  zip_ref.extractall(extract_dir)
87
 
88
+ print("🔁 Tokenizer yükleniyor...")
89
  tokenizer = AutoTokenizer.from_pretrained(os.path.join(extract_dir, "output"))
90
+
91
+ print("🧠 Base model indiriliyor...")
92
+ base_model = AutoModelForCausalLM.from_pretrained(MODEL_BASE, torch_dtype="auto")
93
+
94
+ print("➕ LoRA adapter uygulanıyor...")
95
+ model = PeftModel.from_pretrained(base_model, os.path.join(extract_dir, "output"))
96
 
97
  print("📚 RAG dataseti yükleniyor...")
98
  rag = load_dataset(RAG_DATA_REPO, data_files=RAG_DATA_FILE, split="train", token=HF_TOKEN)