flare / llm_model.py
ciyidogan's picture
Update llm_model.py
edc8519 verified
raw
history blame
2.78 kB
import torch
import traceback
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification
from log import log
from pydantic import BaseModel
import os
import json
class Message(BaseModel):
user_input: str
class LLMModel:
def __init__(self):
self.model = None
self.tokenizer = None
self.eos_token_id = None
def setup(self, s_config, project_config, project_path):
try:
log("🧠 LLMModel setup() başladı")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
log(f"📡 Kullanılan cihaz: {device}")
model_base = project_config["model_base"]
token = s_config.get_auth_token()
if s_config.work_mode == "hfcloud":
log(f"📦 Hugging Face cloud modeli yükleniyor: {model_base}")
self.tokenizer = AutoTokenizer.from_pretrained(model_base, token=token, use_fast=False)
self.model = AutoModelForCausalLM.from_pretrained(model_base, token=token, torch_dtype=torch.float32).to(device)
else:
log(f"📦 Model indiriliyor veya yükleniyor: {model_base}")
self.tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
self.model = AutoModelForCausalLM.from_pretrained(model_base, torch_dtype=torch.float32).to(device)
self.tokenizer.pad_token = self.tokenizer.pad_token or self.tokenizer.eos_token
self.model.config.pad_token_id = self.tokenizer.pad_token_id
self.eos_token_id = self.tokenizer("<|im_end|>", add_special_tokens=False)["input_ids"][0]
self.model.eval()
log("✅ LLMModel setup() başarıyla tamamlandı.")
except Exception as e:
log(f"❌ LLMModel setup() hatası: {e}")
traceback.print_exc()
async def generate_response_with_messages(self, messages, project_config):
encodeds = self.tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True)
input_ids = encodeds.to(self.model.device)
attention_mask = (input_ids != self.tokenizer.pad_token_id).long()
with torch.no_grad():
output = self.model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=256,
do_sample=project_config["use_sampling"],
eos_token_id=self.eos_token_id,
pad_token_id=self.tokenizer.pad_token_id,
return_dict_in_generate=True,
output_scores=True
)
decoded = self.tokenizer.decode(output.sequences[0], skip_special_tokens=True).strip()
return decoded