Spaces:
Runtime error
Runtime error
File size: 4,733 Bytes
bc35cc2 4b86c44 bc35cc2 4b86c44 bc35cc2 4b86c44 bc35cc2 4b86c44 bc35cc2 4b86c44 bc35cc2 4b86c44 bc35cc2 4b86c44 bc35cc2 4b86c44 bc35cc2 4b86c44 bc35cc2 4b86c44 bc35cc2 4b86c44 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
import numpy as np
import streamlit as st
from PIL import Image
from streamlit_drawable_canvas import st_canvas
from io import BytesIO
from copy import deepcopy
from src.core import process_inpaint
def image_download_button(pil_image, filename: str, fmt: str, label="Download"):
if fmt not in ["jpg", "png"]:
raise Exception(f"Unknown image format (Available: {fmt} - case sensitive)")
pil_format = "JPEG" if fmt == "jpg" else "PNG"
file_format = "jpg" if fmt == "jpg" else "png"
mime = "image/jpeg" if fmt == "jpg" else "image/png"
buf = BytesIO()
pil_image.save(buf, format=pil_format)
return st.download_button(
label=label,
data=buf.getvalue(),
file_name=f'{filename}.{file_format}',
mime=mime,
)
if "button_id" not in st.session_state:
st.session_state["button_id"] = ""
if "color_to_label" not in st.session_state:
st.session_state["color_to_label"] = {}
if 'reuse_image' not in st.session_state:
st.session_state.reuse_image = None
def set_image(img):
st.session_state.reuse_image = img
uploaded_file = st.file_uploader("Choose image", accept_multiple_files=False, type=["png", "jpg", "jpeg"])
if uploaded_file is not None:
if st.session_state.reuse_image is not None:
img_input = Image.fromarray(st.session_state.reuse_image)
else:
bytes_data = uploaded_file.getvalue()
img_input = Image.open(BytesIO(bytes_data)).convert("RGBA")
# Resize the image while maintaining aspect ratio
max_size = 2000
img_width, img_height = img_input.size
if img_width > max_size or img_height > max_size:
if img_width > img_height:
new_width = max_size
new_height = int((max_size / img_width) * img_height)
else:
new_height = max_size
new_width = int((max_size / img_height) * img_width)
img_input = img_input.resize((new_width, new_height))
stroke_width = st.slider("Brush size", 1, 100, 50)
st.write("**Now draw (brush) the part of image that you want to remove.**")
# Canvas size logic
canvas_bg = deepcopy(img_input)
aspect_ratio = canvas_bg.width / canvas_bg.height
streamlit_width = 720
# Max width is 720. Resize the height to maintain its aspectratio.
if canvas_bg.width > streamlit_width:
canvas_bg = canvas_bg.resize((streamlit_width, int(streamlit_width / aspect_ratio)))
canvas_result = st_canvas(
stroke_color="rgba(255, 0, 255, 1)",
stroke_width=stroke_width,
background_image=canvas_bg,
width=canvas_bg.width,
height=canvas_bg.height,
drawing_mode="freedraw",
key="compute_arc_length",
)
if canvas_result.image_data is not None:
im = np.array(Image.fromarray(canvas_result.image_data.astype(np.uint8)).resize(img_input.size))
background = np.where(
(im[:, :, 0] == 0) &
(im[:, :, 1] == 0) &
(im[:, :, 2] == 0)
)
drawing = np.where(
(im[:, :, 0] == 255) &
(im[:, :, 1] == 0) &
(im[:, :, 2] == 255)
)
im[background]=[0,0,0,255]
im[drawing]=[0,0,0,0] # RGBA
if st.button('Submit'):
with st.spinner("AI is doing the magic!"):
output = process_inpaint(np.array(img_input), np.array(im))
img_output = Image.fromarray(output).convert("RGB")
st.write("AI has finished the job!")
st.image(img_output)
uploaded_name = os.path.splitext(uploaded_file.name)[0]
image_download_button(
pil_image=img_output,
filename=uploaded_name,
fmt="jpg",
label="Download Image"
)
st.info("**TIP**: If the result is not perfect, you can download it then "
"upload then remove the artifacts.")
# CSS styles for dark mode and responsiveness
st.markdown("""
<style>
body {
font-family: 'Arial', sans-serif;
background-color: #1a1a2e;
color: #ffffff;
}
.stApp {
max-width: 800px;
margin: 0 auto;
padding: 20px;
background-color: #2a2a3e;
border-radius: 10px;
}
canvas {
display: block;
margin: 0 auto;
max-width: 100%;
height: auto;
}
.stButton button {
background-color: #9370db;
border: none;
padding: 10px;
border-radius: 5px;
color: #ffffff;
font-weight: bold;
cursor: pointer;
}
.stButton button:hover {
background-color: #7b5fbf;
}
.stSlider {
color: #ffffff;
}
</style>
""", unsafe_allow_html=True)
|