File size: 20,653 Bytes
9bc9c3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
import os
import json # Added for JSON parsing
from google import genai
from google.genai import types
from PIL import Image
from io import BytesIO
from langchain_core.tools import tool
from langfuse import Langfuse
from langfuse.decorators import observe, langfuse_context
GOOGLE_API_KEY = os.getenv("GEMINI_API_KEY")
GEMINI_MODEL_NAME = os.getenv("GEMINI_MODEL_NAME", "gemini-2.5-pro-preview-06-05")
GEMINI_THINKING_BUDGET = int(os.getenv("GEMINI_THINKING_BUDGET", "128"))

if not GOOGLE_API_KEY:
    # Consider raising an error or logging if the API key is critical for module loading
    print("Attenzione: GEMINI_API_KEY non trovato nelle variabili d'ambiente.")
    # raise ValueError("GEMINI_API_KEY not found in environment variables.")

# Utilizziamo lo stesso modello specificato in object_detection_tools.py per coerenza,
# o un modello potente per la generazione come "gemini-1.5-pro-latest".
# Se "gemini-2.5-pro-preview-06-05" è disponibile e preferito:
MODEL_NAME = GEMINI_MODEL_NAME
# Altrimenti, un'opzione robusta:
# MODEL_NAME = "gemini-1.5-pro-latest"

try:
    client = genai.Client(api_key=GOOGLE_API_KEY)
except Exception as e:
    print(f"Errore durante l'inizializzazione del client GenAI: {e}")
    client = None # o gestire l'errore come appropriato

# Global safety settings
SAFETY_SETTINGS = [
    types.SafetySetting(category="HARM_CATEGORY_DANGEROUS_CONTENT", threshold="BLOCK_ONLY_HIGH"),
    types.SafetySetting(category="HARM_CATEGORY_HATE_SPEECH", threshold="BLOCK_ONLY_HIGH"),
    types.SafetySetting(category="HARM_CATEGORY_HARASSMENT", threshold="BLOCK_ONLY_HIGH"),
    types.SafetySetting(category="HARM_CATEGORY_SEXUALLY_EXPLICIT", threshold="BLOCK_ONLY_HIGH"),
]

# Langfuse initialization
LANGFUSE_PUBLIC_KEY = os.getenv("LANGFUSE_PUBLIC_KEY")
LANGFUSE_SECRET_KEY = os.getenv("LANGFUSE_SECRET_KEY")
LANGFUSE_HOST = os.getenv("LANGFUSE_HOST", "http://localhost:3000") # Default to local if not set

if LANGFUSE_PUBLIC_KEY and LANGFUSE_SECRET_KEY:
    try:
        Langfuse(
            public_key=LANGFUSE_PUBLIC_KEY,
            secret_key=LANGFUSE_SECRET_KEY,
            host=LANGFUSE_HOST
        )
        print(f"Langfuse tracing enabled for {__name__}")
    except Exception as e:
        print(f"Failed to initialize Langfuse for {__name__}: {e}. Tracing will be disabled.")

import base64
import mimetypes
import os
import re
from xml.etree import ElementTree as ET
from typing import Optional, Dict, Any, List as TypingList

def convert_image_to_base64(image_path: str) -> Optional[str]:
    """
    Converte un'immagine in stringa base64 nel formato Draw.io
    
    Args:
        image_path: Percorso del file immagine
        
    Returns:
        Stringa base64 nel formato Draw.io (data:image/type,base64_data) o None se errore
    """
    if not os.path.exists(image_path):
        print(f"Warning: Image {image_path} not found")
        return None
        
    try:
        # Determina il MIME type
        mime_type, _ = mimetypes.guess_type(image_path)
        if not mime_type or not mime_type.startswith('image/'):
            mime_type = 'image/png'  # default fallback
        
        # Leggi e converti in base64
        with open(image_path, 'rb') as img_file:
            img_data = img_file.read()
            base64_str = base64.b64encode(img_data).decode('utf-8')
            
        # Formato Draw.io: data:image/type,base64_data (SENZA ;base64)
        return f"data:{mime_type},{base64_str}"
        
    except Exception as e:
        print(f"Error converting {image_path} to base64: {e}")
        return None

def replace_image_references_in_drawio_xml(xml_content: str, base_folder: str = "output_llm") -> str:
    """
    Sostituisce tutti i riferimenti alle immagini nell'XML Draw.io con versioni base64
    
    Args:
        xml_content: Contenuto XML Draw.io come stringa
        base_folder: Cartella base dove cercare le immagini
        
    Returns:
        XML modificato con immagini base64 embedded
    """
    try:
        # Pattern per trovare riferimenti alle immagini negli attributi style
        # Cerca pattern come: image=filename.png o image='filename.png' o image="filename.png"
        image_patterns = [
            r'image=([\'"]?)([^\'";,\s]+\.(png|jpg|jpeg|gif|bmp|svg))\1',  # image=file.png, image='file.png', image="file.png"
            r'image=([\'"]?)(file://\.?/?([^\'";,\s]+\.(png|jpg|jpeg|gif|bmp|svg)))\1',  # image=file://./file.png
        ]
        
        modified_xml = xml_content
        processed_files = set()  # Per evitare conversioni duplicate
        
        for pattern in image_patterns:
            matches = re.finditer(pattern, modified_xml, re.IGNORECASE)
            
            for match in matches:
                full_match = match.group(0)
                quote_char = match.group(1) if match.group(1) else ''
                
                # Estrai il nome del file
                if 'file://' in full_match:
                    # Per pattern file://./filename.png
                    filename = match.group(3) if len(match.groups()) >= 3 else match.group(2)
                else:
                    # Per pattern semplici
                    filename = match.group(2)
                
                # Rimuovi eventuali prefissi di path
                filename = os.path.basename(filename)
                
                if filename in processed_files:
                    continue
                    
                processed_files.add(filename)
                image_path = os.path.join(base_folder, filename)
                
                # Converti in base64
                base64_data = convert_image_to_base64(image_path)
                
                if base64_data:
                    # Sostituisci tutti i riferimenti a questo file
                    old_patterns = [
                        f'image={quote_char}{filename}{quote_char}',
                        f'image={quote_char}file://\./{filename}{quote_char}',
                        f'image={quote_char}file://{filename}{quote_char}',
                        f'image={filename}',  # senza quote
                    ]
                    
                    new_value = f'image={quote_char}{base64_data}{quote_char}' if quote_char else f'image={base64_data}'
                    
                    for old_pattern in old_patterns:
                        modified_xml = modified_xml.replace(old_pattern, new_value)
                    
                    print(f"Replaced image reference: {filename} -> base64 ({len(base64_data)} chars)")
                else:
                    print(f"Failed to convert image: {filename}")
        
        return modified_xml
        
    except Exception as e:
        print(f"Error processing XML: {e}")
        return xml_content  # Ritorna l'originale in caso di errore

def replace_image_references_xml_parser(xml_content: str, base_folder: str = "output_llm") -> str:
    """
    Versione alternativa che usa XML parser per maggiore precisione
    Sostituisce i riferimenti alle immagini negli attributi style dei mxCell
    """
    try:
        # Parse dell'XML
        root = ET.fromstring(xml_content)
        
        # Trova tutti gli elementi mxCell con attributo style contenente image=
        for cell in root.iter('mxCell'):
            style = cell.get('style', '')
            if 'image=' in style:
                # Estrai il valore dell'immagine dallo style
                style_parts = style.split(';')
                new_style_parts = []
                
                for part in style_parts:
                    if part.startswith('image='):
                        # Estrai il nome del file
                        image_ref = part[6:]  # Rimuovi 'image='
                        
                        # Rimuovi eventuali quote
                        if image_ref.startswith('"') and image_ref.endswith('"'):
                            image_ref = image_ref[1:-1]
                        elif image_ref.startswith("'") and image_ref.endswith("'"):
                            image_ref = image_ref[1:-1]
                        
                        # Gestisci file:// prefix
                        if image_ref.startswith('file://'):
                            image_ref = image_ref.replace('file://', '').lstrip('./')
                        
                        filename = os.path.basename(image_ref)
                        image_path = os.path.join(base_folder, filename)
                        
                        # Converti in base64
                        base64_data = convert_image_to_base64(image_path)
                        
                        if base64_data:
                            new_style_parts.append(f'image={base64_data}')
                            print(f"XML Parser: Replaced {filename} with base64 data")
                        else:
                            new_style_parts.append(part)  # Mantieni originale se conversione fallisce
                    else:
                        new_style_parts.append(part)
                
                # Ricostruisci lo style
                cell.set('style', ';'.join(new_style_parts))
        
        # Converti back in stringa
        return ET.tostring(root, encoding='unicode')
        
    except ET.ParseError as e:
        print(f"XML parsing error: {e}")
        # Fallback al metodo regex
        return replace_image_references_in_drawio_xml(xml_content, base_folder)
    except Exception as e:
        print(f"Error in XML parser method: {e}")
        return xml_content


@tool("generate_drawio_from_image_and_objects_tool", parse_docstring=True) # Uncomment if you plan to use it directly as a langchain tool
@observe(as_type="generation")
def generate_drawio_from_image_and_objects(original_image_path: str, object_names: list[str]) -> str:
    """
    Generates a Draw.io XML diagram from an original image and a list of detected object names.

    The function first instructs a generative model to create a Draw.io XML representation
    of the scene in the original image. It then incorporates references to cropped images
    of specified objects (expected to be in the 'output_llm' folder).
    Finally, it post-processes this XML to replace all local image file references
    with their base64 encoded data, making the Draw.io diagram self-contained.

    Args:
        original_image_path (str): The file path to the original image to be diagrammed.
        object_names (list[str]): A list of object names (e.g., ['cat.png', 'dog.png']) that have been previously detected and saved as image files in the 'output_llm' folder. These will be embedded into the diagram.

    Returns:
        bool: True if the Draw.io XML was successfully generated and saved, or an error message if something went wrong.
    """
    if not GOOGLE_API_KEY or not client:
        return "Errore: GEMINI_API_KEY non configurato o client non inizializzato."

    try:
        with open(original_image_path, "rb") as f:
            img_bytes = f.read()
        original_image = Image.open(BytesIO(img_bytes))
        original_image.thumbnail([1024, 1024], Image.Resampling.LANCZOS)

        object_image_folder = "output_llm"

        prompt_parts = [
            "Generate a Draw.io XML diagram for the provided original image.",
            "The diagram should represent the overall scene, focusing on spatial relationships and composition."
        ]

        if object_names:
            object_filenames_str = ", ".join([f"'{name}'" for name in object_names])
            prompt_parts.extend([
                f"Incorporate the following object images as assets: {object_filenames_str}.",
                f"These images are in the '{object_image_folder}' directory.",
                "Use simple filename references in the image attribute, like: image=cat.png",
                "Do NOT use base64 encoding - just use the filename directly.",
                "The image paths will be processed later to embed the actual image data."
            ])

        prompt_parts.extend([
            "Position and size elements based on their approximate location in the original image.",
            "Create complete Draw.io XML structure with proper mxGraphModel, root, and mxCell elements.",
            "Ensure all mxCell elements have unique id attributes."
        ])
        
        user_prompt = " ".join(prompt_parts)

        # System instructions semplificato per riferimenti diretti
        simple_ref_instructions = """
You are an expert Draw.io diagram generator.
Create Draw.io XML using simple filename references for images.

Structure:
<mxfile compressed="false" host="GeminiAgent" version="1.0" type="device">
  <diagram id="diagram-1" name="Page-1">
    <mxGraphModel dx="1000" dy="600" grid="1" gridSize="10" guides="1" tooltips="1" connect="1" arrows="1" fold="1" page="1" pageScale="1" pageWidth="850" pageHeight="1100" math="0" shadow="0">
      <root>
        <mxCell id="0" />
        <mxCell id="1" parent="0" />
        
        <mxCell id="obj_1" value="object_name" style="shape=image;html=1;imageAspect=1;aspect=fixed;image=filename.png" vertex="1" parent="1">
          <mxGeometry x="100" y="100" width="80" height="60" as="geometry" />
        </mxCell>
      </root>
    </mxGraphModel>
  </diagram>
</mxfile>

Use simple filename references like 'image=cat.png' - do NOT embed base64 data.
Position elements to match the original image layout.
"""

        # System instructions per la verifica e correzione dell'XML
        verify_correct_xml_instructions = """
You are an expert Draw.io diagram verifier and corrector.
You will be given an original image and a Draw.io XML generated for that image.
Your task is to:
1. Verify if the XML accurately represents the objects, their positions, sizes, and connections as shown in the original image.
2. Correct any inaccuracies in the XML. This includes adjusting positions, sizes, shapes, or connections.
3. Ensure all image references within the XML use simple filename references (e.g., image=filename.png). Do NOT use base64 encoding.
4. Ensure the XML structure is valid Draw.io format.
Return ONLY the corrected Draw.io XML. Do not include any other text, explanations, or markdown formatting around the XML.
If the XML is already perfect, return it as is.
Focus on accuracy of representation and valid Draw.io XML output.
"""

        response = client.models.generate_content(
            contents=[user_prompt, original_image],
            model=MODEL_NAME,
            config=types.GenerateContentConfig(
                system_instruction=simple_ref_instructions,
                temperature=0,
                safety_settings=SAFETY_SETTINGS,
                thinking_config=types.ThinkingConfig(thinking_budget=GEMINI_THINKING_BUDGET)
            )
        )

        langfuse_context.update_current_observation(
            input=[user_prompt, original_image],
            model=MODEL_NAME,
            usage_details={
                "input": response.usage_metadata.prompt_token_count,
                "output": response.usage_metadata.candidates_token_count,
                "total": response.usage_metadata.total_token_count
            }
        )

        xml_output = response.text.strip()
        
        # Clean up markdown formatting
        if xml_output.startswith("```xml"): 
            xml_output = xml_output[len("```xml"):]
        if xml_output.endswith("```"): 
            xml_output = xml_output[:-len("```")]
        
        xml_output = xml_output.strip()
        
        # SECONDA CHIAMATA LLM: Verifica e correzione dell'XML generato
        print("Second LLM call: Verifying and correcting generated XML...")
        verification_prompt_parts = [
            original_image, # L'immagine originale
            f"Generated Draw.io XML to verify and correct:\n{xml_output}", # L'XML generato
            "Please verify this XML against the original image. Correct any errors in object placement, connections, or representation. Ensure all image references are simple filenames like 'image=filename.png'. Return only the corrected Draw.io XML."
        ]

        correction_response = client.models.generate_content(
            contents=verification_prompt_parts,
            model=MODEL_NAME,
            config=types.GenerateContentConfig(
                system_instruction=verify_correct_xml_instructions,
                temperature=0, # Bassa temperatura per output più deterministico/corretto
                safety_settings=SAFETY_SETTINGS,
                thinking_config=types.ThinkingConfig(thinking_budget=GEMINI_THINKING_BUDGET)
            )
        )

        langfuse_context.update_current_observation(
            input=verification_prompt_parts, # Aggiorna l'input per il trace della correzione
            model=MODEL_NAME,
            metadata={"step": "xml_correction"}, # Aggiungi metadati per distinguere questa chiamata
            usage_details={
                "input": correction_response.usage_metadata.prompt_token_count,
                "output": correction_response.usage_metadata.candidates_token_count,
                "total": correction_response.usage_metadata.total_token_count
            }
        )

        xml_output = correction_response.text.strip() # Sovrascrivi xml_output con la versione corretta
        if xml_output.startswith("```xml"): xml_output = xml_output[len("```xml"):]
        if xml_output.endswith("```"): xml_output = xml_output[:-len("```")]
        xml_output = xml_output.strip()
        print("XML verification/correction complete.")

        # POST-PROCESSING: Sostituisci i riferimenti con base64
        print("Post-processing: Converting image references to base64...")
        final_xml = replace_image_references_xml_parser(xml_output, object_image_folder)
        save_message = save_drawio_xml(final_xml, "drawio_output", output_directory="output_llm")
        print(save_message)

        return True

    except FileNotFoundError:
        return f"Errore: File immagine originale non trovato a {original_image_path}."
    except Exception as e:
        print(f"Errore dettagliato in generate_drawio_from_image_and_objects_v4: {e}")
        return f"Errore durante la generazione dell'XML Draw.io: {str(e)}"


# Funzione standalone per post-processare XML esistenti
def post_process_drawio_xml_file(xml_file_path: str, base_folder: str = "output_llm", output_path: str = None) -> str:
    """
    Post-processa un file XML Draw.io esistente per sostituire i riferimenti alle immagini
    
    Args:
        xml_file_path: Percorso del file XML Draw.io
        base_folder: Cartella base per le immagini
        output_path: Percorso di output (se None, sovrascrive l'originale)
        
    Returns:
        Percorso del file processato
    """
    try:
        with open(xml_file_path, 'r', encoding='utf-8') as f:
            xml_content = f.read()
        
        processed_xml = replace_image_references_xml_parser(xml_content, base_folder)
        
        if output_path is None:
            output_path = xml_file_path
        
        with open(output_path, 'w', encoding='utf-8') as f:
            f.write(processed_xml)
        
        print(f"Processed XML saved to: {output_path}")
        return output_path
        
    except Exception as e:
        print(f"Error processing XML file: {e}")
        return xml_file_path
    

def save_drawio_xml(xml_content: str, filename_prefix: str, output_directory: str = "output_llm") -> str:
    """
    Salva una stringa XML di Draw.io in un file .drawio.

    Args:
        xml_content (str): La stringa XML del diagramma Draw.io.
        filename_prefix (str): Il prefisso per il nome del file. Il file verrà salvato come '{filename_prefix}.drawio'.
        output_directory (str): La directory dove salvare il file. Default 'output_llm'.

    Returns:
        str: Il percorso del file salvato o un messaggio di errore.
    """
    try:
        if not os.path.exists(output_directory):
            os.makedirs(output_directory)

        # Assicurati che il nome del file finisca con .drawio
        if not filename_prefix.endswith(".drawio"):
            filename = f"{filename_prefix}.drawio"
        else:
            filename = filename_prefix

        file_path = os.path.join(output_directory, filename)
        with open(file_path, "w", encoding="utf-8") as f:
            f.write(xml_content)
        return f"File Draw.io salvato con successo in: {os.path.abspath(file_path)}"
    except Exception as e:
        return f"Errore durante il salvataggio del file Draw.io: {str(e)}"