|
FROM pytorch/pytorch:2.1.0-cuda11.8-cudnn8-runtime |
|
|
|
WORKDIR /app |
|
|
|
RUN apt-get update && apt-get install -y \ |
|
build-essential \ |
|
git \ |
|
python3-dev \ |
|
libsndfile1 \ |
|
&& rm -rf /var/lib/apt/lists/* |
|
|
|
COPY requirements.txt . |
|
RUN pip install -r requirements.txt |
|
|
|
COPY . . |
|
|
|
RUN mkdir -p checkpoints/hf_cache runs |
|
|
|
ENV PYTHONPATH=/app |
|
ENV HF_HUB_CACHE=/app/checkpoints/hf_cache |
|
|
|
RUN echo '#!/bin/bash\n\ |
|
python train.py \ |
|
--config ./configs/presets/config_dit_mel_seed_uvit_whisper_small_wavenet.yml \ |
|
--dataset-dir dataset \ |
|
--run-name training-run \ |
|
--batch-size 2 \ |
|
--max-steps 300 \ |
|
--max-epochs 1000 \ |
|
--save-every 100 \ |
|
--num-workers 0' > entrypoint.sh \ |
|
&& chmod +x entrypoint.sh |
|
|
|
ENTRYPOINT ["./entrypoint.sh"] |