Update app.py
Browse files
app.py
CHANGED
@@ -1,41 +1,16 @@
|
|
1 |
import os
|
2 |
-
|
3 |
import cv2
|
4 |
-
import gradio as gr
|
5 |
import torch
|
|
|
6 |
from basicsr.archs.srvgg_arch import SRVGGNetCompact
|
7 |
from gfpgan.utils import GFPGANer
|
8 |
from realesrgan.utils import RealESRGANer
|
|
|
|
|
9 |
|
10 |
-
|
11 |
-
# download weights
|
12 |
-
if not os.path.exists('realesr-general-x4v3.pth'):
|
13 |
-
os.system("wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth -P .")
|
14 |
-
if not os.path.exists('GFPGANv1.2.pth'):
|
15 |
-
os.system("wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.2.pth -P .")
|
16 |
-
if not os.path.exists('GFPGANv1.3.pth'):
|
17 |
-
os.system("wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth -P .")
|
18 |
-
if not os.path.exists('GFPGANv1.4.pth'):
|
19 |
-
os.system("wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth -P .")
|
20 |
-
if not os.path.exists('RestoreFormer.pth'):
|
21 |
-
os.system("wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth -P .")
|
22 |
-
if not os.path.exists('CodeFormer.pth'):
|
23 |
-
os.system("wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/CodeFormer.pth -P .")
|
24 |
-
|
25 |
-
torch.hub.download_url_to_file(
|
26 |
-
'https://thumbs.dreamstime.com/b/tower-bridge-traditional-red-bus-black-white-colors-view-to-tower-bridge-london-black-white-colors-108478942.jpg',
|
27 |
-
'a1.jpg')
|
28 |
-
torch.hub.download_url_to_file(
|
29 |
-
'https://media.istockphoto.com/id/523514029/photo/london-skyline-b-w.jpg?s=612x612&w=0&k=20&c=kJS1BAtfqYeUDaORupj0sBPc1hpzJhBUUqEFfRnHzZ0=',
|
30 |
-
'a2.jpg')
|
31 |
-
torch.hub.download_url_to_file(
|
32 |
-
'https://i.guim.co.uk/img/media/06f614065ed82ca0e917b149a32493c791619854/0_0_3648_2789/master/3648.jpg?width=700&quality=85&auto=format&fit=max&s=05764b507c18a38590090d987c8b6202',
|
33 |
-
'a3.jpg')
|
34 |
-
torch.hub.download_url_to_file(
|
35 |
-
'https://i.pinimg.com/736x/46/96/9e/46969eb94aec2437323464804d27706d--victorian-london-victorian-era.jpg',
|
36 |
-
'a4.jpg')
|
37 |
|
38 |
-
#
|
39 |
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
|
40 |
model_path = 'realesr-general-x4v3.pth'
|
41 |
half = True if torch.cuda.is_available() else False
|
@@ -43,17 +18,30 @@ upsampler = RealESRGANer(scale=4, model_path=model_path, model=model, tile=0, ti
|
|
43 |
|
44 |
os.makedirs('output', exist_ok=True)
|
45 |
|
46 |
-
|
47 |
-
|
48 |
-
def inference(img, version, scale):
|
49 |
-
# weight /= 100
|
50 |
-
print(img, version, scale)
|
51 |
try:
|
52 |
-
|
53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
if len(img.shape) == 3 and img.shape[2] == 4:
|
55 |
img_mode = 'RGBA'
|
56 |
-
elif len(img.shape) == 2:
|
57 |
img_mode = None
|
58 |
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
59 |
else:
|
@@ -63,80 +51,172 @@ def inference(img, version, scale):
|
|
63 |
if h < 300:
|
64 |
img = cv2.resize(img, (w * 2, h * 2), interpolation=cv2.INTER_LANCZOS4)
|
65 |
|
|
|
66 |
if version == 'v1.2':
|
67 |
face_enhancer = GFPGANer(
|
68 |
-
|
69 |
elif version == 'v1.3':
|
70 |
face_enhancer = GFPGANer(
|
71 |
-
|
72 |
elif version == 'v1.4':
|
73 |
face_enhancer = GFPGANer(
|
74 |
-
|
75 |
elif version == 'RestoreFormer':
|
76 |
face_enhancer = GFPGANer(
|
77 |
-
|
78 |
elif version == 'CodeFormer':
|
79 |
-
|
80 |
-
|
81 |
elif version == 'RealESR-General-x4v3':
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
try:
|
86 |
-
# _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True, weight=weight)
|
87 |
-
_, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
|
88 |
-
except RuntimeError as error:
|
89 |
-
print('Error', error)
|
90 |
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
else:
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
return
|
107 |
-
|
108 |
-
|
109 |
-
return
|
110 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
# gr.inputs.Radio(['v1.2', 'v1.3', 'v1.4', 'RestoreFormer', 'CodeFormer'], type="value", default='v1.4', label='version'),
|
127 |
-
gr.inputs.Radio(['v1.2', 'v1.3', 'v1.4', 'RestoreFormer','CodeFormer','RealESR-General-x4v3'], type="value", default='v1.4', label='version'),
|
128 |
-
gr.inputs.Number(label="Rescaling factor", default=2),
|
129 |
-
# gr.Slider(0, 100, label='Weight, only for CodeFormer. 0 for better quality, 100 for better identity', default=50)
|
130 |
-
], [
|
131 |
-
gr.outputs.Image(type="numpy", label="Output (The whole image)"),
|
132 |
-
gr.outputs.File(label="Download the output image")
|
133 |
-
],
|
134 |
-
title=title,
|
135 |
-
description=description,
|
136 |
-
article=article,
|
137 |
-
# examples=[['AI-generate.jpg', 'v1.4', 2, 50], ['lincoln.jpg', 'v1.4', 2, 50], ['Blake_Lively.jpg', 'v1.4', 2, 50],
|
138 |
-
# ['10045.png', 'v1.4', 2, 50]]).launch()
|
139 |
-
examples=[['a1.jpg', 'v1.4', 2], ['a2.jpg', 'v1.4', 2], ['a3.jpg', 'v1.4', 2],['a4.jpg', 'v1.4', 2]])
|
140 |
|
141 |
-
|
142 |
-
demo.launch()
|
|
|
1 |
import os
|
|
|
2 |
import cv2
|
|
|
3 |
import torch
|
4 |
+
from flask import Flask, request, jsonify, send_file
|
5 |
from basicsr.archs.srvgg_arch import SRVGGNetCompact
|
6 |
from gfpgan.utils import GFPGANer
|
7 |
from realesrgan.utils import RealESRGANer
|
8 |
+
import uuid
|
9 |
+
import tempfile
|
10 |
|
11 |
+
app = Flask(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
+
# モデルの初期化
|
14 |
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
|
15 |
model_path = 'realesr-general-x4v3.pth'
|
16 |
half = True if torch.cuda.is_available() else False
|
|
|
18 |
|
19 |
os.makedirs('output', exist_ok=True)
|
20 |
|
21 |
+
@app.route('/api/restore', methods=['POST'])
|
22 |
+
def restore_image():
|
|
|
|
|
|
|
23 |
try:
|
24 |
+
# リクエストからパラメータを取得
|
25 |
+
if 'file' not in request.files:
|
26 |
+
return jsonify({'error': 'No file uploaded'}), 400
|
27 |
+
|
28 |
+
file = request.files['file']
|
29 |
+
version = request.form.get('version', 'v1.4')
|
30 |
+
scale = float(request.form.get('scale', 2))
|
31 |
+
# weight = float(request.form.get('weight', 50)) / 100 # CodeFormer用のweightパラメータが必要な場合
|
32 |
+
|
33 |
+
# 一時ファイルに保存
|
34 |
+
temp_dir = tempfile.mkdtemp()
|
35 |
+
input_path = os.path.join(temp_dir, file.filename)
|
36 |
+
file.save(input_path)
|
37 |
+
|
38 |
+
# 画像処理
|
39 |
+
extension = os.path.splitext(os.path.basename(str(input_path)))[1]
|
40 |
+
img = cv2.imread(input_path, cv2.IMREAD_UNCHANGED)
|
41 |
+
|
42 |
if len(img.shape) == 3 and img.shape[2] == 4:
|
43 |
img_mode = 'RGBA'
|
44 |
+
elif len(img.shape) == 2:
|
45 |
img_mode = None
|
46 |
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
47 |
else:
|
|
|
51 |
if h < 300:
|
52 |
img = cv2.resize(img, (w * 2, h * 2), interpolation=cv2.INTER_LANCZOS4)
|
53 |
|
54 |
+
# バージョンに応じてモデルを選択
|
55 |
if version == 'v1.2':
|
56 |
face_enhancer = GFPGANer(
|
57 |
+
model_path='GFPGANv1.2.pth', upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=upsampler)
|
58 |
elif version == 'v1.3':
|
59 |
face_enhancer = GFPGANer(
|
60 |
+
model_path='GFPGANv1.3.pth', upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=upsampler)
|
61 |
elif version == 'v1.4':
|
62 |
face_enhancer = GFPGANer(
|
63 |
+
model_path='GFPGANv1.4.pth', upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=upsampler)
|
64 |
elif version == 'RestoreFormer':
|
65 |
face_enhancer = GFPGANer(
|
66 |
+
model_path='RestoreFormer.pth', upscale=2, arch='RestoreFormer', channel_multiplier=2, bg_upsampler=upsampler)
|
67 |
elif version == 'CodeFormer':
|
68 |
+
face_enhancer = GFPGANer(
|
69 |
+
model_path='CodeFormer.pth', upscale=2, arch='CodeFormer', channel_multiplier=2, bg_upsampler=upsampler)
|
70 |
elif version == 'RealESR-General-x4v3':
|
71 |
+
face_enhancer = GFPGANer(
|
72 |
+
model_path='realesr-general-x4v3.pth', upscale=2, arch='realesr-general', channel_multiplier=2, bg_upsampler=upsampler)
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
|
74 |
+
# 画像を拡張
|
75 |
+
_, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
|
76 |
+
|
77 |
+
# スケール調整
|
78 |
+
if scale != 2:
|
79 |
+
interpolation = cv2.INTER_AREA if scale < 2 else cv2.INTER_LANCZOS4
|
80 |
+
h, w = img.shape[0:2]
|
81 |
+
output = cv2.resize(output, (int(w * scale / 2), int(h * scale / 2)), interpolation=interpolation)
|
82 |
+
|
83 |
+
# 出力ファイルを保存
|
84 |
+
output_filename = f'output_{uuid.uuid4().hex}'
|
85 |
+
if img_mode == 'RGBA':
|
86 |
+
output_path = os.path.join('output', f'{output_filename}.png')
|
87 |
+
cv2.imwrite(output_path, output)
|
88 |
+
mimetype = 'image/png'
|
89 |
else:
|
90 |
+
output_path = os.path.join('output', f'{output_filename}.jpg')
|
91 |
+
cv2.imwrite(output_path, output)
|
92 |
+
mimetype = 'image/jpeg'
|
93 |
+
|
94 |
+
# 結果を返す
|
95 |
+
return send_file(output_path, mimetype=mimetype, as_attachment=True, download_name=os.path.basename(output_path))
|
96 |
+
|
97 |
+
except Exception as e:
|
98 |
+
return jsonify({'error': str(e)}), 500
|
99 |
|
100 |
+
@app.route('/')
|
101 |
+
def index():
|
102 |
+
return """
|
103 |
+
<!DOCTYPE html>
|
104 |
+
<html>
|
105 |
+
<head>
|
106 |
+
<title>Image Upscaling & Restoration API</title>
|
107 |
+
<style>
|
108 |
+
body { font-family: Arial, sans-serif; max-width: 800px; margin: 0 auto; padding: 20px; }
|
109 |
+
.container { border: 1px solid #ddd; padding: 20px; border-radius: 5px; }
|
110 |
+
.form-group { margin-bottom: 15px; }
|
111 |
+
label { display: block; margin-bottom: 5px; }
|
112 |
+
input, select { width: 100%; padding: 8px; box-sizing: border-box; }
|
113 |
+
button { background-color: #4CAF50; color: white; padding: 10px 15px; border: none; border-radius: 4px; cursor: pointer; }
|
114 |
+
button:hover { background-color: #45a049; }
|
115 |
+
#result { margin-top: 20px; }
|
116 |
+
#preview { max-width: 100%; margin-top: 10px; }
|
117 |
+
</style>
|
118 |
+
</head>
|
119 |
+
<body>
|
120 |
+
<h1>Image Upscaling & Restoration API</h1>
|
121 |
+
<div class="container">
|
122 |
+
<form id="uploadForm" enctype="multipart/form-data">
|
123 |
+
<div class="form-group">
|
124 |
+
<label for="file">Upload Image:</label>
|
125 |
+
<input type="file" id="file" name="file" required>
|
126 |
+
</div>
|
127 |
+
<div class="form-group">
|
128 |
+
<label for="version">Version:</label>
|
129 |
+
<select id="version" name="version">
|
130 |
+
<option value="v1.2">v1.2</option>
|
131 |
+
<option value="v1.3">v1.3</option>
|
132 |
+
<option value="v1.4" selected>v1.4</option>
|
133 |
+
<option value="RestoreFormer">RestoreFormer</option>
|
134 |
+
<option value="CodeFormer">CodeFormer</option>
|
135 |
+
<option value="RealESR-General-x4v3">RealESR-General-x4v3</option>
|
136 |
+
</select>
|
137 |
+
</div>
|
138 |
+
<div class="form-group">
|
139 |
+
<label for="scale">Rescaling factor:</label>
|
140 |
+
<input type="number" id="scale" name="scale" value="2" step="0.1" min="1" max="4" required>
|
141 |
+
</div>
|
142 |
+
<!-- CodeFormer用のweightパラメータが必要な場合 -->
|
143 |
+
<!--
|
144 |
+
<div class="form-group">
|
145 |
+
<label for="weight">Weight (only for CodeFormer):</label>
|
146 |
+
<input type="range" id="weight" name="weight" min="0" max="100" value="50">
|
147 |
+
<span id="weightValue">50</span>
|
148 |
+
</div>
|
149 |
+
-->
|
150 |
+
<button type="submit">Process Image</button>
|
151 |
+
</form>
|
152 |
+
|
153 |
+
<div id="result">
|
154 |
+
<h3>Result:</h3>
|
155 |
+
<div id="outputContainer" style="display: none;">
|
156 |
+
<img id="preview" src="" alt="Processed Image">
|
157 |
+
<a id="downloadLink" href="#" download>Download Image</a>
|
158 |
+
</div>
|
159 |
+
</div>
|
160 |
+
</div>
|
161 |
+
|
162 |
+
<script>
|
163 |
+
document.getElementById('uploadForm').addEventListener('submit', function(e) {
|
164 |
+
e.preventDefault();
|
165 |
+
|
166 |
+
const formData = new FormData();
|
167 |
+
formData.append('file', document.getElementById('file').files[0]);
|
168 |
+
formData.append('version', document.getElementById('version').value);
|
169 |
+
formData.append('scale', document.getElementById('scale').value);
|
170 |
+
// formData.append('weight', document.getElementById('weight').value); // CodeFormer用
|
171 |
+
|
172 |
+
fetch('/api/restore', {
|
173 |
+
method: 'POST',
|
174 |
+
body: formData
|
175 |
+
})
|
176 |
+
.then(response => {
|
177 |
+
if (!response.ok) {
|
178 |
+
return response.json().then(err => { throw new Error(err.error || 'Unknown error'); });
|
179 |
+
}
|
180 |
+
return response.blob();
|
181 |
+
})
|
182 |
+
.then(blob => {
|
183 |
+
const url = URL.createObjectURL(blob);
|
184 |
+
const preview = document.getElementById('preview');
|
185 |
+
const downloadLink = document.getElementById('downloadLink');
|
186 |
+
const outputContainer = document.getElementById('outputContainer');
|
187 |
+
|
188 |
+
preview.src = url;
|
189 |
+
downloadLink.href = url;
|
190 |
+
downloadLink.download = 'restored_' + document.getElementById('file').files[0].name;
|
191 |
+
outputContainer.style.display = 'block';
|
192 |
+
})
|
193 |
+
.catch(error => {
|
194 |
+
alert('Error: ' + error.message);
|
195 |
+
});
|
196 |
+
});
|
197 |
+
|
198 |
+
// CodeFormer用のweightパラメータが必要な場合
|
199 |
+
// document.getElementById('weight').addEventListener('input', function() {
|
200 |
+
// document.getElementById('weightValue').textContent = this.value;
|
201 |
+
// });
|
202 |
+
</script>
|
203 |
+
</body>
|
204 |
+
</html>
|
205 |
+
"""
|
206 |
|
207 |
+
if __name__ == '__main__':
|
208 |
+
# ウェイトファイルをダウンロード(存在しない場合)
|
209 |
+
if not os.path.exists('realesr-general-x4v3.pth'):
|
210 |
+
os.system("wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth -P .")
|
211 |
+
if not os.path.exists('GFPGANv1.2.pth'):
|
212 |
+
os.system("wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.2.pth -P .")
|
213 |
+
if not os.path.exists('GFPGANv1.3.pth'):
|
214 |
+
os.system("wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth -P .")
|
215 |
+
if not os.path.exists('GFPGANv1.4.pth'):
|
216 |
+
os.system("wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth -P .")
|
217 |
+
if not os.path.exists('RestoreFormer.pth'):
|
218 |
+
os.system("wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth -P .")
|
219 |
+
if not os.path.exists('CodeFormer.pth'):
|
220 |
+
os.system("wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/CodeFormer.pth -P .")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
221 |
|
222 |
+
app.run(host='0.0.0.0', port=5000, debug=True)
|
|