Update app.py
Browse files
app.py
CHANGED
|
@@ -12,15 +12,6 @@ from transformers.models.whisper.tokenization_whisper import TO_LANGUAGE_CODE
|
|
| 12 |
from transformers.pipelines.audio_utils import ffmpeg_read
|
| 13 |
from whisper_jax import FlaxWhisperPipline
|
| 14 |
|
| 15 |
-
cc.initialize_cache("./jax_cache")
|
| 16 |
-
checkpoint = "openai/whisper-large-v3"
|
| 17 |
-
|
| 18 |
-
BATCH_SIZE = 32
|
| 19 |
-
CHUNK_LENGTH_S = 30
|
| 20 |
-
NUM_PROC = 32
|
| 21 |
-
FILE_LIMIT_MB = 10000
|
| 22 |
-
YT_LENGTH_LIMIT_S = 15000 # limit to 2 hour YouTube files
|
| 23 |
-
|
| 24 |
app = FastAPI(title="Whisper JAX: The Fastest Whisper API ⚡️")
|
| 25 |
|
| 26 |
logger = logging.getLogger("whisper-jax-app")
|
|
@@ -31,6 +22,14 @@ formatter = logging.Formatter("%(asctime)s;%(levelname)s;%(message)s", "%Y-%m-%d
|
|
| 31 |
ch.setFormatter(formatter)
|
| 32 |
logger.addHandler(ch)
|
| 33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
pipeline = FlaxWhisperPipline(checkpoint, dtype=jnp.bfloat16, batch_size=BATCH_SIZE)
|
| 35 |
stride_length_s = CHUNK_LENGTH_S / 6
|
| 36 |
chunk_len = round(CHUNK_LENGTH_S * pipeline.feature_extractor.sampling_rate)
|
|
@@ -149,4 +148,23 @@ def download_yt_audio(yt_url, filename):
|
|
| 149 |
try:
|
| 150 |
ydl.download([yt_url])
|
| 151 |
except youtube_dl.utils.ExtractorError as err:
|
| 152 |
-
raise HTTPException(status_code=400, detail=str(err))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
from transformers.pipelines.audio_utils import ffmpeg_read
|
| 13 |
from whisper_jax import FlaxWhisperPipline
|
| 14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
app = FastAPI(title="Whisper JAX: The Fastest Whisper API ⚡️")
|
| 16 |
|
| 17 |
logger = logging.getLogger("whisper-jax-app")
|
|
|
|
| 22 |
ch.setFormatter(formatter)
|
| 23 |
logger.addHandler(ch)
|
| 24 |
|
| 25 |
+
checkpoint = "openai/whisper-large-v3"
|
| 26 |
+
|
| 27 |
+
BATCH_SIZE = 32
|
| 28 |
+
CHUNK_LENGTH_S = 30
|
| 29 |
+
NUM_PROC = 32
|
| 30 |
+
FILE_LIMIT_MB = 10000
|
| 31 |
+
YT_LENGTH_LIMIT_S = 15000 # limit to 2 hour YouTube files
|
| 32 |
+
|
| 33 |
pipeline = FlaxWhisperPipline(checkpoint, dtype=jnp.bfloat16, batch_size=BATCH_SIZE)
|
| 34 |
stride_length_s = CHUNK_LENGTH_S / 6
|
| 35 |
chunk_len = round(CHUNK_LENGTH_S * pipeline.feature_extractor.sampling_rate)
|
|
|
|
| 148 |
try:
|
| 149 |
ydl.download([yt_url])
|
| 150 |
except youtube_dl.utils.ExtractorError as err:
|
| 151 |
+
raise HTTPException(status_code=400, detail=str(err))
|
| 152 |
+
|
| 153 |
+
def format_timestamp(seconds: float, always_include_hours: bool = False, decimal_marker: str = "."):
|
| 154 |
+
if seconds is not None:
|
| 155 |
+
milliseconds = round(seconds * 1000.0)
|
| 156 |
+
|
| 157 |
+
hours = milliseconds // 3_600_000
|
| 158 |
+
milliseconds -= hours * 3_600_000
|
| 159 |
+
|
| 160 |
+
minutes = milliseconds // 60_000
|
| 161 |
+
milliseconds -= minutes * 60_000
|
| 162 |
+
|
| 163 |
+
seconds = milliseconds // 1_000
|
| 164 |
+
milliseconds -= seconds * 1_000
|
| 165 |
+
|
| 166 |
+
hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
|
| 167 |
+
return f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
|
| 168 |
+
else:
|
| 169 |
+
# we have a malformed timestamp so just return it as is
|
| 170 |
+
return seconds
|