File size: 22,498 Bytes
8e4018d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
import os
import torch
import numpy as np
import time
import io
import base64
from typing import Dict, List, Any, Union, Optional, Tuple
from pathlib import Path
from PIL import Image
import qrcode
import cv2
import matplotlib.pyplot as plt
from transformers import BlipProcessor, BlipForConditionalGeneration, WhisperProcessor, WhisperForConditionalGeneration
from gtts import gTTS

from utils.config import AI_MODELS
from utils.logging import get_logger, log_performance, log_ai_model_usage
from utils.error_handling import handle_ai_model_exceptions, AIModelError, ValidationError

# Initialize logger
logger = get_logger(__name__)

# Set environment variable to use CPU if no GPU available
os.environ["CUDA_VISIBLE_DEVICES"] = "" if not torch.cuda.is_available() else "0"

# Import the model cache from ai_models
from utils.ai_models import MODEL_CACHE, get_model

@handle_ai_model_exceptions
def analyze_image(image, max_length: Optional[int] = None) -> str:
    """
    Generate a caption for an image using BLIP model
    
    Args:
        image: PIL Image or path to image file
        max_length: Maximum length of caption (uses config default if None)
        
    Returns:
        Generated caption
        
    Raises:
        AIModelError: If there's an error generating the caption
    """
    task = "image_captioning"
    model_config = AI_MODELS[task]
    model_name = model_config["name"]
    
    # Use config defaults if not provided
    if max_length is None:
        max_length = model_config.get("max_length", 50)
    
    logger.debug(f"Generating caption for image")
    start_time = time.time()
    
    # Load the model and processor
    model, processor = get_model(task)
    
    try:
        # Convert to PIL Image if path is provided
        if isinstance(image, (str, Path)):
            image = Image.open(image).convert('RGB')
        elif not isinstance(image, Image.Image):
            # Try to convert from numpy array or other format
            if isinstance(image, np.ndarray):
                image = Image.fromarray(image)
            else:
                raise ValidationError("Unsupported image format")
        
        # Process the image
        inputs = processor(image, return_tensors="pt")
        
        # Generate caption
        with torch.no_grad():
            output = model.generate(
                **inputs,
                max_length=max_length,
                num_beams=5,
                early_stopping=True
            )
        
        # Decode the caption
        caption = processor.decode(output[0], skip_special_tokens=True)
        
        # Log performance and usage
        elapsed_time = (time.time() - start_time) * 1000  # Convert to ms
        log_performance("analyze_image", elapsed_time)
        log_ai_model_usage(model_name, "image_captioning", len(output[0]))
        
        logger.debug(f"Image caption generated successfully in {elapsed_time:.2f}ms")
        return caption
    except Exception as e:
        logger.error(f"Error generating image caption: {str(e)}")
        raise AIModelError(f"Error generating image caption", {"original_error": str(e)}) from e

@handle_ai_model_exceptions
def extract_text_from_image(image) -> str:
    """
    Extract text from an image using OCR
    
    Args:
        image: PIL Image or path to image file
        
    Returns:
        Extracted text
        
    Raises:
        AIModelError: If there's an error extracting text
    """
    logger.debug(f"Extracting text from image")
    start_time = time.time()
    
    try:
        # Import EasyOCR here to avoid loading it unless needed
        import easyocr
        
        # Convert to PIL Image if path is provided
        if isinstance(image, (str, Path)):
            image_path = str(image)
            image = Image.open(image).convert('RGB')
        elif isinstance(image, Image.Image):
            # Save PIL image to a temporary file
            temp_path = os.path.join(os.path.dirname(__file__), "temp_ocr_image.jpg")
            image.save(temp_path)
            image_path = temp_path
        elif isinstance(image, np.ndarray):
            # Save numpy array to a temporary file
            temp_path = os.path.join(os.path.dirname(__file__), "temp_ocr_image.jpg")
            cv2.imwrite(temp_path, image)
            image_path = temp_path
        else:
            raise ValidationError("Unsupported image format")
        
        # Initialize the OCR reader
        reader = easyocr.Reader(['en'])
        
        # Extract text
        results = reader.readtext(image_path)
        
        # Combine all detected text
        extracted_text = "\n".join([result[1] for result in results])
        
        # Clean up temporary file if created
        if 'temp_path' in locals() and os.path.exists(temp_path):
            os.remove(temp_path)
        
        # Log performance
        elapsed_time = (time.time() - start_time) * 1000  # Convert to ms
        log_performance("extract_text_from_image", elapsed_time)
        
        logger.debug(f"Text extracted successfully in {elapsed_time:.2f}ms")
        return extracted_text
    except Exception as e:
        logger.error(f"Error extracting text from image: {str(e)}")
        raise AIModelError(f"Error extracting text from image", {"original_error": str(e)}) from e

@handle_ai_model_exceptions
def generate_qr_code(data: str, box_size: int = 10, border: int = 4) -> Image.Image:
    """
    Generate a QR code from text data
    
    Args:
        data: Text data to encode in the QR code
        box_size: Size of each box in the QR code
        border: Border size of the QR code
        
    Returns:
        PIL Image containing the QR code
        
    Raises:
        AIModelError: If there's an error generating the QR code
    """
    logger.debug(f"Generating QR code for data: {data[:20]}...")
    start_time = time.time()
    
    try:
        # Create QR code instance
        qr = qrcode.QRCode(
            version=1,
            error_correction=qrcode.constants.ERROR_CORRECT_L,
            box_size=box_size,
            border=border,
        )
        
        # Add data to the QR code
        qr.add_data(data)
        qr.make(fit=True)
        
        # Create an image from the QR Code instance
        img = qr.make_image(fill_color="black", back_color="white")
        
        # Log performance
        elapsed_time = (time.time() - start_time) * 1000  # Convert to ms
        log_performance("generate_qr_code", elapsed_time)
        
        logger.debug(f"QR code generated successfully in {elapsed_time:.2f}ms")
        return img
    except Exception as e:
        logger.error(f"Error generating QR code: {str(e)}")
        raise AIModelError(f"Error generating QR code", {"original_error": str(e)}) from e

@handle_ai_model_exceptions
def scan_document(image) -> Dict[str, Any]:
    """
    Scan a document from an image, extract text and detect document boundaries
    
    Args:
        image: PIL Image or path to image file
        
    Returns:
        Dictionary with extracted text and processed image
        
    Raises:
        AIModelError: If there's an error scanning the document
    """
    logger.debug(f"Scanning document from image")
    start_time = time.time()
    
    try:
        # Convert to OpenCV format if needed
        if isinstance(image, (str, Path)):
            img = cv2.imread(str(image))
        elif isinstance(image, Image.Image):
            img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
        elif isinstance(image, np.ndarray):
            img = image
        else:
            raise ValidationError("Unsupported image format")
        
        # Convert to grayscale
        gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
        
        # Apply Gaussian blur
        blur = cv2.GaussianBlur(gray, (5, 5), 0)
        
        # Apply edge detection
        edges = cv2.Canny(blur, 75, 200)
        
        # Find contours
        contours, _ = cv2.findContours(edges, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
        contours = sorted(contours, key=cv2.contourArea, reverse=True)
        
        # Initialize document contour
        doc_contour = None
        
        # Find the document contour (largest contour with 4 corners)
        for contour in contours:
            perimeter = cv2.arcLength(contour, True)
            approx = cv2.approxPolyDP(contour, 0.02 * perimeter, True)
            
            if len(approx) == 4:
                doc_contour = approx
                break
        
        # Process the document if contour found
        if doc_contour is not None:
            # Draw the contour on a copy of the original image
            img_with_contour = img.copy()
            cv2.drawContours(img_with_contour, [doc_contour], -1, (0, 255, 0), 2)
            
            # Convert back to PIL for consistency
            processed_img = Image.fromarray(cv2.cvtColor(img_with_contour, cv2.COLOR_BGR2RGB))
        else:
            # If no document contour found, use original image
            processed_img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
        
        # Extract text from the document
        extracted_text = extract_text_from_image(gray)
        
        # Log performance
        elapsed_time = (time.time() - start_time) * 1000  # Convert to ms
        log_performance("scan_document", elapsed_time)
        
        logger.debug(f"Document scanned successfully in {elapsed_time:.2f}ms")
        
        return {
            "text": extracted_text,
            "processed_image": processed_img,
            "document_detected": doc_contour is not None
        }
    except Exception as e:
        logger.error(f"Error scanning document: {str(e)}")
        raise AIModelError(f"Error scanning document", {"original_error": str(e)}) from e

@handle_ai_model_exceptions
def generate_mind_map(topics: List[str], connections: Optional[List[Tuple[int, int]]] = None) -> Image.Image:
    """
    Generate a mind map visualization from topics and their connections
    
    Args:
        topics: List of topic strings
        connections: List of tuples indicating connections between topics by index
        
    Returns:
        PIL Image containing the mind map
        
    Raises:
        AIModelError: If there's an error generating the mind map
    """
    logger.debug(f"Generating mind map with {len(topics)} topics")
    start_time = time.time()
    
    try:
        # Create a new figure
        plt.figure(figsize=(12, 8))
        
        # If no connections provided, create a radial structure from first topic
        if connections is None:
            connections = [(0, i) for i in range(1, len(topics))]
        
        # Create a graph using networkx
        import networkx as nx
        G = nx.Graph()
        
        # Add nodes (topics)
        for i, topic in enumerate(topics):
            G.add_node(i, label=topic)
        
        # Add edges (connections)
        for source, target in connections:
            G.add_edge(source, target)
        
        # Create positions for nodes
        pos = nx.spring_layout(G, seed=42)  # For reproducibility
        
        # Draw the graph
        nx.draw_networkx_nodes(G, pos, node_size=2000, node_color='skyblue', alpha=0.8)
        nx.draw_networkx_edges(G, pos, width=2, alpha=0.5, edge_color='gray')
        
        # Add labels
        labels = {i: data['label'] for i, data in G.nodes(data=True)}
        nx.draw_networkx_labels(G, pos, labels, font_size=10, font_weight='bold')
        
        # Remove axis
        plt.axis('off')
        
        # Save the figure to a buffer
        buf = io.BytesIO()
        plt.savefig(buf, format='png', dpi=100, bbox_inches='tight')
        buf.seek(0)
        
        # Convert buffer to PIL Image
        mind_map_img = Image.open(buf)
        
        # Close the figure to free memory
        plt.close()
        
        # Log performance
        elapsed_time = (time.time() - start_time) * 1000  # Convert to ms
        log_performance("generate_mind_map", elapsed_time)
        
        logger.debug(f"Mind map generated successfully in {elapsed_time:.2f}ms")
        return mind_map_img
    except Exception as e:
        logger.error(f"Error generating mind map: {str(e)}")
        raise AIModelError(f"Error generating mind map", {"original_error": str(e)}) from e

@handle_ai_model_exceptions
def transcribe_speech(audio_file) -> str:
    """
    Transcribe speech from an audio file using Whisper model
    
    Args:
        audio_file: Path to audio file or file-like object
        
    Returns:
        Transcribed text
        
    Raises:
        AIModelError: If there's an error transcribing the speech
    """
    task = "speech_to_text"
    model_name = AI_MODELS[task]["name"]
    
    logger.debug(f"Transcribing speech from audio file")
    start_time = time.time()
    
    # Load the model and processor
    model, processor = get_model(task)
    
    try:
        # Load audio file
        if isinstance(audio_file, (str, Path)):
            # Load audio file using librosa
            import librosa
            audio_array, sampling_rate = librosa.load(audio_file, sr=16000)
        else:
            raise ValidationError("Unsupported audio format")
        
        # Process the audio
        input_features = processor(audio_array, sampling_rate=16000, return_tensors="pt").input_features
        
        # Generate transcription
        with torch.no_grad():
            predicted_ids = model.generate(input_features)
        
        # Decode the transcription
        transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
        
        # Log performance and usage
        elapsed_time = (time.time() - start_time) * 1000  # Convert to ms
        log_performance("transcribe_speech", elapsed_time)
        log_ai_model_usage(model_name, "speech_to_text", len(predicted_ids[0]))
        
        logger.debug(f"Speech transcribed successfully in {elapsed_time:.2f}ms")
        return transcription
    except Exception as e:
        logger.error(f"Error transcribing speech: {str(e)}")
        raise AIModelError(f"Error transcribing speech", {"original_error": str(e)}) from e

@handle_ai_model_exceptions
def text_to_speech(text: str, lang: str = 'en', slow: bool = False) -> bytes:
    """
    Convert text to speech using gTTS
    
    Args:
        text: Text to convert to speech
        lang: Language code (default: 'en')
        slow: Whether to speak slowly (default: False)
        
    Returns:
        Audio data as bytes
        
    Raises:
        AIModelError: If there's an error converting text to speech
    """
    logger.debug(f"Converting text to speech: {text[:50]}...")
    start_time = time.time()
    
    try:
        # Create a BytesIO object to store the audio file
        audio_io = io.BytesIO()
        
        # Create gTTS object
        tts = gTTS(text=text, lang=lang, slow=slow)
        
        # Save the audio to the BytesIO object
        tts.write_to_fp(audio_io)
        
        # Reset the pointer to the beginning of the BytesIO object
        audio_io.seek(0)
        
        # Get the audio data as bytes
        audio_data = audio_io.read()
        
        # Log performance
        elapsed_time = (time.time() - start_time) * 1000  # Convert to ms
        log_performance("text_to_speech", elapsed_time)
        
        logger.debug(f"Text converted to speech successfully in {elapsed_time:.2f}ms")
        return audio_data
    except Exception as e:
        logger.error(f"Error converting text to speech: {str(e)}")
        raise AIModelError(f"Error converting text to speech", {"original_error": str(e)}) from e

@handle_ai_model_exceptions
def detect_language(audio_file) -> str:
    """
    Detect language from speech in an audio file
    
    Args:
        audio_file: Path to audio file or file-like object
        
    Returns:
        Detected language code
        
    Raises:
        AIModelError: If there's an error detecting the language
    """
    logger.debug(f"Detecting language from audio file")
    start_time = time.time()
    
    try:
        # First transcribe the speech
        transcription = transcribe_speech(audio_file)
        
        # Use langdetect to identify the language
        from langdetect import detect
        language_code = detect(transcription)
        
        # Log performance
        elapsed_time = (time.time() - start_time) * 1000  # Convert to ms
        log_performance("detect_language", elapsed_time)
        
        logger.debug(f"Language detected successfully in {elapsed_time:.2f}ms: {language_code}")
        return language_code
    except Exception as e:
        logger.error(f"Error detecting language: {str(e)}")
        raise AIModelError(f"Error detecting language", {"original_error": str(e)}) from e

@handle_ai_model_exceptions
def tag_image(image) -> List[str]:
    """
    Generate tags for an image using image captioning and NLP
    
    Args:
        image: PIL Image or path to image file
        
    Returns:
        List of tags
        
    Raises:
        AIModelError: If there's an error generating tags
    """
    logger.debug(f"Generating tags for image")
    start_time = time.time()
    
    try:
        # First generate a caption for the image
        caption = analyze_image(image)
        
        # Use NLP to extract keywords as tags
        import nltk
        from nltk.corpus import stopwords
        from nltk.tokenize import word_tokenize
        
        # Download necessary NLTK data if not already present
        try:
            nltk.data.find('tokenizers/punkt')
        except LookupError:
            nltk.download('punkt')
        
        try:
            nltk.data.find('corpora/stopwords')
        except LookupError:
            nltk.download('stopwords')
        
        # Tokenize the caption
        tokens = word_tokenize(caption.lower())
        
        # Remove stopwords and non-alphabetic tokens
        stop_words = set(stopwords.words('english'))
        filtered_tokens = [word for word in tokens if word.isalpha() and word not in stop_words]
        
        # Get unique tags
        tags = list(set(filtered_tokens))
        
        # Log performance
        elapsed_time = (time.time() - start_time) * 1000  # Convert to ms
        log_performance("tag_image", elapsed_time)
        
        logger.debug(f"Image tags generated successfully in {elapsed_time:.2f}ms")
        return tags
    except Exception as e:
        logger.error(f"Error generating image tags: {str(e)}")
        raise AIModelError(f"Error generating image tags", {"original_error": str(e)}) from e

@handle_ai_model_exceptions
def create_diagram(diagram_type: str, data: Dict[str, Any]) -> Image.Image:
    """
    Create a diagram based on the specified type and data
    
    Args:
        diagram_type: Type of diagram ('flowchart', 'sequence', 'class', etc.)
        data: Data for the diagram
        
    Returns:
        PIL Image containing the diagram
        
    Raises:
        AIModelError: If there's an error creating the diagram
    """
    logger.debug(f"Creating {diagram_type} diagram")
    start_time = time.time()
    
    try:
        # Create a new figure
        plt.figure(figsize=(12, 8))
        
        if diagram_type == 'flowchart':
            # Create a flowchart using networkx
            import networkx as nx
            G = nx.DiGraph()
            
            # Add nodes
            for node in data.get('nodes', []):
                G.add_node(node['id'], label=node.get('label', node['id']))
            
            # Add edges
            for edge in data.get('edges', []):
                G.add_edge(edge['source'], edge['target'], label=edge.get('label', ''))
            
            # Create positions for nodes
            pos = nx.spring_layout(G, seed=42)  # For reproducibility
            
            # Draw the graph
            nx.draw_networkx_nodes(G, pos, node_size=2000, node_color='lightblue', alpha=0.8)
            nx.draw_networkx_edges(G, pos, width=2, alpha=0.5, edge_color='gray', arrowsize=20)
            
            # Add labels
            labels = {node: data['label'] for node, data in G.nodes(data=True)}
            nx.draw_networkx_labels(G, pos, labels, font_size=10, font_weight='bold')
            
            # Add edge labels
            edge_labels = {(u, v): d['label'] for u, v, d in G.edges(data=True) if 'label' in d}
            nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels, font_size=8)
            
        elif diagram_type == 'bar_chart':
            # Create a bar chart
            plt.bar(data.get('labels', []), data.get('values', []), color=data.get('colors', 'blue'))
            plt.xlabel(data.get('x_label', ''))
            plt.ylabel(data.get('y_label', ''))
            plt.title(data.get('title', 'Bar Chart'))
            
        elif diagram_type == 'pie_chart':
            # Create a pie chart
            plt.pie(data.get('values', []), labels=data.get('labels', []), autopct='%1.1f%%', 
                   shadow=data.get('shadow', False), startangle=data.get('start_angle', 90))
            plt.axis('equal')  # Equal aspect ratio ensures that pie is drawn as a circle
            plt.title(data.get('title', 'Pie Chart'))
            
        else:
            raise ValidationError(f"Unsupported diagram type: {diagram_type}")
        
        # Remove axis for flowcharts
        if diagram_type == 'flowchart':
            plt.axis('off')
        
        # Save the figure to a buffer
        buf = io.BytesIO()
        plt.savefig(buf, format='png', dpi=100, bbox_inches='tight')
        buf.seek(0)
        
        # Convert buffer to PIL Image
        diagram_img = Image.open(buf)
        
        # Close the figure to free memory
        plt.close()
        
        # Log performance
        elapsed_time = (time.time() - start_time) * 1000  # Convert to ms
        log_performance(f"create_{diagram_type}_diagram", elapsed_time)
        
        logger.debug(f"{diagram_type.capitalize()} diagram created successfully in {elapsed_time:.2f}ms")
        return diagram_img
    except Exception as e:
        logger.error(f"Error creating {diagram_type} diagram: {str(e)}")
        raise AIModelError(f"Error creating {diagram_type} diagram", {"original_error": str(e)}) from e