Spaces:
Running
Running
import os | |
import logging | |
from dotenv import load_dotenv | |
import google.generativeai as genai | |
from hardware_detector import HardwareDetector | |
from optimization_knowledge import get_optimization_guide | |
from typing import Dict, List | |
import json | |
# Optional imports for tool calling | |
try: | |
import requests | |
from urllib.parse import urljoin, urlparse | |
from bs4 import BeautifulSoup | |
TOOLS_AVAILABLE = True | |
except ImportError: | |
TOOLS_AVAILABLE = False | |
requests = None | |
urlparse = None | |
BeautifulSoup = None | |
load_dotenv() | |
# Configure logging | |
logging.basicConfig( | |
level=logging.DEBUG, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
handlers=[ | |
logging.FileHandler('auto_diffusers.log'), | |
logging.StreamHandler() | |
] | |
) | |
logger = logging.getLogger(__name__) | |
class AutoDiffusersGenerator: | |
def __init__(self, api_key: str): | |
logger.info("Initializing AutoDiffusersGenerator") | |
logger.debug(f"API key length: {len(api_key) if api_key else 'None'}") | |
try: | |
genai.configure(api_key=api_key) | |
# Define tools for Gemini to use (if available) | |
if TOOLS_AVAILABLE: | |
self.tools = self._create_tools() | |
# Initialize model with tools | |
self.model = genai.GenerativeModel( | |
'gemini-2.5-flash-preview-05-20', | |
tools=self.tools | |
) | |
logger.info("Successfully configured Gemini AI model with tools") | |
else: | |
self.tools = None | |
# Initialize model without tools | |
self.model = genai.GenerativeModel('gemini-2.5-flash-preview-05-20') | |
logger.warning("Tool calling dependencies not available, running without tools") | |
except Exception as e: | |
logger.error(f"Failed to configure Gemini AI: {e}") | |
raise | |
try: | |
self.hardware_detector = HardwareDetector() | |
logger.info("Hardware detector initialized successfully") | |
except Exception as e: | |
logger.error(f"Failed to initialize hardware detector: {e}") | |
raise | |
def _create_tools(self): | |
"""Create function tools for Gemini to use.""" | |
logger.debug("Creating tools for Gemini") | |
if not TOOLS_AVAILABLE: | |
logger.warning("Tools dependencies not available, returning empty tools") | |
return [] | |
def fetch_huggingface_docs(url: str) -> str: | |
"""Fetch documentation from HuggingFace URLs.""" | |
logger.info("🌐 TOOL CALL: fetch_huggingface_docs") | |
logger.info(f"📋 Requested URL: {url}") | |
try: | |
# Validate URL is from HuggingFace | |
parsed = urlparse(url) | |
logger.debug(f"URL validation - Domain: {parsed.netloc}, Path: {parsed.path}") | |
if not any(domain in parsed.netloc for domain in ['huggingface.co', 'hf.co']): | |
error_msg = "Error: URL must be from huggingface.co domain" | |
logger.warning(f"❌ URL validation failed: {error_msg}") | |
return error_msg | |
logger.info(f"✅ URL validation passed for domain: {parsed.netloc}") | |
headers = { | |
'User-Agent': 'Auto-Diffusers-Config/1.0 (Educational Tool)' | |
} | |
logger.info(f"🔄 Fetching content from: {url}") | |
response = requests.get(url, headers=headers, timeout=10) | |
response.raise_for_status() | |
logger.info(f"✅ HTTP {response.status_code} - Successfully fetched {len(response.text)} characters") | |
# Parse HTML content | |
logger.info("🔍 Parsing HTML content...") | |
soup = BeautifulSoup(response.text, 'html.parser') | |
# Extract main content (remove navigation, footers, etc.) | |
content = "" | |
element_count = 0 | |
for element in soup.find_all(['p', 'pre', 'code', 'h1', 'h2', 'h3', 'h4', 'li']): | |
text = element.get_text().strip() | |
if text: | |
content += text + "\\n" | |
element_count += 1 | |
logger.info(f"📄 Extracted content from {element_count} HTML elements") | |
# Limit content length | |
original_length = len(content) | |
if len(content) > 5000: | |
content = content[:5000] + "...[truncated]" | |
logger.info(f"✂️ Content truncated from {original_length} to 5000 characters") | |
logger.info(f"📊 Final processed content: {len(content)} characters") | |
# Log a preview of the fetched content | |
preview = content[:200].replace('\\n', ' ') | |
logger.info(f"📋 Content preview: {preview}...") | |
# Log content sections found | |
sections = [] | |
for header in soup.find_all(['h1', 'h2', 'h3']): | |
header_text = header.get_text().strip() | |
if header_text: | |
sections.append(header_text) | |
if sections: | |
logger.info(f"📑 Found sections: {', '.join(sections[:5])}{'...' if len(sections) > 5 else ''}") | |
logger.info("✅ Content extraction completed successfully") | |
return content | |
except Exception as e: | |
logger.error(f"❌ Error fetching {url}: {type(e).__name__}: {e}") | |
return f"Error fetching documentation: {str(e)}" | |
def fetch_model_info(model_id: str) -> str: | |
"""Fetch model information from HuggingFace API.""" | |
logger.info("🤖 TOOL CALL: fetch_model_info") | |
logger.info(f"📋 Requested model: {model_id}") | |
try: | |
# Use HuggingFace API to get model info | |
api_url = f"https://huggingface.co/api/models/{model_id}" | |
logger.info(f"🔄 Fetching model info from: {api_url}") | |
headers = { | |
'User-Agent': 'Auto-Diffusers-Config/1.0 (Educational Tool)' | |
} | |
response = requests.get(api_url, headers=headers, timeout=10) | |
response.raise_for_status() | |
logger.info(f"✅ HTTP {response.status_code} - Model API response received") | |
model_data = response.json() | |
logger.info(f"📊 Raw API response contains {len(model_data)} fields") | |
# Extract relevant information | |
info = { | |
'model_id': model_data.get('id', model_id), | |
'pipeline_tag': model_data.get('pipeline_tag', 'unknown'), | |
'tags': model_data.get('tags', []), | |
'library_name': model_data.get('library_name', 'unknown'), | |
'downloads': model_data.get('downloads', 0), | |
'likes': model_data.get('likes', 0) | |
} | |
logger.info(f"📋 Extracted model info:") | |
logger.info(f" - Pipeline: {info['pipeline_tag']}") | |
logger.info(f" - Library: {info['library_name']}") | |
logger.info(f" - Downloads: {info['downloads']:,}") | |
logger.info(f" - Likes: {info['likes']:,}") | |
logger.info(f" - Tags: {len(info['tags'])} tags") | |
result = json.dumps(info, indent=2) | |
logger.info(f"✅ Model info formatting completed ({len(result)} characters)") | |
return result | |
except Exception as e: | |
logger.error(f"Error fetching model info for {model_id}: {e}") | |
return f"Error fetching model information: {str(e)}" | |
def search_optimization_guides(query: str) -> str: | |
"""Search for optimization guides and best practices.""" | |
logger.info("🔍 TOOL CALL: search_optimization_guides") | |
logger.info(f"📋 Search query: '{query}'") | |
try: | |
# Search common optimization documentation URLs | |
docs_urls = [ | |
"https://huggingface.co/docs/diffusers/optimization/fp16", | |
"https://huggingface.co/docs/diffusers/optimization/memory", | |
"https://huggingface.co/docs/diffusers/optimization/torch2", | |
"https://huggingface.co/docs/diffusers/optimization/mps", | |
"https://huggingface.co/docs/diffusers/optimization/xformers" | |
] | |
logger.info(f"🔎 Searching through {len(docs_urls)} optimization guide URLs...") | |
results = [] | |
matched_urls = [] | |
for url in docs_urls: | |
if any(keyword in url for keyword in query.lower().split()): | |
logger.info(f"✅ URL matched query: {url}") | |
matched_urls.append(url) | |
content = fetch_huggingface_docs(url) | |
if not content.startswith("Error"): | |
results.append(f"From {url}:\\n{content[:1000]}...\\n") | |
logger.info(f"📄 Successfully processed content from {url}") | |
else: | |
logger.warning(f"❌ Failed to fetch content from {url}") | |
else: | |
logger.debug(f"⏭️ URL skipped (no match): {url}") | |
logger.info(f"📊 Search completed: {len(matched_urls)} URLs matched, {len(results)} successful fetches") | |
if results: | |
final_result = "\\n".join(results) | |
logger.info(f"✅ Returning combined content ({len(final_result)} characters)") | |
return final_result | |
else: | |
logger.warning("❌ No specific optimization guides found for the query") | |
return "No specific optimization guides found for the query" | |
except Exception as e: | |
logger.error(f"Error searching optimization guides: {e}") | |
return f"Error searching guides: {str(e)}" | |
# Define tools schema for Gemini (simplified for now) | |
tools = [ | |
{ | |
"function_declarations": [ | |
{ | |
"name": "fetch_huggingface_docs", | |
"description": "Fetch current documentation from HuggingFace URLs for diffusers library, models, or optimization guides", | |
"parameters": { | |
"type": "object", | |
"properties": { | |
"url": { | |
"type": "string", | |
"description": "The HuggingFace documentation URL to fetch" | |
} | |
}, | |
"required": ["url"] | |
} | |
}, | |
{ | |
"name": "fetch_model_info", | |
"description": "Fetch current model information and metadata from HuggingFace API", | |
"parameters": { | |
"type": "object", | |
"properties": { | |
"model_id": { | |
"type": "string", | |
"description": "The HuggingFace model ID (e.g., 'black-forest-labs/FLUX.1-schnell')" | |
} | |
}, | |
"required": ["model_id"] | |
} | |
}, | |
{ | |
"name": "search_optimization_guides", | |
"description": "Search for optimization guides and best practices for diffusers models", | |
"parameters": { | |
"type": "object", | |
"properties": { | |
"query": { | |
"type": "string", | |
"description": "Search query for optimization topics (e.g., 'memory', 'fp16', 'torch compile')" | |
} | |
}, | |
"required": ["query"] | |
} | |
} | |
] | |
} | |
] | |
# Store function implementations for execution | |
self.tool_functions = { | |
'fetch_huggingface_docs': fetch_huggingface_docs, | |
'fetch_model_info': fetch_model_info, | |
'search_optimization_guides': search_optimization_guides | |
} | |
logger.info(f"Created {len(tools[0]['function_declarations'])} tools for Gemini") | |
return tools | |
def generate_optimized_code(self, | |
model_name: str, | |
prompt_text: str, | |
image_size: tuple = (768, 1360), | |
num_inference_steps: int = 4, | |
use_manual_specs: bool = False, | |
manual_specs: Dict = None, | |
memory_analysis: Dict = None) -> str: | |
"""Generate optimized diffusers code based on hardware specs and memory analysis.""" | |
logger.info(f"Starting code generation for model: {model_name}") | |
logger.debug(f"Parameters: prompt='{prompt_text[:50]}...', size={image_size}, steps={num_inference_steps}") | |
logger.debug(f"Manual specs: {use_manual_specs}, Memory analysis provided: {memory_analysis is not None}") | |
# Get hardware specifications | |
if use_manual_specs and manual_specs: | |
logger.info("Using manual hardware specifications") | |
hardware_specs = manual_specs | |
logger.debug(f"Manual specs: {hardware_specs}") | |
# Determine optimization profile based on manual specs | |
if hardware_specs.get('gpu_info') and hardware_specs['gpu_info']: | |
vram_gb = hardware_specs['gpu_info'][0]['memory_mb'] / 1024 | |
logger.debug(f"GPU detected with {vram_gb:.1f} GB VRAM") | |
if vram_gb >= 16: | |
optimization_profile = 'performance' | |
elif vram_gb >= 8: | |
optimization_profile = 'balanced' | |
else: | |
optimization_profile = 'memory_efficient' | |
else: | |
optimization_profile = 'cpu_only' | |
logger.info("No GPU detected, using CPU-only profile") | |
logger.info(f"Selected optimization profile: {optimization_profile}") | |
else: | |
logger.info("Using automatic hardware detection") | |
hardware_specs = self.hardware_detector.specs | |
optimization_profile = self.hardware_detector.get_optimization_profile() | |
logger.debug(f"Detected specs: {hardware_specs}") | |
logger.info(f"Auto-detected optimization profile: {optimization_profile}") | |
# Create the prompt for Gemini API | |
logger.debug("Creating generation prompt for Gemini API") | |
system_prompt = self._create_generation_prompt( | |
model_name, prompt_text, image_size, num_inference_steps, | |
hardware_specs, optimization_profile, memory_analysis | |
) | |
logger.debug(f"Prompt length: {len(system_prompt)} characters") | |
# Log the actual prompt being sent to Gemini API | |
logger.info("=" * 80) | |
logger.info("PROMPT SENT TO GEMINI API:") | |
logger.info("=" * 80) | |
logger.info(system_prompt) | |
logger.info("=" * 80) | |
try: | |
logger.info("Sending request to Gemini API") | |
response = self.model.generate_content(system_prompt) | |
# Handle tool calling if present and tools are available | |
if self.tools and response.candidates[0].content.parts: | |
for part in response.candidates[0].content.parts: | |
if hasattr(part, 'function_call') and part.function_call: | |
function_name = part.function_call.name | |
function_args = dict(part.function_call.args) | |
logger.info("🛠️ " + "=" * 60) | |
logger.info(f"🛠️ GEMINI REQUESTED TOOL CALL: {function_name}") | |
logger.info("🛠️ " + "=" * 60) | |
logger.info(f"📋 Tool arguments: {function_args}") | |
if function_name in self.tool_functions: | |
logger.info(f"✅ Tool function found, executing...") | |
tool_result = self.tool_functions[function_name](**function_args) | |
logger.info("🛠️ " + "=" * 60) | |
logger.info(f"🛠️ TOOL EXECUTION COMPLETED: {function_name}") | |
logger.info("🛠️ " + "=" * 60) | |
logger.info(f"📊 Tool result length: {len(str(tool_result))} characters") | |
# Log a preview of the tool result | |
preview = str(tool_result)[:300].replace('\\n', ' ') | |
logger.info(f"📋 Tool result preview: {preview}...") | |
logger.info("🛠️ " + "=" * 60) | |
# Create a follow-up conversation with the tool result | |
follow_up_prompt = f""" | |
{system_prompt} | |
ADDITIONAL CONTEXT FROM TOOLS: | |
Tool: {function_name} | |
Result: {tool_result} | |
Please use this current information to generate the most up-to-date and optimized code. | |
""" | |
# Log the follow-up prompt | |
logger.info("=" * 80) | |
logger.info("FOLLOW-UP PROMPT SENT TO GEMINI API (WITH TOOL RESULTS):") | |
logger.info("=" * 80) | |
logger.info(follow_up_prompt) | |
logger.info("=" * 80) | |
# Generate final response with tool context | |
logger.info("Generating final response with tool context") | |
final_response = self.model.generate_content(follow_up_prompt) | |
logger.info("Successfully received final response from Gemini API") | |
logger.debug(f"Final response length: {len(final_response.text)} characters") | |
return final_response.text | |
# No tool calling, return direct response | |
logger.info("Successfully received response from Gemini API (no tools used)") | |
logger.debug(f"Response length: {len(response.text)} characters") | |
return response.text | |
except Exception as e: | |
logger.error(f"Error generating code: {str(e)}") | |
return f"Error generating code: {str(e)}" | |
def _create_generation_prompt(self, | |
model_name: str, | |
prompt_text: str, | |
image_size: tuple, | |
num_inference_steps: int, | |
hardware_specs: Dict, | |
optimization_profile: str, | |
memory_analysis: Dict = None) -> str: | |
"""Create the prompt for Gemini API to generate optimized code.""" | |
base_prompt = f""" | |
You are an expert in optimizing diffusers library code for different hardware configurations. | |
NOTE: This system includes curated optimization knowledge from HuggingFace documentation. | |
TASK: Generate optimized Python code for running a diffusion model with the following specifications: | |
- Model: {model_name} | |
- Prompt: "{prompt_text}" | |
- Image size: {image_size[0]}x{image_size[1]} | |
- Inference steps: {num_inference_steps} | |
HARDWARE SPECIFICATIONS: | |
- Platform: {hardware_specs['platform']} ({hardware_specs['architecture']}) | |
- CPU Cores: {hardware_specs['cpu_count']} | |
- CUDA Available: {hardware_specs['cuda_available']} | |
- MPS Available: {hardware_specs['mps_available']} | |
- Optimization Profile: {optimization_profile} | |
""" | |
if hardware_specs.get('gpu_info'): | |
base_prompt += f"- GPU: {hardware_specs['gpu_info'][0]['name']} ({hardware_specs['gpu_info'][0]['memory_mb']/1024:.1f} GB VRAM)\n" | |
# Add user dtype preference if specified | |
if hardware_specs.get('user_dtype'): | |
base_prompt += f"- User specified dtype: {hardware_specs['user_dtype']}\n" | |
# Add memory analysis information | |
if memory_analysis: | |
memory_info = memory_analysis.get('memory_info', {}) | |
recommendations = memory_analysis.get('recommendations', {}) | |
base_prompt += f"\nMEMORY ANALYSIS:\n" | |
if memory_info.get('estimated_inference_memory_fp16_gb'): | |
base_prompt += f"- Model Memory Requirements: {memory_info['estimated_inference_memory_fp16_gb']} GB (FP16 inference)\n" | |
if memory_info.get('memory_fp16_gb'): | |
base_prompt += f"- Model Weights Size: {memory_info['memory_fp16_gb']} GB (FP16)\n" | |
if recommendations.get('recommendations'): | |
base_prompt += f"- Memory Recommendation: {', '.join(recommendations['recommendations'])}\n" | |
if recommendations.get('recommended_precision'): | |
base_prompt += f"- Recommended Precision: {recommendations['recommended_precision']}\n" | |
if recommendations.get('cpu_offload'): | |
base_prompt += f"- CPU Offloading Required: {recommendations['cpu_offload']}\n" | |
if recommendations.get('attention_slicing'): | |
base_prompt += f"- Attention Slicing Recommended: {recommendations['attention_slicing']}\n" | |
if recommendations.get('vae_slicing'): | |
base_prompt += f"- VAE Slicing Recommended: {recommendations['vae_slicing']}\n" | |
base_prompt += f""" | |
OPTIMIZATION KNOWLEDGE BASE: | |
{get_optimization_guide()} | |
IMPORTANT: For FLUX.1-schnell models, do NOT include guidance_scale parameter as it's not needed. | |
Using the OPTIMIZATION KNOWLEDGE BASE above, generate Python code that: | |
1. **Selects the best optimization techniques** for the specific hardware profile | |
2. **Applies appropriate memory optimizations** based on available VRAM | |
3. **Uses optimal data types** for the target hardware: | |
- User specified dtype (if provided): Use exactly as specified | |
- Apple Silicon (MPS): prefer torch.bfloat16 | |
- NVIDIA GPUs: prefer torch.float16 or torch.bfloat16 | |
- CPU only: use torch.float32 | |
4. **Implements hardware-specific optimizations** (CUDA, MPS, CPU) | |
5. **Follows model-specific guidelines** (e.g., FLUX guidance_scale handling) | |
IMPORTANT GUIDELINES: | |
- Reference the OPTIMIZATION KNOWLEDGE BASE to select appropriate techniques | |
- Include all necessary imports | |
- Add brief comments explaining optimization choices | |
- Generate compact, production-ready code | |
- Inline values where possible for concise code | |
- Generate ONLY the Python code, no explanations before or after the code block | |
""" | |
return base_prompt | |
def run_interactive_mode(self): | |
"""Run the generator in interactive mode.""" | |
print("=== Auto-Diffusers Code Generator ===") | |
print("This tool generates optimized diffusers code based on your hardware.\n") | |
# Check hardware | |
print("=== Hardware Detection ===") | |
self.hardware_detector.print_specs() | |
use_manual = input("\nUse manual hardware input? (y/n): ").lower() == 'y' | |
# Get user inputs | |
print("\n=== Model Configuration ===") | |
model_name = input("Model name (default: black-forest-labs/FLUX.1-schnell): ").strip() | |
if not model_name: | |
model_name = "black-forest-labs/FLUX.1-schnell" | |
prompt_text = input("Prompt text (default: A cat holding a sign that says hello world): ").strip() | |
if not prompt_text: | |
prompt_text = "A cat holding a sign that says hello world" | |
try: | |
width = int(input("Image width (default: 1360): ") or "1360") | |
height = int(input("Image height (default: 768): ") or "768") | |
steps = int(input("Inference steps (default: 4): ") or "4") | |
except ValueError: | |
width, height, steps = 1360, 768, 4 | |
print("\n=== Generating Optimized Code ===") | |
# Generate code | |
optimized_code = self.generate_optimized_code( | |
model_name=model_name, | |
prompt_text=prompt_text, | |
image_size=(height, width), | |
num_inference_steps=steps, | |
use_manual_specs=use_manual | |
) | |
print("\n" + "="*60) | |
print("OPTIMIZED DIFFUSERS CODE:") | |
print("="*60) | |
print(optimized_code) | |
print("="*60) | |
def main(): | |
# Get API key from .env file | |
api_key = os.getenv('GOOGLE_API_KEY') | |
if not api_key: | |
api_key = os.getenv('GEMINI_API_KEY') # fallback | |
if not api_key: | |
api_key = input("Enter your Gemini API key: ").strip() | |
if not api_key: | |
print("API key is required!") | |
return | |
generator = AutoDiffusersGenerator(api_key) | |
generator.run_interactive_mode() | |
if __name__ == "__main__": | |
main() |