Spaces:
Runtime error
Runtime error
Commit
·
08d9656
1
Parent(s):
e95cc03
fix app.py
Browse filescreate upload tool
- app.py +1 -2
- train.py +2 -1
- upload_to_HF.py +56 -0
app.py
CHANGED
|
@@ -4,7 +4,7 @@ import gradio as gr
|
|
| 4 |
from pathlib import Path
|
| 5 |
from denoisers.SpectralGating import SpectralGating
|
| 6 |
|
| 7 |
-
|
| 8 |
|
| 9 |
def denoising_transform(audio):
|
| 10 |
src_path = Path(__file__).parent.resolve() / Path("cache_wav/original/{}.wav".format(str(uuid.uuid4())))
|
|
@@ -32,6 +32,5 @@ demo = gr.Interface(
|
|
| 32 |
)
|
| 33 |
|
| 34 |
if __name__ == "__main__":
|
| 35 |
-
model = SpectralGating()
|
| 36 |
demo.launch()
|
| 37 |
|
|
|
|
| 4 |
from pathlib import Path
|
| 5 |
from denoisers.SpectralGating import SpectralGating
|
| 6 |
|
| 7 |
+
model = SpectralGating()
|
| 8 |
|
| 9 |
def denoising_transform(audio):
|
| 10 |
src_path = Path(__file__).parent.resolve() / Path("cache_wav/original/{}.wav".format(str(uuid.uuid4())))
|
|
|
|
| 32 |
)
|
| 33 |
|
| 34 |
if __name__ == "__main__":
|
|
|
|
| 35 |
demo.launch()
|
| 36 |
|
train.py
CHANGED
|
@@ -34,7 +34,8 @@ def init_wandb(cfg):
|
|
| 34 |
def train(cfg: DictConfig):
|
| 35 |
device = torch.device(f'cuda:{cfg.gpu}' if torch.cuda.is_available() else 'cpu')
|
| 36 |
init_wandb(cfg)
|
| 37 |
-
checkpoint_saver = CheckpointSaver(dirpath=cfg['training']['model_save_path'], run_name=wandb.run.name
|
|
|
|
| 38 |
metrics = Metrics(source_rate=cfg['dataloader']['sample_rate']).to(device)
|
| 39 |
|
| 40 |
model = get_model(cfg['model']).to(device)
|
|
|
|
| 34 |
def train(cfg: DictConfig):
|
| 35 |
device = torch.device(f'cuda:{cfg.gpu}' if torch.cuda.is_available() else 'cpu')
|
| 36 |
init_wandb(cfg)
|
| 37 |
+
checkpoint_saver = CheckpointSaver(dirpath=cfg['training']['model_save_path'], run_name=wandb.run.name,
|
| 38 |
+
decreasing=False)
|
| 39 |
metrics = Metrics(source_rate=cfg['dataloader']['sample_rate']).to(device)
|
| 40 |
|
| 41 |
model = get_model(cfg['model']).to(device)
|
upload_to_HF.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import wandb
|
| 3 |
+
from huggingface_hub import HfApi
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
import huggingface_hub
|
| 6 |
+
import ssl
|
| 7 |
+
import os
|
| 8 |
+
os.environ['CURL_CA_BUNDLE'] = ''
|
| 9 |
+
|
| 10 |
+
ssl._create_default_https_context = ssl._create_unverified_context
|
| 11 |
+
|
| 12 |
+
class Uploader:
|
| 13 |
+
def __init__(self, entity, project, run_name, repo_id, username):
|
| 14 |
+
self.entity = entity
|
| 15 |
+
self.project = project
|
| 16 |
+
self.run_name = run_name
|
| 17 |
+
self.hf_api = HfApi()
|
| 18 |
+
self.wandb_api = wandb.Api()
|
| 19 |
+
self.repo_id = repo_id
|
| 20 |
+
self.username = username
|
| 21 |
+
huggingface_hub.login(os.environ.get('HUGGINGFACE_TOKEN'))
|
| 22 |
+
|
| 23 |
+
def get_model_from_wandb_run(self):
|
| 24 |
+
runs = self.wandb_api.runs(f"{self.entity}/{self.project}",
|
| 25 |
+
# order='+summary_metrics.train_pesq'
|
| 26 |
+
)
|
| 27 |
+
run = [run for run in runs if run.name == self.run_name][0]
|
| 28 |
+
artifacts = run.logged_artifacts()
|
| 29 |
+
best_model = [artifact for artifact in artifacts if artifact.type == 'model'][0]
|
| 30 |
+
artifact_dir = best_model.download()
|
| 31 |
+
model_path = list(Path(artifact_dir).glob("*.pt"))[0].absolute().as_posix()
|
| 32 |
+
print(f"Model validation score = {best_model.metadata['Validation score']}")
|
| 33 |
+
return model_path
|
| 34 |
+
|
| 35 |
+
def upload_to_HF(self):
|
| 36 |
+
model_path = self.get_model_from_wandb_run()
|
| 37 |
+
self.hf_api.upload_file(
|
| 38 |
+
path_or_fileobj=model_path,
|
| 39 |
+
path_in_repo=Path(model_path).name,
|
| 40 |
+
repo_id=f'{self.username}/{self.repo_id}',
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
def create_repo(self):
|
| 44 |
+
self.hf_api.create_repo(repo_id=self.repo_id, exist_ok=True)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
if __name__ == '__main__':
|
| 49 |
+
uploader = Uploader(entity='borisovmaksim',
|
| 50 |
+
project='denoising',
|
| 51 |
+
run_name='wav_normalization',
|
| 52 |
+
repo_id='demucs',
|
| 53 |
+
username='BorisovMaksim')
|
| 54 |
+
uploader.create_repo()
|
| 55 |
+
uploader.upload_to_HF()
|
| 56 |
+
|