Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import os | |
| from PIL import Image | |
| from src.eval import main # Import the modified main function from evl.py | |
| # Title and Description | |
| st.title("Fashion Image Generator") | |
| st.write("Upload a rough sketch, set parameters, and generate realistic garment images.") | |
| # File Upload Section | |
| uploaded_file = st.file_uploader("Upload your rough sketch (PNG, JPG, JPEG):", type=["png", "jpg", "jpeg"]) | |
| # Sidebar for Parameters | |
| st.sidebar.title("Model Configuration") | |
| pretrained_model_path = st.sidebar.text_input("Pretrained Model Path", "runwayml/stable-diffusion-inpainting") | |
| dataset_path = st.sidebar.text_input("Dataset Path", "./datasets/dresscode") | |
| output_dir = st.sidebar.text_input("Output Directory", "./outputs") | |
| guidance_scale_sketch = st.sidebar.slider("Sketch Guidance Scale", 1.0, 10.0, 7.5) | |
| batch_size = st.sidebar.number_input("Batch Size", min_value=1, max_value=16, value=1) | |
| mixed_precision = st.sidebar.selectbox("Mixed Precision Mode", ["fp16", "fp32"], index=0) | |
| seed = st.sidebar.number_input("Random Seed", value=42, step=1) | |
| # Run Button | |
| if st.button("Generate Image"): | |
| if uploaded_file: | |
| # Save uploaded sketch locally | |
| os.makedirs("temp_uploads", exist_ok=True) | |
| sketch_path = os.path.join("temp_uploads", uploaded_file.name) | |
| with open(sketch_path, "wb") as f: | |
| f.write(uploaded_file.getbuffer()) | |
| # Prepare arguments for the backend | |
| args = { | |
| "pretrained_model_name_or_path": pretrained_model_path, | |
| "dataset": "dresscode", | |
| "dataset_path": dataset_path, | |
| "output_dir": output_dir, | |
| "guidance_scale": 7.5, | |
| "guidance_scale_sketch": guidance_scale_sketch, | |
| "mixed_precision": mixed_precision, | |
| "batch_size": batch_size, | |
| "seed": seed, | |
| "save_name": "generated_image", # Output file name | |
| } | |
| # Run the backend model | |
| st.write("Generating image...") | |
| try: | |
| output_path = main(args) # Call your backend main function | |
| st.write("Image generation complete!") | |
| # Display the generated image | |
| output_image_path = os.path.join(output_dir, "generated_image.png") # Update if needed | |
| if os.path.exists(output_image_path): | |
| output_image = Image.open(output_image_path) | |
| st.image(output_image, caption="Generated Image", use_column_width=True) | |
| else: | |
| st.error("Image generation failed. No output file found.") | |
| except Exception as e: | |
| st.error(f"An error occurred: {e}") | |
| else: | |
| st.error("Please upload a sketch before generating an image.") | |