|
import os |
|
import torch |
|
import numpy as np |
|
import time |
|
import io |
|
import base64 |
|
from typing import Dict, List, Any, Union, Optional, Tuple |
|
from pathlib import Path |
|
from PIL import Image |
|
import qrcode |
|
import cv2 |
|
import matplotlib.pyplot as plt |
|
from transformers import BlipProcessor, BlipForConditionalGeneration, WhisperProcessor, WhisperForConditionalGeneration |
|
from gtts import gTTS |
|
|
|
from utils.config import AI_MODELS |
|
from utils.logging import get_logger, log_performance, log_ai_model_usage |
|
from utils.error_handling import handle_ai_model_exceptions, AIModelError, ValidationError |
|
|
|
|
|
logger = get_logger(__name__) |
|
|
|
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = "" if not torch.cuda.is_available() else "0" |
|
|
|
|
|
from utils.ai_models import MODEL_CACHE, get_model |
|
|
|
@handle_ai_model_exceptions |
|
def analyze_image(image, max_length: Optional[int] = None) -> str: |
|
""" |
|
Generate a caption for an image using BLIP model |
|
|
|
Args: |
|
image: PIL Image or path to image file |
|
max_length: Maximum length of caption (uses config default if None) |
|
|
|
Returns: |
|
Generated caption |
|
|
|
Raises: |
|
AIModelError: If there's an error generating the caption |
|
""" |
|
task = "image_captioning" |
|
model_config = AI_MODELS[task] |
|
model_name = model_config["name"] |
|
|
|
|
|
if max_length is None: |
|
max_length = model_config.get("max_length", 50) |
|
|
|
logger.debug(f"Generating caption for image") |
|
start_time = time.time() |
|
|
|
|
|
model, processor = get_model(task) |
|
|
|
try: |
|
|
|
if isinstance(image, (str, Path)): |
|
image = Image.open(image).convert('RGB') |
|
elif not isinstance(image, Image.Image): |
|
|
|
if isinstance(image, np.ndarray): |
|
image = Image.fromarray(image) |
|
else: |
|
raise ValidationError("Unsupported image format") |
|
|
|
|
|
inputs = processor(image, return_tensors="pt") |
|
|
|
|
|
with torch.no_grad(): |
|
output = model.generate( |
|
**inputs, |
|
max_length=max_length, |
|
num_beams=5, |
|
early_stopping=True |
|
) |
|
|
|
|
|
caption = processor.decode(output[0], skip_special_tokens=True) |
|
|
|
|
|
elapsed_time = (time.time() - start_time) * 1000 |
|
log_performance("analyze_image", elapsed_time) |
|
log_ai_model_usage(model_name, "image_captioning", len(output[0])) |
|
|
|
logger.debug(f"Image caption generated successfully in {elapsed_time:.2f}ms") |
|
return caption |
|
except Exception as e: |
|
logger.error(f"Error generating image caption: {str(e)}") |
|
raise AIModelError(f"Error generating image caption", {"original_error": str(e)}) from e |
|
|
|
@handle_ai_model_exceptions |
|
def extract_text_from_image(image) -> str: |
|
""" |
|
Extract text from an image using OCR |
|
|
|
Args: |
|
image: PIL Image or path to image file |
|
|
|
Returns: |
|
Extracted text |
|
|
|
Raises: |
|
AIModelError: If there's an error extracting text |
|
""" |
|
logger.debug(f"Extracting text from image") |
|
start_time = time.time() |
|
|
|
try: |
|
|
|
import easyocr |
|
|
|
|
|
if isinstance(image, (str, Path)): |
|
image_path = str(image) |
|
image = Image.open(image).convert('RGB') |
|
elif isinstance(image, Image.Image): |
|
|
|
temp_path = os.path.join(os.path.dirname(__file__), "temp_ocr_image.jpg") |
|
image.save(temp_path) |
|
image_path = temp_path |
|
elif isinstance(image, np.ndarray): |
|
|
|
temp_path = os.path.join(os.path.dirname(__file__), "temp_ocr_image.jpg") |
|
cv2.imwrite(temp_path, image) |
|
image_path = temp_path |
|
else: |
|
raise ValidationError("Unsupported image format") |
|
|
|
|
|
reader = easyocr.Reader(['en']) |
|
|
|
|
|
results = reader.readtext(image_path) |
|
|
|
|
|
extracted_text = "\n".join([result[1] for result in results]) |
|
|
|
|
|
if 'temp_path' in locals() and os.path.exists(temp_path): |
|
os.remove(temp_path) |
|
|
|
|
|
elapsed_time = (time.time() - start_time) * 1000 |
|
log_performance("extract_text_from_image", elapsed_time) |
|
|
|
logger.debug(f"Text extracted successfully in {elapsed_time:.2f}ms") |
|
return extracted_text |
|
except Exception as e: |
|
logger.error(f"Error extracting text from image: {str(e)}") |
|
raise AIModelError(f"Error extracting text from image", {"original_error": str(e)}) from e |
|
|
|
@handle_ai_model_exceptions |
|
def generate_qr_code(data: str, box_size: int = 10, border: int = 4) -> Image.Image: |
|
""" |
|
Generate a QR code from text data |
|
|
|
Args: |
|
data: Text data to encode in the QR code |
|
box_size: Size of each box in the QR code |
|
border: Border size of the QR code |
|
|
|
Returns: |
|
PIL Image containing the QR code |
|
|
|
Raises: |
|
AIModelError: If there's an error generating the QR code |
|
""" |
|
logger.debug(f"Generating QR code for data: {data[:20]}...") |
|
start_time = time.time() |
|
|
|
try: |
|
|
|
qr = qrcode.QRCode( |
|
version=1, |
|
error_correction=qrcode.constants.ERROR_CORRECT_L, |
|
box_size=box_size, |
|
border=border, |
|
) |
|
|
|
|
|
qr.add_data(data) |
|
qr.make(fit=True) |
|
|
|
|
|
img = qr.make_image(fill_color="black", back_color="white") |
|
|
|
|
|
elapsed_time = (time.time() - start_time) * 1000 |
|
log_performance("generate_qr_code", elapsed_time) |
|
|
|
logger.debug(f"QR code generated successfully in {elapsed_time:.2f}ms") |
|
return img |
|
except Exception as e: |
|
logger.error(f"Error generating QR code: {str(e)}") |
|
raise AIModelError(f"Error generating QR code", {"original_error": str(e)}) from e |
|
|
|
@handle_ai_model_exceptions |
|
def scan_document(image) -> Dict[str, Any]: |
|
""" |
|
Scan a document from an image, extract text and detect document boundaries |
|
|
|
Args: |
|
image: PIL Image or path to image file |
|
|
|
Returns: |
|
Dictionary with extracted text and processed image |
|
|
|
Raises: |
|
AIModelError: If there's an error scanning the document |
|
""" |
|
logger.debug(f"Scanning document from image") |
|
start_time = time.time() |
|
|
|
try: |
|
|
|
if isinstance(image, (str, Path)): |
|
img = cv2.imread(str(image)) |
|
elif isinstance(image, Image.Image): |
|
img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) |
|
elif isinstance(image, np.ndarray): |
|
img = image |
|
else: |
|
raise ValidationError("Unsupported image format") |
|
|
|
|
|
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) |
|
|
|
|
|
blur = cv2.GaussianBlur(gray, (5, 5), 0) |
|
|
|
|
|
edges = cv2.Canny(blur, 75, 200) |
|
|
|
|
|
contours, _ = cv2.findContours(edges, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE) |
|
contours = sorted(contours, key=cv2.contourArea, reverse=True) |
|
|
|
|
|
doc_contour = None |
|
|
|
|
|
for contour in contours: |
|
perimeter = cv2.arcLength(contour, True) |
|
approx = cv2.approxPolyDP(contour, 0.02 * perimeter, True) |
|
|
|
if len(approx) == 4: |
|
doc_contour = approx |
|
break |
|
|
|
|
|
if doc_contour is not None: |
|
|
|
img_with_contour = img.copy() |
|
cv2.drawContours(img_with_contour, [doc_contour], -1, (0, 255, 0), 2) |
|
|
|
|
|
processed_img = Image.fromarray(cv2.cvtColor(img_with_contour, cv2.COLOR_BGR2RGB)) |
|
else: |
|
|
|
processed_img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) |
|
|
|
|
|
extracted_text = extract_text_from_image(gray) |
|
|
|
|
|
elapsed_time = (time.time() - start_time) * 1000 |
|
log_performance("scan_document", elapsed_time) |
|
|
|
logger.debug(f"Document scanned successfully in {elapsed_time:.2f}ms") |
|
|
|
return { |
|
"text": extracted_text, |
|
"processed_image": processed_img, |
|
"document_detected": doc_contour is not None |
|
} |
|
except Exception as e: |
|
logger.error(f"Error scanning document: {str(e)}") |
|
raise AIModelError(f"Error scanning document", {"original_error": str(e)}) from e |
|
|
|
@handle_ai_model_exceptions |
|
def generate_mind_map(topics: List[str], connections: Optional[List[Tuple[int, int]]] = None) -> Image.Image: |
|
""" |
|
Generate a mind map visualization from topics and their connections |
|
|
|
Args: |
|
topics: List of topic strings |
|
connections: List of tuples indicating connections between topics by index |
|
|
|
Returns: |
|
PIL Image containing the mind map |
|
|
|
Raises: |
|
AIModelError: If there's an error generating the mind map |
|
""" |
|
logger.debug(f"Generating mind map with {len(topics)} topics") |
|
start_time = time.time() |
|
|
|
try: |
|
|
|
plt.figure(figsize=(12, 8)) |
|
|
|
|
|
if connections is None: |
|
connections = [(0, i) for i in range(1, len(topics))] |
|
|
|
|
|
import networkx as nx |
|
G = nx.Graph() |
|
|
|
|
|
for i, topic in enumerate(topics): |
|
G.add_node(i, label=topic) |
|
|
|
|
|
for source, target in connections: |
|
G.add_edge(source, target) |
|
|
|
|
|
pos = nx.spring_layout(G, seed=42) |
|
|
|
|
|
nx.draw_networkx_nodes(G, pos, node_size=2000, node_color='skyblue', alpha=0.8) |
|
nx.draw_networkx_edges(G, pos, width=2, alpha=0.5, edge_color='gray') |
|
|
|
|
|
labels = {i: data['label'] for i, data in G.nodes(data=True)} |
|
nx.draw_networkx_labels(G, pos, labels, font_size=10, font_weight='bold') |
|
|
|
|
|
plt.axis('off') |
|
|
|
|
|
buf = io.BytesIO() |
|
plt.savefig(buf, format='png', dpi=100, bbox_inches='tight') |
|
buf.seek(0) |
|
|
|
|
|
mind_map_img = Image.open(buf) |
|
|
|
|
|
plt.close() |
|
|
|
|
|
elapsed_time = (time.time() - start_time) * 1000 |
|
log_performance("generate_mind_map", elapsed_time) |
|
|
|
logger.debug(f"Mind map generated successfully in {elapsed_time:.2f}ms") |
|
return mind_map_img |
|
except Exception as e: |
|
logger.error(f"Error generating mind map: {str(e)}") |
|
raise AIModelError(f"Error generating mind map", {"original_error": str(e)}) from e |
|
|
|
@handle_ai_model_exceptions |
|
def transcribe_speech(audio_file) -> str: |
|
""" |
|
Transcribe speech from an audio file using Whisper model |
|
|
|
Args: |
|
audio_file: Path to audio file or file-like object |
|
|
|
Returns: |
|
Transcribed text |
|
|
|
Raises: |
|
AIModelError: If there's an error transcribing the speech |
|
""" |
|
task = "speech_to_text" |
|
model_name = AI_MODELS[task]["name"] |
|
|
|
logger.debug(f"Transcribing speech from audio file") |
|
start_time = time.time() |
|
|
|
|
|
model, processor = get_model(task) |
|
|
|
try: |
|
|
|
if isinstance(audio_file, (str, Path)): |
|
|
|
import librosa |
|
audio_array, sampling_rate = librosa.load(audio_file, sr=16000) |
|
else: |
|
raise ValidationError("Unsupported audio format") |
|
|
|
|
|
input_features = processor(audio_array, sampling_rate=16000, return_tensors="pt").input_features |
|
|
|
|
|
with torch.no_grad(): |
|
predicted_ids = model.generate(input_features) |
|
|
|
|
|
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] |
|
|
|
|
|
elapsed_time = (time.time() - start_time) * 1000 |
|
log_performance("transcribe_speech", elapsed_time) |
|
log_ai_model_usage(model_name, "speech_to_text", len(predicted_ids[0])) |
|
|
|
logger.debug(f"Speech transcribed successfully in {elapsed_time:.2f}ms") |
|
return transcription |
|
except Exception as e: |
|
logger.error(f"Error transcribing speech: {str(e)}") |
|
raise AIModelError(f"Error transcribing speech", {"original_error": str(e)}) from e |
|
|
|
@handle_ai_model_exceptions |
|
def text_to_speech(text: str, lang: str = 'en', slow: bool = False) -> bytes: |
|
""" |
|
Convert text to speech using gTTS |
|
|
|
Args: |
|
text: Text to convert to speech |
|
lang: Language code (default: 'en') |
|
slow: Whether to speak slowly (default: False) |
|
|
|
Returns: |
|
Audio data as bytes |
|
|
|
Raises: |
|
AIModelError: If there's an error converting text to speech |
|
""" |
|
logger.debug(f"Converting text to speech: {text[:50]}...") |
|
start_time = time.time() |
|
|
|
try: |
|
|
|
audio_io = io.BytesIO() |
|
|
|
|
|
tts = gTTS(text=text, lang=lang, slow=slow) |
|
|
|
|
|
tts.write_to_fp(audio_io) |
|
|
|
|
|
audio_io.seek(0) |
|
|
|
|
|
audio_data = audio_io.read() |
|
|
|
|
|
elapsed_time = (time.time() - start_time) * 1000 |
|
log_performance("text_to_speech", elapsed_time) |
|
|
|
logger.debug(f"Text converted to speech successfully in {elapsed_time:.2f}ms") |
|
return audio_data |
|
except Exception as e: |
|
logger.error(f"Error converting text to speech: {str(e)}") |
|
raise AIModelError(f"Error converting text to speech", {"original_error": str(e)}) from e |
|
|
|
@handle_ai_model_exceptions |
|
def detect_language(audio_file) -> str: |
|
""" |
|
Detect language from speech in an audio file |
|
|
|
Args: |
|
audio_file: Path to audio file or file-like object |
|
|
|
Returns: |
|
Detected language code |
|
|
|
Raises: |
|
AIModelError: If there's an error detecting the language |
|
""" |
|
logger.debug(f"Detecting language from audio file") |
|
start_time = time.time() |
|
|
|
try: |
|
|
|
transcription = transcribe_speech(audio_file) |
|
|
|
|
|
from langdetect import detect |
|
language_code = detect(transcription) |
|
|
|
|
|
elapsed_time = (time.time() - start_time) * 1000 |
|
log_performance("detect_language", elapsed_time) |
|
|
|
logger.debug(f"Language detected successfully in {elapsed_time:.2f}ms: {language_code}") |
|
return language_code |
|
except Exception as e: |
|
logger.error(f"Error detecting language: {str(e)}") |
|
raise AIModelError(f"Error detecting language", {"original_error": str(e)}) from e |
|
|
|
@handle_ai_model_exceptions |
|
def tag_image(image) -> List[str]: |
|
""" |
|
Generate tags for an image using image captioning and NLP |
|
|
|
Args: |
|
image: PIL Image or path to image file |
|
|
|
Returns: |
|
List of tags |
|
|
|
Raises: |
|
AIModelError: If there's an error generating tags |
|
""" |
|
logger.debug(f"Generating tags for image") |
|
start_time = time.time() |
|
|
|
try: |
|
|
|
caption = analyze_image(image) |
|
|
|
|
|
import nltk |
|
from nltk.corpus import stopwords |
|
from nltk.tokenize import word_tokenize |
|
|
|
|
|
try: |
|
nltk.data.find('tokenizers/punkt') |
|
except LookupError: |
|
nltk.download('punkt') |
|
|
|
try: |
|
nltk.data.find('corpora/stopwords') |
|
except LookupError: |
|
nltk.download('stopwords') |
|
|
|
|
|
tokens = word_tokenize(caption.lower()) |
|
|
|
|
|
stop_words = set(stopwords.words('english')) |
|
filtered_tokens = [word for word in tokens if word.isalpha() and word not in stop_words] |
|
|
|
|
|
tags = list(set(filtered_tokens)) |
|
|
|
|
|
elapsed_time = (time.time() - start_time) * 1000 |
|
log_performance("tag_image", elapsed_time) |
|
|
|
logger.debug(f"Image tags generated successfully in {elapsed_time:.2f}ms") |
|
return tags |
|
except Exception as e: |
|
logger.error(f"Error generating image tags: {str(e)}") |
|
raise AIModelError(f"Error generating image tags", {"original_error": str(e)}) from e |
|
|
|
@handle_ai_model_exceptions |
|
def create_diagram(diagram_type: str, data: Dict[str, Any]) -> Image.Image: |
|
""" |
|
Create a diagram based on the specified type and data |
|
|
|
Args: |
|
diagram_type: Type of diagram ('flowchart', 'sequence', 'class', etc.) |
|
data: Data for the diagram |
|
|
|
Returns: |
|
PIL Image containing the diagram |
|
|
|
Raises: |
|
AIModelError: If there's an error creating the diagram |
|
""" |
|
logger.debug(f"Creating {diagram_type} diagram") |
|
start_time = time.time() |
|
|
|
try: |
|
|
|
plt.figure(figsize=(12, 8)) |
|
|
|
if diagram_type == 'flowchart': |
|
|
|
import networkx as nx |
|
G = nx.DiGraph() |
|
|
|
|
|
for node in data.get('nodes', []): |
|
G.add_node(node['id'], label=node.get('label', node['id'])) |
|
|
|
|
|
for edge in data.get('edges', []): |
|
G.add_edge(edge['source'], edge['target'], label=edge.get('label', '')) |
|
|
|
|
|
pos = nx.spring_layout(G, seed=42) |
|
|
|
|
|
nx.draw_networkx_nodes(G, pos, node_size=2000, node_color='lightblue', alpha=0.8) |
|
nx.draw_networkx_edges(G, pos, width=2, alpha=0.5, edge_color='gray', arrowsize=20) |
|
|
|
|
|
labels = {node: data['label'] for node, data in G.nodes(data=True)} |
|
nx.draw_networkx_labels(G, pos, labels, font_size=10, font_weight='bold') |
|
|
|
|
|
edge_labels = {(u, v): d['label'] for u, v, d in G.edges(data=True) if 'label' in d} |
|
nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels, font_size=8) |
|
|
|
elif diagram_type == 'bar_chart': |
|
|
|
plt.bar(data.get('labels', []), data.get('values', []), color=data.get('colors', 'blue')) |
|
plt.xlabel(data.get('x_label', '')) |
|
plt.ylabel(data.get('y_label', '')) |
|
plt.title(data.get('title', 'Bar Chart')) |
|
|
|
elif diagram_type == 'pie_chart': |
|
|
|
plt.pie(data.get('values', []), labels=data.get('labels', []), autopct='%1.1f%%', |
|
shadow=data.get('shadow', False), startangle=data.get('start_angle', 90)) |
|
plt.axis('equal') |
|
plt.title(data.get('title', 'Pie Chart')) |
|
|
|
else: |
|
raise ValidationError(f"Unsupported diagram type: {diagram_type}") |
|
|
|
|
|
if diagram_type == 'flowchart': |
|
plt.axis('off') |
|
|
|
|
|
buf = io.BytesIO() |
|
plt.savefig(buf, format='png', dpi=100, bbox_inches='tight') |
|
buf.seek(0) |
|
|
|
|
|
diagram_img = Image.open(buf) |
|
|
|
|
|
plt.close() |
|
|
|
|
|
elapsed_time = (time.time() - start_time) * 1000 |
|
log_performance(f"create_{diagram_type}_diagram", elapsed_time) |
|
|
|
logger.debug(f"{diagram_type.capitalize()} diagram created successfully in {elapsed_time:.2f}ms") |
|
return diagram_img |
|
except Exception as e: |
|
logger.error(f"Error creating {diagram_type} diagram: {str(e)}") |
|
raise AIModelError(f"Error creating {diagram_type} diagram", {"original_error": str(e)}) from e |