Spaces:
Running
Running
# multimodal_module.py | |
import os | |
import pickle | |
import subprocess | |
import tempfile | |
import shutil | |
import asyncio | |
from datetime import datetime | |
from typing import Dict, List, Optional, Any | |
import io | |
import uuid | |
# Core ML libs | |
import torch | |
from transformers import ( | |
pipeline, | |
AutoModelForSeq2SeqLM, | |
AutoTokenizer, | |
Wav2Vec2Processor, | |
Wav2Vec2ForSequenceClassification, | |
) | |
from diffusers import StableDiffusionPipeline, StableDiffusionInpaintPipeline | |
from transformers import AutoModelForCausalLM, AutoTokenizer as HFTokenizer | |
# Audio / speech | |
import librosa | |
import speech_recognition as sr | |
from gtts import gTTS | |
# Image, video, files | |
from PIL import Image, ImageOps | |
import imageio_ffmpeg as ffmpeg | |
import imageio | |
import moviepy.editor as mp | |
import fitz # PyMuPDF for PDFs | |
# Misc | |
from langdetect import DetectorFactory | |
DetectorFactory.seed = 0 | |
# Optional: safety-check toggles | |
USE_SAFETY_CHECKER = False | |
# Helper for temp files | |
def _tmp_path(suffix=""): | |
return os.path.join(tempfile.gettempdir(), f"{uuid.uuid4().hex}{suffix}") | |
class MultiModalChatModule: | |
""" | |
Full-power multimodal module. | |
- Lazy-loads big models on first use. | |
- Methods are async-friendly. | |
""" | |
def __init__(self, chat_history_file: str = "chat_histories.pkl"): | |
self.user_chat_histories: Dict[int, List[dict]] = self._load_chat_histories(chat_history_file) | |
self.chat_history_file = chat_history_file | |
# device | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"[MultiModal] device: {self.device}") | |
# placeholders for large models (lazy) | |
self._voice_processor = None | |
self._voice_emotion_model = None | |
self._translator = None | |
self._chat_tokenizer = None | |
self._chat_model = None | |
self._chat_model_name = "bigscience/bloom" # placeholder; will set proper below | |
self._image_captioner = None | |
self._sd_pipe = None | |
self._sd_inpaint = None | |
self._code_tokenizer = None | |
self._code_model = None | |
# other small helpers | |
self._sr_recognizer = sr.Recognizer() | |
# set common model names (you can change) | |
self.model_names = { | |
"voice_emotion_processor": "facebook/hubert-large-ls960-ft", | |
"voice_emotion_model": "superb/hubert-base-superb-er", | |
"translation_model": "facebook/nllb-200-distilled-600M", | |
"chatbot_tokenizer": "facebook/blenderbot-400M-distill", | |
"chatbot_model": "facebook/blenderbot-400M-distill", | |
"image_captioner": "Salesforce/blip-image-captioning-base", | |
"sd_inpaint": "runwayml/stable-diffusion-inpainting", | |
"sd_text2img": "runwayml/stable-diffusion-v1-5", | |
"code_model": "bigcode/starcoder", # Or use a specific StarCoder checkpoint on HF | |
} | |
# keep track of which heavy groups are loaded | |
self._loaded = { | |
"voice": False, | |
"translation": False, | |
"chat": False, | |
"image_caption": False, | |
"sd": False, | |
"code": False, | |
} | |
# ---------------------- | |
# persistence | |
# ---------------------- | |
def _load_chat_histories(self, fn: str) -> Dict[int, List[dict]]: | |
try: | |
with open(fn, "rb") as f: | |
return pickle.load(f) | |
except Exception: | |
return {} | |
def _save_chat_histories(self): | |
try: | |
with open(self.chat_history_file, "wb") as f: | |
pickle.dump(self.user_chat_histories, f) | |
except Exception as e: | |
print("[MultiModal] Warning: failed to save chat histories:", e) | |
# ---------------------- | |
# Lazy loaders | |
# ---------------------- | |
def _load_voice_models(self): | |
if self._loaded["voice"]: | |
return | |
print("[MultiModal] Loading voice/emotion models...") | |
self._voice_processor = Wav2Vec2Processor.from_pretrained(self.model_names["voice_emotion_processor"]) | |
self._voice_emotion_model = Wav2Vec2ForSequenceClassification.from_pretrained(self.model_names["voice_emotion_model"]).to(self.device) | |
self._loaded["voice"] = True | |
print("[MultiModal] Voice models loaded.") | |
def _load_translation(self): | |
if self._loaded["translation"]: | |
return | |
print("[MultiModal] Loading translation pipeline...") | |
device_idx = 0 if self.device == "cuda" else -1 | |
self._translator = pipeline("translation", model=self.model_names["translation_model"], device=device_idx) | |
self._loaded["translation"] = True | |
print("[MultiModal] Translation loaded.") | |
def _load_chatbot(self): | |
if self._loaded["chat"]: | |
return | |
print("[MultiModal] Loading chatbot model...") | |
# chatbot: keep current blenderbot to preserve behaviour | |
self._chat_tokenizer = AutoTokenizer.from_pretrained(self.model_names["chatbot_tokenizer"]) | |
self._chat_model = AutoModelForSeq2SeqLM.from_pretrained(self.model_names["chatbot_model"]).to(self.device) | |
self._loaded["chat"] = True | |
print("[MultiModal] Chatbot loaded.") | |
def _load_image_captioner(self): | |
if self._loaded["image_caption"]: | |
return | |
print("[MultiModal] Loading image captioner...") | |
device_idx = 0 if self.device == "cuda" else -1 | |
self._image_captioner = pipeline("image-to-text", model=self.model_names["image_captioner"], device=device_idx) | |
self._loaded["image_caption"] = True | |
print("[MultiModal] Image captioner loaded.") | |
def _load_sd(self): | |
if self._loaded["sd"]: | |
return | |
print("[MultiModal] Loading Stable Diffusion pipelines...") | |
# text2img | |
sd_model = self.model_names["sd_text2img"] | |
sd_inpaint_model = self.model_names["sd_inpaint"] | |
# Use float16 on GPU for speed | |
torch_dtype = torch.float16 if self.device == "cuda" else torch.float32 | |
try: | |
self._sd_pipe = StableDiffusionPipeline.from_pretrained(sd_model, torch_dtype=torch_dtype) | |
self._sd_pipe = self._sd_pipe.to(self.device) | |
except Exception as e: | |
print("[MultiModal] Warning loading text2img:", e) | |
self._sd_pipe = None | |
try: | |
self._sd_inpaint = StableDiffusionInpaintPipeline.from_pretrained(sd_inpaint_model, torch_dtype=torch_dtype) | |
self._sd_inpaint = self._sd_inpaint.to(self.device) | |
except Exception as e: | |
print("[MultiModal] Warning loading inpaint:", e) | |
self._sd_inpaint = None | |
self._loaded["sd"] = True | |
print("[MultiModal] Stable Diffusion loaded (where possible).") | |
def _load_code_model(self): | |
if self._loaded["code"]: | |
return | |
print("[MultiModal] Loading code model...") | |
# StarCoder style model (may require HF_TOKEN or large memory) | |
try: | |
self._code_tokenizer = HFTokenizer.from_pretrained(self.model_names["code_model"]) | |
self._code_model = AutoModelForCausalLM.from_pretrained(self.model_names["code_model"]).to(self.device) | |
self._loaded["code"] = True | |
print("[MultiModal] Code model loaded.") | |
except Exception as e: | |
print("[MultiModal] Warning: could not load code model:", e) | |
self._code_tokenizer = None | |
self._code_model = None | |
# ---------------------- | |
# Voice: analyze emotion, transcribe | |
# ---------------------- | |
async def analyze_voice_emotion(self, audio_path: str) -> str: | |
self._load_voice_models() | |
speech, sr_ = librosa.load(audio_path, sr=16000) | |
inputs = self._voice_processor(speech, sampling_rate=sr_, return_tensors="pt", padding=True).to(self.device) | |
with torch.no_grad(): | |
logits = self._voice_emotion_model(**inputs).logits | |
predicted_class = torch.argmax(logits).item() | |
return { | |
0: "😊 Happy", | |
1: "😢 Sad", | |
2: "😠 Angry", | |
3: "😨 Fearful", | |
4: "😌 Calm", | |
5: "😲 Surprised", | |
}.get(predicted_class, "🤔 Unknown") | |
async def process_voice_message(self, voice_file, user_id: int) -> dict: | |
""" | |
voice_file: Starlette UploadFile or object with get_file() used previously in your code. | |
Returns: {text, language, emotion} | |
""" | |
# Save OGG locally | |
ogg_path = _tmp_path(".ogg") | |
wav_path = _tmp_path(".wav") | |
tf = await voice_file.get_file() | |
await tf.download_to_drive(ogg_path) | |
# Convert to WAV via ffmpeg | |
try: | |
ffmpeg_path = ffmpeg.get_ffmpeg_exe() | |
subprocess.run([ffmpeg_path, "-y", "-i", ogg_path, wav_path], check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) | |
except Exception as e: | |
# fallback: try ffmpeg in PATH | |
try: | |
subprocess.run(["ffmpeg", "-y", "-i", ogg_path, wav_path], check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) | |
except Exception as ee: | |
raise RuntimeError(f"ffmpeg conversion failed: {e} / {ee}") | |
# Transcribe using SpeechRecognition Google STT (as before) -- or you can integrate whisper | |
recognizer = self._sr_recognizer | |
with sr.AudioFile(wav_path) as source: | |
audio = recognizer.record(source) | |
detected_lang = None | |
detected_text = "" | |
# tried languages set | |
lang_map = { | |
"zh": {"stt": "zh-CN"}, | |
"ja": {"stt": "ja-JP"}, | |
"ko": {"stt": "ko-KR"}, | |
"en": {"stt": "en-US"}, | |
"es": {"stt": "es-ES"}, | |
"fr": {"stt": "fr-FR"}, | |
"de": {"stt": "de-DE"}, | |
"it": {"stt": "it-IT"}, | |
} | |
for lang_code, lang_data in lang_map.items(): | |
try: | |
detected_text = recognizer.recognize_google(audio, language=lang_data["stt"]) | |
detected_lang = lang_code | |
break | |
except sr.UnknownValueError: | |
continue | |
except Exception: | |
continue | |
if not detected_lang: | |
# If not recognized, try fallback: detect from small chunk via langdetect | |
detected_lang = "en" | |
detected_text = "" | |
# emotion | |
emotion = await self.analyze_voice_emotion(wav_path) | |
# remove temp files | |
try: | |
os.remove(ogg_path) | |
os.remove(wav_path) | |
except Exception: | |
pass | |
return {"text": detected_text, "language": detected_lang, "emotion": emotion} | |
# ---------------------- | |
# Text chat with translation & history | |
# ---------------------- | |
async def generate_response(self, text: str, user_id: int, lang: str = "en") -> str: | |
# Ensure chat model loaded | |
self._load_chatbot() | |
self._load_translation() | |
if user_id not in self.user_chat_histories: | |
self.user_chat_histories[user_id] = [] | |
self.user_chat_histories[user_id].append({"timestamp": datetime.now().isoformat(), "role": "user", "text": text, "language": lang}) | |
self.user_chat_histories[user_id] = self.user_chat_histories[user_id][-100:] | |
self._save_chat_histories() | |
# Build context: translate last few msgs to English for consistency | |
context_texts = [] | |
for msg in self.user_chat_histories[user_id][-5:]: | |
if msg.get("language", "en") != "en": | |
try: | |
translated = self._translator(msg["text"])[0]["translation_text"] | |
except Exception: | |
translated = msg["text"] | |
else: | |
translated = msg["text"] | |
context_texts.append(f"{msg['role']}: {translated}") | |
context = "\n".join(context_texts) | |
input_text = f"Context:\n{context}\nUser: {text if lang == 'en' else context_texts[-1].split(': ', 1)[1]}" | |
# Tokenize + generate | |
inputs = self._chat_tokenizer.encode(input_text, return_tensors="pt").to(self.device) | |
outputs = self._chat_model.generate(inputs, max_length=1000) | |
response_en = self._chat_tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Translate back to user's language if needed | |
if lang != "en": | |
try: | |
response = self._translator(response_en)[0]["translation_text"] | |
except Exception: | |
response = response_en | |
else: | |
response = response_en | |
self.user_chat_histories[user_id].append({"timestamp": datetime.now().isoformat(), "role": "bot", "text": response, "language": lang}) | |
self._save_chat_histories() | |
return response | |
# ---------------------- | |
# Image captioning (existing) | |
# ---------------------- | |
async def process_image_message(self, image_file, user_id: int) -> str: | |
# Save image | |
img_path = _tmp_path(".jpg") | |
tf = await image_file.get_file() | |
await tf.download_to_drive(img_path) | |
# load captioner | |
self._load_image_captioner() | |
try: | |
image = Image.open(img_path).convert("RGB") | |
description = self._image_captioner(image)[0]["generated_text"] | |
except Exception as e: | |
description = f"[Error generating caption: {e}]" | |
# cleanup | |
try: | |
os.remove(img_path) | |
except Exception: | |
pass | |
# store in history | |
if user_id not in self.user_chat_histories: | |
self.user_chat_histories[user_id] = [] | |
self.user_chat_histories[user_id].append({"timestamp": datetime.now().isoformat(), "role": "user", "text": "[Image]", "language": "en"}) | |
self.user_chat_histories[user_id].append({"timestamp": datetime.now().isoformat(), "role": "bot", "text": f"Image description: {description}", "language": "en"}) | |
self._save_chat_histories() | |
return description | |
# ---------------------- | |
# Voice reply (TTS) | |
# ---------------------- | |
async def generate_voice_reply(self, text: str, user_id: int, fmt: str = "ogg") -> str: | |
""" | |
Generate TTS audio reply using gTTS (or swap out to another TTS if you have). | |
Returns path to audio file. | |
""" | |
mp3_path = _tmp_path(".mp3") | |
out_path = _tmp_path(f".{fmt}") | |
try: | |
tts = gTTS(text) | |
tts.save(mp3_path) | |
# convert to requested format using ffmpeg (ogg/opus for Telegram voice) | |
ffmpeg_path = ffmpeg.get_ffmpeg_exe() | |
if fmt == "ogg": | |
# convert mp3 -> ogg (opus) | |
subprocess.run([ffmpeg_path, "-y", "-i", mp3_path, "-c:a", "libopus", out_path], check=True) | |
elif fmt == "wav": | |
subprocess.run([ffmpeg_path, "-y", "-i", mp3_path, out_path], check=True) | |
else: | |
# default: return mp3 | |
shutil.move(mp3_path, out_path) | |
except Exception as e: | |
# fallback: raise | |
raise RuntimeError(f"TTS failed: {e}") | |
finally: | |
try: | |
if os.path.exists(mp3_path) and os.path.exists(out_path) and mp3_path != out_path: | |
os.remove(mp3_path) | |
except Exception: | |
pass | |
return out_path | |
# ---------------------- | |
# Image generation (text -> image) | |
# ---------------------- | |
async def generate_image_from_text(self, prompt: str, user_id: int, width: int = 512, height: int = 512, steps: int = 30) -> str: | |
self._load_sd() | |
if self._sd_pipe is None: | |
raise RuntimeError("Stable Diffusion pipeline not available.") | |
out_path = _tmp_path(".png") | |
try: | |
# diffusion pipeline uses CPU/GPU internally | |
result = self._sd_pipe(prompt, num_inference_steps=steps, height=height, width=width) | |
image = result.images[0] | |
image.save(out_path) | |
except Exception as e: | |
raise RuntimeError(f"Image generation failed: {e}") | |
return out_path | |
# ---------------------- | |
# Image editing (inpainting) | |
# ---------------------- | |
async def edit_image_inpaint(self, image_file, mask_file=None, prompt: str = "", user_id: int = 0) -> str: | |
self._load_sd() | |
if self._sd_inpaint is None: | |
raise RuntimeError("Inpainting pipeline not available.") | |
# Save files | |
img_path = _tmp_path(".png") | |
tf = await image_file.get_file() | |
await tf.download_to_drive(img_path) | |
if mask_file: | |
mask_path = _tmp_path(".png") | |
m_tf = await mask_file.get_file() | |
await m_tf.download_to_drive(mask_path) | |
mask_image = Image.open(mask_path).convert("L") | |
else: | |
# default mask (edit entire image) | |
mask_image = Image.new("L", Image.open(img_path).size, color=255) | |
mask_path = None | |
init_image = Image.open(img_path).convert("RGB") | |
# run inpaint | |
out_path = _tmp_path(".png") | |
try: | |
result = self._sd_inpaint(prompt=prompt if prompt else " ", image=init_image, mask_image=mask_image, guidance_scale=7.5, num_inference_steps=30) | |
edited = result.images[0] | |
edited.save(out_path) | |
except Exception as e: | |
raise RuntimeError(f"Inpainting failed: {e}") | |
finally: | |
try: | |
os.remove(img_path) | |
if mask_path: | |
os.remove(mask_path) | |
except Exception: | |
pass | |
return out_path | |
# ---------------------- | |
# Video processing: extract audio, frames, summarize | |
# ---------------------- | |
async def process_video(self, video_file, user_id: int, max_frames: int = 4) -> dict: | |
""" | |
Accepts uploaded video file, extracts audio (for transcription) and sample frames, | |
returns summary: {duration, fps, transcriptions, captions} | |
""" | |
vid_path = _tmp_path(".mp4") | |
tf = await video_file.get_file() | |
await tf.download_to_drive(vid_path) | |
# Extract audio | |
audio_path = _tmp_path(".wav") | |
try: | |
clip = mp.VideoFileClip(vid_path) | |
clip.audio.write_audiofile(audio_path, logger=None) | |
duration = clip.duration | |
fps = clip.fps | |
except Exception as e: | |
raise RuntimeError(f"Video processing failed: {e}") | |
# Transcribe audio using the same process_voice_message flow: use SpeechRecognition or integrate Whisper | |
# For now we'll try SpeechRecognition on the audio | |
recognizer = sr.Recognizer() | |
with sr.AudioFile(audio_path) as source: | |
audio = recognizer.record(source) | |
transcribed = "" | |
try: | |
transcribed = recognizer.recognize_google(audio) | |
except Exception: | |
transcribed = "" | |
# Extract a few frames evenly | |
frames = [] | |
try: | |
clip_reader = imageio.get_reader(vid_path, "ffmpeg") | |
total_frames = clip_reader.count_frames() | |
step = max(1, total_frames // max_frames) | |
for i in range(0, total_frames, step): | |
try: | |
frame = clip_reader.get_data(i) | |
pil = Image.fromarray(frame) | |
ppath = _tmp_path(".jpg") | |
pil.save(ppath) | |
frames.append(ppath) | |
if len(frames) >= max_frames: | |
break | |
except Exception: | |
continue | |
clip_reader.close() | |
except Exception: | |
pass | |
# Use image captioner on the frames | |
captions = [] | |
if frames: | |
self._load_image_captioner() | |
for p in frames: | |
try: | |
img = Image.open(p).convert("RGB") | |
c = self._image_captioner(img)[0]["generated_text"] | |
captions.append(c) | |
except Exception: | |
captions.append("") | |
finally: | |
try: | |
os.remove(p) | |
except Exception: | |
pass | |
# cleanup | |
try: | |
os.remove(vid_path) | |
os.remove(audio_path) | |
except Exception: | |
pass | |
return {"duration": duration, "fps": fps, "transcription": transcribed, "captions": captions} | |
# ---------------------- | |
# File processing (PDF, DOCX, TXT, CSV) | |
# ---------------------- | |
async def process_file(self, file_obj, user_id: int) -> dict: | |
""" | |
Reads a file, extracts text (supports PDF/TXT/CSV/DOCX if python-docx added), | |
and returns a short summary. | |
""" | |
# Save file | |
fpath = _tmp_path() | |
tf = await file_obj.get_file() | |
await tf.download_to_drive(fpath) | |
lower = fpath.lower() | |
text = "" | |
if fpath.endswith(".pdf"): | |
try: | |
doc = fitz.open(fpath) | |
for page in doc: | |
text += page.get_text() | |
except Exception as e: | |
text = f"[PDF read error: {e}]" | |
elif fpath.endswith((".txt", ".csv")): | |
try: | |
with open(fpath, "r", encoding="utf-8", errors="ignore") as fh: | |
text = fh.read() | |
except Exception as e: | |
text = f"[File read error: {e}]" | |
elif fpath.endswith(".docx"): | |
try: | |
import docx | |
doc = docx.Document(fpath) | |
text = "\n".join([p.text for p in doc.paragraphs]) | |
except Exception as e: | |
text = f"[DOCX read error: {e}]" | |
else: | |
text = "[Unsupported file type]" | |
# Summarize: simple heuristic or use translator/chat model to summarize (but that costs compute) | |
summary = text[:300] + ("..." if len(text) > 300 else "") | |
try: | |
os.remove(fpath) | |
except Exception: | |
pass | |
return {"summary": summary, "full_text_length": len(text)} | |
# ---------------------- | |
# Code assistance: generate / explain code | |
# ---------------------- | |
async def code_complete(self, prompt: str, max_tokens: int = 512, temperature: float = 0.2) -> str: | |
""" | |
Uses a code LLM (StarCoder or similar) to complete or generate code. | |
""" | |
self._load_code_model() | |
if not self._code_model or not self._code_tokenizer: | |
raise RuntimeError("Code model not available.") | |
input_ids = self._code_tokenizer(prompt, return_tensors="pt").input_ids.to(self.device) | |
gen = self._code_model.generate(input_ids, max_new_tokens=max_tokens, do_sample=False) | |
out = self._code_tokenizer.decode(gen[0], skip_special_tokens=True) | |
return out | |
# ---------------------- | |
# Optional: execute Python code in sandbox (WARNING: security risk) | |
# ---------------------- | |
async def execute_python_code(self, code: str, timeout: int = 5) -> dict: | |
""" | |
Execute Python code in a very limited sandbox subprocess. | |
WARNING: Running arbitrary code is dangerous. Use only with trusted inputs or stronger sandboxing (containers). | |
""" | |
# Create temp dir | |
d = tempfile.mkdtemp() | |
file_path = os.path.join(d, "main.py") | |
with open(file_path, "w", encoding="utf-8") as f: | |
f.write(code) | |
# run with timeout | |
try: | |
proc = await asyncio.create_subprocess_exec( | |
"python3", file_path, | |
stdout=asyncio.subprocess.PIPE, | |
stderr=asyncio.subprocess.PIPE, | |
) | |
try: | |
stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=timeout) | |
except asyncio.TimeoutError: | |
proc.kill() | |
return {"error": "Execution timed out"} | |
return {"stdout": stdout.decode("utf-8", errors="ignore"), "stderr": stderr.decode("utf-8", errors="ignore")} | |
finally: | |
try: | |
shutil.rmtree(d) | |
except Exception: | |
pass | |