SY23 / app.py
brightlembo's picture
Update app.py
c3a1adc verified
raw
history blame
6.66 kB
import gradio as gr
import torch
from transformers import (
BlipProcessor,
BlipForQuestionAnswering,
pipeline
)
from modelscope.pipelines import pipeline as ms_pipeline
from PIL import Image
import os
import logging
import tempfile
import shutil
import atexit
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class TempFileManager:
def __init__(self):
self.temp_dir = tempfile.mkdtemp(prefix='multimodal_app_')
atexit.register(self.cleanup)
def get_path(self, filename):
return os.path.join(self.temp_dir, filename)
def cleanup(self):
try:
if os.path.exists(self.temp_dir):
shutil.rmtree(self.temp_dir, ignore_errors=True)
except Exception as e:
logger.error(f"Erreur lors du nettoyage des fichiers temporaires: {str(e)}")
class MultimodalProcessor:
def __init__(self):
self.temp_manager = TempFileManager()
self.load_models()
def load_models(self):
"""Charge les modèles"""
try:
logger.info("Chargement des modèles...")
self.blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
self.blip_model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base")
self.audio_transcriber = pipeline("automatic-speech-recognition",
model="openai/whisper-base")
self.video_pipeline = ms_pipeline(
'text-to-video-synthesis',
model='damo/text-to-video-synthesis'
)
logger.info("Modèles chargés avec succès")
except Exception as e:
logger.error(f"Erreur lors du chargement des modèles: {str(e)}")
raise
def analyze_image(self, image):
"""Analyse une image avec BLIP"""
if image is None:
return ""
try:
questions = [
"What is in the picture?",
"What are the main colors?",
"What is the setting or background?"
]
responses = {}
for question in questions:
inputs = self.blip_processor(images=image, text=question, return_tensors="pt")
outputs = self.blip_model.generate(**inputs)
answer = self.blip_processor.decode(outputs[0], skip_special_tokens=True)
responses[question] = answer
description = (
f"This image shows {responses['What is in the picture?']}. "
f"The main colors are {responses['What are the main colors?']}. "
f"The setting is {responses['What is the setting or background?']}."
)
return description
except Exception as e:
logger.error(f"Erreur lors de l'analyse de l'image: {str(e)}")
return "Erreur lors de l'analyse de l'image."
def transcribe_audio(self, audio_path):
"""Transcrit un fichier audio avec Whisper"""
if audio_path is None:
return ""
try:
return self.audio_transcriber(audio_path)["text"]
except Exception as e:
logger.error(f"Erreur lors de la transcription audio: {str(e)}")
return "Erreur lors de la transcription audio."
def generate_video(self, prompt):
"""Génère une vidéo avec ModelScope"""
if not prompt:
return None
try:
output_path = self.temp_manager.get_path("output.mp4")
result = self.video_pipeline({
'text': prompt,
'output_path': output_path
})
if not os.path.exists(output_path):
raise Exception("La vidéo n'a pas été générée correctement")
# Copie la vidéo vers un emplacement permanent si nécessaire
permanent_path = f"outputs/video_{hash(prompt)}.mp4"
os.makedirs(os.path.dirname(permanent_path), exist_ok=True)
shutil.copy2(output_path, permanent_path)
return permanent_path
except Exception as e:
logger.error(f"Erreur lors de la génération de vidéo: {str(e)}")
return None
def process_inputs(self, image, audio, text):
"""Traite les entrées multimodales"""
try:
combined_parts = []
if image is not None:
image_desc = self.analyze_image(image)
if image_desc:
combined_parts.append(f"Scene: {image_desc}")
if audio is not None:
audio_text = self.transcribe_audio(audio)
if audio_text:
combined_parts.append(f"Audio narration: {audio_text}")
if text:
combined_parts.append(f"Additional context: {text}")
final_prompt = " ".join(combined_parts) if combined_parts else "Empty scene with neutral background"
output_video = self.generate_video(final_prompt)
return output_video, final_prompt
except Exception as e:
logger.error(f"Erreur lors du traitement des entrées: {str(e)}")
return None, "Une erreur est survenue lors du traitement des entrées."
finally:
# Nettoyage explicite des fichiers temporaires après chaque traitement
self.temp_manager.cleanup()
def create_interface():
"""Crée l'interface Gradio"""
processor = MultimodalProcessor()
interface = gr.Interface(
fn=processor.process_inputs,
inputs=[
gr.Image(type="pil", label="Télécharger une image"),
gr.Audio(type="filepath", label="Télécharger un fichier audio"),
gr.Textbox(label="Entrez du texte additionnel")
],
outputs=[
gr.Video(label="Vidéo générée"),
gr.Textbox(label="Description utilisée")
],
title="Générateur de Vidéo Multimodal",
description="""
Téléchargez une image, un fichier audio et/ou ajoutez du texte.
L'application va:
1. Analyser l'image pour en extraire une description
2. Transcrire l'audio en texte
3. Combiner ces éléments avec votre texte
4. Générer une vidéo basée sur la description combinée
"""
)
return interface
if __name__ == "__main__":
interface = create_interface()
interface.launch()