jbilcke-hf's picture
jbilcke-hf HF Staff
we are going to hack into finetrainers
9fd1204
raw
history blame
2.31 kB
import os
from typing import List, Union
import numpy as np
import wandb
from diffusers.utils import export_to_video
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from PIL import Image
def save_model_card(
args,
repo_id: str,
videos: Union[List[str], Union[List[Image.Image], List[np.ndarray]]],
validation_prompts: List[str],
fps: int = 30,
) -> None:
widget_dict = []
output_dir = str(args.output_dir)
if videos is not None and len(videos) > 0:
for i, (video, validation_prompt) in enumerate(zip(videos, validation_prompts)):
if not isinstance(video, str):
export_to_video(video, os.path.join(output_dir, f"final_video_{i}.mp4"), fps=fps)
widget_dict.append(
{
"text": validation_prompt if validation_prompt else " ",
"output": {"url": video if isinstance(video, str) else f"final_video_{i}.mp4"},
}
)
model_description = f"""
# LoRA Finetune
<Gallery />
## Model description
This is a lora finetune of model: `{args.pretrained_model_name_or_path}`.
The model was trained using [`finetrainers`](https://github.com/a-r-r-o-w/finetrainers).
## Download model
[Download LoRA]({repo_id}/tree/main) in the Files & Versions tab.
## Usage
Requires the [🧨 Diffusers library](https://github.com/huggingface/diffusers) installed.
```py
TODO
```
For more details, including weighting, merging and fusing LoRAs, check the [documentation](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters) on loading LoRAs in diffusers.
"""
if wandb.run.url:
model_description += f"""
Find out the wandb run URL and training configurations [here]({wandb.run.url}).
"""
model_card = load_or_create_model_card(
repo_id_or_path=repo_id,
from_training=True,
base_model=args.pretrained_model_name_or_path,
model_description=model_description,
widget=widget_dict,
)
tags = [
"text-to-video",
"diffusers-training",
"diffusers",
"lora",
"template:sd-lora",
]
model_card = populate_model_card(model_card, tags=tags)
model_card.save(os.path.join(args.output_dir, "README.md"))