Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import random
|
4 |
+
import torch
|
5 |
+
from PIL import Image, ImageOps
|
6 |
+
from diffusers import StableDiffusionInstructPix2PixPipeline
|
7 |
+
import streamlit as st
|
8 |
+
|
9 |
+
# Load model
|
10 |
+
pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(
|
11 |
+
"timbrooks/instruct-pix2pix",
|
12 |
+
torch_dtype=torch.float16, safety_checker=None
|
13 |
+
).to("cpu")
|
14 |
+
|
15 |
+
# Main app
|
16 |
+
def main():
|
17 |
+
st.title("InstructPix2Pix Image Editing")
|
18 |
+
|
19 |
+
uploaded_image = st.file_uploader("Upload an Image", type=["png", "jpg", "jpeg"])
|
20 |
+
|
21 |
+
if not uploaded_image:
|
22 |
+
st.warning("Please upload an image to proceed.")
|
23 |
+
return
|
24 |
+
|
25 |
+
input_image = Image.open(uploaded_image).convert("RGB")
|
26 |
+
st.image(input_image, caption="Uploaded Image", width=512)
|
27 |
+
|
28 |
+
instruction = st.text_input("Enter instruction", "Make it a painting")
|
29 |
+
|
30 |
+
steps = st.slider("Steps", 20, 100, 50)
|
31 |
+
randomize_seed = st.checkbox("Randomize Seed", True)
|
32 |
+
seed = st.number_input("Seed", 0, value=random.randint(0, 10000), disabled=randomize_seed)
|
33 |
+
|
34 |
+
text_cfg_scale = st.slider("Text CFG", 1.0, 10.0, 7.5)
|
35 |
+
image_cfg_scale = st.slider("Image CFG", 0.5, 2.0, 1.5)
|
36 |
+
|
37 |
+
if st.button("Generate"):
|
38 |
+
result_image = generate(input_image, instruction, steps, randomize_seed, seed, text_cfg_scale, image_cfg_scale)
|
39 |
+
st.image(result_image, caption="Edited Image", width=512)
|
40 |
+
st.download_button("Download", data=result_image.tobytes(), file_name="edited_image.png", mime="image/png")
|
41 |
+
|
42 |
+
# Generate image
|
43 |
+
def generate(input_image, instruction, steps, randomize_seed, seed, text_cfg_scale, image_cfg_scale):
|
44 |
+
if randomize_seed:
|
45 |
+
seed = random.randint(0, 100000)
|
46 |
+
|
47 |
+
input_image = ImageOps.fit(input_image, (512, 512), method=Image.Resampling.LANCZOS)
|
48 |
+
|
49 |
+
generator = torch.manual_seed(seed)
|
50 |
+
edited_image = pipe(
|
51 |
+
instruction, image=input_image, guidance_scale=text_cfg_scale,
|
52 |
+
image_guidance_scale=image_cfg_scale, num_inference_steps=steps,
|
53 |
+
generator=generator
|
54 |
+
).images[0]
|
55 |
+
|
56 |
+
return edited_image
|
57 |
+
|
58 |
+
if __name__ == "__main__":
|
59 |
+
main()
|