mrfakename commited on
Commit
0b36a0d
·
verified ·
1 Parent(s): a6bb81f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -6
app.py CHANGED
@@ -13,12 +13,31 @@ import spaces
13
  os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
14
 
15
  checkpoints_path = snapshot_download("mrfakename/openf5-v2", allow_patterns=["model_*.pt", "vocab.txt"], token=os.getenv("HF_TOKEN"))
16
-
17
  models = {}
18
- for checkpoint_path in tqdm(os.listdir(checkpoints_path), desc="Loading models"):
19
- if checkpoint_path.endswith(".pt"):
20
- model_name = checkpoint_path.split("/")[-1].replace(".pt", "")
21
- models[model_name] = F5TTS(ckpt_file=os.path.join(checkpoints_path, checkpoint_path), vocab_file=os.path.join(checkpoints_path, "vocab.txt"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  @spaces.GPU
24
  def generate_audio(model_name, ref_file, ref_text, gen_text, progress=gr.Progress()):
@@ -36,7 +55,7 @@ def generate_audio(model_name, ref_file, ref_text, gen_text, progress=gr.Progres
36
 
37
  with gr.Blocks() as demo:
38
  gr.Markdown(ABOUT)
39
- model_name = gr.Dropdown(label="Model", choices=list(models.keys()))
40
  ref_file = gr.Audio(label="Reference Audio", type="filepath")
41
  gen_text = gr.Textbox(label="Text")
42
  btn_generate = gr.Button("Generate Audio", variant="primary")
 
13
  os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
14
 
15
  checkpoints_path = snapshot_download("mrfakename/openf5-v2", allow_patterns=["model_*.pt", "vocab.txt"], token=os.getenv("HF_TOKEN"))
 
16
  models = {}
17
+ checkpoint_files = [f for f in os.listdir(checkpoints_path) if f.endswith(".pt")]
18
+
19
+ # Sort checkpoint files by step number
20
+ def get_step_number(filename):
21
+ name = filename.replace(".pt", "")
22
+ if name == "model_last":
23
+ return float('-inf') # Ensure model_last comes first
24
+ try:
25
+ return int(name.split("_")[1])
26
+ except (IndexError, ValueError):
27
+ return float('inf') # Put non-standard names at the end
28
+
29
+ sorted_checkpoints = sorted(checkpoint_files, key=get_step_number)
30
+
31
+ # Load models in the sorted order
32
+ for checkpoint_path in tqdm(sorted_checkpoints, desc="Loading models"):
33
+ model_name = checkpoint_path.replace(".pt", "")
34
+ # Load one model at a time to be memory efficient
35
+ models[model_name] = F5TTS(ckpt_file=os.path.join(checkpoints_path, checkpoint_path),
36
+ vocab_file=os.path.join(checkpoints_path, "vocab.txt"))
37
+
38
+
39
+
40
+
41
 
42
  @spaces.GPU
43
  def generate_audio(model_name, ref_file, ref_text, gen_text, progress=gr.Progress()):
 
55
 
56
  with gr.Blocks() as demo:
57
  gr.Markdown(ABOUT)
58
+ model_name = gr.Radio(label="Model", choices=list(models.keys()))
59
  ref_file = gr.Audio(label="Reference Audio", type="filepath")
60
  gen_text = gr.Textbox(label="Text")
61
  btn_generate = gr.Button("Generate Audio", variant="primary")