Update factual/green_score/green.py
Browse files
factual/green_score/green.py
CHANGED
@@ -62,6 +62,7 @@ class GREEN:
|
|
62 |
category=FutureWarning,
|
63 |
module="transformers.tokenization_utils_base",
|
64 |
)
|
|
|
65 |
self.cpu = cpu
|
66 |
self.model_name = model_name.split("/")[-1]
|
67 |
self.output_dir = output_dir
|
@@ -100,8 +101,8 @@ class GREEN:
|
|
100 |
model_name,
|
101 |
trust_remote_code=False if "Phi" in model_name else True,
|
102 |
device_map=(
|
103 |
-
{"": "cuda:{
|
104 |
-
if not self.cpu
|
105 |
else {"": "cpu"}
|
106 |
),
|
107 |
torch_dtype=torch.float16,
|
@@ -211,7 +212,7 @@ class GREEN:
|
|
211 |
return self.process_results()
|
212 |
|
213 |
def tokenize_batch_as_chat(self, batch):
|
214 |
-
local_rank =
|
215 |
batch = [
|
216 |
self.tokenizer.apply_chat_template(
|
217 |
i, tokenize=False, add_generation_prompt=True
|
|
|
62 |
category=FutureWarning,
|
63 |
module="transformers.tokenization_utils_base",
|
64 |
)
|
65 |
+
cpu = cpu or not torch.cuda.is_available()
|
66 |
self.cpu = cpu
|
67 |
self.model_name = model_name.split("/")[-1]
|
68 |
self.output_dir = output_dir
|
|
|
101 |
model_name,
|
102 |
trust_remote_code=False if "Phi" in model_name else True,
|
103 |
device_map=(
|
104 |
+
{"": f"cuda:{torch.cuda.current_device()}"}
|
105 |
+
if (not self.cpu and torch.cuda.is_available())
|
106 |
else {"": "cpu"}
|
107 |
),
|
108 |
torch_dtype=torch.float16,
|
|
|
212 |
return self.process_results()
|
213 |
|
214 |
def tokenize_batch_as_chat(self, batch):
|
215 |
+
local_rank = self.device
|
216 |
batch = [
|
217 |
self.tokenizer.apply_chat_template(
|
218 |
i, tokenize=False, add_generation_prompt=True
|