|
from __future__ import annotations |
|
|
|
import random |
|
import torch |
|
from PIL import Image, ImageOps |
|
from diffusers import StableDiffusionInstructPix2PixPipeline |
|
import streamlit as st |
|
|
|
|
|
pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained( |
|
"timbrooks/instruct-pix2pix", |
|
torch_dtype=torch.float16, safety_checker=None |
|
).to("cpu") |
|
|
|
|
|
def main(): |
|
st.title("InstructPix2Pix Image Editing") |
|
|
|
uploaded_image = st.file_uploader("Upload an Image", type=["png", "jpg", "jpeg"]) |
|
|
|
if not uploaded_image: |
|
st.warning("Please upload an image to proceed.") |
|
return |
|
|
|
input_image = Image.open(uploaded_image).convert("RGB") |
|
st.image(input_image, caption="Uploaded Image", width=512) |
|
|
|
instruction = st.text_input("Enter instruction", "Make it a painting") |
|
|
|
steps = st.slider("Steps", 20, 100, 50) |
|
randomize_seed = st.checkbox("Randomize Seed", True) |
|
seed = st.number_input("Seed", 0, value=random.randint(0, 10000), disabled=randomize_seed) |
|
|
|
text_cfg_scale = st.slider("Text CFG", 1.0, 10.0, 7.5) |
|
image_cfg_scale = st.slider("Image CFG", 0.5, 2.0, 1.5) |
|
|
|
if st.button("Generate"): |
|
result_image = generate(input_image, instruction, steps, randomize_seed, seed, text_cfg_scale, image_cfg_scale) |
|
st.image(result_image, caption="Edited Image", width=512) |
|
st.download_button("Download", data=result_image.tobytes(), file_name="edited_image.png", mime="image/png") |
|
|
|
|
|
def generate(input_image, instruction, steps, randomize_seed, seed, text_cfg_scale, image_cfg_scale): |
|
if randomize_seed: |
|
seed = random.randint(0, 100000) |
|
|
|
input_image = ImageOps.fit(input_image, (512, 512), method=Image.Resampling.LANCZOS) |
|
|
|
generator = torch.manual_seed(seed) |
|
edited_image = pipe( |
|
instruction, image=input_image, guidance_scale=text_cfg_scale, |
|
image_guidance_scale=image_cfg_scale, num_inference_steps=steps, |
|
generator=generator |
|
).images[0] |
|
|
|
return edited_image |
|
|
|
if __name__ == "__main__": |
|
main() |