bugfixes
Browse files- scripts/finetune.py +3 -2
scripts/finetune.py
CHANGED
|
@@ -427,9 +427,10 @@ def train(
|
|
| 427 |
max_packed_sequence_len = min(max_packed_sequence_len, cfg.sequence_len) # make sure we don't accidentally set it larger than sequence_len
|
| 428 |
ds_hash = str(md5((str(max_packed_sequence_len) + "@" + "|".join(sorted([f"{d.path}:{d.type}" for d in cfg.datasets]))).encode('utf-8')).hexdigest())
|
| 429 |
prepared_ds_path = Path(cfg.dataset_prepared_path) / ds_hash if cfg.dataset_prepared_path else Path(DEFAULT_DATASET_PREPARED_PATH) / ds_hash
|
|
|
|
| 430 |
if any(prepared_ds_path.glob("*")):
|
| 431 |
logging.info("Loading prepared dataset from disk...")
|
| 432 |
-
dataset = load_from_disk(
|
| 433 |
logging.info("Prepared dataset loaded from disk...")
|
| 434 |
else:
|
| 435 |
logging.info("Loading raw datasets...")
|
|
@@ -437,7 +438,7 @@ def train(
|
|
| 437 |
for d in cfg.datasets:
|
| 438 |
ds_from_hub = False
|
| 439 |
try:
|
| 440 |
-
|
| 441 |
ds_from_hub = True
|
| 442 |
except FileNotFoundError:
|
| 443 |
pass
|
|
|
|
| 427 |
max_packed_sequence_len = min(max_packed_sequence_len, cfg.sequence_len) # make sure we don't accidentally set it larger than sequence_len
|
| 428 |
ds_hash = str(md5((str(max_packed_sequence_len) + "@" + "|".join(sorted([f"{d.path}:{d.type}" for d in cfg.datasets]))).encode('utf-8')).hexdigest())
|
| 429 |
prepared_ds_path = Path(cfg.dataset_prepared_path) / ds_hash if cfg.dataset_prepared_path else Path(DEFAULT_DATASET_PREPARED_PATH) / ds_hash
|
| 430 |
+
|
| 431 |
if any(prepared_ds_path.glob("*")):
|
| 432 |
logging.info("Loading prepared dataset from disk...")
|
| 433 |
+
dataset = load_from_disk(str(prepared_ds_path))
|
| 434 |
logging.info("Prepared dataset loaded from disk...")
|
| 435 |
else:
|
| 436 |
logging.info("Loading raw datasets...")
|
|
|
|
| 438 |
for d in cfg.datasets:
|
| 439 |
ds_from_hub = False
|
| 440 |
try:
|
| 441 |
+
load_dataset(d.path, streaming=True)
|
| 442 |
ds_from_hub = True
|
| 443 |
except FileNotFoundError:
|
| 444 |
pass
|