clean up default args
Browse files- src/retool_trainer.py +12 -7
src/retool_trainer.py
CHANGED
@@ -65,12 +65,18 @@ class ReToolTrainer(Trainer): # Change this line
|
|
65 |
# Store processing_class for compatibility
|
66 |
self.processing_class = processing_class or self.tokenizer
|
67 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
# Add reward function handling (since Trainer doesn't have this)
|
69 |
self.reward_funcs = reward_funcs or [self._binary_reward_function]
|
70 |
-
|
71 |
-
# Rest of the ReTool-specific code stays exactly the same!
|
72 |
-
self.eos_id = eos_id or self.processing_class.eos_token_id
|
73 |
-
|
74 |
|
75 |
# ReTool specific attributes
|
76 |
self.eos_id = eos_id or self.processing_class.eos_token_id
|
@@ -99,16 +105,15 @@ class ReToolTrainer(Trainer): # Change this line
|
|
99 |
do_sample=True,
|
100 |
pad_token_id=self.processing_class.pad_token_id,
|
101 |
bos_token_id=self.processing_class.bos_token_id,
|
102 |
-
eos_token_id=
|
103 |
temperature=self.temperature,
|
104 |
top_p=self.top_p,
|
105 |
top_k=self.top_k,
|
106 |
min_p=self.min_p,
|
107 |
return_dict_in_generate=True,
|
108 |
use_cache=True,
|
|
|
109 |
)
|
110 |
-
|
111 |
-
|
112 |
def _get_interpreter_token_ids(self) -> list[int]:
|
113 |
"""Get token IDs for <interpreter> and </interpreter> tags."""
|
114 |
start_token = self.processing_class.encode("<interpreter>", add_special_tokens=False)[0]
|
|
|
65 |
# Store processing_class for compatibility
|
66 |
self.processing_class = processing_class or self.tokenizer
|
67 |
|
68 |
+
# Processing class
|
69 |
+
if processing_class is None:
|
70 |
+
self.processing_class = AutoTokenizer.from_pretrained(model.config._name_or_path, padding_side="left")
|
71 |
+
else:
|
72 |
+
# Store processing_class for compatibility
|
73 |
+
self.processing_class = processing_class or self.tokenizer
|
74 |
+
if processing_class.pad_token is None:
|
75 |
+
self.processing_class.pad_token = processing_class.eos_token
|
76 |
+
|
77 |
+
|
78 |
# Add reward function handling (since Trainer doesn't have this)
|
79 |
self.reward_funcs = reward_funcs or [self._binary_reward_function]
|
|
|
|
|
|
|
|
|
80 |
|
81 |
# ReTool specific attributes
|
82 |
self.eos_id = eos_id or self.processing_class.eos_token_id
|
|
|
105 |
do_sample=True,
|
106 |
pad_token_id=self.processing_class.pad_token_id,
|
107 |
bos_token_id=self.processing_class.bos_token_id,
|
108 |
+
eos_token_id=self.eos_id, # default stop on EOS
|
109 |
temperature=self.temperature,
|
110 |
top_p=self.top_p,
|
111 |
top_k=self.top_k,
|
112 |
min_p=self.min_p,
|
113 |
return_dict_in_generate=True,
|
114 |
use_cache=True,
|
115 |
+
cache_implementation=args.cache_implementation, #args.cache_implementation = 'Offloaded Cache'
|
116 |
)
|
|
|
|
|
117 |
def _get_interpreter_token_ids(self) -> list[int]:
|
118 |
"""Get token IDs for <interpreter> and </interpreter> tags."""
|
119 |
start_token = self.processing_class.encode("<interpreter>", add_special_tokens=False)[0]
|