flare / llm_model.py
ciyidogan's picture
Upload 22 files
cb61e8e verified
raw
history blame
2.91 kB
import torch
import traceback
from transformers import AutoTokenizer, AutoModelForCausalLM
from log import log
from pydantic import BaseModel
import os
import json
class Message(BaseModel):
user_input: str
class LLMModel:
def __init__(self):
self.model = None
self.tokenizer = None
self.eos_token_id = None
def setup(self, s_config, project_config, project_path):
try:
log("🧠 LLMModel setup() başladı")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
log(f"📡 Kullanılan cihaz: {device}")
model_base = project_config["model_base"]
token = s_config.get_auth_token()
if s_config.work_mode == "hfcloud":
log(f"📦 Hugging Face cloud modeli yükleniyor: {model_base}")
self.tokenizer = AutoTokenizer.from_pretrained(model_base, token=token, use_fast=False)
self.model = AutoModelForCausalLM.from_pretrained(model_base, token=token, torch_dtype=torch.float32).to(device)
else:
log(f"📦 Model indiriliyor veya yükleniyor: {model_base}")
self.tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
self.model = AutoModelForCausalLM.from_pretrained(model_base, torch_dtype=torch.float32).to(device)
self.tokenizer.pad_token = self.tokenizer.pad_token or self.tokenizer.eos_token
self.model.config.pad_token_id = self.tokenizer.pad_token_id
self.eos_token_id = self.tokenizer("<|im_end|>", add_special_tokens=False)["input_ids"][0]
self.model.eval()
log("✅ LLMModel setup() başarıyla tamamlandı.")
except Exception as e:
log(f"❌ LLMModel setup() hatası: {e}")
traceback.print_exc()
async def generate_response_with_messages(self, messages, project_config, system_prompt):
all_messages = [{"role": "system", "content": system_prompt}] + messages
encodeds = self.tokenizer.apply_chat_template(all_messages, return_tensors="pt", add_generation_prompt=True)
input_ids = encodeds.to(self.model.device)
attention_mask = (input_ids != self.tokenizer.pad_token_id).long()
with torch.no_grad():
output = self.model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=256,
do_sample=project_config["use_sampling"],
eos_token_id=self.eos_token_id,
pad_token_id=self.tokenizer.pad_token_id,
return_dict_in_generate=True,
output_scores=True
)
decoded = self.tokenizer.decode(output.sequences[0], skip_special_tokens=True).strip()
return decoded