Spaces:
Running
Running
flash
Browse files- app.py +4 -1
- midi_model.py +1 -4
app.py
CHANGED
|
@@ -223,13 +223,16 @@ if __name__ == "__main__":
|
|
| 223 |
"touhou finetune model": ["skytnt/midi-model-ft", "touhou/"],
|
| 224 |
}
|
| 225 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
|
|
|
|
|
|
| 226 |
models = {}
|
| 227 |
tokenizer = MIDITokenizer()
|
| 228 |
for name, (repo_id, path) in models_info.items():
|
| 229 |
|
| 230 |
model_path = hf_hub_download_retry(repo_id=repo_id, filename=f"{path}model.ckpt")
|
| 231 |
model = MIDIModel(tokenizer).to(device=device)
|
| 232 |
-
ckpt = torch.load(model_path)
|
| 233 |
state_dict = ckpt.get("state_dict", ckpt)
|
| 234 |
model.load_state_dict(state_dict, strict=False)
|
| 235 |
model.eval()
|
|
|
|
| 223 |
"touhou finetune model": ["skytnt/midi-model-ft", "touhou/"],
|
| 224 |
}
|
| 225 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 226 |
+
if device=="cuda": # flash attn
|
| 227 |
+
torch.backends.cuda.enable_mem_efficient_sdp(True)
|
| 228 |
+
torch.backends.cuda.enable_flash_sdp(True)
|
| 229 |
models = {}
|
| 230 |
tokenizer = MIDITokenizer()
|
| 231 |
for name, (repo_id, path) in models_info.items():
|
| 232 |
|
| 233 |
model_path = hf_hub_download_retry(repo_id=repo_id, filename=f"{path}model.ckpt")
|
| 234 |
model = MIDIModel(tokenizer).to(device=device)
|
| 235 |
+
ckpt = torch.load(model_path, weights_only=True)
|
| 236 |
state_dict = ckpt.get("state_dict", ckpt)
|
| 237 |
model.load_state_dict(state_dict, strict=False)
|
| 238 |
model.eval()
|
midi_model.py
CHANGED
|
@@ -9,7 +9,7 @@ from midi_tokenizer import MIDITokenizer
|
|
| 9 |
|
| 10 |
|
| 11 |
class MIDIModel(nn.Module):
|
| 12 |
-
def __init__(self, tokenizer: MIDITokenizer, n_layer=12, n_head=16, n_embd=1024, n_inner=4096,
|
| 13 |
*args, **kwargs):
|
| 14 |
super(MIDIModel, self).__init__()
|
| 15 |
self.tokenizer = tokenizer
|
|
@@ -21,9 +21,6 @@ class MIDIModel(nn.Module):
|
|
| 21 |
hidden_size=n_embd, num_attention_heads=n_head // 4,
|
| 22 |
num_hidden_layers=n_layer // 4, intermediate_size=n_inner // 4,
|
| 23 |
pad_token_id=tokenizer.pad_id, max_position_embeddings=4096))
|
| 24 |
-
if flash:
|
| 25 |
-
torch.backends.cuda.enable_mem_efficient_sdp(True)
|
| 26 |
-
torch.backends.cuda.enable_flash_sdp(True)
|
| 27 |
self.lm_head = nn.Linear(n_embd, tokenizer.vocab_size, bias=False)
|
| 28 |
self.device = "cpu"
|
| 29 |
|
|
|
|
| 9 |
|
| 10 |
|
| 11 |
class MIDIModel(nn.Module):
|
| 12 |
+
def __init__(self, tokenizer: MIDITokenizer, n_layer=12, n_head=16, n_embd=1024, n_inner=4096,
|
| 13 |
*args, **kwargs):
|
| 14 |
super(MIDIModel, self).__init__()
|
| 15 |
self.tokenizer = tokenizer
|
|
|
|
| 21 |
hidden_size=n_embd, num_attention_heads=n_head // 4,
|
| 22 |
num_hidden_layers=n_layer // 4, intermediate_size=n_inner // 4,
|
| 23 |
pad_token_id=tokenizer.pad_id, max_position_embeddings=4096))
|
|
|
|
|
|
|
|
|
|
| 24 |
self.lm_head = nn.Linear(n_embd, tokenizer.vocab_size, bias=False)
|
| 25 |
self.device = "cpu"
|
| 26 |
|