X-iZhang commited on
Commit
9998f52
·
verified ·
1 Parent(s): d61d98e

Update factual/green_score/green.py

Browse files
Files changed (1) hide show
  1. factual/green_score/green.py +4 -3
factual/green_score/green.py CHANGED
@@ -62,6 +62,7 @@ class GREEN:
62
  category=FutureWarning,
63
  module="transformers.tokenization_utils_base",
64
  )
 
65
  self.cpu = cpu
66
  self.model_name = model_name.split("/")[-1]
67
  self.output_dir = output_dir
@@ -100,8 +101,8 @@ class GREEN:
100
  model_name,
101
  trust_remote_code=False if "Phi" in model_name else True,
102
  device_map=(
103
- {"": "cuda:{}".format(torch.cuda.current_device())}
104
- if not self.cpu
105
  else {"": "cpu"}
106
  ),
107
  torch_dtype=torch.float16,
@@ -211,7 +212,7 @@ class GREEN:
211
  return self.process_results()
212
 
213
  def tokenize_batch_as_chat(self, batch):
214
- local_rank = int(os.environ.get("LOCAL_RANK", 0)) if not self.cpu else "cpu"
215
  batch = [
216
  self.tokenizer.apply_chat_template(
217
  i, tokenize=False, add_generation_prompt=True
 
62
  category=FutureWarning,
63
  module="transformers.tokenization_utils_base",
64
  )
65
+ cpu = cpu or not torch.cuda.is_available()
66
  self.cpu = cpu
67
  self.model_name = model_name.split("/")[-1]
68
  self.output_dir = output_dir
 
101
  model_name,
102
  trust_remote_code=False if "Phi" in model_name else True,
103
  device_map=(
104
+ {"": f"cuda:{torch.cuda.current_device()}"}
105
+ if (not self.cpu and torch.cuda.is_available())
106
  else {"": "cpu"}
107
  ),
108
  torch_dtype=torch.float16,
 
212
  return self.process_results()
213
 
214
  def tokenize_batch_as_chat(self, batch):
215
+ local_rank = self.device
216
  batch = [
217
  self.tokenizer.apply_chat_template(
218
  i, tokenize=False, add_generation_prompt=True