File size: 11,160 Bytes
57eccf2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
# Generate_holiday_borders.py

import streamlit as st
import os
from PIL import Image
from io import BytesIO
import zipfile
from dotenv import load_dotenv

# Import from helper_utilities.py
from utils.helper_utilities import (
    get_closest_aspect_ratio, process_image, generate_flux_image,  # Replace ControlNet API call with Flux API call
    draw_crop_preview, combine_images, get_next_largest_aspect_ratio
)

# Import from configuration.py
from utils.configuration import (
    default_guidance_scale, 
    default_num_inference_steps, default_seed, 
    holiday_border_prompts
)



# Initialize session state
if 'uploaded_file' not in st.session_state:
    st.session_state.uploaded_file = None
if 'card_params' not in st.session_state:
    st.session_state.card_params = [{} for _ in range(4)]
if 'generated_cards' not in st.session_state:
    st.session_state.generated_cards = [None for _ in range(4)]

# Streamlit app starts here
st.image("img/fireworksai_logo.png")
st.title("🎨 Holiday Multi-Card Generator🎨")
st.markdown(
    """Welcome to the first part of your holiday card creation journey! 🌟 Here, you'll play around with different styles, prompts, and parameters to design the perfect card border before adding a personal message in the section 'Customize Holiday Borders'. Let your creativity flow! πŸŽ‰

### How it works:

1. **πŸ–ΌοΈ Upload Your Image:** Choose the image that will be the center of your card.
2. **βœ‚οΈ Crop It:** Adjust the crop to highlight the most important part of your image.
3. **πŸ’‘ Choose Your Style:** Select from festive border themes or input your own custom prompt to design something unique.
4. **βš™οΈ Fine-Tune Parameters:** Experiment with guidance scales, seeds, inference steps, and more for the perfect aesthetic.
5. **πŸ‘€ Preview & Download:** See your generated holiday cards, tweak them until they're just right, and download the final designs and metadata in a neat ZIP file!

Once you've got the perfect look, head over to **Part B** to add your personal message and finalize your holiday card! πŸ’Œ
"""
)

# Load API Key
st.divider()
st.subheader("Load Fireworks API Key")

# Define and ensure the .env directory and file exist
dotenv_path = os.path.join(os.path.dirname(__file__), '..', 'env', '.env')
os.makedirs(os.path.dirname(dotenv_path), exist_ok=True)

# Create the .env file if it doesn't exist
if not os.path.exists(dotenv_path):
    with open(dotenv_path, "w") as f:
        st.success(f"Created {dotenv_path}")

# Load environment variables from the .env file
load_dotenv(dotenv_path, override=True)

# Check if the Fireworks API key is set or blank
fireworks_api_key = os.getenv("FIREWORKS_API_KEY")

# Show the entire app but disable running parts if no API key
if not fireworks_api_key or fireworks_api_key.strip() == "":
    fireworks_api_key = st.text_input("Enter Fireworks API Key", type="password")

    # Optionally, allow the user to save the API key to the .env file
    if fireworks_api_key and st.checkbox("Save API key for future use"):
        with open(dotenv_path, "a") as f:
            f.write(f"FIREWORKS_API_KEY={fireworks_api_key}\n")
        st.success("API key saved to .env file.")
else:
    st.success(f"API key loaded successfully: partial preview {fireworks_api_key[:5]}")

# Step 1: Upload Image
st.divider()
st.subheader("πŸ–ΌοΈ Step 1: Upload Your Picture!")
uploaded_file = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"])
if uploaded_file is not None:
    st.session_state.uploaded_file = uploaded_file
    original_image = Image.open(uploaded_file)
    img_width, img_height = original_image.size
    
    # Calculate the next largest valid aspect ratio
    aspect_ratio = get_next_largest_aspect_ratio(img_width, img_height)  # Ensure the aspect ratio is valid

    st.image(original_image, caption="Uploaded Image", use_column_width=True)

    # Step 2: Crop Image
    st.divider()
    st.subheader("βœ‚οΈ Step 2: Crop It Like It's Hot!")
    img_width, img_height = original_image.size
    col1, col2 = st.columns(2)
    with col1:
        x_pos = st.slider("X position (Left-Right)", 0, img_width, img_width // 4)
        crop_width = st.slider("Width", 10, img_width - x_pos, min(img_width // 2, img_width - x_pos))
    with col2:
        y_pos = st.slider("Y position (Up-Down)", 0, img_height, img_height // 4)
        crop_height = st.slider("Height", 10, img_height - y_pos, min(img_height // 2, img_height - y_pos))

    preview_image = draw_crop_preview(original_image.copy(), x_pos, y_pos, crop_width, crop_height)
    st.image(preview_image, caption="Crop Preview", use_column_width=True)

    # Step 3: Set Card Parameters
    st.divider()
    st.subheader("βš™οΈ Step 3: Set Your Festive Border Design with Flux + Fireworks!")
    for i in range(4):
        with st.expander(f"Holiday Card {i + 1} Parameters"):
            card_params = st.session_state.card_params[i]

            # Set default values for card parameters if not already set
            card_params.setdefault("prompt", holiday_border_prompts[i % len(holiday_border_prompts)])  # Set default from holiday prompts
            card_params.setdefault("guidance_scale", default_guidance_scale)
            card_params.setdefault("num_inference_steps", default_num_inference_steps)
            card_params.setdefault("seed", i * 100)
        
            selected_prompt = st.selectbox(f"Choose a holiday-themed prompt for Holiday Card {i + 1}", options=["Custom"] + holiday_border_prompts)
            custom_prompt = st.text_input(f"Enter custom prompt for Holiday Card {i + 1}", value=card_params["prompt"]) if selected_prompt == "Custom" else selected_prompt
            
            # Allow the user to tweak other parameters
            guidance_scale = st.slider(
                f"Guidance Scale for Holiday Card {i + 1}",
                min_value=0.0, max_value=20.0, value=card_params["guidance_scale"], step=0.1
            )
            num_inference_steps = st.slider(
                f"Number of Inference Steps for Holiday Card {i + 1}",
                min_value=1, max_value=100, value=card_params["num_inference_steps"], step=1
            )
            seed = st.slider(
                f"Random Seed for Holiday Card {i + 1}",
                min_value=0, max_value=1000, value=card_params["seed"]
            )
            
            st.session_state.card_params[i] = {
                "prompt": custom_prompt,
                "guidance_scale": guidance_scale,
                "num_inference_steps": num_inference_steps,
                "seed": seed
            }

# Generate Holiday Cards
st.divider()
st.subheader("Preview and Share the Holiday Cheer! πŸŽ…πŸ“¬")
st.markdown(""" 
Click "Generate Image" and watch the magic happen! Your holiday card is just moments away from spreading joy to everyone on your list. πŸŽ„πŸŽβœ¨
""")

# Disable the generate button if the API key is missing
if not fireworks_api_key or fireworks_api_key.strip() == "":
    st.warning("Enter a valid Fireworks API key to enable card generation.")
    generate_button = st.button("Generate Holiday Cards", disabled=True)
else:
    generate_button = st.button("Generate Holiday Cards")

if generate_button:
    with st.spinner("Processing..."):
        cols = st.columns(4)
        image_files = []
        metadata = []

        for i, params in enumerate(st.session_state.card_params):
            # Generate image using Flux API with the next largest valid aspect ratio
            generated_image = generate_flux_image(
                model_path="flux-1-schnell-fp8",  
                prompt=params['prompt'],
                steps=params['num_inference_steps'],  
                guidance_scale=params['guidance_scale'],  
                seed=params['seed'],  
                api_key=fireworks_api_key,
                aspect_ratio=f"{aspect_ratio[0]}:{aspect_ratio[1]}"  # Ensure aspect ratio is passed as a string in "width:height" format
            )

            if generated_image:
                generated_image = generated_image.resize(original_image.size)

                # Center the cropped original image onto the generated image
                cropped_original = original_image.crop((x_pos, y_pos, x_pos + crop_width, y_pos + crop_height))
                flux_width, flux_height = generated_image.size
                cropped_width, cropped_height = cropped_original.size
                center_x = (flux_width - cropped_width) // 2
                center_y = (flux_height - cropped_height) // 2
                final_image = generated_image.copy()
                final_image.paste(cropped_original, (center_x, center_y))

                # Save final image and metadata
                img_byte_arr = BytesIO()
                final_image.save(img_byte_arr, format="PNG")
                img_byte_arr.seek(0)
                image_files.append((f"holiday_card_{i + 1}.png", img_byte_arr))

                metadata.append({
                    "Card": f"Holiday Card {i + 1}",
                    "Prompt": params['prompt'],
                    "Guidance Scale": params['guidance_scale'],
                    "Inference Steps": params['num_inference_steps'],
                    "Seed": params['seed']
                })

                st.session_state.generated_cards[i] = {
                    "image": final_image,
                    "metadata": metadata[-1]
                }

                # Display the final holiday card
                cols[i].image(final_image, caption=f"Holiday Card {i + 1}", use_column_width=True)
                cols[i].write(f"**Prompt:** {params['prompt']}")
                cols[i].write(f"**Guidance Scale:** {params['guidance_scale']}")
                cols[i].write(f"**Inference Steps:** {params['num_inference_steps']}")
                cols[i].write(f"**Seed:** {params['seed']}")
            else:
                st.error(f"Failed to generate holiday card {i + 1}. Please try again.")

        # Create the ZIP file with all images and metadata
        if image_files:
            zip_buffer = BytesIO()
            with zipfile.ZipFile(zip_buffer, "w") as zf:
                for file_name, img_data in image_files:
                    zf.writestr(file_name, img_data.getvalue())
                
                metadata_str = "\n\n".join([f"{m['Card']}:\nPrompt: {m['Prompt']}\nGuidance Scale: {m['Guidance Scale']}\nInference Steps: {m['Inference Steps']}\nSeed: {m['Seed']}" for m in metadata])
                zf.writestr("metadata.txt", metadata_str)

            zip_buffer.seek(0)

            # Single download button for all images and metadata
            st.download_button(
                label="Download all images and metadata as ZIP",
                data=zip_buffer,
                file_name="holiday_cards.zip",
                mime="application/zip"
            )


# Footer Section
st.divider()
st.markdown(
    """
    Thank you for using the Holiday Card Generator powered by **Fireworks**! πŸŽ‰  
    Share your creations with the world and spread the holiday cheer!  
    Happy Holidays from the **Fireworks Team**. πŸ’₯  
    """
)