File size: 2,072 Bytes
1a0f0c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from __future__ import annotations

import random
import torch
from PIL import Image, ImageOps
from diffusers import StableDiffusionInstructPix2PixPipeline
import streamlit as st

# Load model
pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(
    "timbrooks/instruct-pix2pix", 
    torch_dtype=torch.float16, safety_checker=None
).to("cpu")

# Main app
def main():
    st.title("InstructPix2Pix Image Editing")

    uploaded_image = st.file_uploader("Upload an Image", type=["png", "jpg", "jpeg"])
    
    if not uploaded_image:
        st.warning("Please upload an image to proceed.")
        return

    input_image = Image.open(uploaded_image).convert("RGB")
    st.image(input_image, caption="Uploaded Image", width=512)

    instruction = st.text_input("Enter instruction", "Make it a painting")
    
    steps = st.slider("Steps", 20, 100, 50)
    randomize_seed = st.checkbox("Randomize Seed", True)
    seed = st.number_input("Seed", 0, value=random.randint(0, 10000), disabled=randomize_seed)
    
    text_cfg_scale = st.slider("Text CFG", 1.0, 10.0, 7.5)
    image_cfg_scale = st.slider("Image CFG", 0.5, 2.0, 1.5)

    if st.button("Generate"):
        result_image = generate(input_image, instruction, steps, randomize_seed, seed, text_cfg_scale, image_cfg_scale)
        st.image(result_image, caption="Edited Image", width=512)
        st.download_button("Download", data=result_image.tobytes(), file_name="edited_image.png", mime="image/png")

# Generate image
def generate(input_image, instruction, steps, randomize_seed, seed, text_cfg_scale, image_cfg_scale):
    if randomize_seed:
        seed = random.randint(0, 100000)
    
    input_image = ImageOps.fit(input_image, (512, 512), method=Image.Resampling.LANCZOS)
    
    generator = torch.manual_seed(seed)
    edited_image = pipe(
        instruction, image=input_image, guidance_scale=text_cfg_scale,
        image_guidance_scale=image_cfg_scale, num_inference_steps=steps,
        generator=generator
    ).images[0]

    return edited_image

if __name__ == "__main__":
    main()