Spaces:
Running
Running
import streamlit as st | |
import torch | |
import io | |
import os | |
from PIL import Image | |
# Set page config | |
st.set_page_config( | |
page_title="Portrait Generator", | |
page_icon="🖼️", | |
layout="centered" | |
) | |
# App title and description | |
st.title("AI Portrait Generator") | |
st.markdown("Generate beautiful portraits using the AWPortraitCN2 model") | |
# Model parameters | |
with st.sidebar: | |
st.header("Generation Settings") | |
steps = st.slider("Inference Steps", min_value=20, max_value=100, value=40) | |
guidance_scale = st.slider("Guidance Scale", min_value=1.0, max_value=15.0, value=7.5, step=0.5) | |
negative_prompt = st.text_area( | |
"Negative Prompt", | |
value="lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, watermark, signature, out of frame" | |
) | |
seed = st.number_input("Random Seed (leave at -1 for random)", min_value=-1, value=-1) | |
# Main prompt input | |
prompt = st.text_area( | |
"Describe the portrait you want to generate", | |
value="Masterpiece portrait of a beautiful young woman with flowing hair, detailed face, photorealistic, 8k, professional photography" | |
) | |
# Function to load model using modern API | |
def load_model(): | |
try: | |
# Import these inside the function to handle errors gracefully | |
from diffusers import AutoPipelineForText2Image | |
# Use AutoPipeline which is more compatible with newer versions | |
pipeline = AutoPipelineForText2Image.from_pretrained( | |
"Shakker-Labs/AWPortraitCN2", | |
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
use_safetensors=True, | |
variant="fp16" if torch.cuda.is_available() else None | |
) | |
# Move to GPU if available | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
pipeline = pipeline.to(device) | |
return pipeline | |
except Exception as e: | |
st.error(f"Error loading model: {str(e)}") | |
st.info("Debug info: Using modern API with AutoPipelineForText2Image") | |
# Fallback to traditional StableDiffusionPipeline if needed | |
try: | |
st.info("Trying alternative method...") | |
from diffusers import StableDiffusionPipeline | |
pipeline = StableDiffusionPipeline.from_pretrained( | |
"Shakker-Labs/AWPortraitCN2", | |
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
use_safetensors=True | |
) | |
# Move to GPU if available | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
pipeline = pipeline.to(device) | |
return pipeline | |
except Exception as e2: | |
st.error(f"Alternative method also failed: {str(e2)}") | |
return None | |
# Generate button | |
if st.button("Generate Portrait", type="primary"): | |
with st.spinner("Loading model and generating portrait..."): | |
try: | |
# Load the model | |
pipeline = load_model() | |
if pipeline is None: | |
st.error("Failed to load the model. Check the logs for details.") | |
st.stop() | |
# Set seed if specified | |
generator = None | |
if seed != -1: | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
generator = torch.Generator(device).manual_seed(seed) | |
# Generate the image | |
image = pipeline( | |
prompt=prompt, | |
negative_prompt=negative_prompt, | |
num_inference_steps=steps, | |
guidance_scale=guidance_scale, | |
generator=generator | |
).images[0] | |
# Display the generated image | |
st.image(image, caption="Generated Portrait", use_column_width=True) | |
# Option to download | |
buf = io.BytesIO() | |
image.save(buf, format="PNG") | |
byte_im = buf.getvalue() | |
st.download_button( | |
label="Download Portrait", | |
data=byte_im, | |
file_name="generated_portrait.png", | |
mime="image/png" | |
) | |
except Exception as e: | |
st.error(f"An error occurred during generation: {str(e)}") | |
st.info("Make sure you have enough GPU memory (T4 or better recommended).") | |
# Add hardware info at the bottom | |
if torch.cuda.is_available(): | |
st.markdown("---") | |
st.markdown(f""" | |
### Hardware Info | |
- Running on: {torch.cuda.get_device_name(0)} | |
- Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB | |
""") | |
else: | |
st.markdown("---") | |
st.markdown("⚠️ Running on CPU. For better performance, use a GPU runtime.") |