Update app.py
Browse files
app.py
CHANGED
|
@@ -7,7 +7,19 @@ from gfpgan.utils import GFPGANer
|
|
| 7 |
from realesrgan.utils import RealESRGANer
|
| 8 |
import uuid
|
| 9 |
import tempfile
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
if not os.path.exists('realesr-general-x4v3.pth'):
|
| 12 |
os.system("wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth -P .")
|
| 13 |
if not os.path.exists('GFPGANv1.2.pth'):
|
|
@@ -31,6 +43,36 @@ upsampler = RealESRGANer(scale=4, model_path=model_path, model=model, tile=0, ti
|
|
| 31 |
|
| 32 |
os.makedirs('output', exist_ok=True)
|
| 33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
@app.route('/api/restore', methods=['POST'])
|
| 35 |
def restore_image():
|
| 36 |
try:
|
|
@@ -41,7 +83,7 @@ def restore_image():
|
|
| 41 |
file = request.files['file']
|
| 42 |
version = request.form.get('version', 'v1.4')
|
| 43 |
scale = float(request.form.get('scale', 2))
|
| 44 |
-
|
| 45 |
|
| 46 |
# 一時ファイルに保存
|
| 47 |
temp_dir = tempfile.mkdtemp()
|
|
@@ -49,7 +91,7 @@ def restore_image():
|
|
| 49 |
file.save(input_path)
|
| 50 |
|
| 51 |
# 画像処理
|
| 52 |
-
extension = os.path.splitext(os.path.basename(str(input_path))
|
| 53 |
img = cv2.imread(input_path, cv2.IMREAD_UNCHANGED)
|
| 54 |
|
| 55 |
if len(img.shape) == 3 and img.shape[2] == 4:
|
|
@@ -68,25 +110,26 @@ def restore_image():
|
|
| 68 |
if version == 'v1.2':
|
| 69 |
face_enhancer = GFPGANer(
|
| 70 |
model_path='GFPGANv1.2.pth', upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=upsampler)
|
|
|
|
| 71 |
elif version == 'v1.3':
|
| 72 |
face_enhancer = GFPGANer(
|
| 73 |
model_path='GFPGANv1.3.pth', upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=upsampler)
|
|
|
|
| 74 |
elif version == 'v1.4':
|
| 75 |
face_enhancer = GFPGANer(
|
| 76 |
model_path='GFPGANv1.4.pth', upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=upsampler)
|
|
|
|
| 77 |
elif version == 'RestoreFormer':
|
| 78 |
face_enhancer = GFPGANer(
|
| 79 |
model_path='RestoreFormer.pth', upscale=2, arch='RestoreFormer', channel_multiplier=2, bg_upsampler=upsampler)
|
|
|
|
| 80 |
elif version == 'CodeFormer':
|
| 81 |
-
|
| 82 |
-
model_path='CodeFormer.pth', upscale=2, arch='CodeFormer', channel_multiplier=2, bg_upsampler=upsampler)
|
| 83 |
elif version == 'RealESR-General-x4v3':
|
| 84 |
face_enhancer = GFPGANer(
|
| 85 |
-
model_path='realesr-general-x4v3.pth', upscale=2, arch='realesr-general', channel_multiplier=2, bg_upsampler=upsampler)
|
|
|
|
| 86 |
|
| 87 |
-
# 画像を拡張
|
| 88 |
-
_, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
|
| 89 |
-
|
| 90 |
# スケール調整
|
| 91 |
if scale != 2:
|
| 92 |
interpolation = cv2.INTER_AREA if scale < 2 else cv2.INTER_LANCZOS4
|
|
@@ -179,6 +222,10 @@ def index():
|
|
| 179 |
<label for="scale">Rescaling factor:</label>
|
| 180 |
<input type="number" id="scale" name="scale" value="2" step="0.1" min="1" max="4" required>
|
| 181 |
</div>
|
|
|
|
|
|
|
|
|
|
|
|
|
| 182 |
<button type="submit" id="submitButton">Process Image</button>
|
| 183 |
</form>
|
| 184 |
<div id="loading" class="loader"></div>
|
|
@@ -198,11 +245,23 @@ def index():
|
|
| 198 |
</div>
|
| 199 |
|
| 200 |
<script>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 201 |
// フォームの変更を監視してAPI使用例を更新
|
| 202 |
function updateApiUsage() {
|
| 203 |
const fileInput = document.getElementById('file');
|
| 204 |
const version = document.getElementById('version').value;
|
| 205 |
const scale = document.getElementById('scale').value;
|
|
|
|
| 206 |
|
| 207 |
// 現在のURLからベースURLを取得(パス、パラメータ、ハッシュを含めない)
|
| 208 |
const baseUrl = window.location.origin;
|
|
@@ -216,27 +275,32 @@ def index():
|
|
| 216 |
reader.onload = function(e) {
|
| 217 |
const dataURL = e.target.result;
|
| 218 |
if (dataURL.length > 40) {
|
| 219 |
-
filePreview =
|
| 220 |
} else {
|
| 221 |
-
filePreview =
|
| 222 |
}
|
| 223 |
-
updateFetchCode(apiUrl, version, scale, filePreview);
|
| 224 |
};
|
| 225 |
reader.readAsDataURL(file);
|
| 226 |
} else {
|
| 227 |
-
updateFetchCode(apiUrl, version, scale, filePreview);
|
| 228 |
}
|
| 229 |
}
|
| 230 |
|
| 231 |
-
function updateFetchCode(apiUrl, version, scale, filePreview) {
|
| 232 |
const fetchCodeDiv = document.getElementById('fetchCode');
|
| 233 |
-
|
| 234 |
-
// JavaScript fetch example:
|
| 235 |
const formData = new FormData();
|
| 236 |
formData.append('file', ${filePreview});
|
| 237 |
formData.append('version', '${version}');
|
| 238 |
-
formData.append('scale', ${scale})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 239 |
|
|
|
|
| 240 |
fetch('${apiUrl}', {
|
| 241 |
method: 'POST',
|
| 242 |
body: formData
|
|
@@ -256,12 +320,15 @@ fetch('${apiUrl}', {
|
|
| 256 |
.catch(error => {
|
| 257 |
console.error('Error:', error.message);
|
| 258 |
});`;
|
|
|
|
|
|
|
| 259 |
}
|
| 260 |
|
| 261 |
// フォーム要素の変更を監視
|
| 262 |
document.getElementById('file').addEventListener('change', updateApiUsage);
|
| 263 |
document.getElementById('version').addEventListener('change', updateApiUsage);
|
| 264 |
document.getElementById('scale').addEventListener('input', updateApiUsage);
|
|
|
|
| 265 |
|
| 266 |
// 初期表示
|
| 267 |
updateApiUsage();
|
|
@@ -281,6 +348,11 @@ fetch('${apiUrl}', {
|
|
| 281 |
formData.append('version', document.getElementById('version').value);
|
| 282 |
formData.append('scale', document.getElementById('scale').value);
|
| 283 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 284 |
// 現在のURLからベースURLを取得(パス、パラメータ、ハッシュを含めない)
|
| 285 |
const baseUrl = window.location.origin;
|
| 286 |
const apiUrl = baseUrl + '/api/restore';
|
|
@@ -321,5 +393,4 @@ fetch('${apiUrl}', {
|
|
| 321 |
"""
|
| 322 |
|
| 323 |
if __name__ == '__main__':
|
| 324 |
-
|
| 325 |
app.run(host='0.0.0.0', port=7860, debug=True)
|
|
|
|
| 7 |
from realesrgan.utils import RealESRGANer
|
| 8 |
import uuid
|
| 9 |
import tempfile
|
| 10 |
+
from torchvision.transforms.functional import normalize
|
| 11 |
+
from torchvision import transforms
|
| 12 |
+
from PIL import Image
|
| 13 |
+
from basicsr.utils import img2tensor, tensor2img
|
| 14 |
+
from facelib.utils.face_restoration_helper import FaceRestoreHelper
|
| 15 |
+
from codeformer.archs.codeformer_arch import CodeFormer
|
| 16 |
+
|
| 17 |
+
# 依存関係のインストール
|
| 18 |
+
os.system("git clone https://github.com/sczhou/CodeFormer.git")
|
| 19 |
+
os.system("cd CodeFormer && pip install -r requirements.txt")
|
| 20 |
+
os.system("cd CodeFormer && python basicsr/setup.py develop")
|
| 21 |
+
|
| 22 |
+
# ウェイトファイルをダウンロード(毎回消えるので毎回必ず実行。)
|
| 23 |
if not os.path.exists('realesr-general-x4v3.pth'):
|
| 24 |
os.system("wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth -P .")
|
| 25 |
if not os.path.exists('GFPGANv1.2.pth'):
|
|
|
|
| 43 |
|
| 44 |
os.makedirs('output', exist_ok=True)
|
| 45 |
|
| 46 |
+
def restore_with_codeformer(img, scale=2, weight=0.5):
|
| 47 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 48 |
+
net = CodeFormer(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, connect_list=['32', '64', '128', '256']).to(device)
|
| 49 |
+
net.load_state_dict(torch.load('CodeFormer.pth')['params_ema'])
|
| 50 |
+
net.eval()
|
| 51 |
+
|
| 52 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 53 |
+
img = Image.fromarray(img)
|
| 54 |
+
|
| 55 |
+
face_helper = FaceRestoreHelper(
|
| 56 |
+
upscale_factor=scale, face_size=512, crop_ratio=(1, 1), use_parse=True,
|
| 57 |
+
device=device)
|
| 58 |
+
|
| 59 |
+
face_helper.clean_all()
|
| 60 |
+
face_helper.read_image(img)
|
| 61 |
+
face_helper.get_face_landmarks_5(only_center_face=False, resize=640)
|
| 62 |
+
face_helper.align_warp_face()
|
| 63 |
+
|
| 64 |
+
for idx, cropped_face in enumerate(face_helper.cropped_faces):
|
| 65 |
+
cropped_face_t = img2tensor(cropped_face / 255.0, bgr2rgb=False, float32=True)
|
| 66 |
+
normalize(cropped_face_t, [0.5], [0.5], inplace=True)
|
| 67 |
+
cropped_face_t = cropped_face_t.unsqueeze(0).to(device)
|
| 68 |
+
with torch.no_grad():
|
| 69 |
+
output = net(cropped_face_t, w=weight, adain=True)[0]
|
| 70 |
+
restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
|
| 71 |
+
face_helper.add_restored_face(restored_face)
|
| 72 |
+
|
| 73 |
+
restored_img = face_helper.paste_faces_to_input_image()
|
| 74 |
+
return cv2.cvtColor(restored_img, cv2.COLOR_RGB2BGR)
|
| 75 |
+
|
| 76 |
@app.route('/api/restore', methods=['POST'])
|
| 77 |
def restore_image():
|
| 78 |
try:
|
|
|
|
| 83 |
file = request.files['file']
|
| 84 |
version = request.form.get('version', 'v1.4')
|
| 85 |
scale = float(request.form.get('scale', 2))
|
| 86 |
+
weight = float(request.form.get('weight', 0.5)) # CodeFormer用のweightパラメータ
|
| 87 |
|
| 88 |
# 一時ファイルに保存
|
| 89 |
temp_dir = tempfile.mkdtemp()
|
|
|
|
| 91 |
file.save(input_path)
|
| 92 |
|
| 93 |
# 画像処理
|
| 94 |
+
extension = os.path.splitext(os.path.basename(str(input_path))[1]
|
| 95 |
img = cv2.imread(input_path, cv2.IMREAD_UNCHANGED)
|
| 96 |
|
| 97 |
if len(img.shape) == 3 and img.shape[2] == 4:
|
|
|
|
| 110 |
if version == 'v1.2':
|
| 111 |
face_enhancer = GFPGANer(
|
| 112 |
model_path='GFPGANv1.2.pth', upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=upsampler)
|
| 113 |
+
_, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
|
| 114 |
elif version == 'v1.3':
|
| 115 |
face_enhancer = GFPGANer(
|
| 116 |
model_path='GFPGANv1.3.pth', upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=upsampler)
|
| 117 |
+
_, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
|
| 118 |
elif version == 'v1.4':
|
| 119 |
face_enhancer = GFPGANer(
|
| 120 |
model_path='GFPGANv1.4.pth', upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=upsampler)
|
| 121 |
+
_, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
|
| 122 |
elif version == 'RestoreFormer':
|
| 123 |
face_enhancer = GFPGANer(
|
| 124 |
model_path='RestoreFormer.pth', upscale=2, arch='RestoreFormer', channel_multiplier=2, bg_upsampler=upsampler)
|
| 125 |
+
_, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
|
| 126 |
elif version == 'CodeFormer':
|
| 127 |
+
output = restore_with_codeformer(img, scale=scale, weight=weight)
|
|
|
|
| 128 |
elif version == 'RealESR-General-x4v3':
|
| 129 |
face_enhancer = GFPGANer(
|
| 130 |
+
model_path='realesr-general-x4v3.pth', upscale=2, arch='realesr-general', channel_multiplier=2, bg_upsampler=upsampler, map_location=torch.device('cpu'))
|
| 131 |
+
_, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
|
| 132 |
|
|
|
|
|
|
|
|
|
|
| 133 |
# スケール調整
|
| 134 |
if scale != 2:
|
| 135 |
interpolation = cv2.INTER_AREA if scale < 2 else cv2.INTER_LANCZOS4
|
|
|
|
| 222 |
<label for="scale">Rescaling factor:</label>
|
| 223 |
<input type="number" id="scale" name="scale" value="2" step="0.1" min="1" max="4" required>
|
| 224 |
</div>
|
| 225 |
+
<div class="form-group" id="weightGroup" style="display: none;">
|
| 226 |
+
<label for="weight">CodeFormer Weight (0-1):</label>
|
| 227 |
+
<input type="number" id="weight" name="weight" value="0.5" step="0.1" min="0" max="1">
|
| 228 |
+
</div>
|
| 229 |
<button type="submit" id="submitButton">Process Image</button>
|
| 230 |
</form>
|
| 231 |
<div id="loading" class="loader"></div>
|
|
|
|
| 245 |
</div>
|
| 246 |
|
| 247 |
<script>
|
| 248 |
+
// CodeFormerが選択された時にweightパラメータを表示
|
| 249 |
+
document.getElementById('version').addEventListener('change', function() {
|
| 250 |
+
const weightGroup = document.getElementById('weightGroup');
|
| 251 |
+
if (this.value === 'CodeFormer') {
|
| 252 |
+
weightGroup.style.display = 'block';
|
| 253 |
+
} else {
|
| 254 |
+
weightGroup.style.display = 'none';
|
| 255 |
+
}
|
| 256 |
+
updateApiUsage();
|
| 257 |
+
});
|
| 258 |
+
|
| 259 |
// フォームの変更を監視してAPI使用例を更新
|
| 260 |
function updateApiUsage() {
|
| 261 |
const fileInput = document.getElementById('file');
|
| 262 |
const version = document.getElementById('version').value;
|
| 263 |
const scale = document.getElementById('scale').value;
|
| 264 |
+
const weight = document.getElementById('weight').value;
|
| 265 |
|
| 266 |
// 現在のURLからベースURLを取得(パス、パラメータ、ハッシュを含めない)
|
| 267 |
const baseUrl = window.location.origin;
|
|
|
|
| 275 |
reader.onload = function(e) {
|
| 276 |
const dataURL = e.target.result;
|
| 277 |
if (dataURL.length > 40) {
|
| 278 |
+
filePreview = "${dataURL.substring(0, 20)}...${dataURL.substring(dataURL.length - 20)}";
|
| 279 |
} else {
|
| 280 |
+
filePreview = "${dataURL}";
|
| 281 |
}
|
| 282 |
+
updateFetchCode(apiUrl, version, scale, weight, filePreview);
|
| 283 |
};
|
| 284 |
reader.readAsDataURL(file);
|
| 285 |
} else {
|
| 286 |
+
updateFetchCode(apiUrl, version, scale, weight, filePreview);
|
| 287 |
}
|
| 288 |
}
|
| 289 |
|
| 290 |
+
function updateFetchCode(apiUrl, version, scale, weight, filePreview) {
|
| 291 |
const fetchCodeDiv = document.getElementById('fetchCode');
|
| 292 |
+
let code = `// JavaScript fetch example:
|
|
|
|
| 293 |
const formData = new FormData();
|
| 294 |
formData.append('file', ${filePreview});
|
| 295 |
formData.append('version', '${version}');
|
| 296 |
+
formData.append('scale', ${scale});`;
|
| 297 |
+
|
| 298 |
+
if (version === 'CodeFormer') {
|
| 299 |
+
code += `
|
| 300 |
+
formData.append('weight', ${weight});`;
|
| 301 |
+
}
|
| 302 |
|
| 303 |
+
code += `
|
| 304 |
fetch('${apiUrl}', {
|
| 305 |
method: 'POST',
|
| 306 |
body: formData
|
|
|
|
| 320 |
.catch(error => {
|
| 321 |
console.error('Error:', error.message);
|
| 322 |
});`;
|
| 323 |
+
|
| 324 |
+
fetchCodeDiv.innerHTML = code;
|
| 325 |
}
|
| 326 |
|
| 327 |
// フォーム要素の変更を監視
|
| 328 |
document.getElementById('file').addEventListener('change', updateApiUsage);
|
| 329 |
document.getElementById('version').addEventListener('change', updateApiUsage);
|
| 330 |
document.getElementById('scale').addEventListener('input', updateApiUsage);
|
| 331 |
+
document.getElementById('weight').addEventListener('input', updateApiUsage);
|
| 332 |
|
| 333 |
// 初期表示
|
| 334 |
updateApiUsage();
|
|
|
|
| 348 |
formData.append('version', document.getElementById('version').value);
|
| 349 |
formData.append('scale', document.getElementById('scale').value);
|
| 350 |
|
| 351 |
+
// CodeFormerが選択されている場合はweightも追加
|
| 352 |
+
if (document.getElementById('version').value === 'CodeFormer') {
|
| 353 |
+
formData.append('weight', document.getElementById('weight').value);
|
| 354 |
+
}
|
| 355 |
+
|
| 356 |
// 現在のURLからベースURLを取得(パス、パラメータ、ハッシュを含めない)
|
| 357 |
const baseUrl = window.location.origin;
|
| 358 |
const apiUrl = baseUrl + '/api/restore';
|
|
|
|
| 393 |
"""
|
| 394 |
|
| 395 |
if __name__ == '__main__':
|
|
|
|
| 396 |
app.run(host='0.0.0.0', port=7860, debug=True)
|