amiguel commited on
Commit
d7a8919
·
verified ·
1 Parent(s): 817be3b

Upload model_utils.py

Browse files
Files changed (1) hide show
  1. src/model_utils.py +14 -1
src/model_utils.py CHANGED
@@ -1,9 +1,22 @@
1
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
 
2
 
3
  def load_hf_model(model_name, device="cpu"):
4
  tokenizer = AutoTokenizer.from_pretrained(model_name)
5
  model = AutoModelForCausalLM.from_pretrained(model_name)
6
- return pipeline("text-generation", model=model, tokenizer=tokenizer, device=0 if device=="cuda" else -1)
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  def generate_answer(text_gen, question, context):
9
  prompt = f"Context: {context}\n\nQuestion: {question}\n\nAnswer:"
 
1
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
2
+ import torch
3
 
4
  def load_hf_model(model_name, device="cpu"):
5
  tokenizer = AutoTokenizer.from_pretrained(model_name)
6
  model = AutoModelForCausalLM.from_pretrained(model_name)
7
+
8
+ # Handle meta tensors properly
9
+ if device == "cuda" and torch.cuda.is_available():
10
+ try:
11
+ model = model.to(device)
12
+ except NotImplementedError:
13
+ # If meta tensor error occurs, use to_empty()
14
+ model = model.to_empty(device=device)
15
+ device_id = 0
16
+ else:
17
+ device_id = -1
18
+
19
+ return pipeline("text-generation", model=model, tokenizer=tokenizer, device=device_id)
20
 
21
  def generate_answer(text_gen, question, context):
22
  prompt = f"Context: {context}\n\nQuestion: {question}\n\nAnswer:"