Princeaka's picture
Update app.py
052fcb6 verified
import os
import shutil
import asyncio
import gradio as gr
from multimodal_module import MultiModalChatModule
# Initialize module
mm = MultiModalChatModule()
# Environment configuration (already safe but keep)
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # Disable GPU
os.environ["IMAGEIO_FFMPEG_EXE"] = "/usr/bin/ffmpeg" # Explicit path
os.environ["FFMPEG_BINARY"] = "/usr/bin/ffmpeg" # Backup for older versions
# A tiny async-compatible "file-like" wrapper so your multimodal_module methods
# (which expect objects with an async download_to_drive(...) method) work
class AsyncPathWrapper:
def __init__(self, path: str):
self.path = path
async def download_to_drive(self, dst_path: str):
# perform copy synchronously but keep API async
try:
os.makedirs(os.path.dirname(dst_path), exist_ok=True)
shutil.copy(self.path, dst_path)
except Exception as e:
# raise to allow upper-level error handling
raise
# Helper to call async methods from sync Gradio callbacks
def run_async(fn, *args, **kwargs):
return asyncio.run(fn(*args, **kwargs))
# Wrappers that adapt Gradio returned file paths to the module's expected interface
def _wrap_audio(audio_path):
if not audio_path:
return None
return AsyncPathWrapper(audio_path)
def _wrap_image(image_path):
if not image_path:
return None
return AsyncPathWrapper(image_path)
def _wrap_file(file_path):
if not file_path:
return None
return AsyncPathWrapper(file_path)
# Gradio binding functions
def process_voice(audio_filepath, user_id):
# mm.process_voice_message expects an object with download_to_drive
wrapped = _wrap_audio(audio_filepath)
return run_async(mm.process_voice_message, wrapped, int(user_id))
def process_image(image_filepath, user_id):
wrapped = _wrap_image(image_filepath)
return run_async(mm.process_image_message, wrapped, int(user_id))
def chat(text, user_id, lang):
return run_async(mm.generate_response, text, int(user_id), lang)
def generate_image(prompt, user_id):
return run_async(mm.generate_image_from_text, prompt, int(user_id))
def process_file(file_path, user_id):
wrapped = _wrap_file(file_path)
return run_async(mm.process_file, wrapped, int(user_id))
with gr.Blocks(title="Multimodal AI Assistant") as app:
gr.Markdown("## πŸš€ Multimodal AI Assistant (Space-friendly)")
with gr.Tab("πŸ’¬ Text Chat"):
with gr.Row():
user_id_txt = gr.Textbox(label="User ID", value="123")
lang = gr.Dropdown(["en", "es", "fr", "de"], label="Language", value="en")
chat_input = gr.Textbox(label="Your Message")
chat_output = gr.Textbox(label="AI Response", interactive=False)
chat_btn = gr.Button("Send")
chat_btn.click(fn=chat, inputs=[chat_input, user_id_txt, lang], outputs=chat_output)
with gr.Tab("πŸŽ™οΈ Voice"):
voice_input = gr.Audio(source="microphone", type="filepath", label="Speak or upload an audio file")
voice_user = gr.Textbox(label="User ID", value="123")
voice_output = gr.JSON(label="Analysis Results")
voice_btn = gr.Button("Process")
voice_btn.click(fn=process_voice, inputs=[voice_input, voice_user], outputs=voice_output)
with gr.Tab("πŸ–ΌοΈ Images"):
with gr.Tab("Describe"):
img_input = gr.Image(type="filepath", label="Upload an image")
img_user = gr.Textbox(label="User ID", value="123")
img_output = gr.Textbox(label="Description")
img_btn = gr.Button("Describe")
img_btn.click(fn=process_image, inputs=[img_input, img_user], outputs=img_output)
with gr.Tab("Generate"):
gen_prompt = gr.Textbox(label="Prompt")
gen_user = gr.Textbox(label="User ID", value="123")
gen_output = gr.Image(label="Generated Image")
gen_btn = gr.Button("Generate")
gen_btn.click(fn=generate_image, inputs=[gen_prompt, gen_user], outputs=gen_output)
with gr.Tab("πŸ“„ Files"):
file_input = gr.File(file_count="single", label="Upload a document (pdf, txt, docx)")
file_user = gr.Textbox(label="User ID", value="123")
file_output = gr.JSON(label="File Processing Result")
file_btn = gr.Button("Process File")
file_btn.click(fn=process_file, inputs=[file_input, file_user], outputs=file_output)
if __name__ == "__main__":
# Let Spaces manage server settings. This still works locally.
app.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))