soiz1 commited on
Commit
914381e
·
verified ·
1 Parent(s): 8f711a8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +91 -109
app.py CHANGED
@@ -1,101 +1,45 @@
1
- import sys
2
  import os
3
-
4
- # 依存関係のインストール
5
- os.system("git clone https://github.com/sczhou/CodeFormer.git")
6
- os.system("cd CodeFormer && pip install -r requirements.txt")
7
- os.system("cd CodeFormer && python basicsr/setup.py develop")
8
- sys.path.append(os.path.abspath('CodeFormer'))
9
- sys.path.append(os.path.abspath('CodeFormer/CodeFormer'))
10
- # ウェイトファイルをダウンロード(毎回消えるので毎回必ず実行。)
11
- if not os.path.exists('realesr-general-x4v3.pth'):
12
- os.system("wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth -P .")
13
- if not os.path.exists('GFPGANv1.2.pth'):
14
- os.system("wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.2.pth -P .")
15
- if not os.path.exists('GFPGANv1.3.pth'):
16
- os.system("wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth -P .")
17
- if not os.path.exists('GFPGANv1.4.pth'):
18
- os.system("wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth -P .")
19
- if not os.path.exists('RestoreFormer.pth'):
20
- os.system("wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth -P .")
21
- if not os.path.exists('CodeFormer.pth'):
22
- os.system("wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/CodeFormer.pth -P .")
23
-
24
  import cv2
25
  import torch
26
- from flask import Flask, request, jsonify, send_file
27
  from basicsr.archs.srvgg_arch import SRVGGNetCompact
28
  from gfpgan.utils import GFPGANer
29
  from realesrgan.utils import RealESRGANer
30
- import uuid
31
  import tempfile
32
- from torchvision.transforms.functional import normalize
33
- from torchvision import transforms
34
- from PIL import Image
35
- from basicsr.utils import img2tensor, tensor2img
36
- from facexlib.utils.face_restoration_helper import FaceRestoreHelper
37
- from codeformer.archs.codeformer_arch import CodeFormer
38
 
39
  app = Flask(__name__)
40
 
41
- # モデルの初期化
42
  model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
43
  model_path = 'realesr-general-x4v3.pth'
44
  half = True if torch.cuda.is_available() else False
45
  upsampler = RealESRGANer(scale=4, model_path=model_path, model=model, tile=0, tile_pad=10, pre_pad=0, half=half)
46
 
 
47
  os.makedirs('output', exist_ok=True)
48
 
49
- def restore_with_codeformer(img, scale=2, weight=0.5):
50
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
51
- net = CodeFormer(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, connect_list=['32', '64', '128', '256']).to(device)
52
- net.load_state_dict(torch.load('CodeFormer.pth')['params_ema'])
53
- net.eval()
54
-
55
- img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
56
- img = Image.fromarray(img)
57
-
58
- face_helper = FaceRestoreHelper(
59
- upscale_factor=scale, face_size=512, crop_ratio=(1, 1), use_parse=True,
60
- device=device)
61
 
62
- face_helper.clean_all()
63
- face_helper.read_image(img)
64
- face_helper.get_face_landmarks_5(only_center_face=False, resize=640)
65
- face_helper.align_warp_face()
66
 
67
- for idx, cropped_face in enumerate(face_helper.cropped_faces):
68
- cropped_face_t = img2tensor(cropped_face / 255.0, bgr2rgb=False, float32=True)
69
- normalize(cropped_face_t, [0.5], [0.5], inplace=True)
70
- cropped_face_t = cropped_face_t.unsqueeze(0).to(device)
71
- with torch.no_grad():
72
- output = net(cropped_face_t, w=weight, adain=True)[0]
73
- restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
74
- face_helper.add_restored_face(restored_face)
75
-
76
- restored_img = face_helper.paste_faces_to_input_image()
77
- return cv2.cvtColor(restored_img, cv2.COLOR_RGB2BGR)
78
 
79
- @app.route('/api/restore', methods=['POST'])
80
- def restore_image():
81
  try:
82
- # リクエストからパラメータを取得
83
- if 'file' not in request.files:
84
- return jsonify({'error': 'No file uploaded'}), 400
85
-
86
- file = request.files['file']
87
- version = request.form.get('version', 'v1.4')
88
- scale = float(request.form.get('scale', 2))
89
- weight = float(request.form.get('weight', 0.5)) # CodeFormer用のweightパラメータ
90
-
91
- # 一時ファイルに保存
92
- temp_dir = tempfile.mkdtemp()
93
- input_path = os.path.join(temp_dir, file.filename)
94
- file.save(input_path)
95
-
96
- # 画像処理
97
- extension = os.path.splitext(os.path.basename(str(input_path)))[1]
98
- img = cv2.imread(input_path, cv2.IMREAD_UNCHANGED)
99
 
100
  if len(img.shape) == 3 and img.shape[2] == 4:
101
  img_mode = 'RGBA'
@@ -109,56 +53,59 @@ def restore_image():
109
  if h < 300:
110
  img = cv2.resize(img, (w * 2, h * 2), interpolation=cv2.INTER_LANCZOS4)
111
 
112
- # バージョンに応じてモデルを選択
113
  if version == 'v1.2':
114
  face_enhancer = GFPGANer(
115
  model_path='GFPGANv1.2.pth', upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=upsampler)
116
- _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
117
  elif version == 'v1.3':
118
  face_enhancer = GFPGANer(
119
  model_path='GFPGANv1.3.pth', upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=upsampler)
120
- _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
121
  elif version == 'v1.4':
122
  face_enhancer = GFPGANer(
123
  model_path='GFPGANv1.4.pth', upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=upsampler)
124
- _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
125
  elif version == 'RestoreFormer':
126
  face_enhancer = GFPGANer(
127
  model_path='RestoreFormer.pth', upscale=2, arch='RestoreFormer', channel_multiplier=2, bg_upsampler=upsampler)
128
- _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
129
  elif version == 'CodeFormer':
130
- output = restore_with_codeformer(img, scale=scale, weight=weight)
 
131
  elif version == 'RealESR-General-x4v3':
132
  face_enhancer = GFPGANer(
133
- model_path='realesr-general-x4v3.pth', upscale=2, arch='realesr-general', channel_multiplier=2, bg_upsampler=upsampler, map_location=torch.device('cpu'))
134
- _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
 
 
 
 
 
 
 
 
135
 
136
- # スケール調整
137
- if scale != 2:
138
- interpolation = cv2.INTER_AREA if scale < 2 else cv2.INTER_LANCZOS4
139
- h, w = img.shape[0:2]
140
- output = cv2.resize(output, (int(w * scale / 2), int(h * scale / 2)), interpolation=interpolation)
 
 
 
 
 
 
141
 
142
- # 出力ファイルを保存
143
- output_filename = f'output_{uuid.uuid4().hex}'
144
  if img_mode == 'RGBA':
145
- output_path = os.path.join('output', f'{output_filename}.png')
146
- cv2.imwrite(output_path, output)
147
- mimetype = 'image/png'
148
  else:
149
- output_path = os.path.join('output', f'{output_filename}.jpg')
150
- cv2.imwrite(output_path, output)
151
- mimetype = 'image/jpeg'
152
-
153
- # 結果を返す
154
- return send_file(output_path, mimetype=mimetype, as_attachment=True, download_name=os.path.basename(output_path))
155
-
156
- except Exception as e:
157
- return jsonify({'error': str(e)}), 500
158
 
159
  @app.route('/')
160
  def index():
161
- return """
162
  <!DOCTYPE html>
163
  <html>
164
  <head>
@@ -278,9 +225,9 @@ def index():
278
  reader.onload = function(e) {
279
  const dataURL = e.target.result;
280
  if (dataURL.length > 40) {
281
- filePreview = "${dataURL.substring(0, 20)}...${dataURL.substring(dataURL.length - 20)}";
282
  } else {
283
- filePreview = "${dataURL}";
284
  }
285
  updateFetchCode(apiUrl, version, scale, weight, filePreview);
286
  };
@@ -393,7 +340,42 @@ fetch('${apiUrl}', {
393
  </script>
394
  </body>
395
  </html>
396
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
397
 
398
  if __name__ == '__main__':
399
- app.run(host='0.0.0.0', port=7860, debug=True)
 
 
1
  import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import cv2
3
  import torch
4
+ from flask import Flask, request, jsonify, send_file, render_template_string
5
  from basicsr.archs.srvgg_arch import SRVGGNetCompact
6
  from gfpgan.utils import GFPGANer
7
  from realesrgan.utils import RealESRGANer
 
8
  import tempfile
9
+ import uuid
 
 
 
 
 
10
 
11
  app = Flask(__name__)
12
 
13
+ # Initialize models
14
  model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
15
  model_path = 'realesr-general-x4v3.pth'
16
  half = True if torch.cuda.is_available() else False
17
  upsampler = RealESRGANer(scale=4, model_path=model_path, model=model, tile=0, tile_pad=10, pre_pad=0, half=half)
18
 
19
+ # Ensure output directory exists
20
  os.makedirs('output', exist_ok=True)
21
 
22
+ # Download weights if not exists
23
+ def download_weights():
24
+ weights = {
25
+ 'realesr-general-x4v3.pth': 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth',
26
+ 'GFPGANv1.2.pth': 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.2.pth',
27
+ 'GFPGANv1.3.pth': 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth',
28
+ 'GFPGANv1.4.pth': 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth',
29
+ 'RestoreFormer.pth': 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth',
30
+ 'CodeFormer.pth': 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/CodeFormer.pth'
31
+ }
 
 
32
 
33
+ for weight_file, url in weights.items():
34
+ if not os.path.exists(weight_file):
35
+ os.system(f"wget {url} -O {weight_file}")
 
36
 
37
+ download_weights()
 
 
 
 
 
 
 
 
 
 
38
 
39
+ def process_image(img_path, version, scale, weight=0.5):
 
40
  try:
41
+ extension = os.path.splitext(os.path.basename(str(img_path)))[1]
42
+ img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
  if len(img.shape) == 3 and img.shape[2] == 4:
45
  img_mode = 'RGBA'
 
53
  if h < 300:
54
  img = cv2.resize(img, (w * 2, h * 2), interpolation=cv2.INTER_LANCZOS4)
55
 
 
56
  if version == 'v1.2':
57
  face_enhancer = GFPGANer(
58
  model_path='GFPGANv1.2.pth', upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=upsampler)
 
59
  elif version == 'v1.3':
60
  face_enhancer = GFPGANer(
61
  model_path='GFPGANv1.3.pth', upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=upsampler)
 
62
  elif version == 'v1.4':
63
  face_enhancer = GFPGANer(
64
  model_path='GFPGANv1.4.pth', upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=upsampler)
 
65
  elif version == 'RestoreFormer':
66
  face_enhancer = GFPGANer(
67
  model_path='RestoreFormer.pth', upscale=2, arch='RestoreFormer', channel_multiplier=2, bg_upsampler=upsampler)
 
68
  elif version == 'CodeFormer':
69
+ face_enhancer = GFPGANer(
70
+ model_path='CodeFormer.pth', upscale=2, arch='CodeFormer', channel_multiplier=2, bg_upsampler=upsampler)
71
  elif version == 'RealESR-General-x4v3':
72
  face_enhancer = GFPGANer(
73
+ model_path='realesr-general-x4v3.pth', upscale=2, arch='realesr-general', channel_multiplier=2, bg_upsampler=upsampler)
74
+
75
+ try:
76
+ if version == 'CodeFormer':
77
+ _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True, weight=weight)
78
+ else:
79
+ _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
80
+ except RuntimeError as error:
81
+ print('Error', error)
82
+ raise Exception(f"Enhancement error: {str(error)}")
83
 
84
+ try:
85
+ if scale != 2:
86
+ interpolation = cv2.INTER_AREA if scale < 2 else cv2.INTER_LANCZOS4
87
+ h, w = img.shape[0:2]
88
+ output = cv2.resize(output, (int(w * scale / 2), int(h * scale / 2)), interpolation=interpolation)
89
+ except Exception as error:
90
+ print('wrong scale input.', error)
91
+
92
+ # Save to temporary file
93
+ output_filename = f"output_{uuid.uuid4().hex}.jpg"
94
+ output_path = os.path.join('output', output_filename)
95
 
 
 
96
  if img_mode == 'RGBA':
97
+ cv2.imwrite(output_path, output, [int(cv2.IMWRITE_PNG_COMPRESSION), 9])
 
 
98
  else:
99
+ cv2.imwrite(output_path, output, [int(cv2.IMWRITE_JPEG_QUALITY), 95])
100
+
101
+ return output_path
102
+ except Exception as error:
103
+ print('Global exception', error)
104
+ raise Exception(f"Processing error: {str(error)}")
 
 
 
105
 
106
  @app.route('/')
107
  def index():
108
+ return render_template_string('''
109
  <!DOCTYPE html>
110
  <html>
111
  <head>
 
225
  reader.onload = function(e) {
226
  const dataURL = e.target.result;
227
  if (dataURL.length > 40) {
228
+ filePreview = `"${dataURL.substring(0, 20)}...${dataURL.substring(dataURL.length - 20)}"`;
229
  } else {
230
+ filePreview = `"${dataURL}"`;
231
  }
232
  updateFetchCode(apiUrl, version, scale, weight, filePreview);
233
  };
 
340
  </script>
341
  </body>
342
  </html>
343
+ ''')
344
+
345
+ @app.route('/api/restore', methods=['POST'])
346
+ def api_restore():
347
+ if 'file' not in request.files:
348
+ return jsonify({'error': 'No file uploaded'}), 400
349
+
350
+ file = request.files['file']
351
+ version = request.form.get('version', 'v1.4')
352
+ scale = float(request.form.get('scale', 2))
353
+ weight = float(request.form.get('weight', 0.5)) if version == 'CodeFormer' else None
354
+
355
+ if file.filename == '':
356
+ return jsonify({'error': 'No selected file'}), 400
357
+
358
+ try:
359
+ # Save uploaded file to temp location
360
+ temp_dir = tempfile.mkdtemp()
361
+ input_path = os.path.join(temp_dir, file.filename)
362
+ file.save(input_path)
363
+
364
+ # Process image
365
+ output_path = process_image(input_path, version, scale, weight)
366
+
367
+ # Return the processed image
368
+ return send_file(output_path, mimetype='image/jpeg')
369
+
370
+ except Exception as e:
371
+ return jsonify({'error': str(e)}), 500
372
+
373
+ finally:
374
+ # Clean up temp files
375
+ if 'input_path' in locals() and os.path.exists(input_path):
376
+ os.remove(input_path)
377
+ if 'temp_dir' in locals() and os.path.exists(temp_dir):
378
+ os.rmdir(temp_dir)
379
 
380
  if __name__ == '__main__':
381
+ app.run(host='0.0.0.0', port=5000, debug=True)