Spaces:
Paused
Paused
try onnx again
Browse files- README.md +1 -1
- app_onnx.py +6 -9
README.md
CHANGED
|
@@ -5,7 +5,7 @@ colorFrom: red
|
|
| 5 |
colorTo: indigo
|
| 6 |
sdk: gradio
|
| 7 |
sdk_version: 4.43.0
|
| 8 |
-
app_file:
|
| 9 |
pinned: true
|
| 10 |
license: apache-2.0
|
| 11 |
---
|
|
|
|
| 5 |
colorTo: indigo
|
| 6 |
sdk: gradio
|
| 7 |
sdk_version: 4.43.0
|
| 8 |
+
app_file: app_onnx.py
|
| 9 |
pinned: true
|
| 10 |
license: apache-2.0
|
| 11 |
---
|
app_onnx.py
CHANGED
|
@@ -170,9 +170,10 @@ def run(model_name, tab, mid_seq, continuation_state, continuation_select, instr
|
|
| 170 |
key_sig, mid, midi_events, reduce_cc_st, remap_track_channel, add_default_instr, remove_empty_channels,
|
| 171 |
seed, seed_rand, gen_events, temp, top_p, top_k, allow_cc):
|
| 172 |
model = models[model_name]
|
| 173 |
-
model[0]
|
| 174 |
-
model[1]
|
| 175 |
tokenizer = model[2]
|
|
|
|
| 176 |
bpm = int(bpm)
|
| 177 |
if time_sig == "auto":
|
| 178 |
time_sig = None
|
|
@@ -426,22 +427,18 @@ if __name__ == "__main__":
|
|
| 426 |
]
|
| 427 |
}
|
| 428 |
models = {}
|
| 429 |
-
providers = ['CPUExecutionProvider']
|
| 430 |
|
| 431 |
for name, (repo_id, path, config, loras) in models_info.items():
|
| 432 |
model_base_path = hf_hub_download_retry(repo_id=repo_id, filename=f"{path}onnx/model_base.onnx")
|
| 433 |
model_token_path = hf_hub_download_retry(repo_id=repo_id, filename=f"{path}onnx/model_token.onnx")
|
| 434 |
-
model_base = rt.InferenceSession(model_base_path, providers=providers)
|
| 435 |
-
model_token = rt.InferenceSession(model_token_path, providers=providers)
|
| 436 |
tokenizer = get_tokenizer(config)
|
| 437 |
-
models[name] = [
|
| 438 |
for lora_name, lora_repo in loras.items():
|
| 439 |
model_base_path = hf_hub_download_retry(repo_id=lora_repo, filename=f"onnx/model_base.onnx")
|
| 440 |
model_token_path = hf_hub_download_retry(repo_id=lora_repo, filename=f"onnx/model_token.onnx")
|
| 441 |
-
model_base = rt.InferenceSession(model_base_path, providers=providers)
|
| 442 |
-
model_token = rt.InferenceSession(model_token_path, providers=providers)
|
| 443 |
tokenizer = get_tokenizer(config)
|
| 444 |
-
models[f"{name} with {lora_name} lora"] = [
|
| 445 |
|
| 446 |
load_javascript()
|
| 447 |
app = gr.Blocks()
|
|
|
|
| 170 |
key_sig, mid, midi_events, reduce_cc_st, remap_track_channel, add_default_instr, remove_empty_channels,
|
| 171 |
seed, seed_rand, gen_events, temp, top_p, top_k, allow_cc):
|
| 172 |
model = models[model_name]
|
| 173 |
+
model_base = rt.InferenceSession(model[0], providers=providers)
|
| 174 |
+
model_token = rt.InferenceSession(model[1], providers=providers)
|
| 175 |
tokenizer = model[2]
|
| 176 |
+
model = [model_base, model_token, tokenizer]
|
| 177 |
bpm = int(bpm)
|
| 178 |
if time_sig == "auto":
|
| 179 |
time_sig = None
|
|
|
|
| 427 |
]
|
| 428 |
}
|
| 429 |
models = {}
|
| 430 |
+
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
|
| 431 |
|
| 432 |
for name, (repo_id, path, config, loras) in models_info.items():
|
| 433 |
model_base_path = hf_hub_download_retry(repo_id=repo_id, filename=f"{path}onnx/model_base.onnx")
|
| 434 |
model_token_path = hf_hub_download_retry(repo_id=repo_id, filename=f"{path}onnx/model_token.onnx")
|
|
|
|
|
|
|
| 435 |
tokenizer = get_tokenizer(config)
|
| 436 |
+
models[name] = [model_base_path, model_token_path, tokenizer]
|
| 437 |
for lora_name, lora_repo in loras.items():
|
| 438 |
model_base_path = hf_hub_download_retry(repo_id=lora_repo, filename=f"onnx/model_base.onnx")
|
| 439 |
model_token_path = hf_hub_download_retry(repo_id=lora_repo, filename=f"onnx/model_token.onnx")
|
|
|
|
|
|
|
| 440 |
tokenizer = get_tokenizer(config)
|
| 441 |
+
models[f"{name} with {lora_name} lora"] = [model_base_path, model_token_path, tokenizer]
|
| 442 |
|
| 443 |
load_javascript()
|
| 444 |
app = gr.Blocks()
|