flare / llm_model.py
ciyidogan's picture
Update llm_model.py
c6f773b verified
raw
history blame
4.17 kB
import torch
import traceback
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification
from log import log
from pydantic import BaseModel
import os
import json
class Message(BaseModel):
user_input: str
class LLMModel:
def __init__(self):
self.model = None # ana model
self.tokenizer = None
self.eos_token_id = None
self.intent_model = None # intent modeli
self.intent_tokenizer = None
self.intent_label2id = None
def setup(self, s_config, project_config, project_path):
try:
log("🧠 LLMModel setup() başladı")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
log(f"📡 Kullanılan cihaz: {device}")
model_base = project_config["model_base"]
token = s_config.get_auth_token()
if s_config.work_mode == "hfcloud":
log(f"📦 Hugging Face cloud modeli yükleniyor: {model_base}")
self.tokenizer = AutoTokenizer.from_pretrained(model_base, token=token, use_fast=False)
self.model = AutoModelForCausalLM.from_pretrained(model_base, token=token, torch_dtype=torch.float32).to(device)
elif s_config.work_mode in ["cloud", "on-prem"]:
log(f"📦 Model indiriliyor veya yükleniyor: {model_base}")
self.tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
self.model = AutoModelForCausalLM.from_pretrained(model_base, torch_dtype=torch.float32).to(device)
else:
raise Exception(f"Bilinmeyen work_mode: {s_config.work_mode}")
self.tokenizer.pad_token = self.tokenizer.pad_token or self.tokenizer.eos_token
self.model.config.pad_token_id = self.tokenizer.pad_token_id
self.eos_token_id = self.tokenizer("<|im_end|>", add_special_tokens=False)["input_ids"][0]
self.model.eval()
log("✅ LLMModel setup() başarıyla tamamlandı.")
except Exception as e:
log(f"❌ LLMModel setup() hatası: {e}")
traceback.print_exc()
def load_intent_model(self, model_path):
try:
log(f"🔧 Intent modeli yükleniyor: {model_path}")
self.intent_tokenizer = AutoTokenizer.from_pretrained(model_path)
self.intent_model = AutoModelForSequenceClassification.from_pretrained(model_path)
with open(os.path.join(model_path, "label2id.json")) as f:
self.intent_label2id = json.load(f)
log("✅ Intent modeli yüklemesi tamamlandı.")
except Exception as e:
log(f"❌ Intent modeli yükleme hatası: {e}")
traceback.print_exc()
async def generate_response(self, text, project_config):
messages = [{"role": "user", "content": text}]
encodeds = self.tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True)
input_ids = encodeds.to(self.model.device)
attention_mask = (input_ids != self.tokenizer.pad_token_id).long()
with torch.no_grad():
output = self.model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=128,
do_sample=project_config["use_sampling"],
eos_token_id=self.eos_token_id,
pad_token_id=self.tokenizer.pad_token_id,
return_dict_in_generate=True,
output_scores=True
)
if not project_config["use_sampling"]:
scores = torch.stack(output.scores, dim=1)
probs = torch.nn.functional.softmax(scores[0], dim=-1)
top_conf = probs.max().item()
else:
top_conf = None
decoded = self.tokenizer.decode(output.sequences[0], skip_special_tokens=True).strip()
for tag in ["assistant", "<|im_start|>assistant"]:
start = decoded.find(tag)
if start != -1:
decoded = decoded[start + len(tag):].strip()
break
return decoded, top_conf