multimodal_module / multimodal_module.py
Princeaka's picture
Upload multimodal_module.py
4815613 verified
raw
history blame
25 kB
# multimodal_module.py
import os
import pickle
import subprocess
import tempfile
import shutil
import asyncio
from datetime import datetime
from typing import Dict, List, Optional, Any
import io
import uuid
# Core ML libs
import torch
from transformers import (
pipeline,
AutoModelForSeq2SeqLM,
AutoTokenizer,
Wav2Vec2Processor,
Wav2Vec2ForSequenceClassification,
)
from diffusers import StableDiffusionPipeline, StableDiffusionInpaintPipeline
from transformers import AutoModelForCausalLM, AutoTokenizer as HFTokenizer
# Audio / speech
import librosa
import speech_recognition as sr
from gtts import gTTS
# Image, video, files
from PIL import Image, ImageOps
import imageio_ffmpeg as ffmpeg
import imageio
import moviepy.editor as mp
import fitz # PyMuPDF for PDFs
# Misc
from langdetect import DetectorFactory
DetectorFactory.seed = 0
# Optional: safety-check toggles
USE_SAFETY_CHECKER = False
# Helper for temp files
def _tmp_path(suffix=""):
return os.path.join(tempfile.gettempdir(), f"{uuid.uuid4().hex}{suffix}")
class MultiModalChatModule:
"""
Full-power multimodal module.
- Lazy-loads big models on first use.
- Methods are async-friendly.
"""
def __init__(self, chat_history_file: str = "chat_histories.pkl"):
self.user_chat_histories: Dict[int, List[dict]] = self._load_chat_histories(chat_history_file)
self.chat_history_file = chat_history_file
# device
self.device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"[MultiModal] device: {self.device}")
# placeholders for large models (lazy)
self._voice_processor = None
self._voice_emotion_model = None
self._translator = None
self._chat_tokenizer = None
self._chat_model = None
self._chat_model_name = "bigscience/bloom" # placeholder; will set proper below
self._image_captioner = None
self._sd_pipe = None
self._sd_inpaint = None
self._code_tokenizer = None
self._code_model = None
# other small helpers
self._sr_recognizer = sr.Recognizer()
# set common model names (you can change)
self.model_names = {
"voice_emotion_processor": "facebook/hubert-large-ls960-ft",
"voice_emotion_model": "superb/hubert-base-superb-er",
"translation_model": "facebook/nllb-200-distilled-600M",
"chatbot_tokenizer": "facebook/blenderbot-400M-distill",
"chatbot_model": "facebook/blenderbot-400M-distill",
"image_captioner": "Salesforce/blip-image-captioning-base",
"sd_inpaint": "runwayml/stable-diffusion-inpainting",
"sd_text2img": "runwayml/stable-diffusion-v1-5",
"code_model": "bigcode/starcoder", # Or use a specific StarCoder checkpoint on HF
}
# keep track of which heavy groups are loaded
self._loaded = {
"voice": False,
"translation": False,
"chat": False,
"image_caption": False,
"sd": False,
"code": False,
}
# ----------------------
# persistence
# ----------------------
def _load_chat_histories(self, fn: str) -> Dict[int, List[dict]]:
try:
with open(fn, "rb") as f:
return pickle.load(f)
except Exception:
return {}
def _save_chat_histories(self):
try:
with open(self.chat_history_file, "wb") as f:
pickle.dump(self.user_chat_histories, f)
except Exception as e:
print("[MultiModal] Warning: failed to save chat histories:", e)
# ----------------------
# Lazy loaders
# ----------------------
def _load_voice_models(self):
if self._loaded["voice"]:
return
print("[MultiModal] Loading voice/emotion models...")
self._voice_processor = Wav2Vec2Processor.from_pretrained(self.model_names["voice_emotion_processor"])
self._voice_emotion_model = Wav2Vec2ForSequenceClassification.from_pretrained(self.model_names["voice_emotion_model"]).to(self.device)
self._loaded["voice"] = True
print("[MultiModal] Voice models loaded.")
def _load_translation(self):
if self._loaded["translation"]:
return
print("[MultiModal] Loading translation pipeline...")
device_idx = 0 if self.device == "cuda" else -1
self._translator = pipeline("translation", model=self.model_names["translation_model"], device=device_idx)
self._loaded["translation"] = True
print("[MultiModal] Translation loaded.")
def _load_chatbot(self):
if self._loaded["chat"]:
return
print("[MultiModal] Loading chatbot model...")
# chatbot: keep current blenderbot to preserve behaviour
self._chat_tokenizer = AutoTokenizer.from_pretrained(self.model_names["chatbot_tokenizer"])
self._chat_model = AutoModelForSeq2SeqLM.from_pretrained(self.model_names["chatbot_model"]).to(self.device)
self._loaded["chat"] = True
print("[MultiModal] Chatbot loaded.")
def _load_image_captioner(self):
if self._loaded["image_caption"]:
return
print("[MultiModal] Loading image captioner...")
device_idx = 0 if self.device == "cuda" else -1
self._image_captioner = pipeline("image-to-text", model=self.model_names["image_captioner"], device=device_idx)
self._loaded["image_caption"] = True
print("[MultiModal] Image captioner loaded.")
def _load_sd(self):
if self._loaded["sd"]:
return
print("[MultiModal] Loading Stable Diffusion pipelines...")
# text2img
sd_model = self.model_names["sd_text2img"]
sd_inpaint_model = self.model_names["sd_inpaint"]
# Use float16 on GPU for speed
torch_dtype = torch.float16 if self.device == "cuda" else torch.float32
try:
self._sd_pipe = StableDiffusionPipeline.from_pretrained(sd_model, torch_dtype=torch_dtype)
self._sd_pipe = self._sd_pipe.to(self.device)
except Exception as e:
print("[MultiModal] Warning loading text2img:", e)
self._sd_pipe = None
try:
self._sd_inpaint = StableDiffusionInpaintPipeline.from_pretrained(sd_inpaint_model, torch_dtype=torch_dtype)
self._sd_inpaint = self._sd_inpaint.to(self.device)
except Exception as e:
print("[MultiModal] Warning loading inpaint:", e)
self._sd_inpaint = None
self._loaded["sd"] = True
print("[MultiModal] Stable Diffusion loaded (where possible).")
def _load_code_model(self):
if self._loaded["code"]:
return
print("[MultiModal] Loading code model...")
# StarCoder style model (may require HF_TOKEN or large memory)
try:
self._code_tokenizer = HFTokenizer.from_pretrained(self.model_names["code_model"])
self._code_model = AutoModelForCausalLM.from_pretrained(self.model_names["code_model"]).to(self.device)
self._loaded["code"] = True
print("[MultiModal] Code model loaded.")
except Exception as e:
print("[MultiModal] Warning: could not load code model:", e)
self._code_tokenizer = None
self._code_model = None
# ----------------------
# Voice: analyze emotion, transcribe
# ----------------------
async def analyze_voice_emotion(self, audio_path: str) -> str:
self._load_voice_models()
speech, sr_ = librosa.load(audio_path, sr=16000)
inputs = self._voice_processor(speech, sampling_rate=sr_, return_tensors="pt", padding=True).to(self.device)
with torch.no_grad():
logits = self._voice_emotion_model(**inputs).logits
predicted_class = torch.argmax(logits).item()
return {
0: "😊 Happy",
1: "😢 Sad",
2: "😠 Angry",
3: "😨 Fearful",
4: "😌 Calm",
5: "😲 Surprised",
}.get(predicted_class, "🤔 Unknown")
async def process_voice_message(self, voice_file, user_id: int) -> dict:
"""
voice_file: Starlette UploadFile or object with get_file() used previously in your code.
Returns: {text, language, emotion}
"""
# Save OGG locally
ogg_path = _tmp_path(".ogg")
wav_path = _tmp_path(".wav")
tf = await voice_file.get_file()
await tf.download_to_drive(ogg_path)
# Convert to WAV via ffmpeg
try:
ffmpeg_path = ffmpeg.get_ffmpeg_exe()
subprocess.run([ffmpeg_path, "-y", "-i", ogg_path, wav_path], check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
except Exception as e:
# fallback: try ffmpeg in PATH
try:
subprocess.run(["ffmpeg", "-y", "-i", ogg_path, wav_path], check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
except Exception as ee:
raise RuntimeError(f"ffmpeg conversion failed: {e} / {ee}")
# Transcribe using SpeechRecognition Google STT (as before) -- or you can integrate whisper
recognizer = self._sr_recognizer
with sr.AudioFile(wav_path) as source:
audio = recognizer.record(source)
detected_lang = None
detected_text = ""
# tried languages set
lang_map = {
"zh": {"stt": "zh-CN"},
"ja": {"stt": "ja-JP"},
"ko": {"stt": "ko-KR"},
"en": {"stt": "en-US"},
"es": {"stt": "es-ES"},
"fr": {"stt": "fr-FR"},
"de": {"stt": "de-DE"},
"it": {"stt": "it-IT"},
}
for lang_code, lang_data in lang_map.items():
try:
detected_text = recognizer.recognize_google(audio, language=lang_data["stt"])
detected_lang = lang_code
break
except sr.UnknownValueError:
continue
except Exception:
continue
if not detected_lang:
# If not recognized, try fallback: detect from small chunk via langdetect
detected_lang = "en"
detected_text = ""
# emotion
emotion = await self.analyze_voice_emotion(wav_path)
# remove temp files
try:
os.remove(ogg_path)
os.remove(wav_path)
except Exception:
pass
return {"text": detected_text, "language": detected_lang, "emotion": emotion}
# ----------------------
# Text chat with translation & history
# ----------------------
async def generate_response(self, text: str, user_id: int, lang: str = "en") -> str:
# Ensure chat model loaded
self._load_chatbot()
self._load_translation()
if user_id not in self.user_chat_histories:
self.user_chat_histories[user_id] = []
self.user_chat_histories[user_id].append({"timestamp": datetime.now().isoformat(), "role": "user", "text": text, "language": lang})
self.user_chat_histories[user_id] = self.user_chat_histories[user_id][-100:]
self._save_chat_histories()
# Build context: translate last few msgs to English for consistency
context_texts = []
for msg in self.user_chat_histories[user_id][-5:]:
if msg.get("language", "en") != "en":
try:
translated = self._translator(msg["text"])[0]["translation_text"]
except Exception:
translated = msg["text"]
else:
translated = msg["text"]
context_texts.append(f"{msg['role']}: {translated}")
context = "\n".join(context_texts)
input_text = f"Context:\n{context}\nUser: {text if lang == 'en' else context_texts[-1].split(': ', 1)[1]}"
# Tokenize + generate
inputs = self._chat_tokenizer.encode(input_text, return_tensors="pt").to(self.device)
outputs = self._chat_model.generate(inputs, max_length=1000)
response_en = self._chat_tokenizer.decode(outputs[0], skip_special_tokens=True)
# Translate back to user's language if needed
if lang != "en":
try:
response = self._translator(response_en)[0]["translation_text"]
except Exception:
response = response_en
else:
response = response_en
self.user_chat_histories[user_id].append({"timestamp": datetime.now().isoformat(), "role": "bot", "text": response, "language": lang})
self._save_chat_histories()
return response
# ----------------------
# Image captioning (existing)
# ----------------------
async def process_image_message(self, image_file, user_id: int) -> str:
# Save image
img_path = _tmp_path(".jpg")
tf = await image_file.get_file()
await tf.download_to_drive(img_path)
# load captioner
self._load_image_captioner()
try:
image = Image.open(img_path).convert("RGB")
description = self._image_captioner(image)[0]["generated_text"]
except Exception as e:
description = f"[Error generating caption: {e}]"
# cleanup
try:
os.remove(img_path)
except Exception:
pass
# store in history
if user_id not in self.user_chat_histories:
self.user_chat_histories[user_id] = []
self.user_chat_histories[user_id].append({"timestamp": datetime.now().isoformat(), "role": "user", "text": "[Image]", "language": "en"})
self.user_chat_histories[user_id].append({"timestamp": datetime.now().isoformat(), "role": "bot", "text": f"Image description: {description}", "language": "en"})
self._save_chat_histories()
return description
# ----------------------
# Voice reply (TTS)
# ----------------------
async def generate_voice_reply(self, text: str, user_id: int, fmt: str = "ogg") -> str:
"""
Generate TTS audio reply using gTTS (or swap out to another TTS if you have).
Returns path to audio file.
"""
mp3_path = _tmp_path(".mp3")
out_path = _tmp_path(f".{fmt}")
try:
tts = gTTS(text)
tts.save(mp3_path)
# convert to requested format using ffmpeg (ogg/opus for Telegram voice)
ffmpeg_path = ffmpeg.get_ffmpeg_exe()
if fmt == "ogg":
# convert mp3 -> ogg (opus)
subprocess.run([ffmpeg_path, "-y", "-i", mp3_path, "-c:a", "libopus", out_path], check=True)
elif fmt == "wav":
subprocess.run([ffmpeg_path, "-y", "-i", mp3_path, out_path], check=True)
else:
# default: return mp3
shutil.move(mp3_path, out_path)
except Exception as e:
# fallback: raise
raise RuntimeError(f"TTS failed: {e}")
finally:
try:
if os.path.exists(mp3_path) and os.path.exists(out_path) and mp3_path != out_path:
os.remove(mp3_path)
except Exception:
pass
return out_path
# ----------------------
# Image generation (text -> image)
# ----------------------
async def generate_image_from_text(self, prompt: str, user_id: int, width: int = 512, height: int = 512, steps: int = 30) -> str:
self._load_sd()
if self._sd_pipe is None:
raise RuntimeError("Stable Diffusion pipeline not available.")
out_path = _tmp_path(".png")
try:
# diffusion pipeline uses CPU/GPU internally
result = self._sd_pipe(prompt, num_inference_steps=steps, height=height, width=width)
image = result.images[0]
image.save(out_path)
except Exception as e:
raise RuntimeError(f"Image generation failed: {e}")
return out_path
# ----------------------
# Image editing (inpainting)
# ----------------------
async def edit_image_inpaint(self, image_file, mask_file=None, prompt: str = "", user_id: int = 0) -> str:
self._load_sd()
if self._sd_inpaint is None:
raise RuntimeError("Inpainting pipeline not available.")
# Save files
img_path = _tmp_path(".png")
tf = await image_file.get_file()
await tf.download_to_drive(img_path)
if mask_file:
mask_path = _tmp_path(".png")
m_tf = await mask_file.get_file()
await m_tf.download_to_drive(mask_path)
mask_image = Image.open(mask_path).convert("L")
else:
# default mask (edit entire image)
mask_image = Image.new("L", Image.open(img_path).size, color=255)
mask_path = None
init_image = Image.open(img_path).convert("RGB")
# run inpaint
out_path = _tmp_path(".png")
try:
result = self._sd_inpaint(prompt=prompt if prompt else " ", image=init_image, mask_image=mask_image, guidance_scale=7.5, num_inference_steps=30)
edited = result.images[0]
edited.save(out_path)
except Exception as e:
raise RuntimeError(f"Inpainting failed: {e}")
finally:
try:
os.remove(img_path)
if mask_path:
os.remove(mask_path)
except Exception:
pass
return out_path
# ----------------------
# Video processing: extract audio, frames, summarize
# ----------------------
async def process_video(self, video_file, user_id: int, max_frames: int = 4) -> dict:
"""
Accepts uploaded video file, extracts audio (for transcription) and sample frames,
returns summary: {duration, fps, transcriptions, captions}
"""
vid_path = _tmp_path(".mp4")
tf = await video_file.get_file()
await tf.download_to_drive(vid_path)
# Extract audio
audio_path = _tmp_path(".wav")
try:
clip = mp.VideoFileClip(vid_path)
clip.audio.write_audiofile(audio_path, logger=None)
duration = clip.duration
fps = clip.fps
except Exception as e:
raise RuntimeError(f"Video processing failed: {e}")
# Transcribe audio using the same process_voice_message flow: use SpeechRecognition or integrate Whisper
# For now we'll try SpeechRecognition on the audio
recognizer = sr.Recognizer()
with sr.AudioFile(audio_path) as source:
audio = recognizer.record(source)
transcribed = ""
try:
transcribed = recognizer.recognize_google(audio)
except Exception:
transcribed = ""
# Extract a few frames evenly
frames = []
try:
clip_reader = imageio.get_reader(vid_path, "ffmpeg")
total_frames = clip_reader.count_frames()
step = max(1, total_frames // max_frames)
for i in range(0, total_frames, step):
try:
frame = clip_reader.get_data(i)
pil = Image.fromarray(frame)
ppath = _tmp_path(".jpg")
pil.save(ppath)
frames.append(ppath)
if len(frames) >= max_frames:
break
except Exception:
continue
clip_reader.close()
except Exception:
pass
# Use image captioner on the frames
captions = []
if frames:
self._load_image_captioner()
for p in frames:
try:
img = Image.open(p).convert("RGB")
c = self._image_captioner(img)[0]["generated_text"]
captions.append(c)
except Exception:
captions.append("")
finally:
try:
os.remove(p)
except Exception:
pass
# cleanup
try:
os.remove(vid_path)
os.remove(audio_path)
except Exception:
pass
return {"duration": duration, "fps": fps, "transcription": transcribed, "captions": captions}
# ----------------------
# File processing (PDF, DOCX, TXT, CSV)
# ----------------------
async def process_file(self, file_obj, user_id: int) -> dict:
"""
Reads a file, extracts text (supports PDF/TXT/CSV/DOCX if python-docx added),
and returns a short summary.
"""
# Save file
fpath = _tmp_path()
tf = await file_obj.get_file()
await tf.download_to_drive(fpath)
lower = fpath.lower()
text = ""
if fpath.endswith(".pdf"):
try:
doc = fitz.open(fpath)
for page in doc:
text += page.get_text()
except Exception as e:
text = f"[PDF read error: {e}]"
elif fpath.endswith((".txt", ".csv")):
try:
with open(fpath, "r", encoding="utf-8", errors="ignore") as fh:
text = fh.read()
except Exception as e:
text = f"[File read error: {e}]"
elif fpath.endswith(".docx"):
try:
import docx
doc = docx.Document(fpath)
text = "\n".join([p.text for p in doc.paragraphs])
except Exception as e:
text = f"[DOCX read error: {e}]"
else:
text = "[Unsupported file type]"
# Summarize: simple heuristic or use translator/chat model to summarize (but that costs compute)
summary = text[:300] + ("..." if len(text) > 300 else "")
try:
os.remove(fpath)
except Exception:
pass
return {"summary": summary, "full_text_length": len(text)}
# ----------------------
# Code assistance: generate / explain code
# ----------------------
async def code_complete(self, prompt: str, max_tokens: int = 512, temperature: float = 0.2) -> str:
"""
Uses a code LLM (StarCoder or similar) to complete or generate code.
"""
self._load_code_model()
if not self._code_model or not self._code_tokenizer:
raise RuntimeError("Code model not available.")
input_ids = self._code_tokenizer(prompt, return_tensors="pt").input_ids.to(self.device)
gen = self._code_model.generate(input_ids, max_new_tokens=max_tokens, do_sample=False)
out = self._code_tokenizer.decode(gen[0], skip_special_tokens=True)
return out
# ----------------------
# Optional: execute Python code in sandbox (WARNING: security risk)
# ----------------------
async def execute_python_code(self, code: str, timeout: int = 5) -> dict:
"""
Execute Python code in a very limited sandbox subprocess.
WARNING: Running arbitrary code is dangerous. Use only with trusted inputs or stronger sandboxing (containers).
"""
# Create temp dir
d = tempfile.mkdtemp()
file_path = os.path.join(d, "main.py")
with open(file_path, "w", encoding="utf-8") as f:
f.write(code)
# run with timeout
try:
proc = await asyncio.create_subprocess_exec(
"python3", file_path,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
try:
stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=timeout)
except asyncio.TimeoutError:
proc.kill()
return {"error": "Execution timed out"}
return {"stdout": stdout.decode("utf-8", errors="ignore"), "stderr": stderr.decode("utf-8", errors="ignore")}
finally:
try:
shutil.rmtree(d)
except Exception:
pass