Sofia Casadei
commited on
Commit
Β·
437ed2e
1
Parent(s):
7338a56
fix: use hf-cloudflare turn server
Browse files
main.py
CHANGED
|
@@ -29,8 +29,6 @@ from transformers.utils import is_flash_attn_2_available
|
|
| 29 |
|
| 30 |
from utils.logger_config import setup_logging
|
| 31 |
from utils.device import get_device, get_torch_and_np_dtypes
|
| 32 |
-
from utils.turn_server import get_credential_function, get_rtc_credentials
|
| 33 |
-
|
| 34 |
|
| 35 |
load_dotenv()
|
| 36 |
setup_logging()
|
|
@@ -40,7 +38,6 @@ logger = logging.getLogger(__name__)
|
|
| 40 |
UI_MODE = os.getenv("UI_MODE", "fastapi").lower() # gradio | fastapi
|
| 41 |
UI_TYPE = os.getenv("UI_TYPE", "base").lower() # base | screen
|
| 42 |
APP_MODE = os.getenv("APP_MODE", "local").lower() # local | deployed
|
| 43 |
-
TURN_SERVER_PROVIDER = os.getenv("TURN_SERVER_PROVIDER", "hf-cloudflare").lower() # hf-cloudflare | cloudflare | hf | twilio
|
| 44 |
MODEL_ID = os.getenv("MODEL_ID", "openai/whisper-large-v3-turbo")
|
| 45 |
LANGUAGE = os.getenv("LANGUAGE", "english").lower()
|
| 46 |
|
|
@@ -48,7 +45,6 @@ device = get_device(force_cpu=False)
|
|
| 48 |
torch_dtype, np_dtype = get_torch_and_np_dtypes(device, use_bfloat16=False)
|
| 49 |
logger.info(f"Using device: {device}, torch_dtype: {torch_dtype}, np_dtype: {np_dtype}")
|
| 50 |
|
| 51 |
-
|
| 52 |
attention = "flash_attention_2" if is_flash_attn_2_available() else "sdpa"
|
| 53 |
logger.info(f"Using attention: {attention}")
|
| 54 |
|
|
@@ -87,7 +83,6 @@ warmup_audio = np.zeros((16000,), dtype=np_dtype) # 1s of silence
|
|
| 87 |
transcribe_pipeline(warmup_audio)
|
| 88 |
logger.info("Model warmup complete")
|
| 89 |
|
| 90 |
-
|
| 91 |
async def transcribe(audio: tuple[int, np.ndarray]):
|
| 92 |
sample_rate, audio_array = audio
|
| 93 |
logger.info(f"Sample rate: {sample_rate}Hz, Shape: {audio_array.shape}")
|
|
@@ -104,11 +99,6 @@ async def transcribe(audio: tuple[int, np.ndarray]):
|
|
| 104 |
)
|
| 105 |
yield AdditionalOutputs(outputs["text"].strip())
|
| 106 |
|
| 107 |
-
async def get_credentials():
|
| 108 |
-
return await get_cloudflare_turn_credentials_async(hf_token=os.getenv("HF_TOKEN"))
|
| 109 |
-
|
| 110 |
-
server_credentials = get_cloudflare_turn_credentials(ttl=360_000) if APP_MODE == "deployed" else None
|
| 111 |
-
|
| 112 |
logger.info("Initializing FastRTC stream")
|
| 113 |
stream = Stream(
|
| 114 |
handler=ReplyOnPause(
|
|
@@ -146,8 +136,7 @@ stream = Stream(
|
|
| 146 |
gr.Textbox(label="Transcript"),
|
| 147 |
],
|
| 148 |
additional_outputs_handler=lambda current, new: current + " " + new,
|
| 149 |
-
rtc_configuration=
|
| 150 |
-
server_rtc_configuration=server_credentials,
|
| 151 |
concurrency_limit=6
|
| 152 |
)
|
| 153 |
|
|
@@ -162,8 +151,9 @@ async def index():
|
|
| 162 |
elif UI_TYPE == "screen":
|
| 163 |
html_content = open("static/index-screen.html").read()
|
| 164 |
|
| 165 |
-
|
| 166 |
-
|
|
|
|
| 167 |
|
| 168 |
@app.get("/transcript")
|
| 169 |
def _(webrtc_id: str):
|
|
|
|
| 29 |
|
| 30 |
from utils.logger_config import setup_logging
|
| 31 |
from utils.device import get_device, get_torch_and_np_dtypes
|
|
|
|
|
|
|
| 32 |
|
| 33 |
load_dotenv()
|
| 34 |
setup_logging()
|
|
|
|
| 38 |
UI_MODE = os.getenv("UI_MODE", "fastapi").lower() # gradio | fastapi
|
| 39 |
UI_TYPE = os.getenv("UI_TYPE", "base").lower() # base | screen
|
| 40 |
APP_MODE = os.getenv("APP_MODE", "local").lower() # local | deployed
|
|
|
|
| 41 |
MODEL_ID = os.getenv("MODEL_ID", "openai/whisper-large-v3-turbo")
|
| 42 |
LANGUAGE = os.getenv("LANGUAGE", "english").lower()
|
| 43 |
|
|
|
|
| 45 |
torch_dtype, np_dtype = get_torch_and_np_dtypes(device, use_bfloat16=False)
|
| 46 |
logger.info(f"Using device: {device}, torch_dtype: {torch_dtype}, np_dtype: {np_dtype}")
|
| 47 |
|
|
|
|
| 48 |
attention = "flash_attention_2" if is_flash_attn_2_available() else "sdpa"
|
| 49 |
logger.info(f"Using attention: {attention}")
|
| 50 |
|
|
|
|
| 83 |
transcribe_pipeline(warmup_audio)
|
| 84 |
logger.info("Model warmup complete")
|
| 85 |
|
|
|
|
| 86 |
async def transcribe(audio: tuple[int, np.ndarray]):
|
| 87 |
sample_rate, audio_array = audio
|
| 88 |
logger.info(f"Sample rate: {sample_rate}Hz, Shape: {audio_array.shape}")
|
|
|
|
| 99 |
)
|
| 100 |
yield AdditionalOutputs(outputs["text"].strip())
|
| 101 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
logger.info("Initializing FastRTC stream")
|
| 103 |
stream = Stream(
|
| 104 |
handler=ReplyOnPause(
|
|
|
|
| 136 |
gr.Textbox(label="Transcript"),
|
| 137 |
],
|
| 138 |
additional_outputs_handler=lambda current, new: current + " " + new,
|
| 139 |
+
rtc_configuration=get_cloudflare_turn_credentials_async(hf_token=os.getenv("HF_TOKEN")) if APP_MODE == "deployed" else None,
|
|
|
|
| 140 |
concurrency_limit=6
|
| 141 |
)
|
| 142 |
|
|
|
|
| 151 |
elif UI_TYPE == "screen":
|
| 152 |
html_content = open("static/index-screen.html").read()
|
| 153 |
|
| 154 |
+
rtc_configuration = get_cloudflare_turn_credentials_async(hf_token=os.getenv("HF_TOKEN")) if APP_MODE == "deployed" else None
|
| 155 |
+
html_content = html_content.replace("__RTC_CONFIGURATION__", json.dumps(rtc_configuration))
|
| 156 |
+
return HTMLResponse(content=html_content)
|
| 157 |
|
| 158 |
@app.get("/transcript")
|
| 159 |
def _(webrtc_id: str):
|