walidadebayo commited on
Commit
be8ba6b
·
1 Parent(s): df5d762

Add Docker support, entrypoint script, and core functionality for image processing

Browse files

- Created a Dockerfile and docker-compose.yml for containerization.
- Implemented entrypoint.py to run both Flask API and Streamlit app.
- Added core image processing functions in core.py for seam carving and object removal.
- Developed helper functions in helper.py for model downloading and image manipulation.
- Integrated Streamlit for a user interface with custom styles in st_style.py.
- Updated requirements.txt to include necessary libraries for the project.

Files changed (15) hide show
  1. .gitattributes +2 -0
  2. .gitignore +128 -0
  3. .streamlit/config.toml +6 -0
  4. Dockerfile +15 -0
  5. api.py +76 -0
  6. app.py +138 -0
  7. assets/big-lama.pt +3 -0
  8. assets/demo.gif +3 -0
  9. assets/demo.png +3 -0
  10. docker-compose.yml +14 -0
  11. entrypoint.py +37 -0
  12. requirements.txt +10 -0
  13. src/core.py +466 -0
  14. src/helper.py +87 -0
  15. src/st_style.py +42 -0
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/demo.gif filter=lfs diff=lfs merge=lfs -text
37
+ assets/demo.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ pip-wheel-metadata/
24
+ share/python-wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+
30
+ # PyInstaller
31
+ # Usually these files are written by a python script from a template
32
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
33
+ *.manifest
34
+ *.spec
35
+
36
+ # Installer logs
37
+ pip-log.txt
38
+ pip-delete-this-directory.txt
39
+
40
+ # Unit test / coverage reports
41
+ htmlcov/
42
+ .tox/
43
+ .nox/
44
+ .coverage
45
+ .coverage.*
46
+ .cache
47
+ nosetests.xml
48
+ coverage.xml
49
+ *.cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+
53
+ # Translations
54
+ *.mo
55
+ *.pot
56
+
57
+ # Django stuff:
58
+ *.log
59
+ local_settings.py
60
+ db.sqlite3
61
+ db.sqlite3-journal
62
+
63
+ # Flask stuff:
64
+ instance/
65
+ .webassets-cache
66
+
67
+ # Scrapy stuff:
68
+ .scrapy
69
+
70
+ # Sphinx documentation
71
+ docs/_build/
72
+
73
+ # PyBuilder
74
+ target/
75
+
76
+ # Jupyter Notebook
77
+ .ipynb_checkpoints
78
+
79
+ # IPython
80
+ profile_default/
81
+ ipython_config.py
82
+
83
+ # pyenv
84
+ .python-version
85
+
86
+ # pipenv
87
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
88
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
89
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
90
+ # install all needed dependencies.
91
+ #Pipfile.lock
92
+
93
+ # celery beat schedule file
94
+ celerybeat-schedule
95
+
96
+ # SageMath parsed files
97
+ *.sage.py
98
+
99
+ # Environments
100
+ .env
101
+ .venv
102
+ env/
103
+ venv/
104
+ ENV/
105
+ env.bak/
106
+ venv.bak/
107
+
108
+ # Spyder project settings
109
+ .spyderproject
110
+ .spyproject
111
+
112
+ # Rope project settings
113
+ .ropeproject
114
+
115
+ # mkdocs documentation
116
+ /site
117
+
118
+ # mypy
119
+ .mypy_cache/
120
+ .dmypy.json
121
+ dmypy.json
122
+
123
+ # Pyre type checker
124
+ .pyre/
125
+
126
+ frontend
127
+ magic-eraser-api
128
+ magic-eraser
.streamlit/config.toml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ [server]
2
+ maxUploadSize = 10
3
+
4
+ [theme]
5
+ base="light"
6
+ primaryColor="#0074ff"
Dockerfile ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM pytorch/pytorch:latest
2
+
3
+ WORKDIR /app
4
+
5
+ COPY . .
6
+
7
+ RUN pip install -r requirements.txt
8
+
9
+ # Default to running both Flask API and Streamlit app
10
+ ENV RUN_MODE=all
11
+ ENV PORT=7860
12
+ ENV STREAMLIT_PORT=8501
13
+
14
+ # Use our entrypoint script
15
+ CMD [ "python", "entrypoint.py" ]
api.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, request, jsonify
2
+ from flask_cors import CORS
3
+ import numpy as np
4
+ import cv2
5
+ import base64
6
+ from src.core import process_inpaint, s_image
7
+ import os
8
+
9
+ app = Flask(__name__)
10
+ CORS(app)
11
+
12
+ # endpoint for health checks
13
+ @app.route('/', methods=['GET'])
14
+ def health_check():
15
+ return jsonify({"status": "API is running"})
16
+
17
+ @app.route('/api/inpaint', methods=['POST'])
18
+ def inpaint():
19
+ # Get data from request
20
+ data = request.json
21
+ image_data = data.get('image')
22
+ mask_data = data.get('mask')
23
+
24
+ # Convert base64 to numpy arrays
25
+ image = base64_to_image(image_data)
26
+ mask = base64_to_image(mask_data)
27
+
28
+ # Process the image
29
+ result = process_inpaint(image, mask)
30
+
31
+ # Convert back to base64
32
+ result_base64 = image_to_base64(result)
33
+
34
+ return jsonify({'result': result_base64})
35
+
36
+ @app.route('/api/seam-carve', methods=['POST'])
37
+ def seam_carve():
38
+ # Get data from request
39
+ data = request.json
40
+ image_data = data.get('image')
41
+ mask_data = data.get('mask')
42
+ vs = int(data.get('vs', 0)) # vertical seams
43
+ hs = int(data.get('hs', 0)) # horizontal seams
44
+ mode = data.get('mode', 'resize') # resize or remove
45
+
46
+ # Convert base64 to numpy arrays
47
+ image = base64_to_image(image_data)
48
+ mask = base64_to_image(mask_data)
49
+
50
+ # Process the image
51
+ result = s_image(image, mask, vs, hs, mode)
52
+
53
+ # Convert back to base64
54
+ result_base64 = image_to_base64(result)
55
+
56
+ return jsonify({'result': result_base64})
57
+
58
+ def base64_to_image(base64_str):
59
+ img_bytes = base64.b64decode(base64_str.split(',')[1])
60
+ img_array = np.frombuffer(img_bytes, np.uint8)
61
+ img = cv2.imdecode(img_array, cv2.IMREAD_UNCHANGED)
62
+ return img
63
+
64
+ def image_to_base64(image):
65
+ # Convert to BGR if it's RGB
66
+ if len(image.shape) > 2 and image.shape[2] == 3:
67
+ image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
68
+
69
+ _, buffer = cv2.imencode('.png', image)
70
+ img_bytes = base64.b64encode(buffer).decode('utf-8')
71
+ return f"data:image/png;base64,{img_bytes}"
72
+
73
+ if __name__ == '__main__':
74
+ # Use the PORT environment variable provided by Hugging Face or default to 7860
75
+ port = int(os.environ.get("PORT", 7860))
76
+ app.run(debug=True, host='0.0.0.0', port=port)
app.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+ import streamlit as st
4
+ import os
5
+ from datetime import datetime
6
+ from PIL import Image
7
+ from streamlit_drawable_canvas import st_canvas
8
+ from io import BytesIO
9
+ from copy import deepcopy
10
+
11
+ from src.core import process_inpaint
12
+
13
+ # API availability
14
+ st.sidebar.info("""
15
+ ## API Endpoints
16
+ REST API endpoints are available at:
17
+ - `GET /` - Health check
18
+ - `POST /api/inpaint` - For inpainting/object removal
19
+ - `POST /api/seam-carve` - For seam carving
20
+ """)
21
+
22
+ def image_download_button(pil_image, filename: str, fmt: str, label="Download"):
23
+ if fmt not in ["jpg", "png"]:
24
+ raise Exception(f"Unknown image format (Available: {fmt} - case sensitive)")
25
+
26
+ pil_format = "JPEG" if fmt == "jpg" else "PNG"
27
+ file_format = "jpg" if fmt == "jpg" else "png"
28
+ mime = "image/jpeg" if fmt == "jpg" else "image/png"
29
+
30
+ buf = BytesIO()
31
+ pil_image.save(buf, format=pil_format)
32
+
33
+ return st.download_button(
34
+ label=label,
35
+ data=buf.getvalue(),
36
+ file_name=f'{filename}.{file_format}',
37
+ mime=mime,
38
+ )
39
+
40
+ if "button_id" not in st.session_state:
41
+ st.session_state["button_id"] = ""
42
+ if "color_to_label" not in st.session_state:
43
+ st.session_state["color_to_label"] = {}
44
+
45
+ if 'reuse_image' not in st.session_state:
46
+ st.session_state.reuse_image = None
47
+ def set_image(img):
48
+ st.session_state.reuse_image = img
49
+
50
+ st.title("Magic-Eraser")
51
+
52
+ st.image(open("assets/demo.png", "rb").read())
53
+
54
+ st.markdown(
55
+ """
56
+ You don't have to worry about mastering photo editing techniques to remove an object from your photo. **Simply mark over the areas you want to erase, and our AI will take care of the rest.**
57
+ """
58
+ )
59
+ uploaded_file = st.file_uploader("Choose image", accept_multiple_files=False, type=["png", "jpg", "jpeg"])
60
+
61
+ if uploaded_file is not None:
62
+
63
+ if st.session_state.reuse_image is not None:
64
+ img_input = Image.fromarray(st.session_state.reuse_image)
65
+ else:
66
+ bytes_data = uploaded_file.getvalue()
67
+ img_input = Image.open(BytesIO(bytes_data)).convert("RGBA")
68
+
69
+ #resize
70
+ max_size = 2000
71
+ img_width, img_height = img_input.size
72
+ if img_width > max_size or img_height > max_size:
73
+ if img_width > img_height:
74
+ new_width = max_size
75
+ new_height = int((max_size / img_width) * img_height)
76
+ else:
77
+ new_height = max_size
78
+ new_width = int((max_size / img_height) * img_width)
79
+ img_input = img_input.resize((new_width, new_height))
80
+
81
+ stroke_width = st.slider("Brush size", 1, 100, 50)
82
+
83
+ st.write("**Now draw (brush) the part of image that you want to remove.**")
84
+
85
+ canvas_bg = deepcopy(img_input)
86
+ aspect_ratio = canvas_bg.width / canvas_bg.height
87
+ streamlit_width = 720
88
+
89
+ if canvas_bg.width > streamlit_width:
90
+ canvas_bg = canvas_bg.resize((streamlit_width, int(streamlit_width / aspect_ratio)))
91
+
92
+ canvas_result = st_canvas(
93
+ stroke_color="rgba(255, 0, 255, 1)",
94
+ stroke_width=stroke_width,
95
+ background_image=canvas_bg,
96
+ width=canvas_bg.width,
97
+ height=canvas_bg.height,
98
+ drawing_mode="freedraw",
99
+ key="compute_arc_length",
100
+ )
101
+
102
+ if canvas_result.image_data is not None:
103
+ im = np.array(Image.fromarray(canvas_result.image_data.astype(np.uint8)).resize(img_input.size))
104
+ background = np.where(
105
+ (im[:, :, 0] == 0) &
106
+ (im[:, :, 1] == 0) &
107
+ (im[:, :, 2] == 0)
108
+ )
109
+ drawing = np.where(
110
+ (im[:, :, 0] == 255) &
111
+ (im[:, :, 1] == 0) &
112
+ (im[:, :, 2] == 255)
113
+ )
114
+ im[background]=[0,0,0,255]
115
+ im[drawing]=[0,0,0,0] # RGBA
116
+
117
+ reuse = False
118
+
119
+ if st.button('Submit'):
120
+
121
+ with st.spinner("AI is doing the magic!"):
122
+ output = process_inpaint(np.array(img_input), np.array(im)) #TODO Put button here
123
+ img_output = Image.fromarray(output).convert("RGB")
124
+
125
+ st.write("AI has finished the job!")
126
+ st.image(img_output)
127
+ # reuse = st.button('Edit again (Re-use this image)', on_click=set_image, args=(inpainted_img, ))
128
+
129
+ uploaded_name = os.path.splitext(uploaded_file.name)[0]
130
+ image_download_button(
131
+ pil_image=img_output,
132
+ filename=uploaded_name,
133
+ fmt="jpg",
134
+ label="Download Image"
135
+ )
136
+
137
+ st.info("**TIP**: If the result is not perfect, you can download it then "
138
+ "upload then remove the artifacts.")
assets/big-lama.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:344c77bbcb158f17dd143070d1e789f38a66c04202311ae3a258ef66667a9ea9
3
+ size 205669692
assets/demo.gif ADDED

Git LFS Details

  • SHA256: 008972f8d1d49ee16fe2cce2669ea5397788bc7885db40ce3c55aca0df6b411d
  • Pointer size: 132 Bytes
  • Size of remote file: 2.88 MB
assets/demo.png ADDED

Git LFS Details

  • SHA256: 1e895e324b9fbf8aed7e512222d89d3ea48fa96263d3a01d0a6af2cf0749a4d7
  • Pointer size: 131 Bytes
  • Size of remote file: 159 kB
docker-compose.yml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ version: '3'
3
+ services:
4
+ sthf-remove-photo-object:
5
+ build: .
6
+ container_name: sthf-remove-photo-object
7
+ restart: unless-stopped
8
+ ports:
9
+ - 41003:8501
10
+ volumes:
11
+ - .:/app
12
+ environment:
13
+ - TZ=Asia/Jakarta
14
+ # command: streamlit run sdc.py
entrypoint.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+ import threading
4
+ import sys
5
+
6
+ def run_flask():
7
+ print("Starting Flask API server...")
8
+ from api import app
9
+ # Use the PORT environment variable provided by Hugging Face or default to 7860
10
+ port = int(os.environ.get("PORT", 7860))
11
+ app.run(host='0.0.0.0', port=port)
12
+
13
+ def run_streamlit():
14
+ print("Starting Streamlit app...")
15
+ # Start Streamlit on a different port
16
+ streamlit_port = int(os.environ.get("STREAMLIT_PORT", 8501))
17
+ subprocess.run([
18
+ "streamlit", "run", "app.py",
19
+ "--server.port", str(streamlit_port),
20
+ "--server.address", "0.0.0.0"
21
+ ])
22
+
23
+ if __name__ == "__main__":
24
+ mode = os.environ.get("RUN_MODE", "all").lower()
25
+
26
+ if mode == "flask":
27
+ run_flask()
28
+ elif mode == "streamlit":
29
+ run_streamlit()
30
+ else:
31
+ # Run both in different threads
32
+ flask_thread = threading.Thread(target=run_flask)
33
+ flask_thread.daemon = True
34
+ flask_thread.start()
35
+
36
+ # Run Streamlit in main thread
37
+ run_streamlit()
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ streamlit
2
+ streamlit-drawable-canvas
3
+ flask
4
+ flask-cors
5
+ numpy
6
+ opencv-python
7
+ pillow
8
+ torch
9
+ pandas
10
+ scipy
src/core.py ADDED
@@ -0,0 +1,466 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import json
3
+ import os
4
+ import re
5
+ import time
6
+ import uuid
7
+ from io import BytesIO
8
+ from pathlib import Path
9
+ import cv2
10
+
11
+ # For inpainting
12
+
13
+ import numpy as np
14
+ import pandas as pd
15
+ import streamlit as st
16
+ from PIL import Image
17
+ from streamlit_drawable_canvas import st_canvas
18
+
19
+
20
+ import argparse
21
+ import io
22
+ import multiprocessing
23
+ from typing import Union
24
+
25
+ import torch
26
+
27
+ try:
28
+ torch._C._jit_override_can_fuse_on_cpu(False)
29
+ torch._C._jit_override_can_fuse_on_gpu(False)
30
+ torch._C._jit_set_texpr_fuser_enabled(False)
31
+ torch._C._jit_set_nvfuser_enabled(False)
32
+ except:
33
+ pass
34
+
35
+ from src.helper import (
36
+ download_model,
37
+ load_img,
38
+ norm_img,
39
+ numpy_to_bytes,
40
+ pad_img_to_modulo,
41
+ resize_max_size,
42
+ )
43
+
44
+ NUM_THREADS = str(multiprocessing.cpu_count())
45
+
46
+ os.environ["OMP_NUM_THREADS"] = NUM_THREADS
47
+ os.environ["OPENBLAS_NUM_THREADS"] = NUM_THREADS
48
+ os.environ["MKL_NUM_THREADS"] = NUM_THREADS
49
+ os.environ["VECLIB_MAXIMUM_THREADS"] = NUM_THREADS
50
+ os.environ["NUMEXPR_NUM_THREADS"] = NUM_THREADS
51
+ if os.environ.get("CACHE_DIR"):
52
+ os.environ["TORCH_HOME"] = os.environ["CACHE_DIR"]
53
+
54
+ #BUILD_DIR = os.environ.get("LAMA_CLEANER_BUILD_DIR", "./lama_cleaner/app/build")
55
+
56
+ # For Seam-carving
57
+
58
+ from scipy import ndimage as ndi
59
+
60
+ SEAM_COLOR = np.array([255, 200, 200]) # seam visualization color (BGR)
61
+ SHOULD_DOWNSIZE = True # if True, downsize image for faster carving
62
+ DOWNSIZE_WIDTH = 500 # resized image width if SHOULD_DOWNSIZE is True
63
+ ENERGY_MASK_CONST = 100000.0 # large energy value for protective masking
64
+ MASK_THRESHOLD = 10 # minimum pixel intensity for binary mask
65
+ USE_FORWARD_ENERGY = True # if True, use forward energy algorithm
66
+
67
+ device = torch.device("cpu")
68
+ model_path = "./assets/big-lama.pt"
69
+ model = torch.jit.load(model_path, map_location="cpu")
70
+ model = model.to(device)
71
+ model.eval()
72
+
73
+
74
+ ########################################
75
+ # UTILITY CODE
76
+ ########################################
77
+
78
+
79
+ def visualize(im, boolmask=None, rotate=False):
80
+ vis = im.astype(np.uint8)
81
+ if boolmask is not None:
82
+ vis[np.where(boolmask == False)] = SEAM_COLOR
83
+ if rotate:
84
+ vis = rotate_image(vis, False)
85
+ cv2.imshow("visualization", vis)
86
+ cv2.waitKey(1)
87
+ return vis
88
+
89
+ def resize(image, width):
90
+ dim = None
91
+ h, w = image.shape[:2]
92
+ dim = (width, int(h * width / float(w)))
93
+ image = image.astype('float32')
94
+ return cv2.resize(image, dim)
95
+
96
+ def rotate_image(image, clockwise):
97
+ k = 1 if clockwise else 3
98
+ return np.rot90(image, k)
99
+
100
+
101
+ ########################################
102
+ # ENERGY FUNCTIONS
103
+ ########################################
104
+
105
+ def backward_energy(im):
106
+ """
107
+ Simple gradient magnitude energy map.
108
+ """
109
+ xgrad = ndi.convolve1d(im, np.array([1, 0, -1]), axis=1, mode='wrap')
110
+ ygrad = ndi.convolve1d(im, np.array([1, 0, -1]), axis=0, mode='wrap')
111
+
112
+ grad_mag = np.sqrt(np.sum(xgrad**2, axis=2) + np.sum(ygrad**2, axis=2))
113
+
114
+ # vis = visualize(grad_mag)
115
+ # cv2.imwrite("backward_energy_demo.jpg", vis)
116
+
117
+ return grad_mag
118
+
119
+ def forward_energy(im):
120
+ """
121
+ Forward energy algorithm as described in "Improved Seam Carving for Video Retargeting"
122
+ by Rubinstein, Shamir, Avidan.
123
+ Vectorized code adapted from
124
+ https://github.com/axu2/improved-seam-carving.
125
+ """
126
+ h, w = im.shape[:2]
127
+ im = cv2.cvtColor(im.astype(np.uint8), cv2.COLOR_BGR2GRAY).astype(np.float64)
128
+
129
+ energy = np.zeros((h, w))
130
+ m = np.zeros((h, w))
131
+
132
+ U = np.roll(im, 1, axis=0)
133
+ L = np.roll(im, 1, axis=1)
134
+ R = np.roll(im, -1, axis=1)
135
+
136
+ cU = np.abs(R - L)
137
+ cL = np.abs(U - L) + cU
138
+ cR = np.abs(U - R) + cU
139
+
140
+ for i in range(1, h):
141
+ mU = m[i-1]
142
+ mL = np.roll(mU, 1)
143
+ mR = np.roll(mU, -1)
144
+
145
+ mULR = np.array([mU, mL, mR])
146
+ cULR = np.array([cU[i], cL[i], cR[i]])
147
+ mULR += cULR
148
+
149
+ argmins = np.argmin(mULR, axis=0)
150
+ m[i] = np.choose(argmins, mULR)
151
+ energy[i] = np.choose(argmins, cULR)
152
+
153
+ # vis = visualize(energy)
154
+ # cv2.imwrite("forward_energy_demo.jpg", vis)
155
+
156
+ return energy
157
+
158
+ ########################################
159
+ # SEAM HELPER FUNCTIONS
160
+ ########################################
161
+
162
+ def add_seam(im, seam_idx):
163
+ """
164
+ Add a vertical seam to a 3-channel color image at the indices provided
165
+ by averaging the pixels values to the left and right of the seam.
166
+ Code adapted from https://github.com/vivianhylee/seam-carving.
167
+ """
168
+ h, w = im.shape[:2]
169
+ output = np.zeros((h, w + 1, 3))
170
+ for row in range(h):
171
+ col = seam_idx[row]
172
+ for ch in range(3):
173
+ if col == 0:
174
+ p = np.mean(im[row, col: col + 2, ch])
175
+ output[row, col, ch] = im[row, col, ch]
176
+ output[row, col + 1, ch] = p
177
+ output[row, col + 1:, ch] = im[row, col:, ch]
178
+ else:
179
+ p = np.mean(im[row, col - 1: col + 1, ch])
180
+ output[row, : col, ch] = im[row, : col, ch]
181
+ output[row, col, ch] = p
182
+ output[row, col + 1:, ch] = im[row, col:, ch]
183
+
184
+ return output
185
+
186
+ def add_seam_grayscale(im, seam_idx):
187
+ """
188
+ Add a vertical seam to a grayscale image at the indices provided
189
+ by averaging the pixels values to the left and right of the seam.
190
+ """
191
+ h, w = im.shape[:2]
192
+ output = np.zeros((h, w + 1))
193
+ for row in range(h):
194
+ col = seam_idx[row]
195
+ if col == 0:
196
+ p = np.mean(im[row, col: col + 2])
197
+ output[row, col] = im[row, col]
198
+ output[row, col + 1] = p
199
+ output[row, col + 1:] = im[row, col:]
200
+ else:
201
+ p = np.mean(im[row, col - 1: col + 1])
202
+ output[row, : col] = im[row, : col]
203
+ output[row, col] = p
204
+ output[row, col + 1:] = im[row, col:]
205
+
206
+ return output
207
+
208
+ def remove_seam(im, boolmask):
209
+ h, w = im.shape[:2]
210
+ boolmask3c = np.stack([boolmask] * 3, axis=2)
211
+ return im[boolmask3c].reshape((h, w - 1, 3))
212
+
213
+ def remove_seam_grayscale(im, boolmask):
214
+ h, w = im.shape[:2]
215
+ return im[boolmask].reshape((h, w - 1))
216
+
217
+ def get_minimum_seam(im, mask=None, remove_mask=None):
218
+ """
219
+ DP algorithm for finding the seam of minimum energy. Code adapted from
220
+ https://karthikkaranth.me/blog/implementing-seam-carving-with-python/
221
+ """
222
+ h, w = im.shape[:2]
223
+ energyfn = forward_energy if USE_FORWARD_ENERGY else backward_energy
224
+ M = energyfn(im)
225
+
226
+ if mask is not None:
227
+ M[np.where(mask > MASK_THRESHOLD)] = ENERGY_MASK_CONST
228
+
229
+ # give removal mask priority over protective mask by using larger negative value
230
+ if remove_mask is not None:
231
+ M[np.where(remove_mask > MASK_THRESHOLD)] = -ENERGY_MASK_CONST * 100
232
+
233
+ seam_idx, boolmask = compute_shortest_path(M, im, h, w)
234
+
235
+ return np.array(seam_idx), boolmask
236
+
237
+ def compute_shortest_path(M, im, h, w):
238
+ backtrack = np.zeros_like(M, dtype=np.int_)
239
+
240
+
241
+ # populate DP matrix
242
+ for i in range(1, h):
243
+ for j in range(0, w):
244
+ if j == 0:
245
+ idx = np.argmin(M[i - 1, j:j + 2])
246
+ backtrack[i, j] = idx + j
247
+ min_energy = M[i-1, idx + j]
248
+ else:
249
+ idx = np.argmin(M[i - 1, j - 1:j + 2])
250
+ backtrack[i, j] = idx + j - 1
251
+ min_energy = M[i - 1, idx + j - 1]
252
+
253
+ M[i, j] += min_energy
254
+
255
+ # backtrack to find path
256
+ seam_idx = []
257
+ boolmask = np.ones((h, w), dtype=np.bool_)
258
+ j = np.argmin(M[-1])
259
+ for i in range(h-1, -1, -1):
260
+ boolmask[i, j] = False
261
+ seam_idx.append(j)
262
+ j = backtrack[i, j]
263
+
264
+ seam_idx.reverse()
265
+ return seam_idx, boolmask
266
+
267
+ ########################################
268
+ # MAIN ALGORITHM
269
+ ########################################
270
+
271
+ def seams_removal(im, num_remove, mask=None, vis=False, rot=False):
272
+ for _ in range(num_remove):
273
+ seam_idx, boolmask = get_minimum_seam(im, mask)
274
+ if vis:
275
+ visualize(im, boolmask, rotate=rot)
276
+ im = remove_seam(im, boolmask)
277
+ if mask is not None:
278
+ mask = remove_seam_grayscale(mask, boolmask)
279
+ return im, mask
280
+
281
+
282
+ def seams_insertion(im, num_add, mask=None, vis=False, rot=False):
283
+ seams_record = []
284
+ temp_im = im.copy()
285
+ temp_mask = mask.copy() if mask is not None else None
286
+
287
+ for _ in range(num_add):
288
+ seam_idx, boolmask = get_minimum_seam(temp_im, temp_mask)
289
+ if vis:
290
+ visualize(temp_im, boolmask, rotate=rot)
291
+
292
+ seams_record.append(seam_idx)
293
+ temp_im = remove_seam(temp_im, boolmask)
294
+ if temp_mask is not None:
295
+ temp_mask = remove_seam_grayscale(temp_mask, boolmask)
296
+
297
+ seams_record.reverse()
298
+
299
+ for _ in range(num_add):
300
+ seam = seams_record.pop()
301
+ im = add_seam(im, seam)
302
+ if vis:
303
+ visualize(im, rotate=rot)
304
+ if mask is not None:
305
+ mask = add_seam_grayscale(mask, seam)
306
+
307
+ # update the remaining seam indices
308
+ for remaining_seam in seams_record:
309
+ remaining_seam[np.where(remaining_seam >= seam)] += 2
310
+
311
+ return im, mask
312
+
313
+ ########################################
314
+ # MAIN DRIVER FUNCTIONS
315
+ ########################################
316
+
317
+ def seam_carve(im, dy, dx, mask=None, vis=False):
318
+ im = im.astype(np.float64)
319
+ h, w = im.shape[:2]
320
+ assert h + dy > 0 and w + dx > 0 and dy <= h and dx <= w
321
+
322
+ if mask is not None:
323
+ mask = mask.astype(np.float64)
324
+
325
+ output = im
326
+
327
+ if dx < 0:
328
+ output, mask = seams_removal(output, -dx, mask, vis)
329
+
330
+ elif dx > 0:
331
+ output, mask = seams_insertion(output, dx, mask, vis)
332
+
333
+ if dy < 0:
334
+ output = rotate_image(output, True)
335
+ if mask is not None:
336
+ mask = rotate_image(mask, True)
337
+ output, mask = seams_removal(output, -dy, mask, vis, rot=True)
338
+ output = rotate_image(output, False)
339
+
340
+ elif dy > 0:
341
+ output = rotate_image(output, True)
342
+ if mask is not None:
343
+ mask = rotate_image(mask, True)
344
+ output, mask = seams_insertion(output, dy, mask, vis, rot=True)
345
+ output = rotate_image(output, False)
346
+
347
+ return output
348
+
349
+
350
+ def object_removal(im, rmask, mask=None, vis=False, horizontal_removal=False):
351
+ im = im.astype(np.float64)
352
+ rmask = rmask.astype(np.float64)
353
+ if mask is not None:
354
+ mask = mask.astype(np.float64)
355
+ output = im
356
+
357
+ h, w = im.shape[:2]
358
+
359
+ if horizontal_removal:
360
+ output = rotate_image(output, True)
361
+ rmask = rotate_image(rmask, True)
362
+ if mask is not None:
363
+ mask = rotate_image(mask, True)
364
+
365
+ while len(np.where(rmask > MASK_THRESHOLD)[0]) > 0:
366
+ seam_idx, boolmask = get_minimum_seam(output, mask, rmask)
367
+ if vis:
368
+ visualize(output, boolmask, rotate=horizontal_removal)
369
+ output = remove_seam(output, boolmask)
370
+ rmask = remove_seam_grayscale(rmask, boolmask)
371
+ if mask is not None:
372
+ mask = remove_seam_grayscale(mask, boolmask)
373
+
374
+ num_add = (h if horizontal_removal else w) - output.shape[1]
375
+ output, mask = seams_insertion(output, num_add, mask, vis, rot=horizontal_removal)
376
+ if horizontal_removal:
377
+ output = rotate_image(output, False)
378
+
379
+ return output
380
+
381
+
382
+
383
+ def s_image(im,mask,vs,hs,mode="resize"):
384
+ im = cv2.cvtColor(im, cv2.COLOR_RGBA2RGB)
385
+ mask = 255-mask[:,:,3]
386
+ h, w = im.shape[:2]
387
+ if SHOULD_DOWNSIZE and w > DOWNSIZE_WIDTH:
388
+ im = resize(im, width=DOWNSIZE_WIDTH)
389
+ if mask is not None:
390
+ mask = resize(mask, width=DOWNSIZE_WIDTH)
391
+
392
+ # image resize mode
393
+ if mode=="resize":
394
+ dy = hs#reverse
395
+ dx = vs#reverse
396
+ assert dy is not None and dx is not None
397
+ output = seam_carve(im, dy, dx, mask, False)
398
+
399
+
400
+ # object removal mode
401
+ elif mode=="remove":
402
+ assert mask is not None
403
+ output = object_removal(im, mask, None, False, True)
404
+
405
+ return output
406
+
407
+
408
+ ##### Inpainting helper code
409
+
410
+ def run(image, mask):
411
+ """
412
+ image: [C, H, W]
413
+ mask: [1, H, W]
414
+ return: BGR IMAGE
415
+ """
416
+ origin_height, origin_width = image.shape[1:]
417
+ image = pad_img_to_modulo(image, mod=8)
418
+ mask = pad_img_to_modulo(mask, mod=8)
419
+
420
+ mask = (mask > 0) * 1
421
+ image = torch.from_numpy(image).unsqueeze(0).to(device)
422
+ mask = torch.from_numpy(mask).unsqueeze(0).to(device)
423
+
424
+ start = time.time()
425
+ with torch.no_grad():
426
+ inpainted_image = model(image, mask)
427
+
428
+ print(f"process time: {(time.time() - start)*1000}ms")
429
+ cur_res = inpainted_image[0].permute(1, 2, 0).detach().cpu().numpy()
430
+ cur_res = cur_res[0:origin_height, 0:origin_width, :]
431
+ cur_res = np.clip(cur_res * 255, 0, 255).astype("uint8")
432
+ cur_res = cv2.cvtColor(cur_res, cv2.COLOR_BGR2RGB)
433
+ return cur_res
434
+
435
+
436
+ def get_args_parser():
437
+ parser = argparse.ArgumentParser()
438
+ parser.add_argument("--port", default=8080, type=int)
439
+ parser.add_argument("--device", default="cuda", type=str)
440
+ parser.add_argument("--debug", action="store_true")
441
+ return parser.parse_args()
442
+
443
+
444
+ def process_inpaint(image, mask):
445
+ image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
446
+ original_shape = image.shape
447
+ interpolation = cv2.INTER_CUBIC
448
+
449
+ #size_limit: Union[int, str] = request.form.get("sizeLimit", "1080")
450
+ #if size_limit == "Original":
451
+ size_limit = max(image.shape)
452
+ #else:
453
+ # size_limit = int(size_limit)
454
+
455
+ print(f"Origin image shape: {original_shape}")
456
+ image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation)
457
+ print(f"Resized image shape: {image.shape}")
458
+ image = norm_img(image)
459
+
460
+ mask = 255-mask[:,:,3]
461
+ mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
462
+ mask = norm_img(mask)
463
+
464
+ res_np_img = run(image, mask)
465
+
466
+ return cv2.cvtColor(res_np_img, cv2.COLOR_BGR2RGB)
src/helper.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ from urllib.parse import urlparse
5
+ import cv2
6
+ import numpy as np
7
+ import torch
8
+ from torch.hub import download_url_to_file, get_dir
9
+
10
+ LAMA_MODEL_URL = os.environ.get(
11
+ "LAMA_MODEL_URL",
12
+ "https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
13
+ )
14
+
15
+
16
+ def download_model(url=LAMA_MODEL_URL):
17
+ parts = urlparse(url)
18
+ hub_dir = get_dir()
19
+ model_dir = os.path.join(hub_dir, "checkpoints")
20
+ if not os.path.isdir(model_dir):
21
+ os.makedirs(os.path.join(model_dir, "hub", "checkpoints"))
22
+ filename = os.path.basename(parts.path)
23
+ cached_file = os.path.join(model_dir, filename)
24
+ if not os.path.exists(cached_file):
25
+ sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
26
+ hash_prefix = None
27
+ download_url_to_file(url, cached_file, hash_prefix, progress=True)
28
+ return cached_file
29
+
30
+
31
+ def ceil_modulo(x, mod):
32
+ if x % mod == 0:
33
+ return x
34
+ return (x // mod + 1) * mod
35
+
36
+
37
+ def numpy_to_bytes(image_numpy: np.ndarray) -> bytes:
38
+ data = cv2.imencode(".jpg", image_numpy)[1]
39
+ image_bytes = data.tobytes()
40
+ return image_bytes
41
+
42
+
43
+ def load_img(img_bytes, gray: bool = False):
44
+ nparr = np.frombuffer(img_bytes, np.uint8)
45
+ if gray:
46
+ np_img = cv2.imdecode(nparr, cv2.IMREAD_GRAYSCALE)
47
+ else:
48
+ np_img = cv2.imdecode(nparr, cv2.IMREAD_UNCHANGED)
49
+ if len(np_img.shape) == 3 and np_img.shape[2] == 4:
50
+ np_img = cv2.cvtColor(np_img, cv2.COLOR_BGRA2RGB)
51
+ else:
52
+ np_img = cv2.cvtColor(np_img, cv2.COLOR_BGR2RGB)
53
+
54
+ return np_img
55
+
56
+
57
+ def norm_img(np_img):
58
+ if len(np_img.shape) == 2:
59
+ np_img = np_img[:, :, np.newaxis]
60
+ np_img = np.transpose(np_img, (2, 0, 1))
61
+ np_img = np_img.astype("float32") / 255
62
+ return np_img
63
+
64
+
65
+ def resize_max_size(
66
+ np_img, size_limit: int, interpolation=cv2.INTER_CUBIC
67
+ ) -> np.ndarray:
68
+ # Resize image's longer size to size_limit if longer size larger than size_limit
69
+ h, w = np_img.shape[:2]
70
+ if max(h, w) > size_limit:
71
+ ratio = size_limit / max(h, w)
72
+ new_w = int(w * ratio + 0.5)
73
+ new_h = int(h * ratio + 0.5)
74
+ return cv2.resize(np_img, dsize=(new_w, new_h), interpolation=interpolation)
75
+ else:
76
+ return np_img
77
+
78
+
79
+ def pad_img_to_modulo(img, mod):
80
+ channels, height, width = img.shape
81
+ out_height = ceil_modulo(height, mod)
82
+ out_width = ceil_modulo(width, mod)
83
+ return np.pad(
84
+ img,
85
+ ((0, 0), (0, out_height - height), (0, out_width - width)),
86
+ mode="symmetric",
87
+ )
src/st_style.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ button_style = """
2
+ <style>
3
+ div.stButton > button:first-child {
4
+ background-color: rgb(255, 75, 75);
5
+ color: rgb(255, 255, 255);
6
+ }
7
+ div.stButton > button:hover {
8
+ background-color: rgb(255, 75, 75);
9
+ color: rgb(255, 255, 255);
10
+ }
11
+ div.stButton > button:active {
12
+ background-color: rgb(255, 75, 75);
13
+ color: rgb(255, 255, 255);
14
+ }
15
+ div.stButton > button:focus {
16
+ background-color: rgb(255, 75, 75);
17
+ color: rgb(255, 255, 255);
18
+ }
19
+ .css-1cpxqw2:focus:not(:active) {
20
+ background-color: rgb(255, 75, 75);
21
+ border-color: rgb(255, 75, 75);
22
+ color: rgb(255, 255, 255);
23
+ }
24
+ """
25
+
26
+ style = """
27
+ <style>
28
+ #MainMenu {
29
+ visibility: hidden;
30
+ }
31
+ footer {
32
+ visibility: hidden;
33
+ }
34
+ header {
35
+ visibility: hidden;
36
+ }
37
+ </style>
38
+ """
39
+
40
+
41
+ def apply_prod_style(st):
42
+ return st.markdown(style, unsafe_allow_html=True)