unknown commited on
Commit
96946f8
·
1 Parent(s): c4d7252

import transcribe from utils_infer

Browse files
Files changed (1) hide show
  1. src/f5_tts/train/finetune_gradio.py +5 -29
src/f5_tts/train/finetune_gradio.py CHANGED
@@ -26,12 +26,13 @@ from datasets import Dataset as Dataset_
26
  from datasets.arrow_writer import ArrowWriter
27
  from safetensors.torch import save_file
28
  from scipy.io import wavfile
29
- from transformers import pipeline
30
  from cached_path import cached_path
31
  from f5_tts.api import F5TTS
32
  from f5_tts.model.utils import convert_char_to_pinyin
 
33
  from importlib.resources import files
34
 
 
35
  training_process = None
36
  system = platform.system()
37
  python_executable = sys.executable or "python"
@@ -47,8 +48,6 @@ file_train = "src/f5_tts/train/finetune_cli.py"
47
 
48
  device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
49
 
50
- pipe = None
51
-
52
 
53
  # Save settings from a JSON file
54
  def save_settings(
@@ -390,17 +389,15 @@ def start_training(
390
  logger="wandb",
391
  ch_8bit_adam=False,
392
  ):
393
- global training_process, tts_api, stop_signal, pipe
394
 
395
- if tts_api is not None or pipe is not None:
396
  if tts_api is not None:
397
  del tts_api
398
- if pipe is not None:
399
- del pipe
400
  gc.collect()
401
  torch.cuda.empty_cache()
402
  tts_api = None
403
- pipe = None
404
 
405
  path_project = os.path.join(path_data, dataset_name)
406
 
@@ -652,27 +649,6 @@ def create_data_project(name, tokenizer_type):
652
  return gr.update(choices=project_list, value=name)
653
 
654
 
655
- def transcribe(file_audio, language="english"):
656
- global pipe
657
-
658
- if pipe is None:
659
- pipe = pipeline(
660
- "automatic-speech-recognition",
661
- model="openai/whisper-large-v3-turbo",
662
- torch_dtype=torch.float16,
663
- device=device,
664
- )
665
-
666
- text_transcribe = pipe(
667
- file_audio,
668
- chunk_length_s=30,
669
- batch_size=128,
670
- generate_kwargs={"task": "transcribe", "language": language},
671
- return_timestamps=False,
672
- )["text"].strip()
673
- return text_transcribe
674
-
675
-
676
  def transcribe_all(name_project, audio_files, language, user=False, progress=gr.Progress()):
677
  path_project = os.path.join(path_data, name_project)
678
  path_dataset = os.path.join(path_project, "dataset")
 
26
  from datasets.arrow_writer import ArrowWriter
27
  from safetensors.torch import save_file
28
  from scipy.io import wavfile
 
29
  from cached_path import cached_path
30
  from f5_tts.api import F5TTS
31
  from f5_tts.model.utils import convert_char_to_pinyin
32
+ from f5_tts.infer.utils_infer import transcribe
33
  from importlib.resources import files
34
 
35
+
36
  training_process = None
37
  system = platform.system()
38
  python_executable = sys.executable or "python"
 
48
 
49
  device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
50
 
 
 
51
 
52
  # Save settings from a JSON file
53
  def save_settings(
 
389
  logger="wandb",
390
  ch_8bit_adam=False,
391
  ):
392
+ global training_process, tts_api, stop_signal
393
 
394
+ if tts_api is not None:
395
  if tts_api is not None:
396
  del tts_api
397
+
 
398
  gc.collect()
399
  torch.cuda.empty_cache()
400
  tts_api = None
 
401
 
402
  path_project = os.path.join(path_data, dataset_name)
403
 
 
649
  return gr.update(choices=project_list, value=name)
650
 
651
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
652
  def transcribe_all(name_project, audio_files, language, user=False, progress=gr.Progress()):
653
  path_project = os.path.join(path_data, name_project)
654
  path_dataset = os.path.join(path_project, "dataset")