amiguel commited on
Commit
b80d6c2
·
verified ·
1 Parent(s): d7a8919

Upload model_utils.py

Browse files
Files changed (1) hide show
  1. src/model_utils.py +2 -4
src/model_utils.py CHANGED
@@ -12,11 +12,9 @@ def load_hf_model(model_name, device="cpu"):
12
  except NotImplementedError:
13
  # If meta tensor error occurs, use to_empty()
14
  model = model.to_empty(device=device)
15
- device_id = 0
16
- else:
17
- device_id = -1
18
 
19
- return pipeline("text-generation", model=model, tokenizer=tokenizer, device=device_id)
 
20
 
21
  def generate_answer(text_gen, question, context):
22
  prompt = f"Context: {context}\n\nQuestion: {question}\n\nAnswer:"
 
12
  except NotImplementedError:
13
  # If meta tensor error occurs, use to_empty()
14
  model = model.to_empty(device=device)
 
 
 
15
 
16
+ # Don't specify device in pipeline when using accelerate
17
+ return pipeline("text-generation", model=model, tokenizer=tokenizer)
18
 
19
  def generate_answer(text_gen, question, context):
20
  prompt = f"Context: {context}\n\nQuestion: {question}\n\nAnswer:"