Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import torch | |
| from PIL import Image | |
| from io import BytesIO | |
| from diffusers import DDIMScheduler, AutoencoderKL | |
| from transformers import CLIPTextModel, CLIPTokenizer | |
| from src.mgd_pipelines.mgd_pipe import MGDPipe | |
| # Initialize the model and other components | |
| def load_model(): | |
| # Define your model loading logic | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", subfolder="vae") | |
| tokenizer = CLIPTokenizer.from_pretrained("microsoft/xclip-base-patch32", subfolder="tokenizer") | |
| text_encoder = CLIPTextModel.from_pretrained("microsoft/xclip-base-patch32", subfolder="text_encoder") | |
| unet = torch.hub.load("aimagelab/multimodal-garment-designer", model="mgd", pretrained=True) | |
| scheduler = DDIMScheduler.from_pretrained("stabilityai/sd-scheduler", subfolder="scheduler") | |
| pipe = MGDPipe( | |
| text_encoder=text_encoder, | |
| vae=vae, | |
| unet=unet.to(vae.dtype), | |
| tokenizer=tokenizer, | |
| scheduler=scheduler, | |
| ).to(device) | |
| return pipe | |
| pipe = load_model() | |
| def generate_images(pipe, text_input=None, sketch=None): | |
| # Generate images from text or sketch or both | |
| images = [] | |
| if text_input: | |
| prompt = [text_input] | |
| images.extend(pipe(prompt=prompt)) | |
| if sketch: | |
| sketch_image = Image.open(sketch).convert("RGB") | |
| images.extend(pipe(sketch=sketch_image)) | |
| return images | |
| # Streamlit UI | |
| st.title("Sketch & Text-based Image Generation") | |
| st.write("Generate images based on rough sketches, text input, or both.") | |
| option = st.radio("Select Input Type", ("Sketch", "Text", "Both")) | |
| if option in ["Sketch", "Both"]: | |
| sketch_file = st.file_uploader("Upload a Sketch", type=["png", "jpg", "jpeg"]) | |
| if option in ["Text", "Both"]: | |
| text_input = st.text_input("Enter Text Prompt", placeholder="Describe the image you want to generate") | |
| if st.button("Generate"): | |
| if option == "Sketch" and not sketch_file: | |
| st.error("Please upload a sketch.") | |
| elif option == "Text" and not text_input: | |
| st.error("Please provide text input.") | |
| else: | |
| # Generate images based on user input | |
| with st.spinner("Generating images..."): | |
| sketches = BytesIO(sketch_file.read()) if sketch_file else None | |
| images = generate_images(pipe, text_input=text_input, sketch=sketches) | |
| # Display results | |
| for i, img in enumerate(images): | |
| st.image(img, caption=f"Generated Image {i+1}") | |