File size: 26,632 Bytes
80a1334
aae35f1
80a1334
 
 
aae35f1
80a1334
aae35f1
 
 
 
 
 
 
 
 
 
 
 
 
80a1334
 
 
aae35f1
 
 
 
 
 
 
 
 
 
 
80a1334
 
 
aae35f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80a1334
 
 
 
 
 
 
 
 
 
 
aae35f1
 
 
 
80a1334
 
aae35f1
80a1334
aae35f1
 
80a1334
 
 
aae35f1
 
80a1334
 
 
 
 
 
 
 
aae35f1
 
 
80a1334
aae35f1
80a1334
 
aae35f1
 
80a1334
 
aae35f1
80a1334
 
 
 
aae35f1
 
 
 
 
 
 
 
80a1334
 
aae35f1
80a1334
aae35f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80a1334
aae35f1
80a1334
aae35f1
80a1334
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aae35f1
 
80a1334
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aae35f1
 
80a1334
 
 
aae35f1
 
 
 
 
 
 
 
 
 
 
80a1334
 
aae35f1
80a1334
 
aae35f1
 
 
80a1334
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
import os
import logging
from dotenv import load_dotenv
import google.generativeai as genai
from hardware_detector import HardwareDetector
from optimization_knowledge import get_optimization_guide
from typing import Dict, List
import json

# Optional imports for tool calling
try:
    import requests
    from urllib.parse import urljoin, urlparse
    from bs4 import BeautifulSoup
    TOOLS_AVAILABLE = True
except ImportError:
    TOOLS_AVAILABLE = False
    requests = None
    urlparse = None
    BeautifulSoup = None

load_dotenv()

# Configure logging
logging.basicConfig(
    level=logging.DEBUG,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('auto_diffusers.log'),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)


class AutoDiffusersGenerator:
    def __init__(self, api_key: str):
        logger.info("Initializing AutoDiffusersGenerator")
        logger.debug(f"API key length: {len(api_key) if api_key else 'None'}")
        
        try:
            genai.configure(api_key=api_key)
            
            # Define tools for Gemini to use (if available)
            if TOOLS_AVAILABLE:
                self.tools = self._create_tools()
                # Initialize model with tools
                self.model = genai.GenerativeModel(
                    'gemini-2.5-flash-preview-05-20',
                    tools=self.tools
                )
                logger.info("Successfully configured Gemini AI model with tools")
            else:
                self.tools = None
                # Initialize model without tools
                self.model = genai.GenerativeModel('gemini-2.5-flash-preview-05-20')
                logger.warning("Tool calling dependencies not available, running without tools")
        except Exception as e:
            logger.error(f"Failed to configure Gemini AI: {e}")
            raise
        
        try:
            self.hardware_detector = HardwareDetector()
            logger.info("Hardware detector initialized successfully")
        except Exception as e:
            logger.error(f"Failed to initialize hardware detector: {e}")
            raise
    
    def _create_tools(self):
        """Create function tools for Gemini to use."""
        logger.debug("Creating tools for Gemini")
        
        if not TOOLS_AVAILABLE:
            logger.warning("Tools dependencies not available, returning empty tools")
            return []
        
        def fetch_huggingface_docs(url: str) -> str:
            """Fetch documentation from HuggingFace URLs."""
            logger.info("🌐 TOOL CALL: fetch_huggingface_docs")
            logger.info(f"πŸ“‹ Requested URL: {url}")
            
            try:
                # Validate URL is from HuggingFace
                parsed = urlparse(url)
                logger.debug(f"URL validation - Domain: {parsed.netloc}, Path: {parsed.path}")
                
                if not any(domain in parsed.netloc for domain in ['huggingface.co', 'hf.co']):
                    error_msg = "Error: URL must be from huggingface.co domain"
                    logger.warning(f"❌ URL validation failed: {error_msg}")
                    return error_msg
                
                logger.info(f"βœ… URL validation passed for domain: {parsed.netloc}")
                
                headers = {
                    'User-Agent': 'Auto-Diffusers-Config/1.0 (Educational Tool)'
                }
                
                logger.info(f"πŸ”„ Fetching content from: {url}")
                response = requests.get(url, headers=headers, timeout=10)
                response.raise_for_status()
                logger.info(f"βœ… HTTP {response.status_code} - Successfully fetched {len(response.text)} characters")
                
                # Parse HTML content
                logger.info("πŸ” Parsing HTML content...")
                soup = BeautifulSoup(response.text, 'html.parser')
                
                # Extract main content (remove navigation, footers, etc.)
                content = ""
                element_count = 0
                for element in soup.find_all(['p', 'pre', 'code', 'h1', 'h2', 'h3', 'h4', 'li']):
                    text = element.get_text().strip()
                    if text:
                        content += text + "\\n"
                        element_count += 1
                
                logger.info(f"πŸ“„ Extracted content from {element_count} HTML elements")
                
                # Limit content length
                original_length = len(content)
                if len(content) > 5000:
                    content = content[:5000] + "...[truncated]"
                    logger.info(f"βœ‚οΈ Content truncated from {original_length} to 5000 characters")
                
                logger.info(f"πŸ“Š Final processed content: {len(content)} characters")
                
                # Log a preview of the fetched content
                preview = content[:200].replace('\\n', ' ')
                logger.info(f"πŸ“‹ Content preview: {preview}...")
                
                # Log content sections found
                sections = []
                for header in soup.find_all(['h1', 'h2', 'h3']):
                    header_text = header.get_text().strip()
                    if header_text:
                        sections.append(header_text)
                
                if sections:
                    logger.info(f"πŸ“‘ Found sections: {', '.join(sections[:5])}{'...' if len(sections) > 5 else ''}")
                
                logger.info("βœ… Content extraction completed successfully")
                return content
                
            except Exception as e:
                logger.error(f"❌ Error fetching {url}: {type(e).__name__}: {e}")
                return f"Error fetching documentation: {str(e)}"
        
        def fetch_model_info(model_id: str) -> str:
            """Fetch model information from HuggingFace API."""
            logger.info("πŸ€– TOOL CALL: fetch_model_info")
            logger.info(f"πŸ“‹ Requested model: {model_id}")
            try:
                # Use HuggingFace API to get model info
                api_url = f"https://huggingface.co/api/models/{model_id}"
                logger.info(f"πŸ”„ Fetching model info from: {api_url}")
                headers = {
                    'User-Agent': 'Auto-Diffusers-Config/1.0 (Educational Tool)'
                }
                
                response = requests.get(api_url, headers=headers, timeout=10)
                response.raise_for_status()
                logger.info(f"βœ… HTTP {response.status_code} - Model API response received")
                
                model_data = response.json()
                logger.info(f"πŸ“Š Raw API response contains {len(model_data)} fields")
                
                # Extract relevant information
                info = {
                    'model_id': model_data.get('id', model_id),
                    'pipeline_tag': model_data.get('pipeline_tag', 'unknown'),
                    'tags': model_data.get('tags', []),
                    'library_name': model_data.get('library_name', 'unknown'),
                    'downloads': model_data.get('downloads', 0),
                    'likes': model_data.get('likes', 0)
                }
                
                logger.info(f"πŸ“‹ Extracted model info:")
                logger.info(f"   - Pipeline: {info['pipeline_tag']}")
                logger.info(f"   - Library: {info['library_name']}")
                logger.info(f"   - Downloads: {info['downloads']:,}")
                logger.info(f"   - Likes: {info['likes']:,}")
                logger.info(f"   - Tags: {len(info['tags'])} tags")
                
                result = json.dumps(info, indent=2)
                logger.info(f"βœ… Model info formatting completed ({len(result)} characters)")
                return result
                
            except Exception as e:
                logger.error(f"Error fetching model info for {model_id}: {e}")
                return f"Error fetching model information: {str(e)}"
        
        def search_optimization_guides(query: str) -> str:
            """Search for optimization guides and best practices."""
            logger.info("πŸ” TOOL CALL: search_optimization_guides")
            logger.info(f"πŸ“‹ Search query: '{query}'")
            try:
                # Search common optimization documentation URLs
                docs_urls = [
                    "https://huggingface.co/docs/diffusers/optimization/fp16",
                    "https://huggingface.co/docs/diffusers/optimization/memory",
                    "https://huggingface.co/docs/diffusers/optimization/torch2",
                    "https://huggingface.co/docs/diffusers/optimization/mps",
                    "https://huggingface.co/docs/diffusers/optimization/xformers"
                ]
                
                logger.info(f"πŸ”Ž Searching through {len(docs_urls)} optimization guide URLs...")
                
                results = []
                matched_urls = []
                for url in docs_urls:
                    if any(keyword in url for keyword in query.lower().split()):
                        logger.info(f"βœ… URL matched query: {url}")
                        matched_urls.append(url)
                        content = fetch_huggingface_docs(url)
                        if not content.startswith("Error"):
                            results.append(f"From {url}:\\n{content[:1000]}...\\n")
                            logger.info(f"πŸ“„ Successfully processed content from {url}")
                        else:
                            logger.warning(f"❌ Failed to fetch content from {url}")
                    else:
                        logger.debug(f"⏭️ URL skipped (no match): {url}")
                
                logger.info(f"πŸ“Š Search completed: {len(matched_urls)} URLs matched, {len(results)} successful fetches")
                
                if results:
                    final_result = "\\n".join(results)
                    logger.info(f"βœ… Returning combined content ({len(final_result)} characters)")
                    return final_result
                else:
                    logger.warning("❌ No specific optimization guides found for the query")
                    return "No specific optimization guides found for the query"
                
            except Exception as e:
                logger.error(f"Error searching optimization guides: {e}")
                return f"Error searching guides: {str(e)}"
        
        # Define tools schema for Gemini (simplified for now)
        tools = [
            {
                "function_declarations": [
                    {
                        "name": "fetch_huggingface_docs",
                        "description": "Fetch current documentation from HuggingFace URLs for diffusers library, models, or optimization guides",
                        "parameters": {
                            "type": "object",
                            "properties": {
                                "url": {
                                    "type": "string",
                                    "description": "The HuggingFace documentation URL to fetch"
                                }
                            },
                            "required": ["url"]
                        }
                    },
                    {
                        "name": "fetch_model_info",
                        "description": "Fetch current model information and metadata from HuggingFace API",
                        "parameters": {
                            "type": "object",
                            "properties": {
                                "model_id": {
                                    "type": "string",
                                    "description": "The HuggingFace model ID (e.g., 'black-forest-labs/FLUX.1-schnell')"
                                }
                            },
                            "required": ["model_id"]
                        }
                    },
                    {
                        "name": "search_optimization_guides",
                        "description": "Search for optimization guides and best practices for diffusers models",
                        "parameters": {
                            "type": "object",
                            "properties": {
                                "query": {
                                    "type": "string",
                                    "description": "Search query for optimization topics (e.g., 'memory', 'fp16', 'torch compile')"
                                }
                            },
                            "required": ["query"]
                        }
                    }
                ]
            }
        ]
        
        # Store function implementations for execution
        self.tool_functions = {
            'fetch_huggingface_docs': fetch_huggingface_docs,
            'fetch_model_info': fetch_model_info,
            'search_optimization_guides': search_optimization_guides
        }
        
        logger.info(f"Created {len(tools[0]['function_declarations'])} tools for Gemini")
        return tools
        
    def generate_optimized_code(self, 
                              model_name: str, 
                              prompt_text: str,
                              image_size: tuple = (768, 1360),
                              num_inference_steps: int = 4,
                              use_manual_specs: bool = False,
                              manual_specs: Dict = None,
                              memory_analysis: Dict = None) -> str:
        """Generate optimized diffusers code based on hardware specs and memory analysis."""
        
        logger.info(f"Starting code generation for model: {model_name}")
        logger.debug(f"Parameters: prompt='{prompt_text[:50]}...', size={image_size}, steps={num_inference_steps}")
        logger.debug(f"Manual specs: {use_manual_specs}, Memory analysis provided: {memory_analysis is not None}")
        
        # Get hardware specifications
        if use_manual_specs and manual_specs:
            logger.info("Using manual hardware specifications")
            hardware_specs = manual_specs
            logger.debug(f"Manual specs: {hardware_specs}")
            
            # Determine optimization profile based on manual specs
            if hardware_specs.get('gpu_info') and hardware_specs['gpu_info']:
                vram_gb = hardware_specs['gpu_info'][0]['memory_mb'] / 1024
                logger.debug(f"GPU detected with {vram_gb:.1f} GB VRAM")
                
                if vram_gb >= 16:
                    optimization_profile = 'performance'
                elif vram_gb >= 8:
                    optimization_profile = 'balanced'
                else:
                    optimization_profile = 'memory_efficient'
            else:
                optimization_profile = 'cpu_only'
                logger.info("No GPU detected, using CPU-only profile")
                
            logger.info(f"Selected optimization profile: {optimization_profile}")
        else:
            logger.info("Using automatic hardware detection")
            hardware_specs = self.hardware_detector.specs
            optimization_profile = self.hardware_detector.get_optimization_profile()
            logger.debug(f"Detected specs: {hardware_specs}")
            logger.info(f"Auto-detected optimization profile: {optimization_profile}")
        
        # Create the prompt for Gemini API
        logger.debug("Creating generation prompt for Gemini API")
        system_prompt = self._create_generation_prompt(
            model_name, prompt_text, image_size, num_inference_steps, 
            hardware_specs, optimization_profile, memory_analysis
        )
        logger.debug(f"Prompt length: {len(system_prompt)} characters")
        
        # Log the actual prompt being sent to Gemini API
        logger.info("=" * 80)
        logger.info("PROMPT SENT TO GEMINI API:")
        logger.info("=" * 80)
        logger.info(system_prompt)
        logger.info("=" * 80)
        
        try:
            logger.info("Sending request to Gemini API")
            response = self.model.generate_content(system_prompt)
            
            # Handle tool calling if present and tools are available
            if self.tools and response.candidates[0].content.parts:
                for part in response.candidates[0].content.parts:
                    if hasattr(part, 'function_call') and part.function_call:
                        function_name = part.function_call.name
                        function_args = dict(part.function_call.args)
                        
                        logger.info("πŸ› οΈ " + "=" * 60)
                        logger.info(f"πŸ› οΈ GEMINI REQUESTED TOOL CALL: {function_name}")
                        logger.info("πŸ› οΈ " + "=" * 60)
                        logger.info(f"πŸ“‹ Tool arguments: {function_args}")
                        
                        if function_name in self.tool_functions:
                            logger.info(f"βœ… Tool function found, executing...")
                            tool_result = self.tool_functions[function_name](**function_args)
                            logger.info("πŸ› οΈ " + "=" * 60)
                            logger.info(f"πŸ› οΈ TOOL EXECUTION COMPLETED: {function_name}")
                            logger.info("πŸ› οΈ " + "=" * 60)
                            logger.info(f"πŸ“Š Tool result length: {len(str(tool_result))} characters")
                            
                            # Log a preview of the tool result
                            preview = str(tool_result)[:300].replace('\\n', ' ')
                            logger.info(f"πŸ“‹ Tool result preview: {preview}...")
                            logger.info("πŸ› οΈ " + "=" * 60)
                            
                            # Create a follow-up conversation with the tool result
                            follow_up_prompt = f"""
                            {system_prompt}
                            
                            ADDITIONAL CONTEXT FROM TOOLS:
                            Tool: {function_name}
                            Result: {tool_result}
                            
                            Please use this current information to generate the most up-to-date and optimized code.
                            """
                            
                            # Log the follow-up prompt
                            logger.info("=" * 80)
                            logger.info("FOLLOW-UP PROMPT SENT TO GEMINI API (WITH TOOL RESULTS):")
                            logger.info("=" * 80)
                            logger.info(follow_up_prompt)
                            logger.info("=" * 80)                            
                            # Generate final response with tool context
                            logger.info("Generating final response with tool context")
                            final_response = self.model.generate_content(follow_up_prompt)
                            logger.info("Successfully received final response from Gemini API")
                            logger.debug(f"Final response length: {len(final_response.text)} characters")
                            return final_response.text
            
            # No tool calling, return direct response
            logger.info("Successfully received response from Gemini API (no tools used)")
            logger.debug(f"Response length: {len(response.text)} characters")
            return response.text
            
        except Exception as e:
            logger.error(f"Error generating code: {str(e)}")
            return f"Error generating code: {str(e)}"
    
    def _create_generation_prompt(self, 
                                model_name: str, 
                                prompt_text: str,
                                image_size: tuple,
                                num_inference_steps: int,
                                hardware_specs: Dict,
                                optimization_profile: str,
                                memory_analysis: Dict = None) -> str:
        """Create the prompt for Gemini API to generate optimized code."""
        
        base_prompt = f"""
You are an expert in optimizing diffusers library code for different hardware configurations.

NOTE: This system includes curated optimization knowledge from HuggingFace documentation.

TASK: Generate optimized Python code for running a diffusion model with the following specifications:
- Model: {model_name}
- Prompt: "{prompt_text}"
- Image size: {image_size[0]}x{image_size[1]}
- Inference steps: {num_inference_steps}

HARDWARE SPECIFICATIONS:
- Platform: {hardware_specs['platform']} ({hardware_specs['architecture']})
- CPU Cores: {hardware_specs['cpu_count']}
- CUDA Available: {hardware_specs['cuda_available']}
- MPS Available: {hardware_specs['mps_available']}
- Optimization Profile: {optimization_profile}
"""

        if hardware_specs.get('gpu_info'):
            base_prompt += f"- GPU: {hardware_specs['gpu_info'][0]['name']} ({hardware_specs['gpu_info'][0]['memory_mb']/1024:.1f} GB VRAM)\n"

        # Add user dtype preference if specified
        if hardware_specs.get('user_dtype'):
            base_prompt += f"- User specified dtype: {hardware_specs['user_dtype']}\n"

        # Add memory analysis information
        if memory_analysis:
            memory_info = memory_analysis.get('memory_info', {})
            recommendations = memory_analysis.get('recommendations', {})
            
            base_prompt += f"\nMEMORY ANALYSIS:\n"
            if memory_info.get('estimated_inference_memory_fp16_gb'):
                base_prompt += f"- Model Memory Requirements: {memory_info['estimated_inference_memory_fp16_gb']} GB (FP16 inference)\n"
            if memory_info.get('memory_fp16_gb'):
                base_prompt += f"- Model Weights Size: {memory_info['memory_fp16_gb']} GB (FP16)\n"
            if recommendations.get('recommendations'):
                base_prompt += f"- Memory Recommendation: {', '.join(recommendations['recommendations'])}\n"
            if recommendations.get('recommended_precision'):
                base_prompt += f"- Recommended Precision: {recommendations['recommended_precision']}\n"
            if recommendations.get('cpu_offload'):
                base_prompt += f"- CPU Offloading Required: {recommendations['cpu_offload']}\n"
            if recommendations.get('attention_slicing'):
                base_prompt += f"- Attention Slicing Recommended: {recommendations['attention_slicing']}\n"
            if recommendations.get('vae_slicing'):
                base_prompt += f"- VAE Slicing Recommended: {recommendations['vae_slicing']}\n"

        base_prompt += f"""
OPTIMIZATION KNOWLEDGE BASE:
{get_optimization_guide()}

IMPORTANT: For FLUX.1-schnell models, do NOT include guidance_scale parameter as it's not needed.

Using the OPTIMIZATION KNOWLEDGE BASE above, generate Python code that:

1. **Selects the best optimization techniques** for the specific hardware profile
2. **Applies appropriate memory optimizations** based on available VRAM
3. **Uses optimal data types** for the target hardware:
   - User specified dtype (if provided): Use exactly as specified
   - Apple Silicon (MPS): prefer torch.bfloat16
   - NVIDIA GPUs: prefer torch.float16 or torch.bfloat16 
   - CPU only: use torch.float32
4. **Implements hardware-specific optimizations** (CUDA, MPS, CPU)
5. **Follows model-specific guidelines** (e.g., FLUX guidance_scale handling)

IMPORTANT GUIDELINES:
- Reference the OPTIMIZATION KNOWLEDGE BASE to select appropriate techniques
- Include all necessary imports
- Add brief comments explaining optimization choices
- Generate compact, production-ready code
- Inline values where possible for concise code
- Generate ONLY the Python code, no explanations before or after the code block
"""
        
        return base_prompt
    
    def run_interactive_mode(self):
        """Run the generator in interactive mode."""
        print("=== Auto-Diffusers Code Generator ===")
        print("This tool generates optimized diffusers code based on your hardware.\n")
        
        # Check hardware
        print("=== Hardware Detection ===")
        self.hardware_detector.print_specs()
        
        use_manual = input("\nUse manual hardware input? (y/n): ").lower() == 'y'
        
        # Get user inputs
        print("\n=== Model Configuration ===")
        model_name = input("Model name (default: black-forest-labs/FLUX.1-schnell): ").strip()
        if not model_name:
            model_name = "black-forest-labs/FLUX.1-schnell"
            
        prompt_text = input("Prompt text (default: A cat holding a sign that says hello world): ").strip()
        if not prompt_text:
            prompt_text = "A cat holding a sign that says hello world"
            
        try:
            width = int(input("Image width (default: 1360): ") or "1360")
            height = int(input("Image height (default: 768): ") or "768")
            steps = int(input("Inference steps (default: 4): ") or "4")
        except ValueError:
            width, height, steps = 1360, 768, 4
            
        print("\n=== Generating Optimized Code ===")
        
        # Generate code
        optimized_code = self.generate_optimized_code(
            model_name=model_name,
            prompt_text=prompt_text,
            image_size=(height, width),
            num_inference_steps=steps,
            use_manual_specs=use_manual
        )
        
        print("\n" + "="*60)
        print("OPTIMIZED DIFFUSERS CODE:")
        print("="*60)
        print(optimized_code)
        print("="*60)


def main():
    # Get API key from .env file
    api_key = os.getenv('GOOGLE_API_KEY')
    if not api_key:
        api_key = os.getenv('GEMINI_API_KEY')  # fallback
    if not api_key:
        api_key = input("Enter your Gemini API key: ").strip()
        if not api_key:
            print("API key is required!")
            return
    
    generator = AutoDiffusersGenerator(api_key)
    generator.run_interactive_mode()


if __name__ == "__main__":
    main()