Spaces:
Paused
Paused
| import gradio as gr | |
| import google.generativeai as genai | |
| import requests | |
| import logging | |
| # Set up logging | |
| logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s') | |
| # List of popular styles | |
| STYLES = [ | |
| "Photorealistic", "Oil Painting", "Watercolor", "Anime", | |
| "Studio Ghibli", "Black and White", "Polaroid", "Sketch", | |
| "3D Render", "Pixel Art", "Cyberpunk", "Steampunk", | |
| "Art Nouveau", "Pop Art", "Minimalist" | |
| ] | |
| # Default negative prompt | |
| DEFAULT_NEGATIVE_PROMPT = """ | |
| ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, | |
| extra limbs, disfigured, deformed, body out of frame, bad anatomy, watermark, signature, | |
| cut off, low contrast, underexposed, overexposed, bad art, beginner, amateur, distorted face | |
| """ | |
| def enhance_prompt(google_api_key, prompt, style): | |
| genai.configure(api_key=google_api_key) | |
| model = genai.GenerativeModel("gemini-2.0-flash-lite") | |
| enhanced_prompt_request = f""" | |
| Task: Enhance the following prompt for image generation. | |
| Style: {style} | |
| Original prompt: '{prompt}' | |
| Instructions: | |
| 1. Expand the prompt to be more detailed and vivid. | |
| 2. Incorporate elements of the specified style. | |
| 3. Maintain the original intent of the prompt. | |
| 4. Provide ONLY the enhanced prompt, without any explanations or options. | |
| 5. Keep the enhanced prompt concise, ideally under 100 words. | |
| Enhanced prompt: | |
| """ | |
| try: | |
| response = model.generate_content(enhanced_prompt_request) | |
| enhanced_prompt = response.text.strip() | |
| prefixes_to_remove = ["Enhanced prompt:", "Here's the enhanced prompt:", "The enhanced prompt is:"] | |
| for prefix in prefixes_to_remove: | |
| if enhanced_prompt.lower().startswith(prefix.lower()): | |
| enhanced_prompt = enhanced_prompt[len(prefix):].strip() | |
| logging.info(f"Enhanced prompt: {enhanced_prompt}") | |
| return enhanced_prompt | |
| except Exception as e: | |
| logging.error(f"Error in enhance_prompt: {str(e)}") | |
| raise | |
| def generate_image(stability_api_key, enhanced_prompt, style, negative_prompt): | |
| url = "https://api.stability.ai/v2beta/stable-image/generate/sd3" | |
| headers = { | |
| "Accept": "image/*", | |
| "Authorization": f"Bearer {stability_api_key}" | |
| } | |
| data = { | |
| "prompt": f"{enhanced_prompt}, Style: {style}", | |
| "negative_prompt": negative_prompt, | |
| "model": "sd3.5-large-turbo", | |
| "output_format": "jpeg", | |
| "width": 1024, | |
| "height": 1024, | |
| "num_images": 1, | |
| "steps": 4, # SD3.5 Large Turbo generates high-quality images in just 4 steps | |
| } | |
| try: | |
| response = requests.post(url, headers=headers, files={"none": ''}, data=data) | |
| response.raise_for_status() | |
| logging.debug(f"Response headers: {response.headers}") | |
| logging.debug(f"Response content type: {response.headers.get('content-type')}") | |
| if response.headers.get('content-type').startswith('image/'): | |
| return response.content | |
| else: | |
| error_message = response.text | |
| logging.error(f"Unexpected content type: {response.headers.get('content-type')}. Response: {error_message}") | |
| raise Exception(f"Unexpected content type: {response.headers.get('content-type')}. Response: {error_message}") | |
| except requests.exceptions.RequestException as e: | |
| logging.error(f"Request failed: {str(e)}") | |
| raise Exception(f"Request failed: {str(e)}") | |
| def process_and_generate(google_api_key, stability_api_key, prompt, style, negative_prompt): | |
| try: | |
| enhanced_prompt = enhance_prompt(google_api_key, prompt, style) | |
| image_bytes = generate_image(stability_api_key, enhanced_prompt, style, negative_prompt) | |
| # Save image to a file | |
| with open("generated_image.jpeg", "wb") as f: | |
| f.write(image_bytes) | |
| # Return the file path and enhanced prompt | |
| return "generated_image.jpeg", enhanced_prompt | |
| except Exception as e: | |
| logging.error(f"Error in process_and_generate: {str(e)}") | |
| return str(e), enhanced_prompt if 'enhanced_prompt' in locals() else "Error occurred before prompt enhancement" | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Stability AI SD3.5 Large Turbo Image Generator with Google Gemini Prompt Enhancement") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| google_api_key = gr.Textbox(label="Google AI API Key", type="password") | |
| stability_api_key = gr.Textbox(label="Stability AI API Key", type="password") | |
| prompt = gr.Textbox(label="Prompt") | |
| style = gr.Dropdown(label="Style", choices=STYLES) | |
| negative_prompt = gr.Textbox(label="Negative Prompt", value=DEFAULT_NEGATIVE_PROMPT) | |
| submit_btn = gr.Button("Generate Image") | |
| with gr.Column(scale=1): | |
| image_output = gr.Image(label="Generated Image", type="filepath") | |
| enhanced_prompt_output = gr.Textbox(label="Enhanced Prompt") | |
| submit_btn.click( | |
| process_and_generate, | |
| inputs=[google_api_key, stability_api_key, prompt, style, negative_prompt], | |
| outputs=[image_output, enhanced_prompt_output] | |
| ) | |
| demo.launch() |