ai_systems / src /model_utils.py
amiguel's picture
Upload model_utils.py
b80d6c2 verified
raw
history blame contribute delete
949 Bytes
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import torch
def load_hf_model(model_name, device="cpu"):
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
# Handle meta tensors properly
if device == "cuda" and torch.cuda.is_available():
try:
model = model.to(device)
except NotImplementedError:
# If meta tensor error occurs, use to_empty()
model = model.to_empty(device=device)
# Don't specify device in pipeline when using accelerate
return pipeline("text-generation", model=model, tokenizer=tokenizer)
def generate_answer(text_gen, question, context):
prompt = f"Context: {context}\n\nQuestion: {question}\n\nAnswer:"
result = text_gen(prompt, max_new_tokens=256, do_sample=True, temperature=0.7)
return result[0]["generated_text"].split("Answer:")[-1].strip()