tts_labeling / utils /gdrive_downloader.py
Navid Arabi
add audio player
7a295c7
raw
history blame
3.48 kB
# gdrive_downloader.py
from __future__ import annotations
import io
import re
import numpy as np
from pydub import AudioSegment
from googleapiclient.discovery import build
from googleapiclient.http import MediaIoBaseDownload
def extract_folder_id(url_or_id: str) -> str:
"""
اگر کاربر لینک فولدر بدهد ← ID را برمی‌گرداند.
اگر خودش ID باشد همان را برمی‌گرداند.
"""
s = url_or_id.strip()
if "/" not in s and "?" not in s:
return s # احتمالاً خودش ID است
m = re.search(r"/folders/([a-zA-Z0-9_-]{10,})", s)
if not m:
raise ValueError("Cannot extract folder id from url")
return m.group(1)
class PublicFolderAudioLoader:
"""
دانلودر فایل صوتی از فولدر عمومی گوگل‌درایو بدون ذخیره روی دیسک.
Parameters
----------
api_key : str
Google API Key (کیِ عمومی؛ نه OAuth, نه سرویس‌اکانت).
"""
def __init__(self, api_key: str) -> None:
self.svc = build("drive", "v3", developerKey=api_key, cache_discovery=False)
# ---------- helpers ---------- #
def _file_id_by_name(self, folder_id: str, filename: str) -> str:
q = (
f"'{folder_id}' in parents "
f"and name = '{filename}' "
f"and trashed = false"
)
rsp = (
self.svc.files()
.list(q=q, fields="files(id,name)", pageSize=5, supportsAllDrives=True)
.execute()
)
files = rsp.get("files", [])
if not files:
raise FileNotFoundError(f"'{filename}' not found in folder {folder_id}")
return files[0]["id"]
def _download_to_buf(self, file_id: str) -> io.BytesIO:
request = self.svc.files().get_media(fileId=file_id, supportsAllDrives=True)
buf = io.BytesIO()
downloader = MediaIoBaseDownload(buf, request)
done = False
while not done:
_, done = downloader.next_chunk()
buf.seek(0)
return buf
# ---------- public ---------- #
def load_audio(
self,
folder_url_or_id: str,
filename: str,
) -> tuple[int, np.ndarray]:
# """
# فایل را به `(sample_rate, np.ndarray)` نرمال‌شده در بازه‌ی [-1,1] تبدیل می‌کند.
# """
folder_id = extract_folder_id(folder_url_or_id)
file_id = self._file_id_by_name(folder_id, filename)
buf = self._download_to_buf(file_id)
seg = AudioSegment.from_file(buf)
samples = np.array(seg.get_array_of_samples())
# اگر چندکاناله بود، شکل دهیم
if seg.channels > 1:
samples = samples.reshape(-1, seg.channels)
# ---------------------- نرمال‌سازی ----------------------
if np.issubdtype(samples.dtype, np.integer):
max_int = np.iinfo(samples.dtype).max # ← قبل از cast
samples = samples.astype(np.float32)
samples /= max_int # ← از max_int استفاده می‌کنیم
else:
# در حالت float
max_val = np.abs(samples).max()
if max_val > 1:
samples = samples / max_val
samples = samples.astype(np.float32)
# --------------------------------------------------------
return seg.frame_rate, samples