Princeaka commited on
Commit
d45eebc
·
verified ·
1 Parent(s): 5ebcc12

Update multimodal_module.py

Browse files
Files changed (1) hide show
  1. multimodal_module.py +645 -644
multimodal_module.py CHANGED
@@ -1,644 +1,645 @@
1
- # multimodal_module.py
2
- import os
3
- import pickle
4
- import subprocess
5
- import tempfile
6
- import shutil
7
- import asyncio
8
- from datetime import datetime
9
- from typing import Dict, List, Optional, Any
10
- import io
11
- import uuid
12
-
13
- # Core ML libs
14
- import torch
15
- from transformers import (
16
- pipeline,
17
- AutoModelForSeq2SeqLM,
18
- AutoTokenizer,
19
- Wav2Vec2Processor,
20
- Wav2Vec2ForSequenceClassification,
21
- )
22
- from diffusers import StableDiffusionPipeline, StableDiffusionInpaintPipeline
23
- from transformers import AutoModelForCausalLM, AutoTokenizer as HFTokenizer
24
-
25
- # Audio / speech
26
- import librosa
27
- import speech_recognition as sr
28
- from gtts import gTTS
29
-
30
- # Image, video, files
31
- from PIL import Image, ImageOps
32
- import imageio_ffmpeg as ffmpeg
33
- import imageio
34
- import moviepy.editor as mp
35
- import fitz # PyMuPDF for PDFs
36
-
37
- # Misc
38
- from langdetect import DetectorFactory
39
- DetectorFactory.seed = 0
40
-
41
- # Optional: safety-check toggles
42
- USE_SAFETY_CHECKER = False
43
-
44
- # Helper for temp files
45
- def _tmp_path(suffix=""):
46
- return os.path.join(tempfile.gettempdir(), f"{uuid.uuid4().hex}{suffix}")
47
-
48
- class MultiModalChatModule:
49
- """
50
- Full-power multimodal module.
51
- - Lazy-loads big models on first use.
52
- - Methods are async-friendly.
53
- """
54
-
55
- def __init__(self, chat_history_file: str = "chat_histories.pkl"):
56
- self.user_chat_histories: Dict[int, List[dict]] = self._load_chat_histories(chat_history_file)
57
- self.chat_history_file = chat_history_file
58
-
59
- # device
60
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
61
- print(f"[MultiModal] device: {self.device}")
62
-
63
- # placeholders for large models (lazy)
64
- self._voice_processor = None
65
- self._voice_emotion_model = None
66
-
67
- self._translator = None
68
-
69
- self._chat_tokenizer = None
70
- self._chat_model = None
71
- self._chat_model_name = "bigscience/bloom" # placeholder; will set proper below
72
-
73
- self._image_captioner = None
74
-
75
- self._sd_pipe = None
76
- self._sd_inpaint = None
77
-
78
- self._code_tokenizer = None
79
- self._code_model = None
80
-
81
- # other small helpers
82
- self._sr_recognizer = sr.Recognizer()
83
-
84
- # set common model names (you can change)
85
- self.model_names = {
86
- "voice_emotion_processor": "facebook/hubert-large-ls960-ft",
87
- "voice_emotion_model": "superb/hubert-base-superb-er",
88
- "translation_model": "facebook/nllb-200-distilled-600M",
89
- "chatbot_tokenizer": "facebook/blenderbot-400M-distill",
90
- "chatbot_model": "facebook/blenderbot-400M-distill",
91
- "image_captioner": "Salesforce/blip-image-captioning-base",
92
- "sd_inpaint": "runwayml/stable-diffusion-inpainting",
93
- "sd_text2img": "runwayml/stable-diffusion-v1-5",
94
- "code_model": "bigcode/starcoder", # Or use a specific StarCoder checkpoint on HF
95
- }
96
-
97
- # keep track of which heavy groups are loaded
98
- self._loaded = {
99
- "voice": False,
100
- "translation": False,
101
- "chat": False,
102
- "image_caption": False,
103
- "sd": False,
104
- "code": False,
105
- }
106
-
107
- # ----------------------
108
- # persistence
109
- # ----------------------
110
- def _load_chat_histories(self, fn: str) -> Dict[int, List[dict]]:
111
- try:
112
- with open(fn, "rb") as f:
113
- return pickle.load(f)
114
- except Exception:
115
- return {}
116
-
117
- def _save_chat_histories(self):
118
- try:
119
- with open(self.chat_history_file, "wb") as f:
120
- pickle.dump(self.user_chat_histories, f)
121
- except Exception as e:
122
- print("[MultiModal] Warning: failed to save chat histories:", e)
123
-
124
- # ----------------------
125
- # Lazy loaders
126
- # ----------------------
127
- def _load_voice_models(self):
128
- if self._loaded["voice"]:
129
- return
130
- print("[MultiModal] Loading voice/emotion models...")
131
- self._voice_processor = Wav2Vec2Processor.from_pretrained(self.model_names["voice_emotion_processor"])
132
- self._voice_emotion_model = Wav2Vec2ForSequenceClassification.from_pretrained(self.model_names["voice_emotion_model"]).to(self.device)
133
- self._loaded["voice"] = True
134
- print("[MultiModal] Voice models loaded.")
135
-
136
- def _load_translation(self):
137
- if self._loaded["translation"]:
138
- return
139
- print("[MultiModal] Loading translation pipeline...")
140
- device_idx = 0 if self.device == "cuda" else -1
141
- self._translator = pipeline("translation", model=self.model_names["translation_model"], device=device_idx)
142
- self._loaded["translation"] = True
143
- print("[MultiModal] Translation loaded.")
144
-
145
- def _load_chatbot(self):
146
- if self._loaded["chat"]:
147
- return
148
- print("[MultiModal] Loading chatbot model...")
149
- # chatbot: keep current blenderbot to preserve behaviour
150
- self._chat_tokenizer = AutoTokenizer.from_pretrained(self.model_names["chatbot_tokenizer"])
151
- self._chat_model = AutoModelForSeq2SeqLM.from_pretrained(self.model_names["chatbot_model"]).to(self.device)
152
- self._loaded["chat"] = True
153
- print("[MultiModal] Chatbot loaded.")
154
-
155
- def _load_image_captioner(self):
156
- if self._loaded["image_caption"]:
157
- return
158
- print("[MultiModal] Loading image captioner...")
159
- device_idx = 0 if self.device == "cuda" else -1
160
- self._image_captioner = pipeline("image-to-text", model=self.model_names["image_captioner"], device=device_idx)
161
- self._loaded["image_caption"] = True
162
- print("[MultiModal] Image captioner loaded.")
163
-
164
- def _load_sd(self):
165
- if self._loaded["sd"]:
166
- return
167
- print("[MultiModal] Loading Stable Diffusion pipelines...")
168
- # text2img
169
- sd_model = self.model_names["sd_text2img"]
170
- sd_inpaint_model = self.model_names["sd_inpaint"]
171
- # Use float16 on GPU for speed
172
- torch_dtype = torch.float16 if self.device == "cuda" else torch.float32
173
- try:
174
- self._sd_pipe = StableDiffusionPipeline.from_pretrained(sd_model, torch_dtype=torch_dtype)
175
- self._sd_pipe = self._sd_pipe.to(self.device)
176
- except Exception as e:
177
- print("[MultiModal] Warning loading text2img:", e)
178
- self._sd_pipe = None
179
-
180
- try:
181
- self._sd_inpaint = StableDiffusionInpaintPipeline.from_pretrained(sd_inpaint_model, torch_dtype=torch_dtype)
182
- self._sd_inpaint = self._sd_inpaint.to(self.device)
183
- except Exception as e:
184
- print("[MultiModal] Warning loading inpaint:", e)
185
- self._sd_inpaint = None
186
-
187
- self._loaded["sd"] = True
188
- print("[MultiModal] Stable Diffusion loaded (where possible).")
189
-
190
- def _load_code_model(self):
191
- if self._loaded["code"]:
192
- return
193
- print("[MultiModal] Loading code model...")
194
- # StarCoder style model (may require HF_TOKEN or large memory)
195
- try:
196
- self._code_tokenizer = HFTokenizer.from_pretrained(self.model_names["code_model"])
197
- self._code_model = AutoModelForCausalLM.from_pretrained(self.model_names["code_model"]).to(self.device)
198
- self._loaded["code"] = True
199
- print("[MultiModal] Code model loaded.")
200
- except Exception as e:
201
- print("[MultiModal] Warning: could not load code model:", e)
202
- self._code_tokenizer = None
203
- self._code_model = None
204
-
205
- # ----------------------
206
- # Voice: analyze emotion, transcribe
207
- # ----------------------
208
- async def analyze_voice_emotion(self, audio_path: str) -> str:
209
- self._load_voice_models()
210
- speech, sr_ = librosa.load(audio_path, sr=16000)
211
- inputs = self._voice_processor(speech, sampling_rate=sr_, return_tensors="pt", padding=True).to(self.device)
212
- with torch.no_grad():
213
- logits = self._voice_emotion_model(**inputs).logits
214
- predicted_class = torch.argmax(logits).item()
215
- return {
216
- 0: "😊 Happy",
217
- 1: "😢 Sad",
218
- 2: "😠 Angry",
219
- 3: "😨 Fearful",
220
- 4: "😌 Calm",
221
- 5: "😲 Surprised",
222
- }.get(predicted_class, "🤔 Unknown")
223
-
224
- async def process_voice_message(self, voice_file, user_id: int) -> dict:
225
- """
226
- voice_file: Starlette UploadFile or object with get_file() used previously in your code.
227
- Returns: {text, language, emotion}
228
- """
229
- # Save OGG locally
230
- ogg_path = _tmp_path(".ogg")
231
- wav_path = _tmp_path(".wav")
232
- tf = await voice_file.get_file()
233
- await tf.download_to_drive(ogg_path)
234
-
235
- # Convert to WAV via ffmpeg
236
- try:
237
- ffmpeg_path = ffmpeg.get_ffmpeg_exe()
238
- subprocess.run([ffmpeg_path, "-y", "-i", ogg_path, wav_path], check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
239
- except Exception as e:
240
- # fallback: try ffmpeg in PATH
241
- try:
242
- subprocess.run(["ffmpeg", "-y", "-i", ogg_path, wav_path], check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
243
- except Exception as ee:
244
- raise RuntimeError(f"ffmpeg conversion failed: {e} / {ee}")
245
-
246
- # Transcribe using SpeechRecognition Google STT (as before) -- or you can integrate whisper
247
- recognizer = self._sr_recognizer
248
- with sr.AudioFile(wav_path) as source:
249
- audio = recognizer.record(source)
250
-
251
- detected_lang = None
252
- detected_text = ""
253
- # tried languages set
254
- lang_map = {
255
- "zh": {"stt": "zh-CN"},
256
- "ja": {"stt": "ja-JP"},
257
- "ko": {"stt": "ko-KR"},
258
- "en": {"stt": "en-US"},
259
- "es": {"stt": "es-ES"},
260
- "fr": {"stt": "fr-FR"},
261
- "de": {"stt": "de-DE"},
262
- "it": {"stt": "it-IT"},
263
- }
264
- for lang_code, lang_data in lang_map.items():
265
- try:
266
- detected_text = recognizer.recognize_google(audio, language=lang_data["stt"])
267
- detected_lang = lang_code
268
- break
269
- except sr.UnknownValueError:
270
- continue
271
- except Exception:
272
- continue
273
-
274
- if not detected_lang:
275
- # If not recognized, try fallback: detect from small chunk via langdetect
276
- detected_lang = "en"
277
- detected_text = ""
278
-
279
- # emotion
280
- emotion = await self.analyze_voice_emotion(wav_path)
281
-
282
- # remove temp files
283
- try:
284
- os.remove(ogg_path)
285
- os.remove(wav_path)
286
- except Exception:
287
- pass
288
-
289
- return {"text": detected_text, "language": detected_lang, "emotion": emotion}
290
-
291
- # ----------------------
292
- # Text chat with translation & history
293
- # ----------------------
294
- async def generate_response(self, text: str, user_id: int, lang: str = "en") -> str:
295
- # Ensure chat model loaded
296
- self._load_chatbot()
297
- self._load_translation()
298
-
299
- if user_id not in self.user_chat_histories:
300
- self.user_chat_histories[user_id] = []
301
-
302
- self.user_chat_histories[user_id].append({"timestamp": datetime.now().isoformat(), "role": "user", "text": text, "language": lang})
303
- self.user_chat_histories[user_id] = self.user_chat_histories[user_id][-100:]
304
- self._save_chat_histories()
305
-
306
- # Build context: translate last few msgs to English for consistency
307
- context_texts = []
308
- for msg in self.user_chat_histories[user_id][-5:]:
309
- if msg.get("language", "en") != "en":
310
- try:
311
- translated = self._translator(msg["text"])[0]["translation_text"]
312
- except Exception:
313
- translated = msg["text"]
314
- else:
315
- translated = msg["text"]
316
- context_texts.append(f"{msg['role']}: {translated}")
317
-
318
- context = "\n".join(context_texts)
319
- input_text = f"Context:\n{context}\nUser: {text if lang == 'en' else context_texts[-1].split(': ', 1)[1]}"
320
-
321
- # Tokenize + generate
322
- inputs = self._chat_tokenizer.encode(input_text, return_tensors="pt").to(self.device)
323
- outputs = self._chat_model.generate(inputs, max_length=1000)
324
- response_en = self._chat_tokenizer.decode(outputs[0], skip_special_tokens=True)
325
-
326
- # Translate back to user's language if needed
327
- if lang != "en":
328
- try:
329
- response = self._translator(response_en)[0]["translation_text"]
330
- except Exception:
331
- response = response_en
332
- else:
333
- response = response_en
334
-
335
- self.user_chat_histories[user_id].append({"timestamp": datetime.now().isoformat(), "role": "bot", "text": response, "language": lang})
336
- self._save_chat_histories()
337
-
338
- return response
339
-
340
- # ----------------------
341
- # Image captioning (existing)
342
- # ----------------------
343
- async def process_image_message(self, image_file, user_id: int) -> str:
344
- # Save image
345
- img_path = _tmp_path(".jpg")
346
- tf = await image_file.get_file()
347
- await tf.download_to_drive(img_path)
348
-
349
- # load captioner
350
- self._load_image_captioner()
351
- try:
352
- image = Image.open(img_path).convert("RGB")
353
- description = self._image_captioner(image)[0]["generated_text"]
354
- except Exception as e:
355
- description = f"[Error generating caption: {e}]"
356
-
357
- # cleanup
358
- try:
359
- os.remove(img_path)
360
- except Exception:
361
- pass
362
-
363
- # store in history
364
- if user_id not in self.user_chat_histories:
365
- self.user_chat_histories[user_id] = []
366
- self.user_chat_histories[user_id].append({"timestamp": datetime.now().isoformat(), "role": "user", "text": "[Image]", "language": "en"})
367
- self.user_chat_histories[user_id].append({"timestamp": datetime.now().isoformat(), "role": "bot", "text": f"Image description: {description}", "language": "en"})
368
- self._save_chat_histories()
369
-
370
- return description
371
-
372
- # ----------------------
373
- # Voice reply (TTS)
374
- # ----------------------
375
- async def generate_voice_reply(self, text: str, user_id: int, fmt: str = "ogg") -> str:
376
- """
377
- Generate TTS audio reply using gTTS (or swap out to another TTS if you have).
378
- Returns path to audio file.
379
- """
380
- mp3_path = _tmp_path(".mp3")
381
- out_path = _tmp_path(f".{fmt}")
382
-
383
- try:
384
- tts = gTTS(text)
385
- tts.save(mp3_path)
386
- # convert to requested format using ffmpeg (ogg/opus for Telegram voice)
387
- ffmpeg_path = ffmpeg.get_ffmpeg_exe()
388
- if fmt == "ogg":
389
- # convert mp3 -> ogg (opus)
390
- subprocess.run([ffmpeg_path, "-y", "-i", mp3_path, "-c:a", "libopus", out_path], check=True)
391
- elif fmt == "wav":
392
- subprocess.run([ffmpeg_path, "-y", "-i", mp3_path, out_path], check=True)
393
- else:
394
- # default: return mp3
395
- shutil.move(mp3_path, out_path)
396
- except Exception as e:
397
- # fallback: raise
398
- raise RuntimeError(f"TTS failed: {e}")
399
- finally:
400
- try:
401
- if os.path.exists(mp3_path) and os.path.exists(out_path) and mp3_path != out_path:
402
- os.remove(mp3_path)
403
- except Exception:
404
- pass
405
-
406
- return out_path
407
-
408
- # ----------------------
409
- # Image generation (text -> image)
410
- # ----------------------
411
- async def generate_image_from_text(self, prompt: str, user_id: int, width: int = 512, height: int = 512, steps: int = 30) -> str:
412
- self._load_sd()
413
- if self._sd_pipe is None:
414
- raise RuntimeError("Stable Diffusion pipeline not available.")
415
-
416
- out_path = _tmp_path(".png")
417
- try:
418
- # diffusion pipeline uses CPU/GPU internally
419
- result = self._sd_pipe(prompt, num_inference_steps=steps, height=height, width=width)
420
- image = result.images[0]
421
- image.save(out_path)
422
- except Exception as e:
423
- raise RuntimeError(f"Image generation failed: {e}")
424
-
425
- return out_path
426
-
427
- # ----------------------
428
- # Image editing (inpainting)
429
- # ----------------------
430
- async def edit_image_inpaint(self, image_file, mask_file=None, prompt: str = "", user_id: int = 0) -> str:
431
- self._load_sd()
432
- if self._sd_inpaint is None:
433
- raise RuntimeError("Inpainting pipeline not available.")
434
-
435
- # Save files
436
- img_path = _tmp_path(".png")
437
- tf = await image_file.get_file()
438
- await tf.download_to_drive(img_path)
439
-
440
- if mask_file:
441
- mask_path = _tmp_path(".png")
442
- m_tf = await mask_file.get_file()
443
- await m_tf.download_to_drive(mask_path)
444
- mask_image = Image.open(mask_path).convert("L")
445
- else:
446
- # default mask (edit entire image)
447
- mask_image = Image.new("L", Image.open(img_path).size, color=255)
448
- mask_path = None
449
-
450
- init_image = Image.open(img_path).convert("RGB")
451
- # run inpaint
452
- out_path = _tmp_path(".png")
453
- try:
454
- result = self._sd_inpaint(prompt=prompt if prompt else " ", image=init_image, mask_image=mask_image, guidance_scale=7.5, num_inference_steps=30)
455
- edited = result.images[0]
456
- edited.save(out_path)
457
- except Exception as e:
458
- raise RuntimeError(f"Inpainting failed: {e}")
459
- finally:
460
- try:
461
- os.remove(img_path)
462
- if mask_path:
463
- os.remove(mask_path)
464
- except Exception:
465
- pass
466
-
467
- return out_path
468
-
469
- # ----------------------
470
- # Video processing: extract audio, frames, summarize
471
- # ----------------------
472
- async def process_video(self, video_file, user_id: int, max_frames: int = 4) -> dict:
473
- """
474
- Accepts uploaded video file, extracts audio (for transcription) and sample frames,
475
- returns summary: {duration, fps, transcriptions, captions}
476
- """
477
- vid_path = _tmp_path(".mp4")
478
- tf = await video_file.get_file()
479
- await tf.download_to_drive(vid_path)
480
-
481
- # Extract audio
482
- audio_path = _tmp_path(".wav")
483
- try:
484
- clip = mp.VideoFileClip(vid_path)
485
- clip.audio.write_audiofile(audio_path, logger=None)
486
- duration = clip.duration
487
- fps = clip.fps
488
- except Exception as e:
489
- raise RuntimeError(f"Video processing failed: {e}")
490
-
491
- # Transcribe audio using the same process_voice_message flow: use SpeechRecognition or integrate Whisper
492
- # For now we'll try SpeechRecognition on the audio
493
- recognizer = sr.Recognizer()
494
- with sr.AudioFile(audio_path) as source:
495
- audio = recognizer.record(source)
496
- transcribed = ""
497
- try:
498
- transcribed = recognizer.recognize_google(audio)
499
- except Exception:
500
- transcribed = ""
501
-
502
- # Extract a few frames evenly
503
- frames = []
504
- try:
505
- clip_reader = imageio.get_reader(vid_path, "ffmpeg")
506
- total_frames = clip_reader.count_frames()
507
- step = max(1, total_frames // max_frames)
508
- for i in range(0, total_frames, step):
509
- try:
510
- frame = clip_reader.get_data(i)
511
- pil = Image.fromarray(frame)
512
- ppath = _tmp_path(".jpg")
513
- pil.save(ppath)
514
- frames.append(ppath)
515
- if len(frames) >= max_frames:
516
- break
517
- except Exception:
518
- continue
519
- clip_reader.close()
520
- except Exception:
521
- pass
522
-
523
- # Use image captioner on the frames
524
- captions = []
525
- if frames:
526
- self._load_image_captioner()
527
- for p in frames:
528
- try:
529
- img = Image.open(p).convert("RGB")
530
- c = self._image_captioner(img)[0]["generated_text"]
531
- captions.append(c)
532
- except Exception:
533
- captions.append("")
534
- finally:
535
- try:
536
- os.remove(p)
537
- except Exception:
538
- pass
539
-
540
- # cleanup
541
- try:
542
- os.remove(vid_path)
543
- os.remove(audio_path)
544
- except Exception:
545
- pass
546
-
547
- return {"duration": duration, "fps": fps, "transcription": transcribed, "captions": captions}
548
-
549
- # ----------------------
550
- # File processing (PDF, DOCX, TXT, CSV)
551
- # ----------------------
552
- async def process_file(self, file_obj, user_id: int) -> dict:
553
- """
554
- Reads a file, extracts text (supports PDF/TXT/CSV/DOCX if python-docx added),
555
- and returns a short summary.
556
- """
557
- # Save file
558
- fpath = _tmp_path()
559
- tf = await file_obj.get_file()
560
- await tf.download_to_drive(fpath)
561
- lower = fpath.lower()
562
-
563
- text = ""
564
- if fpath.endswith(".pdf"):
565
- try:
566
- doc = fitz.open(fpath)
567
- for page in doc:
568
- text += page.get_text()
569
- except Exception as e:
570
- text = f"[PDF read error: {e}]"
571
- elif fpath.endswith((".txt", ".csv")):
572
- try:
573
- with open(fpath, "r", encoding="utf-8", errors="ignore") as fh:
574
- text = fh.read()
575
- except Exception as e:
576
- text = f"[File read error: {e}]"
577
- elif fpath.endswith(".docx"):
578
- try:
579
- import docx
580
- doc = docx.Document(fpath)
581
- text = "\n".join([p.text for p in doc.paragraphs])
582
- except Exception as e:
583
- text = f"[DOCX read error: {e}]"
584
- else:
585
- text = "[Unsupported file type]"
586
-
587
- # Summarize: simple heuristic or use translator/chat model to summarize (but that costs compute)
588
- summary = text[:300] + ("..." if len(text) > 300 else "")
589
- try:
590
- os.remove(fpath)
591
- except Exception:
592
- pass
593
-
594
- return {"summary": summary, "full_text_length": len(text)}
595
-
596
- # ----------------------
597
- # Code assistance: generate / explain code
598
- # ----------------------
599
- async def code_complete(self, prompt: str, max_tokens: int = 512, temperature: float = 0.2) -> str:
600
- """
601
- Uses a code LLM (StarCoder or similar) to complete or generate code.
602
- """
603
- self._load_code_model()
604
- if not self._code_model or not self._code_tokenizer:
605
- raise RuntimeError("Code model not available.")
606
-
607
- input_ids = self._code_tokenizer(prompt, return_tensors="pt").input_ids.to(self.device)
608
- gen = self._code_model.generate(input_ids, max_new_tokens=max_tokens, do_sample=False)
609
- out = self._code_tokenizer.decode(gen[0], skip_special_tokens=True)
610
- return out
611
-
612
- # ----------------------
613
- # Optional: execute Python code in sandbox (WARNING: security risk)
614
- # ----------------------
615
- async def execute_python_code(self, code: str, timeout: int = 5) -> dict:
616
- """
617
- Execute Python code in a very limited sandbox subprocess.
618
- WARNING: Running arbitrary code is dangerous. Use only with trusted inputs or stronger sandboxing (containers).
619
- """
620
- # Create temp dir
621
- d = tempfile.mkdtemp()
622
- file_path = os.path.join(d, "main.py")
623
- with open(file_path, "w", encoding="utf-8") as f:
624
- f.write(code)
625
-
626
- # run with timeout
627
- try:
628
- proc = await asyncio.create_subprocess_exec(
629
- "python3", file_path,
630
- stdout=asyncio.subprocess.PIPE,
631
- stderr=asyncio.subprocess.PIPE,
632
- )
633
- try:
634
- stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=timeout)
635
- except asyncio.TimeoutError:
636
- proc.kill()
637
- return {"error": "Execution timed out"}
638
- return {"stdout": stdout.decode("utf-8", errors="ignore"), "stderr": stderr.decode("utf-8", errors="ignore")}
639
- finally:
640
- try:
641
- shutil.rmtree(d)
642
- except Exception:
643
- pass
644
-
 
 
1
+ # multimodal_module.py
2
+ import os
3
+ import pickle
4
+ import subprocess
5
+ import tempfile
6
+ import shutil
7
+ import asyncio
8
+ from datetime import datetime
9
+ from huggingface_hub import hf_hub_download, snapshot_download
10
+ from typing import Dict, List, Optional, Any
11
+ import io
12
+ import uuid
13
+
14
+ # Core ML libs
15
+ import torch
16
+ from transformers import (
17
+ pipeline,
18
+ AutoModelForSeq2SeqLM,
19
+ AutoTokenizer,
20
+ Wav2Vec2Processor,
21
+ Wav2Vec2ForSequenceClassification,
22
+ )
23
+ from diffusers import StableDiffusionPipeline, StableDiffusionInpaintPipeline
24
+ from transformers import AutoModelForCausalLM, AutoTokenizer as HFTokenizer
25
+
26
+ # Audio / speech
27
+ import librosa
28
+ import speech_recognition as sr
29
+ from gtts import gTTS
30
+
31
+ # Image, video, files
32
+ from PIL import Image, ImageOps
33
+ import imageio_ffmpeg as ffmpeg
34
+ import imageio
35
+ import moviepy.editor as mp
36
+ import fitz # PyMuPDF for PDFs
37
+
38
+ # Misc
39
+ from langdetect import DetectorFactory
40
+ DetectorFactory.seed = 0
41
+
42
+ # Optional: safety-check toggles
43
+ USE_SAFETY_CHECKER = False
44
+
45
+ # Helper for temp files
46
+ def _tmp_path(suffix=""):
47
+ return os.path.join(tempfile.gettempdir(), f"{uuid.uuid4().hex}{suffix}")
48
+
49
+ class MultiModalChatModule:
50
+ """
51
+ Full-power multimodal module.
52
+ - Lazy-loads big models on first use.
53
+ - Methods are async-friendly.
54
+ """
55
+
56
+ def __init__(self, chat_history_file: str = "chat_histories.pkl"):
57
+ self.user_chat_histories: Dict[int, List[dict]] = self._load_chat_histories(chat_history_file)
58
+ self.chat_history_file = chat_history_file
59
+
60
+ # device
61
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
62
+ print(f"[MultiModal] device: {self.device}")
63
+
64
+ # placeholders for large models (lazy)
65
+ self._voice_processor = None
66
+ self._voice_emotion_model = None
67
+
68
+ self._translator = None
69
+
70
+ self._chat_tokenizer = None
71
+ self._chat_model = None
72
+ self._chat_model_name = "bigscience/bloom" # placeholder; will set proper below
73
+
74
+ self._image_captioner = None
75
+
76
+ self._sd_pipe = None
77
+ self._sd_inpaint = None
78
+
79
+ self._code_tokenizer = None
80
+ self._code_model = None
81
+
82
+ # other small helpers
83
+ self._sr_recognizer = sr.Recognizer()
84
+
85
+ # set common model names (you can change)
86
+ self.model_names = {
87
+ "voice_emotion_processor": "facebook/hubert-large-ls960-ft",
88
+ "voice_emotion_model": "superb/hubert-base-superb-er",
89
+ "translation_model": "facebook/nllb-200-distilled-600M",
90
+ "chatbot_tokenizer": "facebook/blenderbot-400M-distill",
91
+ "chatbot_model": "facebook/blenderbot-400M-distill",
92
+ "image_captioner": "Salesforce/blip-image-captioning-base",
93
+ "sd_inpaint": "runwayml/stable-diffusion-inpainting",
94
+ "sd_text2img": "runwayml/stable-diffusion-v1-5",
95
+ "code_model": "bigcode/starcoder", # Or use a specific StarCoder checkpoint on HF
96
+ }
97
+
98
+ # keep track of which heavy groups are loaded
99
+ self._loaded = {
100
+ "voice": False,
101
+ "translation": False,
102
+ "chat": False,
103
+ "image_caption": False,
104
+ "sd": False,
105
+ "code": False,
106
+ }
107
+
108
+ # ----------------------
109
+ # persistence
110
+ # ----------------------
111
+ def _load_chat_histories(self, fn: str) -> Dict[int, List[dict]]:
112
+ try:
113
+ with open(fn, "rb") as f:
114
+ return pickle.load(f)
115
+ except Exception:
116
+ return {}
117
+
118
+ def _save_chat_histories(self):
119
+ try:
120
+ with open(self.chat_history_file, "wb") as f:
121
+ pickle.dump(self.user_chat_histories, f)
122
+ except Exception as e:
123
+ print("[MultiModal] Warning: failed to save chat histories:", e)
124
+
125
+ # ----------------------
126
+ # Lazy loaders
127
+ # ----------------------
128
+ def _load_voice_models(self):
129
+ if self._loaded["voice"]:
130
+ return
131
+ print("[MultiModal] Loading voice/emotion models...")
132
+ self._voice_processor = Wav2Vec2Processor.from_pretrained(self.model_names["voice_emotion_processor"])
133
+ self._voice_emotion_model = Wav2Vec2ForSequenceClassification.from_pretrained(self.model_names["voice_emotion_model"]).to(self.device)
134
+ self._loaded["voice"] = True
135
+ print("[MultiModal] Voice models loaded.")
136
+
137
+ def _load_translation(self):
138
+ if self._loaded["translation"]:
139
+ return
140
+ print("[MultiModal] Loading translation pipeline...")
141
+ device_idx = 0 if self.device == "cuda" else -1
142
+ self._translator = pipeline("translation", model=self.model_names["translation_model"], device=device_idx)
143
+ self._loaded["translation"] = True
144
+ print("[MultiModal] Translation loaded.")
145
+
146
+ def _load_chatbot(self):
147
+ if self._loaded["chat"]:
148
+ return
149
+ print("[MultiModal] Loading chatbot model...")
150
+ # chatbot: keep current blenderbot to preserve behaviour
151
+ self._chat_tokenizer = AutoTokenizer.from_pretrained(self.model_names["chatbot_tokenizer"])
152
+ self._chat_model = AutoModelForSeq2SeqLM.from_pretrained(self.model_names["chatbot_model"]).to(self.device)
153
+ self._loaded["chat"] = True
154
+ print("[MultiModal] Chatbot loaded.")
155
+
156
+ def _load_image_captioner(self):
157
+ if self._loaded["image_caption"]:
158
+ return
159
+ print("[MultiModal] Loading image captioner...")
160
+ device_idx = 0 if self.device == "cuda" else -1
161
+ self._image_captioner = pipeline("image-to-text", model=self.model_names["image_captioner"], device=device_idx)
162
+ self._loaded["image_caption"] = True
163
+ print("[MultiModal] Image captioner loaded.")
164
+
165
+ def _load_sd(self):
166
+ if self._loaded["sd"]:
167
+ return
168
+ print("[MultiModal] Loading Stable Diffusion pipelines...")
169
+ # text2img
170
+ sd_model = self.model_names["sd_text2img"]
171
+ sd_inpaint_model = self.model_names["sd_inpaint"]
172
+ # Use float16 on GPU for speed
173
+ torch_dtype = torch.float16 if self.device == "cuda" else torch.float32
174
+ try:
175
+ self._sd_pipe = StableDiffusionPipeline.from_pretrained(sd_model, torch_dtype=torch_dtype)
176
+ self._sd_pipe = self._sd_pipe.to(self.device)
177
+ except Exception as e:
178
+ print("[MultiModal] Warning loading text2img:", e)
179
+ self._sd_pipe = None
180
+
181
+ try:
182
+ self._sd_inpaint = StableDiffusionInpaintPipeline.from_pretrained(sd_inpaint_model, torch_dtype=torch_dtype)
183
+ self._sd_inpaint = self._sd_inpaint.to(self.device)
184
+ except Exception as e:
185
+ print("[MultiModal] Warning loading inpaint:", e)
186
+ self._sd_inpaint = None
187
+
188
+ self._loaded["sd"] = True
189
+ print("[MultiModal] Stable Diffusion loaded (where possible).")
190
+
191
+ def _load_code_model(self):
192
+ if self._loaded["code"]:
193
+ return
194
+ print("[MultiModal] Loading code model...")
195
+ # StarCoder style model (may require HF_TOKEN or large memory)
196
+ try:
197
+ self._code_tokenizer = HFTokenizer.from_pretrained(self.model_names["code_model"])
198
+ self._code_model = AutoModelForCausalLM.from_pretrained(self.model_names["code_model"]).to(self.device)
199
+ self._loaded["code"] = True
200
+ print("[MultiModal] Code model loaded.")
201
+ except Exception as e:
202
+ print("[MultiModal] Warning: could not load code model:", e)
203
+ self._code_tokenizer = None
204
+ self._code_model = None
205
+
206
+ # ----------------------
207
+ # Voice: analyze emotion, transcribe
208
+ # ----------------------
209
+ async def analyze_voice_emotion(self, audio_path: str) -> str:
210
+ self._load_voice_models()
211
+ speech, sr_ = librosa.load(audio_path, sr=16000)
212
+ inputs = self._voice_processor(speech, sampling_rate=sr_, return_tensors="pt", padding=True).to(self.device)
213
+ with torch.no_grad():
214
+ logits = self._voice_emotion_model(**inputs).logits
215
+ predicted_class = torch.argmax(logits).item()
216
+ return {
217
+ 0: "😊 Happy",
218
+ 1: "😢 Sad",
219
+ 2: "😠 Angry",
220
+ 3: "😨 Fearful",
221
+ 4: "😌 Calm",
222
+ 5: "😲 Surprised",
223
+ }.get(predicted_class, "🤔 Unknown")
224
+
225
+ async def process_voice_message(self, voice_file, user_id: int) -> dict:
226
+ """
227
+ voice_file: Starlette UploadFile or object with get_file() used previously in your code.
228
+ Returns: {text, language, emotion}
229
+ """
230
+ # Save OGG locally
231
+ ogg_path = _tmp_path(".ogg")
232
+ wav_path = _tmp_path(".wav")
233
+ tf = await voice_file.get_file()
234
+ await tf.download_to_drive(ogg_path)
235
+
236
+ # Convert to WAV via ffmpeg
237
+ try:
238
+ ffmpeg_path = ffmpeg.get_ffmpeg_exe()
239
+ subprocess.run([ffmpeg_path, "-y", "-i", ogg_path, wav_path], check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
240
+ except Exception as e:
241
+ # fallback: try ffmpeg in PATH
242
+ try:
243
+ subprocess.run(["ffmpeg", "-y", "-i", ogg_path, wav_path], check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
244
+ except Exception as ee:
245
+ raise RuntimeError(f"ffmpeg conversion failed: {e} / {ee}")
246
+
247
+ # Transcribe using SpeechRecognition Google STT (as before) -- or you can integrate whisper
248
+ recognizer = self._sr_recognizer
249
+ with sr.AudioFile(wav_path) as source:
250
+ audio = recognizer.record(source)
251
+
252
+ detected_lang = None
253
+ detected_text = ""
254
+ # tried languages set
255
+ lang_map = {
256
+ "zh": {"stt": "zh-CN"},
257
+ "ja": {"stt": "ja-JP"},
258
+ "ko": {"stt": "ko-KR"},
259
+ "en": {"stt": "en-US"},
260
+ "es": {"stt": "es-ES"},
261
+ "fr": {"stt": "fr-FR"},
262
+ "de": {"stt": "de-DE"},
263
+ "it": {"stt": "it-IT"},
264
+ }
265
+ for lang_code, lang_data in lang_map.items():
266
+ try:
267
+ detected_text = recognizer.recognize_google(audio, language=lang_data["stt"])
268
+ detected_lang = lang_code
269
+ break
270
+ except sr.UnknownValueError:
271
+ continue
272
+ except Exception:
273
+ continue
274
+
275
+ if not detected_lang:
276
+ # If not recognized, try fallback: detect from small chunk via langdetect
277
+ detected_lang = "en"
278
+ detected_text = ""
279
+
280
+ # emotion
281
+ emotion = await self.analyze_voice_emotion(wav_path)
282
+
283
+ # remove temp files
284
+ try:
285
+ os.remove(ogg_path)
286
+ os.remove(wav_path)
287
+ except Exception:
288
+ pass
289
+
290
+ return {"text": detected_text, "language": detected_lang, "emotion": emotion}
291
+
292
+ # ----------------------
293
+ # Text chat with translation & history
294
+ # ----------------------
295
+ async def generate_response(self, text: str, user_id: int, lang: str = "en") -> str:
296
+ # Ensure chat model loaded
297
+ self._load_chatbot()
298
+ self._load_translation()
299
+
300
+ if user_id not in self.user_chat_histories:
301
+ self.user_chat_histories[user_id] = []
302
+
303
+ self.user_chat_histories[user_id].append({"timestamp": datetime.now().isoformat(), "role": "user", "text": text, "language": lang})
304
+ self.user_chat_histories[user_id] = self.user_chat_histories[user_id][-100:]
305
+ self._save_chat_histories()
306
+
307
+ # Build context: translate last few msgs to English for consistency
308
+ context_texts = []
309
+ for msg in self.user_chat_histories[user_id][-5:]:
310
+ if msg.get("language", "en") != "en":
311
+ try:
312
+ translated = self._translator(msg["text"])[0]["translation_text"]
313
+ except Exception:
314
+ translated = msg["text"]
315
+ else:
316
+ translated = msg["text"]
317
+ context_texts.append(f"{msg['role']}: {translated}")
318
+
319
+ context = "\n".join(context_texts)
320
+ input_text = f"Context:\n{context}\nUser: {text if lang == 'en' else context_texts[-1].split(': ', 1)[1]}"
321
+
322
+ # Tokenize + generate
323
+ inputs = self._chat_tokenizer.encode(input_text, return_tensors="pt").to(self.device)
324
+ outputs = self._chat_model.generate(inputs, max_length=1000)
325
+ response_en = self._chat_tokenizer.decode(outputs[0], skip_special_tokens=True)
326
+
327
+ # Translate back to user's language if needed
328
+ if lang != "en":
329
+ try:
330
+ response = self._translator(response_en)[0]["translation_text"]
331
+ except Exception:
332
+ response = response_en
333
+ else:
334
+ response = response_en
335
+
336
+ self.user_chat_histories[user_id].append({"timestamp": datetime.now().isoformat(), "role": "bot", "text": response, "language": lang})
337
+ self._save_chat_histories()
338
+
339
+ return response
340
+
341
+ # ----------------------
342
+ # Image captioning (existing)
343
+ # ----------------------
344
+ async def process_image_message(self, image_file, user_id: int) -> str:
345
+ # Save image
346
+ img_path = _tmp_path(".jpg")
347
+ tf = await image_file.get_file()
348
+ await tf.download_to_drive(img_path)
349
+
350
+ # load captioner
351
+ self._load_image_captioner()
352
+ try:
353
+ image = Image.open(img_path).convert("RGB")
354
+ description = self._image_captioner(image)[0]["generated_text"]
355
+ except Exception as e:
356
+ description = f"[Error generating caption: {e}]"
357
+
358
+ # cleanup
359
+ try:
360
+ os.remove(img_path)
361
+ except Exception:
362
+ pass
363
+
364
+ # store in history
365
+ if user_id not in self.user_chat_histories:
366
+ self.user_chat_histories[user_id] = []
367
+ self.user_chat_histories[user_id].append({"timestamp": datetime.now().isoformat(), "role": "user", "text": "[Image]", "language": "en"})
368
+ self.user_chat_histories[user_id].append({"timestamp": datetime.now().isoformat(), "role": "bot", "text": f"Image description: {description}", "language": "en"})
369
+ self._save_chat_histories()
370
+
371
+ return description
372
+
373
+ # ----------------------
374
+ # Voice reply (TTS)
375
+ # ----------------------
376
+ async def generate_voice_reply(self, text: str, user_id: int, fmt: str = "ogg") -> str:
377
+ """
378
+ Generate TTS audio reply using gTTS (or swap out to another TTS if you have).
379
+ Returns path to audio file.
380
+ """
381
+ mp3_path = _tmp_path(".mp3")
382
+ out_path = _tmp_path(f".{fmt}")
383
+
384
+ try:
385
+ tts = gTTS(text)
386
+ tts.save(mp3_path)
387
+ # convert to requested format using ffmpeg (ogg/opus for Telegram voice)
388
+ ffmpeg_path = ffmpeg.get_ffmpeg_exe()
389
+ if fmt == "ogg":
390
+ # convert mp3 -> ogg (opus)
391
+ subprocess.run([ffmpeg_path, "-y", "-i", mp3_path, "-c:a", "libopus", out_path], check=True)
392
+ elif fmt == "wav":
393
+ subprocess.run([ffmpeg_path, "-y", "-i", mp3_path, out_path], check=True)
394
+ else:
395
+ # default: return mp3
396
+ shutil.move(mp3_path, out_path)
397
+ except Exception as e:
398
+ # fallback: raise
399
+ raise RuntimeError(f"TTS failed: {e}")
400
+ finally:
401
+ try:
402
+ if os.path.exists(mp3_path) and os.path.exists(out_path) and mp3_path != out_path:
403
+ os.remove(mp3_path)
404
+ except Exception:
405
+ pass
406
+
407
+ return out_path
408
+
409
+ # ----------------------
410
+ # Image generation (text -> image)
411
+ # ----------------------
412
+ async def generate_image_from_text(self, prompt: str, user_id: int, width: int = 512, height: int = 512, steps: int = 30) -> str:
413
+ self._load_sd()
414
+ if self._sd_pipe is None:
415
+ raise RuntimeError("Stable Diffusion pipeline not available.")
416
+
417
+ out_path = _tmp_path(".png")
418
+ try:
419
+ # diffusion pipeline uses CPU/GPU internally
420
+ result = self._sd_pipe(prompt, num_inference_steps=steps, height=height, width=width)
421
+ image = result.images[0]
422
+ image.save(out_path)
423
+ except Exception as e:
424
+ raise RuntimeError(f"Image generation failed: {e}")
425
+
426
+ return out_path
427
+
428
+ # ----------------------
429
+ # Image editing (inpainting)
430
+ # ----------------------
431
+ async def edit_image_inpaint(self, image_file, mask_file=None, prompt: str = "", user_id: int = 0) -> str:
432
+ self._load_sd()
433
+ if self._sd_inpaint is None:
434
+ raise RuntimeError("Inpainting pipeline not available.")
435
+
436
+ # Save files
437
+ img_path = _tmp_path(".png")
438
+ tf = await image_file.get_file()
439
+ await tf.download_to_drive(img_path)
440
+
441
+ if mask_file:
442
+ mask_path = _tmp_path(".png")
443
+ m_tf = await mask_file.get_file()
444
+ await m_tf.download_to_drive(mask_path)
445
+ mask_image = Image.open(mask_path).convert("L")
446
+ else:
447
+ # default mask (edit entire image)
448
+ mask_image = Image.new("L", Image.open(img_path).size, color=255)
449
+ mask_path = None
450
+
451
+ init_image = Image.open(img_path).convert("RGB")
452
+ # run inpaint
453
+ out_path = _tmp_path(".png")
454
+ try:
455
+ result = self._sd_inpaint(prompt=prompt if prompt else " ", image=init_image, mask_image=mask_image, guidance_scale=7.5, num_inference_steps=30)
456
+ edited = result.images[0]
457
+ edited.save(out_path)
458
+ except Exception as e:
459
+ raise RuntimeError(f"Inpainting failed: {e}")
460
+ finally:
461
+ try:
462
+ os.remove(img_path)
463
+ if mask_path:
464
+ os.remove(mask_path)
465
+ except Exception:
466
+ pass
467
+
468
+ return out_path
469
+
470
+ # ----------------------
471
+ # Video processing: extract audio, frames, summarize
472
+ # ----------------------
473
+ async def process_video(self, video_file, user_id: int, max_frames: int = 4) -> dict:
474
+ """
475
+ Accepts uploaded video file, extracts audio (for transcription) and sample frames,
476
+ returns summary: {duration, fps, transcriptions, captions}
477
+ """
478
+ vid_path = _tmp_path(".mp4")
479
+ tf = await video_file.get_file()
480
+ await tf.download_to_drive(vid_path)
481
+
482
+ # Extract audio
483
+ audio_path = _tmp_path(".wav")
484
+ try:
485
+ clip = mp.VideoFileClip(vid_path)
486
+ clip.audio.write_audiofile(audio_path, logger=None)
487
+ duration = clip.duration
488
+ fps = clip.fps
489
+ except Exception as e:
490
+ raise RuntimeError(f"Video processing failed: {e}")
491
+
492
+ # Transcribe audio using the same process_voice_message flow: use SpeechRecognition or integrate Whisper
493
+ # For now we'll try SpeechRecognition on the audio
494
+ recognizer = sr.Recognizer()
495
+ with sr.AudioFile(audio_path) as source:
496
+ audio = recognizer.record(source)
497
+ transcribed = ""
498
+ try:
499
+ transcribed = recognizer.recognize_google(audio)
500
+ except Exception:
501
+ transcribed = ""
502
+
503
+ # Extract a few frames evenly
504
+ frames = []
505
+ try:
506
+ clip_reader = imageio.get_reader(vid_path, "ffmpeg")
507
+ total_frames = clip_reader.count_frames()
508
+ step = max(1, total_frames // max_frames)
509
+ for i in range(0, total_frames, step):
510
+ try:
511
+ frame = clip_reader.get_data(i)
512
+ pil = Image.fromarray(frame)
513
+ ppath = _tmp_path(".jpg")
514
+ pil.save(ppath)
515
+ frames.append(ppath)
516
+ if len(frames) >= max_frames:
517
+ break
518
+ except Exception:
519
+ continue
520
+ clip_reader.close()
521
+ except Exception:
522
+ pass
523
+
524
+ # Use image captioner on the frames
525
+ captions = []
526
+ if frames:
527
+ self._load_image_captioner()
528
+ for p in frames:
529
+ try:
530
+ img = Image.open(p).convert("RGB")
531
+ c = self._image_captioner(img)[0]["generated_text"]
532
+ captions.append(c)
533
+ except Exception:
534
+ captions.append("")
535
+ finally:
536
+ try:
537
+ os.remove(p)
538
+ except Exception:
539
+ pass
540
+
541
+ # cleanup
542
+ try:
543
+ os.remove(vid_path)
544
+ os.remove(audio_path)
545
+ except Exception:
546
+ pass
547
+
548
+ return {"duration": duration, "fps": fps, "transcription": transcribed, "captions": captions}
549
+
550
+ # ----------------------
551
+ # File processing (PDF, DOCX, TXT, CSV)
552
+ # ----------------------
553
+ async def process_file(self, file_obj, user_id: int) -> dict:
554
+ """
555
+ Reads a file, extracts text (supports PDF/TXT/CSV/DOCX if python-docx added),
556
+ and returns a short summary.
557
+ """
558
+ # Save file
559
+ fpath = _tmp_path()
560
+ tf = await file_obj.get_file()
561
+ await tf.download_to_drive(fpath)
562
+ lower = fpath.lower()
563
+
564
+ text = ""
565
+ if fpath.endswith(".pdf"):
566
+ try:
567
+ doc = fitz.open(fpath)
568
+ for page in doc:
569
+ text += page.get_text()
570
+ except Exception as e:
571
+ text = f"[PDF read error: {e}]"
572
+ elif fpath.endswith((".txt", ".csv")):
573
+ try:
574
+ with open(fpath, "r", encoding="utf-8", errors="ignore") as fh:
575
+ text = fh.read()
576
+ except Exception as e:
577
+ text = f"[File read error: {e}]"
578
+ elif fpath.endswith(".docx"):
579
+ try:
580
+ import docx
581
+ doc = docx.Document(fpath)
582
+ text = "\n".join([p.text for p in doc.paragraphs])
583
+ except Exception as e:
584
+ text = f"[DOCX read error: {e}]"
585
+ else:
586
+ text = "[Unsupported file type]"
587
+
588
+ # Summarize: simple heuristic or use translator/chat model to summarize (but that costs compute)
589
+ summary = text[:300] + ("..." if len(text) > 300 else "")
590
+ try:
591
+ os.remove(fpath)
592
+ except Exception:
593
+ pass
594
+
595
+ return {"summary": summary, "full_text_length": len(text)}
596
+
597
+ # ----------------------
598
+ # Code assistance: generate / explain code
599
+ # ----------------------
600
+ async def code_complete(self, prompt: str, max_tokens: int = 512, temperature: float = 0.2) -> str:
601
+ """
602
+ Uses a code LLM (StarCoder or similar) to complete or generate code.
603
+ """
604
+ self._load_code_model()
605
+ if not self._code_model or not self._code_tokenizer:
606
+ raise RuntimeError("Code model not available.")
607
+
608
+ input_ids = self._code_tokenizer(prompt, return_tensors="pt").input_ids.to(self.device)
609
+ gen = self._code_model.generate(input_ids, max_new_tokens=max_tokens, do_sample=False)
610
+ out = self._code_tokenizer.decode(gen[0], skip_special_tokens=True)
611
+ return out
612
+
613
+ # ----------------------
614
+ # Optional: execute Python code in sandbox (WARNING: security risk)
615
+ # ----------------------
616
+ async def execute_python_code(self, code: str, timeout: int = 5) -> dict:
617
+ """
618
+ Execute Python code in a very limited sandbox subprocess.
619
+ WARNING: Running arbitrary code is dangerous. Use only with trusted inputs or stronger sandboxing (containers).
620
+ """
621
+ # Create temp dir
622
+ d = tempfile.mkdtemp()
623
+ file_path = os.path.join(d, "main.py")
624
+ with open(file_path, "w", encoding="utf-8") as f:
625
+ f.write(code)
626
+
627
+ # run with timeout
628
+ try:
629
+ proc = await asyncio.create_subprocess_exec(
630
+ "python3", file_path,
631
+ stdout=asyncio.subprocess.PIPE,
632
+ stderr=asyncio.subprocess.PIPE,
633
+ )
634
+ try:
635
+ stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=timeout)
636
+ except asyncio.TimeoutError:
637
+ proc.kill()
638
+ return {"error": "Execution timed out"}
639
+ return {"stdout": stdout.decode("utf-8", errors="ignore"), "stderr": stderr.decode("utf-8", errors="ignore")}
640
+ finally:
641
+ try:
642
+ shutil.rmtree(d)
643
+ except Exception:
644
+ pass
645
+