FEAT: add tagging support to axolotl (#1004)
Browse files* add tagging support to axolotl
* chore: lint
* fix method w self
---------
Co-authored-by: Wing Lian <[email protected]>
src/axolotl/core/trainer_builder.py
CHANGED
|
@@ -9,7 +9,7 @@ import math
|
|
| 9 |
import sys
|
| 10 |
from abc import abstractmethod
|
| 11 |
from dataclasses import dataclass, field
|
| 12 |
-
from functools import partial
|
| 13 |
from pathlib import Path
|
| 14 |
from typing import Optional
|
| 15 |
|
|
@@ -120,6 +120,7 @@ class AxolotlTrainer(Trainer):
|
|
| 120 |
"""
|
| 121 |
|
| 122 |
args = None # type: AxolotlTrainingArguments
|
|
|
|
| 123 |
|
| 124 |
def __init__(self, *args, num_epochs=1, bench_data_collator=None, **kwargs):
|
| 125 |
self.num_epochs = num_epochs
|
|
@@ -290,12 +291,41 @@ class AxolotlTrainer(Trainer):
|
|
| 290 |
# return (loss, outputs) if return_outputs else loss
|
| 291 |
return super().compute_loss(model, inputs, return_outputs=return_outputs)
|
| 292 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 293 |
|
| 294 |
class AxolotlMambaTrainer(AxolotlTrainer):
|
| 295 |
"""
|
| 296 |
Mamba specific trainer to handle loss calculation
|
| 297 |
"""
|
| 298 |
|
|
|
|
|
|
|
| 299 |
def compute_loss(
|
| 300 |
self,
|
| 301 |
model,
|
|
@@ -322,6 +352,8 @@ class OneCycleLRSchedulerTrainer(AxolotlTrainer):
|
|
| 322 |
Trainer subclass that uses the OneCycleLR scheduler
|
| 323 |
"""
|
| 324 |
|
|
|
|
|
|
|
| 325 |
def __init__(self, *args, **kwargs):
|
| 326 |
super().__init__(*args, **kwargs)
|
| 327 |
self.lr_scheduler = None
|
|
@@ -351,6 +383,8 @@ class ReLoRATrainer(AxolotlTrainer):
|
|
| 351 |
Trainer subclass that uses the OneCycleLR scheduler
|
| 352 |
"""
|
| 353 |
|
|
|
|
|
|
|
| 354 |
def __init__(self, *args, **kwargs):
|
| 355 |
super().__init__(*args, **kwargs)
|
| 356 |
self.lr_scheduler = None
|
|
|
|
| 9 |
import sys
|
| 10 |
from abc import abstractmethod
|
| 11 |
from dataclasses import dataclass, field
|
| 12 |
+
from functools import partial, wraps
|
| 13 |
from pathlib import Path
|
| 14 |
from typing import Optional
|
| 15 |
|
|
|
|
| 120 |
"""
|
| 121 |
|
| 122 |
args = None # type: AxolotlTrainingArguments
|
| 123 |
+
tag_names = ["axolotl"]
|
| 124 |
|
| 125 |
def __init__(self, *args, num_epochs=1, bench_data_collator=None, **kwargs):
|
| 126 |
self.num_epochs = num_epochs
|
|
|
|
| 291 |
# return (loss, outputs) if return_outputs else loss
|
| 292 |
return super().compute_loss(model, inputs, return_outputs=return_outputs)
|
| 293 |
|
| 294 |
+
def _sanitize_kwargs_for_tagging(self, tag_names, kwargs=None):
|
| 295 |
+
if isinstance(tag_names, str):
|
| 296 |
+
tag_names = [tag_names]
|
| 297 |
+
|
| 298 |
+
if kwargs is not None:
|
| 299 |
+
if "tags" not in kwargs:
|
| 300 |
+
kwargs["tags"] = tag_names
|
| 301 |
+
elif "tags" in kwargs and isinstance(kwargs["tags"], list):
|
| 302 |
+
kwargs["tags"].extend(tag_names)
|
| 303 |
+
elif "tags" in kwargs and isinstance(kwargs["tags"], str):
|
| 304 |
+
tag_names.append(kwargs["tags"])
|
| 305 |
+
kwargs["tags"] = tag_names
|
| 306 |
+
|
| 307 |
+
return kwargs
|
| 308 |
+
|
| 309 |
+
@wraps(Trainer.push_to_hub)
|
| 310 |
+
def push_to_hub(self, *args, **kwargs) -> str:
|
| 311 |
+
"""
|
| 312 |
+
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
|
| 313 |
+
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
|
| 314 |
+
"""
|
| 315 |
+
kwargs = self._sanitize_kwargs_for_tagging(
|
| 316 |
+
tag_names=self.tag_names, kwargs=kwargs
|
| 317 |
+
)
|
| 318 |
+
|
| 319 |
+
return super().push_to_hub(*args, **kwargs)
|
| 320 |
+
|
| 321 |
|
| 322 |
class AxolotlMambaTrainer(AxolotlTrainer):
|
| 323 |
"""
|
| 324 |
Mamba specific trainer to handle loss calculation
|
| 325 |
"""
|
| 326 |
|
| 327 |
+
tag_names = ["axolotl", "mamba"]
|
| 328 |
+
|
| 329 |
def compute_loss(
|
| 330 |
self,
|
| 331 |
model,
|
|
|
|
| 352 |
Trainer subclass that uses the OneCycleLR scheduler
|
| 353 |
"""
|
| 354 |
|
| 355 |
+
tag_names = ["axolotl", "onecycle"]
|
| 356 |
+
|
| 357 |
def __init__(self, *args, **kwargs):
|
| 358 |
super().__init__(*args, **kwargs)
|
| 359 |
self.lr_scheduler = None
|
|
|
|
| 383 |
Trainer subclass that uses the OneCycleLR scheduler
|
| 384 |
"""
|
| 385 |
|
| 386 |
+
tag_names = ["axolotl", "relora"]
|
| 387 |
+
|
| 388 |
def __init__(self, *args, **kwargs):
|
| 389 |
super().__init__(*args, **kwargs)
|
| 390 |
self.lr_scheduler = None
|