Spaces:
Runtime error
Runtime error
fix temp again (#6)
Browse files- fix temp again (8ad07e8b38abaea9ade6a954e7c6d6122a90b43c)
- scripts/exp/fine_tune.py +1 -0
- vampnet/interface.py +1 -1
- vampnet/modules/transformer.py +2 -2
scripts/exp/fine_tune.py
CHANGED
|
@@ -53,6 +53,7 @@ def fine_tune(audio_files_or_folders: List[str], name: str):
|
|
| 53 |
|
| 54 |
"Interface.coarse2fine_ckpt": f"./models/vampnet/c2f.pth",
|
| 55 |
"Interface.coarse2fine_lora_ckpt": f"./runs/{name}/c2f/latest/lora.pth",
|
|
|
|
| 56 |
|
| 57 |
"Interface.codec_ckpt": "./models/vampnet/codec.pth",
|
| 58 |
"AudioLoader.sources": [audio_files_or_folders],
|
|
|
|
| 53 |
|
| 54 |
"Interface.coarse2fine_ckpt": f"./models/vampnet/c2f.pth",
|
| 55 |
"Interface.coarse2fine_lora_ckpt": f"./runs/{name}/c2f/latest/lora.pth",
|
| 56 |
+
"Interface.wavebeat_ckpt": "./models/wavebeat.pth",
|
| 57 |
|
| 58 |
"Interface.codec_ckpt": "./models/vampnet/codec.pth",
|
| 59 |
"AudioLoader.sources": [audio_files_or_folders],
|
vampnet/interface.py
CHANGED
|
@@ -65,7 +65,7 @@ class Interface(torch.nn.Module):
|
|
| 65 |
):
|
| 66 |
super().__init__()
|
| 67 |
assert codec_ckpt is not None, "must provide a codec checkpoint"
|
| 68 |
-
self.codec = DAC.load(codec_ckpt)
|
| 69 |
self.codec.eval()
|
| 70 |
self.codec.to(device)
|
| 71 |
|
|
|
|
| 65 |
):
|
| 66 |
super().__init__()
|
| 67 |
assert codec_ckpt is not None, "must provide a codec checkpoint"
|
| 68 |
+
self.codec = DAC.load(Path(codec_ckpt))
|
| 69 |
self.codec.eval()
|
| 70 |
self.codec.to(device)
|
| 71 |
|
vampnet/modules/transformer.py
CHANGED
|
@@ -581,7 +581,7 @@ class VampNet(at.ml.BaseModel):
|
|
| 581 |
sampling_steps: int = 24,
|
| 582 |
start_tokens: Optional[torch.Tensor] = None,
|
| 583 |
mask: Optional[torch.Tensor] = None,
|
| 584 |
-
temperature:
|
| 585 |
typical_filtering=False,
|
| 586 |
typical_mass=0.2,
|
| 587 |
typical_min_tokens=1,
|
|
@@ -592,7 +592,7 @@ class VampNet(at.ml.BaseModel):
|
|
| 592 |
#####################
|
| 593 |
# resolve temperature #
|
| 594 |
#####################
|
| 595 |
-
|
| 596 |
logging.debug(f"temperature: {temperature}")
|
| 597 |
|
| 598 |
|
|
|
|
| 581 |
sampling_steps: int = 24,
|
| 582 |
start_tokens: Optional[torch.Tensor] = None,
|
| 583 |
mask: Optional[torch.Tensor] = None,
|
| 584 |
+
temperature: float = 2.5,
|
| 585 |
typical_filtering=False,
|
| 586 |
typical_mass=0.2,
|
| 587 |
typical_min_tokens=1,
|
|
|
|
| 592 |
#####################
|
| 593 |
# resolve temperature #
|
| 594 |
#####################
|
| 595 |
+
|
| 596 |
logging.debug(f"temperature: {temperature}")
|
| 597 |
|
| 598 |
|