Spaces:
Sleeping
Sleeping
| import os | |
| import pandas as np | |
| import torch | |
| import streamlit as st | |
| from PIL import Image | |
| from accelerate import Accelerator | |
| from diffusers import DDIMScheduler, AutoencoderKL | |
| from transformers import CLIPTextModel, CLIPTokenizer | |
| from src.mgd_pipelines.mgd_pipe import MGDPipe | |
| from src.mgd_pipelines.mgd_pipe_disentangled import MGDPipeDisentangled | |
| from src.utils.set_seeds import set_seed | |
| from src.utils.image_from_pipe import generate_images_from_mgd_pipe | |
| from src.datasets.dresscode import DressCodeDataset | |
| # Set environment variables | |
| os.environ["TOKENIZERS_PARALLELISM"] = "true" | |
| os.environ["WANDB_START_METHOD"] = "thread" | |
| # Function to process inputs and run inference | |
| def run_inference(prompt, sketch_image=None, category="dresses", seed=None, mixed_precision="fp16"): | |
| # Initialize accelerator | |
| accelerator = Accelerator(mixed_precision=mixed_precision) | |
| device = accelerator.device | |
| # Load models and datasets | |
| tokenizer = CLIPTokenizer.from_pretrained("microsoft/xclip-base-patch32", subfolder="tokenizer") | |
| text_encoder = CLIPTextModel.from_pretrained("microsoft/xclip-base-patch32", subfolder="text_encoder") | |
| vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", subfolder="vae") | |
| val_scheduler = DDIMScheduler.from_pretrained("ptx0/pseudo-journey-v2", subfolder="scheduler") | |
| # Load UNet (assumed pretrained) | |
| unet = torch.hub.load("aimagelab/multimodal-garment-designer", "mgd", pretrained=True) | |
| # Freeze VAE and text encoder | |
| vae.requires_grad_(False) | |
| text_encoder.requires_grad_(False) | |
| # Set seed for reproducibility | |
| if seed is not None: | |
| set_seed(seed) | |
| # Load appropriate dataset | |
| category = [category] | |
| test_dataset = DressCodeDataset( | |
| dataroot_path="path_to_dataset", phase="test", category=category, size=(512, 384) | |
| ) | |
| test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False) | |
| # Move models to the device | |
| text_encoder.to(device) | |
| vae.to(device) | |
| unet.to(device).eval() | |
| # Handle sketch and text inputs | |
| if sketch_image is not None: | |
| # Process the sketch (resize, normalize, etc.) | |
| sketch_image = sketch_image.resize((512, 384)) | |
| sketch_tensor = torch.tensor(np.array(sketch_image)).unsqueeze(0).float().to(device) | |
| # Select pipeline (disentangled if required) | |
| val_pipe = MGDPipeDisentangled( | |
| text_encoder=text_encoder, | |
| vae=vae, | |
| unet=unet, | |
| tokenizer=tokenizer, | |
| scheduler=val_scheduler, | |
| ).to(device) | |
| val_pipe.enable_attention_slicing() | |
| # Generate image | |
| generated_images = generate_images_from_mgd_pipe( | |
| test_dataloader=test_dataloader, | |
| pipe=val_pipe, | |
| guidance_scale=7.5, | |
| seed=seed, | |
| sketch_image=sketch_tensor if sketch_image is not None else None, | |
| prompt=prompt | |
| ) | |
| return generated_images[0] # Assuming single image output | |
| # Streamlit UI | |
| st.title("Fashion Image Generator") | |
| st.write("Generate colorful fashion images based on a rough sketch and/or a text prompt.") | |
| # Upload a sketch image | |
| uploaded_sketch = st.file_uploader("Upload a rough sketch (optional)", type=["png", "jpg", "jpeg"]) | |
| # Text input for prompt | |
| prompt = st.text_input("Enter a prompt (optional)", "A red dress with floral patterns") | |
| # Input options | |
| category = st.text_input("Enter category (optional):", "dresses") | |
| seed = st.slider("Seed", min_value=1, max_value=100, step=1, value=None) | |
| precision = st.selectbox("Select precision:", ["fp16", "fp32"]) | |
| # Show uploaded sketch image | |
| if uploaded_sketch is not None: | |
| sketch_image = Image.open(uploaded_sketch) | |
| st.image(sketch_image, caption="Uploaded Sketch", use_column_width=True) | |
| # Button to generate image | |
| if st.button("Generate Image"): | |
| with st.spinner("Generating image..."): | |
| # Run inference with sketch or prompt (or both) | |
| result_image = run_inference(prompt, sketch_image, category, seed, precision) | |
| st.image(result_image, caption="Generated Image", use_column_width=True) | |