auto-diffuser-config / auto_diffusers.py
chansung's picture
Upload folder using huggingface_hub
aae35f1 verified
import os
import logging
from dotenv import load_dotenv
import google.generativeai as genai
from hardware_detector import HardwareDetector
from optimization_knowledge import get_optimization_guide
from typing import Dict, List
import json
# Optional imports for tool calling
try:
import requests
from urllib.parse import urljoin, urlparse
from bs4 import BeautifulSoup
TOOLS_AVAILABLE = True
except ImportError:
TOOLS_AVAILABLE = False
requests = None
urlparse = None
BeautifulSoup = None
load_dotenv()
# Configure logging
logging.basicConfig(
level=logging.DEBUG,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler('auto_diffusers.log'),
logging.StreamHandler()
]
)
logger = logging.getLogger(__name__)
class AutoDiffusersGenerator:
def __init__(self, api_key: str):
logger.info("Initializing AutoDiffusersGenerator")
logger.debug(f"API key length: {len(api_key) if api_key else 'None'}")
try:
genai.configure(api_key=api_key)
# Define tools for Gemini to use (if available)
if TOOLS_AVAILABLE:
self.tools = self._create_tools()
# Initialize model with tools
self.model = genai.GenerativeModel(
'gemini-2.5-flash-preview-05-20',
tools=self.tools
)
logger.info("Successfully configured Gemini AI model with tools")
else:
self.tools = None
# Initialize model without tools
self.model = genai.GenerativeModel('gemini-2.5-flash-preview-05-20')
logger.warning("Tool calling dependencies not available, running without tools")
except Exception as e:
logger.error(f"Failed to configure Gemini AI: {e}")
raise
try:
self.hardware_detector = HardwareDetector()
logger.info("Hardware detector initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize hardware detector: {e}")
raise
def _create_tools(self):
"""Create function tools for Gemini to use."""
logger.debug("Creating tools for Gemini")
if not TOOLS_AVAILABLE:
logger.warning("Tools dependencies not available, returning empty tools")
return []
def fetch_huggingface_docs(url: str) -> str:
"""Fetch documentation from HuggingFace URLs."""
logger.info("🌐 TOOL CALL: fetch_huggingface_docs")
logger.info(f"📋 Requested URL: {url}")
try:
# Validate URL is from HuggingFace
parsed = urlparse(url)
logger.debug(f"URL validation - Domain: {parsed.netloc}, Path: {parsed.path}")
if not any(domain in parsed.netloc for domain in ['huggingface.co', 'hf.co']):
error_msg = "Error: URL must be from huggingface.co domain"
logger.warning(f"❌ URL validation failed: {error_msg}")
return error_msg
logger.info(f"✅ URL validation passed for domain: {parsed.netloc}")
headers = {
'User-Agent': 'Auto-Diffusers-Config/1.0 (Educational Tool)'
}
logger.info(f"🔄 Fetching content from: {url}")
response = requests.get(url, headers=headers, timeout=10)
response.raise_for_status()
logger.info(f"✅ HTTP {response.status_code} - Successfully fetched {len(response.text)} characters")
# Parse HTML content
logger.info("🔍 Parsing HTML content...")
soup = BeautifulSoup(response.text, 'html.parser')
# Extract main content (remove navigation, footers, etc.)
content = ""
element_count = 0
for element in soup.find_all(['p', 'pre', 'code', 'h1', 'h2', 'h3', 'h4', 'li']):
text = element.get_text().strip()
if text:
content += text + "\\n"
element_count += 1
logger.info(f"📄 Extracted content from {element_count} HTML elements")
# Limit content length
original_length = len(content)
if len(content) > 5000:
content = content[:5000] + "...[truncated]"
logger.info(f"✂️ Content truncated from {original_length} to 5000 characters")
logger.info(f"📊 Final processed content: {len(content)} characters")
# Log a preview of the fetched content
preview = content[:200].replace('\\n', ' ')
logger.info(f"📋 Content preview: {preview}...")
# Log content sections found
sections = []
for header in soup.find_all(['h1', 'h2', 'h3']):
header_text = header.get_text().strip()
if header_text:
sections.append(header_text)
if sections:
logger.info(f"📑 Found sections: {', '.join(sections[:5])}{'...' if len(sections) > 5 else ''}")
logger.info("✅ Content extraction completed successfully")
return content
except Exception as e:
logger.error(f"❌ Error fetching {url}: {type(e).__name__}: {e}")
return f"Error fetching documentation: {str(e)}"
def fetch_model_info(model_id: str) -> str:
"""Fetch model information from HuggingFace API."""
logger.info("🤖 TOOL CALL: fetch_model_info")
logger.info(f"📋 Requested model: {model_id}")
try:
# Use HuggingFace API to get model info
api_url = f"https://huggingface.co/api/models/{model_id}"
logger.info(f"🔄 Fetching model info from: {api_url}")
headers = {
'User-Agent': 'Auto-Diffusers-Config/1.0 (Educational Tool)'
}
response = requests.get(api_url, headers=headers, timeout=10)
response.raise_for_status()
logger.info(f"✅ HTTP {response.status_code} - Model API response received")
model_data = response.json()
logger.info(f"📊 Raw API response contains {len(model_data)} fields")
# Extract relevant information
info = {
'model_id': model_data.get('id', model_id),
'pipeline_tag': model_data.get('pipeline_tag', 'unknown'),
'tags': model_data.get('tags', []),
'library_name': model_data.get('library_name', 'unknown'),
'downloads': model_data.get('downloads', 0),
'likes': model_data.get('likes', 0)
}
logger.info(f"📋 Extracted model info:")
logger.info(f" - Pipeline: {info['pipeline_tag']}")
logger.info(f" - Library: {info['library_name']}")
logger.info(f" - Downloads: {info['downloads']:,}")
logger.info(f" - Likes: {info['likes']:,}")
logger.info(f" - Tags: {len(info['tags'])} tags")
result = json.dumps(info, indent=2)
logger.info(f"✅ Model info formatting completed ({len(result)} characters)")
return result
except Exception as e:
logger.error(f"Error fetching model info for {model_id}: {e}")
return f"Error fetching model information: {str(e)}"
def search_optimization_guides(query: str) -> str:
"""Search for optimization guides and best practices."""
logger.info("🔍 TOOL CALL: search_optimization_guides")
logger.info(f"📋 Search query: '{query}'")
try:
# Search common optimization documentation URLs
docs_urls = [
"https://huggingface.co/docs/diffusers/optimization/fp16",
"https://huggingface.co/docs/diffusers/optimization/memory",
"https://huggingface.co/docs/diffusers/optimization/torch2",
"https://huggingface.co/docs/diffusers/optimization/mps",
"https://huggingface.co/docs/diffusers/optimization/xformers"
]
logger.info(f"🔎 Searching through {len(docs_urls)} optimization guide URLs...")
results = []
matched_urls = []
for url in docs_urls:
if any(keyword in url for keyword in query.lower().split()):
logger.info(f"✅ URL matched query: {url}")
matched_urls.append(url)
content = fetch_huggingface_docs(url)
if not content.startswith("Error"):
results.append(f"From {url}:\\n{content[:1000]}...\\n")
logger.info(f"📄 Successfully processed content from {url}")
else:
logger.warning(f"❌ Failed to fetch content from {url}")
else:
logger.debug(f"⏭️ URL skipped (no match): {url}")
logger.info(f"📊 Search completed: {len(matched_urls)} URLs matched, {len(results)} successful fetches")
if results:
final_result = "\\n".join(results)
logger.info(f"✅ Returning combined content ({len(final_result)} characters)")
return final_result
else:
logger.warning("❌ No specific optimization guides found for the query")
return "No specific optimization guides found for the query"
except Exception as e:
logger.error(f"Error searching optimization guides: {e}")
return f"Error searching guides: {str(e)}"
# Define tools schema for Gemini (simplified for now)
tools = [
{
"function_declarations": [
{
"name": "fetch_huggingface_docs",
"description": "Fetch current documentation from HuggingFace URLs for diffusers library, models, or optimization guides",
"parameters": {
"type": "object",
"properties": {
"url": {
"type": "string",
"description": "The HuggingFace documentation URL to fetch"
}
},
"required": ["url"]
}
},
{
"name": "fetch_model_info",
"description": "Fetch current model information and metadata from HuggingFace API",
"parameters": {
"type": "object",
"properties": {
"model_id": {
"type": "string",
"description": "The HuggingFace model ID (e.g., 'black-forest-labs/FLUX.1-schnell')"
}
},
"required": ["model_id"]
}
},
{
"name": "search_optimization_guides",
"description": "Search for optimization guides and best practices for diffusers models",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "Search query for optimization topics (e.g., 'memory', 'fp16', 'torch compile')"
}
},
"required": ["query"]
}
}
]
}
]
# Store function implementations for execution
self.tool_functions = {
'fetch_huggingface_docs': fetch_huggingface_docs,
'fetch_model_info': fetch_model_info,
'search_optimization_guides': search_optimization_guides
}
logger.info(f"Created {len(tools[0]['function_declarations'])} tools for Gemini")
return tools
def generate_optimized_code(self,
model_name: str,
prompt_text: str,
image_size: tuple = (768, 1360),
num_inference_steps: int = 4,
use_manual_specs: bool = False,
manual_specs: Dict = None,
memory_analysis: Dict = None) -> str:
"""Generate optimized diffusers code based on hardware specs and memory analysis."""
logger.info(f"Starting code generation for model: {model_name}")
logger.debug(f"Parameters: prompt='{prompt_text[:50]}...', size={image_size}, steps={num_inference_steps}")
logger.debug(f"Manual specs: {use_manual_specs}, Memory analysis provided: {memory_analysis is not None}")
# Get hardware specifications
if use_manual_specs and manual_specs:
logger.info("Using manual hardware specifications")
hardware_specs = manual_specs
logger.debug(f"Manual specs: {hardware_specs}")
# Determine optimization profile based on manual specs
if hardware_specs.get('gpu_info') and hardware_specs['gpu_info']:
vram_gb = hardware_specs['gpu_info'][0]['memory_mb'] / 1024
logger.debug(f"GPU detected with {vram_gb:.1f} GB VRAM")
if vram_gb >= 16:
optimization_profile = 'performance'
elif vram_gb >= 8:
optimization_profile = 'balanced'
else:
optimization_profile = 'memory_efficient'
else:
optimization_profile = 'cpu_only'
logger.info("No GPU detected, using CPU-only profile")
logger.info(f"Selected optimization profile: {optimization_profile}")
else:
logger.info("Using automatic hardware detection")
hardware_specs = self.hardware_detector.specs
optimization_profile = self.hardware_detector.get_optimization_profile()
logger.debug(f"Detected specs: {hardware_specs}")
logger.info(f"Auto-detected optimization profile: {optimization_profile}")
# Create the prompt for Gemini API
logger.debug("Creating generation prompt for Gemini API")
system_prompt = self._create_generation_prompt(
model_name, prompt_text, image_size, num_inference_steps,
hardware_specs, optimization_profile, memory_analysis
)
logger.debug(f"Prompt length: {len(system_prompt)} characters")
# Log the actual prompt being sent to Gemini API
logger.info("=" * 80)
logger.info("PROMPT SENT TO GEMINI API:")
logger.info("=" * 80)
logger.info(system_prompt)
logger.info("=" * 80)
try:
logger.info("Sending request to Gemini API")
response = self.model.generate_content(system_prompt)
# Handle tool calling if present and tools are available
if self.tools and response.candidates[0].content.parts:
for part in response.candidates[0].content.parts:
if hasattr(part, 'function_call') and part.function_call:
function_name = part.function_call.name
function_args = dict(part.function_call.args)
logger.info("🛠️ " + "=" * 60)
logger.info(f"🛠️ GEMINI REQUESTED TOOL CALL: {function_name}")
logger.info("🛠️ " + "=" * 60)
logger.info(f"📋 Tool arguments: {function_args}")
if function_name in self.tool_functions:
logger.info(f"✅ Tool function found, executing...")
tool_result = self.tool_functions[function_name](**function_args)
logger.info("🛠️ " + "=" * 60)
logger.info(f"🛠️ TOOL EXECUTION COMPLETED: {function_name}")
logger.info("🛠️ " + "=" * 60)
logger.info(f"📊 Tool result length: {len(str(tool_result))} characters")
# Log a preview of the tool result
preview = str(tool_result)[:300].replace('\\n', ' ')
logger.info(f"📋 Tool result preview: {preview}...")
logger.info("🛠️ " + "=" * 60)
# Create a follow-up conversation with the tool result
follow_up_prompt = f"""
{system_prompt}
ADDITIONAL CONTEXT FROM TOOLS:
Tool: {function_name}
Result: {tool_result}
Please use this current information to generate the most up-to-date and optimized code.
"""
# Log the follow-up prompt
logger.info("=" * 80)
logger.info("FOLLOW-UP PROMPT SENT TO GEMINI API (WITH TOOL RESULTS):")
logger.info("=" * 80)
logger.info(follow_up_prompt)
logger.info("=" * 80)
# Generate final response with tool context
logger.info("Generating final response with tool context")
final_response = self.model.generate_content(follow_up_prompt)
logger.info("Successfully received final response from Gemini API")
logger.debug(f"Final response length: {len(final_response.text)} characters")
return final_response.text
# No tool calling, return direct response
logger.info("Successfully received response from Gemini API (no tools used)")
logger.debug(f"Response length: {len(response.text)} characters")
return response.text
except Exception as e:
logger.error(f"Error generating code: {str(e)}")
return f"Error generating code: {str(e)}"
def _create_generation_prompt(self,
model_name: str,
prompt_text: str,
image_size: tuple,
num_inference_steps: int,
hardware_specs: Dict,
optimization_profile: str,
memory_analysis: Dict = None) -> str:
"""Create the prompt for Gemini API to generate optimized code."""
base_prompt = f"""
You are an expert in optimizing diffusers library code for different hardware configurations.
NOTE: This system includes curated optimization knowledge from HuggingFace documentation.
TASK: Generate optimized Python code for running a diffusion model with the following specifications:
- Model: {model_name}
- Prompt: "{prompt_text}"
- Image size: {image_size[0]}x{image_size[1]}
- Inference steps: {num_inference_steps}
HARDWARE SPECIFICATIONS:
- Platform: {hardware_specs['platform']} ({hardware_specs['architecture']})
- CPU Cores: {hardware_specs['cpu_count']}
- CUDA Available: {hardware_specs['cuda_available']}
- MPS Available: {hardware_specs['mps_available']}
- Optimization Profile: {optimization_profile}
"""
if hardware_specs.get('gpu_info'):
base_prompt += f"- GPU: {hardware_specs['gpu_info'][0]['name']} ({hardware_specs['gpu_info'][0]['memory_mb']/1024:.1f} GB VRAM)\n"
# Add user dtype preference if specified
if hardware_specs.get('user_dtype'):
base_prompt += f"- User specified dtype: {hardware_specs['user_dtype']}\n"
# Add memory analysis information
if memory_analysis:
memory_info = memory_analysis.get('memory_info', {})
recommendations = memory_analysis.get('recommendations', {})
base_prompt += f"\nMEMORY ANALYSIS:\n"
if memory_info.get('estimated_inference_memory_fp16_gb'):
base_prompt += f"- Model Memory Requirements: {memory_info['estimated_inference_memory_fp16_gb']} GB (FP16 inference)\n"
if memory_info.get('memory_fp16_gb'):
base_prompt += f"- Model Weights Size: {memory_info['memory_fp16_gb']} GB (FP16)\n"
if recommendations.get('recommendations'):
base_prompt += f"- Memory Recommendation: {', '.join(recommendations['recommendations'])}\n"
if recommendations.get('recommended_precision'):
base_prompt += f"- Recommended Precision: {recommendations['recommended_precision']}\n"
if recommendations.get('cpu_offload'):
base_prompt += f"- CPU Offloading Required: {recommendations['cpu_offload']}\n"
if recommendations.get('attention_slicing'):
base_prompt += f"- Attention Slicing Recommended: {recommendations['attention_slicing']}\n"
if recommendations.get('vae_slicing'):
base_prompt += f"- VAE Slicing Recommended: {recommendations['vae_slicing']}\n"
base_prompt += f"""
OPTIMIZATION KNOWLEDGE BASE:
{get_optimization_guide()}
IMPORTANT: For FLUX.1-schnell models, do NOT include guidance_scale parameter as it's not needed.
Using the OPTIMIZATION KNOWLEDGE BASE above, generate Python code that:
1. **Selects the best optimization techniques** for the specific hardware profile
2. **Applies appropriate memory optimizations** based on available VRAM
3. **Uses optimal data types** for the target hardware:
- User specified dtype (if provided): Use exactly as specified
- Apple Silicon (MPS): prefer torch.bfloat16
- NVIDIA GPUs: prefer torch.float16 or torch.bfloat16
- CPU only: use torch.float32
4. **Implements hardware-specific optimizations** (CUDA, MPS, CPU)
5. **Follows model-specific guidelines** (e.g., FLUX guidance_scale handling)
IMPORTANT GUIDELINES:
- Reference the OPTIMIZATION KNOWLEDGE BASE to select appropriate techniques
- Include all necessary imports
- Add brief comments explaining optimization choices
- Generate compact, production-ready code
- Inline values where possible for concise code
- Generate ONLY the Python code, no explanations before or after the code block
"""
return base_prompt
def run_interactive_mode(self):
"""Run the generator in interactive mode."""
print("=== Auto-Diffusers Code Generator ===")
print("This tool generates optimized diffusers code based on your hardware.\n")
# Check hardware
print("=== Hardware Detection ===")
self.hardware_detector.print_specs()
use_manual = input("\nUse manual hardware input? (y/n): ").lower() == 'y'
# Get user inputs
print("\n=== Model Configuration ===")
model_name = input("Model name (default: black-forest-labs/FLUX.1-schnell): ").strip()
if not model_name:
model_name = "black-forest-labs/FLUX.1-schnell"
prompt_text = input("Prompt text (default: A cat holding a sign that says hello world): ").strip()
if not prompt_text:
prompt_text = "A cat holding a sign that says hello world"
try:
width = int(input("Image width (default: 1360): ") or "1360")
height = int(input("Image height (default: 768): ") or "768")
steps = int(input("Inference steps (default: 4): ") or "4")
except ValueError:
width, height, steps = 1360, 768, 4
print("\n=== Generating Optimized Code ===")
# Generate code
optimized_code = self.generate_optimized_code(
model_name=model_name,
prompt_text=prompt_text,
image_size=(height, width),
num_inference_steps=steps,
use_manual_specs=use_manual
)
print("\n" + "="*60)
print("OPTIMIZED DIFFUSERS CODE:")
print("="*60)
print(optimized_code)
print("="*60)
def main():
# Get API key from .env file
api_key = os.getenv('GOOGLE_API_KEY')
if not api_key:
api_key = os.getenv('GEMINI_API_KEY') # fallback
if not api_key:
api_key = input("Enter your Gemini API key: ").strip()
if not api_key:
print("API key is required!")
return
generator = AutoDiffusersGenerator(api_key)
generator.run_interactive_mode()
if __name__ == "__main__":
main()