File size: 949 Bytes
d765c07
d7a8919
d765c07
 
 
 
d7a8919
 
 
 
 
 
 
 
 
b80d6c2
 
d765c07
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import torch

def load_hf_model(model_name, device="cpu"):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(model_name)
    
    # Handle meta tensors properly
    if device == "cuda" and torch.cuda.is_available():
        try:
            model = model.to(device)
        except NotImplementedError:
            # If meta tensor error occurs, use to_empty()
            model = model.to_empty(device=device)
    
    # Don't specify device in pipeline when using accelerate
    return pipeline("text-generation", model=model, tokenizer=tokenizer)

def generate_answer(text_gen, question, context):
    prompt = f"Context: {context}\n\nQuestion: {question}\n\nAnswer:"
    result = text_gen(prompt, max_new_tokens=256, do_sample=True, temperature=0.7)
    return result[0]["generated_text"].split("Answer:")[-1].strip()