flare / llm_model.py
ciyidogan's picture
Upload 15 files
16134a9 verified
raw
history blame
3.78 kB
import torch
import traceback
from transformers import AutoTokenizer, AutoModelForCausalLM
from log import log
from pydantic import BaseModel
class Message(BaseModel):
user_input: str
class LLMModel:
def __init__(self):
self.model = None
self.tokenizer = None
self.eos_token_id = None
def setup(self, s_config, project_config):
try:
log("🧠 LLMModel setup() başladı")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
log(f"📡 Kullanılan cihaz: {device}")
model_base = project_config["model_base"]
if s_config.work_mode == "hfcloud":
token = s_config.get_auth_token()
log(f"📦 Hugging Face cloud modeli yükleniyor: {model_base}")
self.tokenizer = AutoTokenizer.from_pretrained(model_base, use_auth_token=token, use_fast=False)
self.model = AutoModelForCausalLM.from_pretrained(model_base, use_auth_token=token, torch_dtype=torch.float32).to(device)
elif s_config.work_mode == "cloud":
log(f"📦 Diğer cloud ortamından model indiriliyor: {model_base}")
self.tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
self.model = AutoModelForCausalLM.from_pretrained(model_base, torch_dtype=torch.float32).to(device)
elif s_config.work_mode == "on-prem":
log(f"📦 On-prem model path: {model_base}")
self.tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
self.model = AutoModelForCausalLM.from_pretrained(model_base, torch_dtype=torch.float32).to(device)
else:
raise Exception(f"Bilinmeyen work_mode: {s_config.work_mode}")
self.tokenizer.pad_token = self.tokenizer.pad_token or self.tokenizer.eos_token
self.model.config.pad_token_id = self.tokenizer.pad_token_id
self.eos_token_id = self.tokenizer("<|im_end|>", add_special_tokens=False)["input_ids"][0]
self.model.eval()
log("✅ LLMModel setup() başarıyla tamamlandı.")
except Exception as e:
log(f"❌ LLMModel setup() hatası: {e}")
traceback.print_exc()
async def generate_response(self, text, project_config):
messages = [{"role": "user", "content": text}]
encodeds = self.tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True)
input_ids = encodeds.to(self.model.device)
attention_mask = (input_ids != self.tokenizer.pad_token_id).long()
with torch.no_grad():
output = self.model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=128,
do_sample=project_config["use_sampling"],
eos_token_id=self.eos_token_id,
pad_token_id=self.tokenizer.pad_token_id,
return_dict_in_generate=True,
output_scores=True
)
if not project_config["use_sampling"]:
scores = torch.stack(output.scores, dim=1)
probs = torch.nn.functional.softmax(scores[0], dim=-1)
top_conf = probs.max().item()
else:
top_conf = None
decoded = self.tokenizer.decode(output.sequences[0], skip_special_tokens=True).strip()
for tag in ["assistant", "<|im_start|>assistant"]:
start = decoded.find(tag)
if start != -1:
decoded = decoded[start + len(tag):].strip()
break
return decoded, top_conf