Spaces:
Running
on
Zero
Running
on
Zero
from threading import Thread | |
import gradio as gr | |
import spaces | |
import torch | |
from transformers import AutoModelForImageTextToText, AutoProcessor | |
from transformers.generation.streamers import TextIteratorStreamer | |
BASE_GEMMA_MODEL_ID = "google/gemma-3n-E2B-it" | |
GEMMA_MODEL_ID = "bilguun/gemma-3n-E2B-it-audio-en-mn" | |
print("Loading processor and model...") | |
processor = AutoProcessor.from_pretrained(BASE_GEMMA_MODEL_ID) | |
model = AutoModelForImageTextToText.from_pretrained( | |
GEMMA_MODEL_ID, torch_dtype=torch.bfloat16, device_map="auto" | |
) | |
print("Model loaded successfully!") | |
def process_audio(audio_file, prompt_type, custom_prompt, max_tokens): | |
if audio_file is None: | |
return "Please upload an audio file." | |
prompts = { | |
"Transcribe": "Transcribe this audio.", | |
"Transcribe EN to MN": "Transcribe this audio into English and translate into Mongolian.", | |
"Transcribe MN to EN": "Transcribe this audio into Mongolian and translate into English.", | |
} | |
if prompt_type == "Custom": | |
if not custom_prompt.strip(): | |
return "Please provide a custom prompt." | |
selected_prompt = custom_prompt.strip() | |
else: | |
selected_prompt = prompts[prompt_type] | |
messages = [ | |
{ | |
"role": "user", | |
"content": [ | |
{"type": "audio", "audio": audio_file}, | |
{"type": "text", "text": selected_prompt}, | |
], | |
} | |
] | |
input_ids = processor.apply_chat_template( | |
messages, | |
add_generation_prompt=True, | |
tokenize=True, | |
return_dict=True, | |
return_tensors="pt", | |
) | |
input_ids = input_ids.to(model.device, dtype=model.dtype) | |
streamer = TextIteratorStreamer( | |
processor, timeout=30.0, skip_prompt=True, skip_special_tokens=True | |
) | |
generate_kwargs = dict( | |
input_ids, | |
streamer=streamer, | |
max_new_tokens=max_tokens, | |
disable_compile=True, | |
) | |
t = Thread(target=model.generate, kwargs=generate_kwargs) | |
t.start() | |
output = "" | |
for delta in streamer: | |
output += delta | |
yield output | |
with gr.Blocks(title="Gemma 3n Audio Transcription & Translation") as demo: | |
gr.Markdown("# Gemma 3n E2B - English-Mongolian Audio Transcription & Translation") | |
gr.Markdown( | |
"Upload an audio file and select the processing type to get transcription and/or translation." | |
) | |
with gr.Row(): | |
with gr.Column(): | |
audio_input = gr.Audio( | |
label="Audio", | |
type="filepath", | |
sources=["upload", "microphone"], | |
max_length=300, | |
) | |
prompt_dropdown = gr.Dropdown( | |
choices=["Transcribe", "Transcribe EN to MN", "Transcribe MN to EN", "Custom"], | |
value="Transcribe", | |
label="Prompt Type", | |
) | |
custom_prompt_input = gr.Textbox( | |
label="Custom Prompt", | |
placeholder="Enter your custom prompt here...", | |
lines=2, | |
visible=False, | |
) | |
process_btn = gr.Button("Process Audio", variant="primary") | |
with gr.Column(): | |
output_text = gr.Textbox( | |
label="Generated Output", | |
lines=10, | |
max_lines=20, | |
placeholder="Transcribed text will appear here...", | |
show_copy_button=True, | |
interactive=False, | |
) | |
with gr.Row(): | |
with gr.Accordion("Additional Settings", open=False): | |
max_tokens_slider = gr.Slider( | |
minimum=16, | |
maximum=512, | |
value=128, | |
step=16, | |
label="Max New Tokens", | |
info="Maximum number of tokens to generate", | |
) | |
def update_custom_prompt_visibility(prompt_type): | |
return gr.update(visible=prompt_type == "Custom") | |
prompt_dropdown.change( | |
fn=update_custom_prompt_visibility, | |
inputs=prompt_dropdown, | |
outputs=custom_prompt_input, | |
) | |
process_btn.click( | |
fn=process_audio, | |
inputs=[audio_input, prompt_dropdown, custom_prompt_input, max_tokens_slider], | |
outputs=output_text, | |
) | |
gr.Examples( | |
examples=[ | |
[ | |
"https://github.com/bilguun0203/gemma3n-audio-mn/raw/refs/heads/main/audio_samples/en1.wav", | |
"Transcribe", | |
"", | |
128, | |
], | |
[ | |
"https://github.com/bilguun0203/gemma3n-audio-mn/raw/refs/heads/main/audio_samples/en3.wav", | |
"Transcribe EN to MN", | |
"", | |
128, | |
], | |
[ | |
"https://github.com/bilguun0203/gemma3n-audio-mn/raw/refs/heads/main/audio_samples/mn2.wav", | |
"Transcribe", | |
"", | |
128, | |
], | |
[ | |
"https://github.com/bilguun0203/gemma3n-audio-mn/raw/refs/heads/main/audio_samples/mn2.wav", | |
"Transcribe MN to EN", | |
"", | |
128, | |
], | |
], | |
inputs=[ | |
audio_input, | |
prompt_dropdown, | |
custom_prompt_input, | |
max_tokens_slider, | |
], | |
outputs=output_text, | |
fn=process_audio, | |
cache_examples=True, | |
cache_mode="eager", # Cache examples eagerly for model warmup | |
label="Example Audio Files", | |
) | |
if __name__ == "__main__": | |
demo.launch() | |