File size: 11,573 Bytes
2d37af9
 
abf7663
9390207
 
2d37af9
14d4d16
9390207
2d37af9
14d4d16
9390207
abf7663
2d37af9
14d4d16
9390207
 
 
 
14d4d16
abf7663
9390207
14d4d16
 
 
9390207
 
abf7663
9390207
 
 
 
 
 
 
abf7663
9390207
 
 
abf7663
9390207
abf7663
9390207
 
14d4d16
 
9390207
14d4d16
9390207
 
14d4d16
abf7663
9390207
14d4d16
abf7663
9390207
14d4d16
abf7663
9390207
 
14d4d16
2d37af9
9390207
abf7663
14d4d16
9390207
 
14d4d16
abf7663
14d4d16
2d37af9
 
 
 
 
 
 
 
 
14d4d16
9390207
14d4d16
9390207
 
 
 
 
 
 
14d4d16
 
9390207
2d37af9
abf7663
9390207
2d37af9
 
9390207
 
 
 
 
 
 
 
 
 
 
 
 
14d4d16
9390207
 
 
14d4d16
9390207
14d4d16
 
9390207
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14d4d16
 
 
9390207
 
 
 
2d37af9
9390207
abf7663
 
14d4d16
 
 
9390207
 
14d4d16
 
2d37af9
 
9390207
14d4d16
 
 
 
9390207
 
14d4d16
 
 
 
9390207
 
 
 
 
 
 
 
14d4d16
 
 
2d37af9
 
 
9390207
14d4d16
 
9390207
14d4d16
 
 
9390207
14d4d16
 
9390207
 
 
 
 
 
 
 
14d4d16
9390207
 
 
 
14d4d16
 
2d37af9
14d4d16
9390207
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
import gradio as gr
import torch
from diffusers import StableDiffusionPipeline, DDIMScheduler, AutoencoderTiny
# For Image-to-Image, you would also import:
# from diffusers import StableDiffusionImg2ImgPipeline
from PIL import Image
import os # For better logging/debugging
from typing import Literal # For type hinting the gender choices

# --- Configuration ---
# 1. Force CPU usage for compatibility on machines without a GPU
device = "cpu"

# 2. Choose a smaller/distilled Stable Diffusion model for CPU speed
#    'nota-ai/bk-sdm-small' offers a good balance of size, speed, and reasonable quality for CPU.
#    If higher quality is essential and you can tolerate much longer generation times on CPU,
#    you might consider 'runwayml/stable-diffusion-v1-5', but be prepared for significant slowdowns
#    and potentially higher memory consumption that might require `enable_sequential_cpu_offload()`.
model_id = "nota-ai/bk-sdm-small"

# 3. Tiny VAE for drastically faster encoding/decoding on CPU. This is a crucial optimization.
tiny_vae_id = "sayakpaul/taesd-diffusers"

# --- Model Loading ---
# Load the pipeline globally when the application starts to avoid reloading on each request.
print(f"[{os.getpid()}] Loading model: {model_id} on {device}...")
try:
    # Use StableDiffusionPipeline for Text-to-Image generation (generate a new person in a style)
    # If you want to transform an uploaded image (Image-to-Image), uncomment the line below
    # and replace `StableDiffusionPipeline` with `StableDiffusionImg2ImgPipeline`.
    pipe_class = StableDiffusionPipeline
    # pipe_class = StableDiffusionImg2ImgPipeline # Uncomment this for Image-to-Image functionality

    pipe = pipe_class.from_pretrained(
        model_id,
        torch_dtype=torch.float32, # CPU usually performs best with float32
        low_cpu_mem_usage=True,    # Helps reduce peak memory usage on CPU
        safety_checker=None        # Disable safety checker to save CPU cycles and memory for faster generation
    )
    print(f"[{os.getpid()}] Main pipeline loaded.")

    # Load and assign the Tiny VAE for significant speed optimization in the VAE step
    print(f"[{os.getpid()}] Loading Tiny VAE from {tiny_vae_id}...")
    try:
        pipe.vae = AutoencoderTiny.from_pretrained(tiny_vae_id, torch_dtype=torch.float32)
        print(f"[{os.getpid()}] Tiny VAE loaded successfully.")
    except Exception as vae_e:
        print(f"[{os.getpid()}] Warning: Could not load Tiny VAE '{tiny_vae_id}': {vae_e}. Using default VAE (this will be slower).")
        # Ensure the default VAE is explicitly moved to CPU if Tiny VAE fails to load
        pipe.vae.to(device)

    # Move entire pipeline components to CPU explicitly
    pipe.to(device)

    # Set up the scheduler. DDIMScheduler is a good general-purpose choice.
    pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)

    # Optional: Enable CPU offload if you encounter Out-Of-Memory errors on CPU,
    # especially with larger models. Be aware that this will make generation significantly slower.
    # pipe.enable_sequential_cpu_offload()

    print(f"[{os.getpid()}] Model fully loaded and configured on {device}.")

except Exception as e:
    print(f"[{os.getpid()}] FATAL ERROR: Failed to load models: {e}")
    # Raise an exception to prevent the application from starting if model loading fails
    raise RuntimeError(f"Failed to load Stable Diffusion model: {e}")

# --- Preset Styles ---
styles = {
    "Pixar": "pixar style portrait of",
    "Anime": "anime style portrait of",
    "Cyberpunk": "cyberpunk futuristic avatar of",
    "Disney": "disney movie character of",
    "Sketch": "pencil sketch portrait of",
    "Astronaut": "realistic astronaut with helmet, portrait of"
}

# --- Generation Function ---
def generate_avatar(image_input: Image.Image, style: str, gender: Literal["male", "female", "unspecified"]):
    """
    Generates an avatar based on a chosen style and gender.

    - If using StableDiffusionPipeline (Text-to-Image): The uploaded `image_input`
      is used only to trigger the generation and is NOT directly used to
      influence the avatar's appearance. A new person is generated based on the text.
    - If using StableDiffusionImg2ImgPipeline (Image-to-Image - commented out by default):
      The `image_input` WOULD be used as the base image for transformation.
    """
    if image_input is None:
        gr.Warning("Please upload an image to enable avatar generation. (Even if it's not directly used for content, it acts as a trigger).")
        return None

    # Base prompt from the selected style
    base_prompt = styles[style]

    # Construct the subject part of the prompt based on gender selection
    gender_subject = ""
    if gender == "male":
        gender_subject = "a man"
    elif gender == "female":
        gender_subject = "a woman"
    else: # unspecified
        gender_subject = "a person" # Model will default based on its biases if no gender specified

    # Enhance the prompt for better quality and detail in text-to-image generation
    prompt = f"{base_prompt} {gender_subject}, high quality, detailed, professional photography, studio lighting, volumetric lighting, 4k, cinematic, sharp focus"
    # Stronger negative prompt to avoid common issues like low quality, distortions, and undesired artifacts
    negative_prompt = "low resolution, blurry, distorted, bad quality, ugly, cartoon, sketch, duplicate, out of frame, bad anatomy, deformed, extra limbs, malformed hands, missing fingers, watermark, text, signature, low contrast, oversaturated"

    # Inference parameters (tuned for a balance of speed and quality on CPU)
    num_inference_steps = 25 # Generally, 20-30 steps is a good range for quality vs speed on CPU
    guidance_scale = 7.5    # Higher values make output closer to prompt, but can be less diverse

    print(f"[{os.getpid()}] Generating for style: '{style}', gender: '{gender}', with prompt: '{prompt}' (Steps: {num_inference_steps}, Guidance: {guidance_scale})")

    try:
        # Use torch.no_grad() or torch.inference_mode() to disable gradient calculations
        # during inference, which saves memory and speeds up computation.
        with torch.no_grad(): # For PyTorch >= 1.9, torch.inference_mode() is also an option
            if isinstance(pipe, StableDiffusionPipeline):
                # Text-to-Image generation: Image_input is ignored for content
                generated_image = pipe(
                    prompt=prompt,
                    negative_prompt=negative_prompt,
                    num_inference_steps=num_inference_steps,
                    guidance_scale=guidance_scale,
                    height=512, # Stable Diffusion 1.x models are usually trained at 512x512
                    width=512
                ).images[0]
            # elif isinstance(pipe, StableDiffusionImg2ImgPipeline):
            #     # Image-to-Image generation: Uncomment this block if you switch to Img2ImgPipeline
            #     # The 'strength' parameter controls how much noise is added to the input image.
            #     # 0.0 means no change, 1.0 means complete re-imagining (like text-to-image).
            #     # A value around 0.7-0.8 is typical for style transfer.
            #     strength = 0.75
            #     generated_image = pipe(
            #         prompt=prompt,
            #         image=image_input, # Pass the uploaded image here for img2img
            #         negative_prompt=negative_prompt,
            #         num_inference_steps=num_inference_steps,
            #         guidance_scale=guidance_scale,
            #         strength=strength
            #     ).images[0]
            else:
                raise ValueError("Unsupported pipeline type. Please check model loading.")

        print(f"[{os.getpid()}] Image generation complete.")
        return generated_image

    except Exception as e:
        print(f"[{os.getpid()}] Error during image generation: {e}")
        # Display an error message to the user in the Gradio interface
        gr.Error(f"An error occurred during image generation: {e}")
        return None # Return None to clear the output image

# --- Gradio Interface Definition ---
with gr.Blocks() as demo:
    gr.Markdown("## 🎨 Stable Diffusion Avatar Generator with Preset Styles (CPU Optimized)")
    gr.Markdown(
        "This demo uses a smaller, distilled Stable Diffusion model and is optimized for CPU inference. "
        "Generation will still take time on CPU compared to GPU (e.g., 20-60 seconds per image depending on CPU and parameters).<br>"
        "**Note:** The uploaded image is currently used only to trigger generation and is **not directly influencing the avatar's appearance**. "
        "It's here for your reference or potential future Image-to-Image features. You will get a new person in the chosen style."
    )

    with gr.Row():
        with gr.Column():
            # Image input component. type="pil" ensures a PIL Image object is passed to the function.
            image_input = gr.Image(
                label="Upload your photo",
                type="pil",
                sources=["upload", "webcam"], # Allow file upload or webcam capture
                # Optional: Add a placeholder image path if you want a default visual
                # value="assets/placeholder.jpg"
            )
            style_selector = gr.Radio(
                choices=list(styles.keys()),
                label="Choose a style",
                value="Anime", # Default selected style
                info="Select the artistic style for your avatar."
            )
            gender_selector = gr.Radio(
                choices=["male", "female", "unspecified"],
                label="Choose a Gender",
                value="male", # Default to male to address your specific issue
                info="Explicitly set the gender of the generated person. 'Unspecified' may lead to biased results from the model."
            )
            generate_btn = gr.Button("Generate Avatar", variant="primary")

        with gr.Column():
            output_image = gr.Image(label="Generated Avatar")

    # Connect the button click to the generation function, passing all inputs
    generate_btn.click(
        fn=generate_avatar,
        inputs=[image_input, style_selector, gender_selector], # Now includes gender_selector
        outputs=output_image
    )

    # Optional: Add examples for quick testing
    gr.Examples(
        examples=[
            # Example format: [image_path_or_None, style_name, gender]
            # Use None for image_input as it's not directly influencing the output in text-to-image mode
            [None, "Pixar", "male"],
            [None, "Anime", "female"],
            [None, "Cyberpunk", "unspecified"], # To show what 'unspecified' might produce
            [None, "Disney", "male"],
            [None, "Sketch", "female"],
            [None, "Astronaut", "male"]
        ],
        inputs=[image_input, style_selector, gender_selector],
        # fn=generate_avatar, # Uncomment if you want examples to run the generation live
        # outputs=output_image,
        cache_examples=False, # Set to True if examples are pre-computed images, False for live generation
        label="Quick Examples (Generates new images each time)"
    )

# Launch the Gradio application
# share=True will generate a public link (useful for sharing demos temporarily)
# auth=("username", "password") for basic authentication
demo.launch(inbrowser=True, show_error=True)