Spaces:
Build error
Build error
Refactor SFTTrainer configuration in train.py to remove data_collator from the SFT config, preventing duplication and enhancing clarity in trainer setup.
Browse files
train.py
CHANGED
@@ -200,6 +200,10 @@ def create_trainer(
|
|
200 |
**cfg.training.sft.data_collator,
|
201 |
)
|
202 |
|
|
|
|
|
|
|
|
|
203 |
trainer = SFTTrainer(
|
204 |
model=model,
|
205 |
tokenizer=tokenizer,
|
@@ -207,7 +211,7 @@ def create_trainer(
|
|
207 |
eval_dataset=dataset["validation"],
|
208 |
args=training_args,
|
209 |
data_collator=data_collator,
|
210 |
-
**
|
211 |
)
|
212 |
logger.info("Trainer created successfully")
|
213 |
return trainer
|
|
|
200 |
**cfg.training.sft.data_collator,
|
201 |
)
|
202 |
|
203 |
+
# Create SFT config without data_collator to avoid duplication
|
204 |
+
sft_config = OmegaConf.to_container(cfg.training.sft, resolve=True)
|
205 |
+
sft_config.pop('data_collator', None) # Remove data_collator from config
|
206 |
+
|
207 |
trainer = SFTTrainer(
|
208 |
model=model,
|
209 |
tokenizer=tokenizer,
|
|
|
211 |
eval_dataset=dataset["validation"],
|
212 |
args=training_args,
|
213 |
data_collator=data_collator,
|
214 |
+
**sft_config,
|
215 |
)
|
216 |
logger.info("Trainer created successfully")
|
217 |
return trainer
|