bird-of-paradise commited on
Commit
f757722
·
verified ·
1 Parent(s): 9b6b77c

clean up default args

Browse files
Files changed (1) hide show
  1. src/retool_trainer.py +12 -7
src/retool_trainer.py CHANGED
@@ -65,12 +65,18 @@ class ReToolTrainer(Trainer): # Change this line
65
  # Store processing_class for compatibility
66
  self.processing_class = processing_class or self.tokenizer
67
 
 
 
 
 
 
 
 
 
 
 
68
  # Add reward function handling (since Trainer doesn't have this)
69
  self.reward_funcs = reward_funcs or [self._binary_reward_function]
70
-
71
- # Rest of the ReTool-specific code stays exactly the same!
72
- self.eos_id = eos_id or self.processing_class.eos_token_id
73
-
74
 
75
  # ReTool specific attributes
76
  self.eos_id = eos_id or self.processing_class.eos_token_id
@@ -99,16 +105,15 @@ class ReToolTrainer(Trainer): # Change this line
99
  do_sample=True,
100
  pad_token_id=self.processing_class.pad_token_id,
101
  bos_token_id=self.processing_class.bos_token_id,
102
- eos_token_id=[self.eos_id, self.code_id[1]], # Stop on EOS or </code>
103
  temperature=self.temperature,
104
  top_p=self.top_p,
105
  top_k=self.top_k,
106
  min_p=self.min_p,
107
  return_dict_in_generate=True,
108
  use_cache=True,
 
109
  )
110
-
111
-
112
  def _get_interpreter_token_ids(self) -> list[int]:
113
  """Get token IDs for <interpreter> and </interpreter> tags."""
114
  start_token = self.processing_class.encode("<interpreter>", add_special_tokens=False)[0]
 
65
  # Store processing_class for compatibility
66
  self.processing_class = processing_class or self.tokenizer
67
 
68
+ # Processing class
69
+ if processing_class is None:
70
+ self.processing_class = AutoTokenizer.from_pretrained(model.config._name_or_path, padding_side="left")
71
+ else:
72
+ # Store processing_class for compatibility
73
+ self.processing_class = processing_class or self.tokenizer
74
+ if processing_class.pad_token is None:
75
+ self.processing_class.pad_token = processing_class.eos_token
76
+
77
+
78
  # Add reward function handling (since Trainer doesn't have this)
79
  self.reward_funcs = reward_funcs or [self._binary_reward_function]
 
 
 
 
80
 
81
  # ReTool specific attributes
82
  self.eos_id = eos_id or self.processing_class.eos_token_id
 
105
  do_sample=True,
106
  pad_token_id=self.processing_class.pad_token_id,
107
  bos_token_id=self.processing_class.bos_token_id,
108
+ eos_token_id=self.eos_id, # default stop on EOS
109
  temperature=self.temperature,
110
  top_p=self.top_p,
111
  top_k=self.top_k,
112
  min_p=self.min_p,
113
  return_dict_in_generate=True,
114
  use_cache=True,
115
+ cache_implementation=args.cache_implementation, #args.cache_implementation = 'Offloaded Cache'
116
  )
 
 
117
  def _get_interpreter_token_ids(self) -> list[int]:
118
  """Get token IDs for <interpreter> and </interpreter> tags."""
119
  start_token = self.processing_class.encode("<interpreter>", add_special_tokens=False)[0]