imseldrith commited on
Commit
0d79ea9
·
verified ·
1 Parent(s): 785bb48

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +107 -0
app.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import cv2
3
+ import numpy as np
4
+ import os
5
+ from tqdm import tqdm
6
+ import torch
7
+ from basicsr.archs.ddcolor_arch import DDColor
8
+ import torch.nn.functional as F
9
+
10
+ class ImageColorizationPipeline(object):
11
+
12
+ def __init__(self, model_path, input_size=256, model_size='large'):
13
+
14
+ self.input_size = input_size
15
+ if torch.cuda.is_available():
16
+ self.device = torch.device('cuda')
17
+ else:
18
+ self.device = torch.device('cpu')
19
+
20
+ if model_size == 'tiny':
21
+ self.encoder_name = 'convnext-t'
22
+ else:
23
+ self.encoder_name = 'convnext-l'
24
+
25
+ self.decoder_type = "MultiScaleColorDecoder"
26
+
27
+ if self.decoder_type == 'MultiScaleColorDecoder':
28
+ self.model = DDColor(
29
+ encoder_name=self.encoder_name,
30
+ decoder_name='MultiScaleColorDecoder',
31
+ input_size=[self.input_size, self.input_size],
32
+ num_output_channels=2,
33
+ last_norm='Spectral',
34
+ do_normalize=False,
35
+ num_queries=100,
36
+ num_scales=3,
37
+ dec_layers=9,
38
+ ).to(self.device)
39
+ else:
40
+ self.model = DDColor(
41
+ encoder_name=self.encoder_name,
42
+ decoder_name='SingleColorDecoder',
43
+ input_size=[self.input_size, self.input_size],
44
+ num_output_channels=2,
45
+ last_norm='Spectral',
46
+ do_normalize=False,
47
+ num_queries=256,
48
+ ).to(self.device)
49
+
50
+ self.model.load_state_dict(
51
+ torch.load(model_path, map_location=torch.device('cpu'))['params'],
52
+ strict=False)
53
+ self.model.eval()
54
+
55
+ @torch.no_grad()
56
+ def process(self, img):
57
+ self.height, self.width = img.shape[:2]
58
+ # print(self.width, self.height)
59
+ # if self.width * self.height < 100000:
60
+ # self.input_size = 256
61
+
62
+ img = (img / 255.0).astype(np.float32)
63
+ orig_l = cv2.cvtColor(img, cv2.COLOR_BGR2Lab)[:, :, :1] # (h, w, 1)
64
+
65
+ # resize rgb image -> lab -> get grey -> rgb
66
+ img = cv2.resize(img, (self.input_size, self.input_size))
67
+ img_l = cv2.cvtColor(img, cv2.COLOR_BGR2Lab)[:, :, :1]
68
+ img_gray_lab = np.concatenate((img_l, np.zeros_like(img_l), np.zeros_like(img_l)), axis=-1)
69
+ img_gray_rgb = cv2.cvtColor(img_gray_lab, cv2.COLOR_LAB2RGB)
70
+
71
+ tensor_gray_rgb = torch.from_numpy(img_gray_rgb.transpose((2, 0, 1))).float().unsqueeze(0).to(self.device)
72
+ output_ab = self.model(tensor_gray_rgb).cpu() # (1, 2, self.height, self.width)
73
+
74
+ # resize ab -> concat original l -> rgb
75
+ output_ab_resize = F.interpolate(output_ab, size=(self.height, self.width))[0].float().numpy().transpose(1, 2, 0)
76
+ output_lab = np.concatenate((orig_l, output_ab_resize), axis=-1)
77
+ output_bgr = cv2.cvtColor(output_lab, cv2.COLOR_LAB2BGR)
78
+
79
+ output_img = (output_bgr * 255.0).round().astype(np.uint8)
80
+
81
+ return output_img
82
+
83
+ colorizer = ImageColorizationPipeline(model_path='/content/DDColor/models/pytorch_model.pt', input_size=512)
84
+
85
+ from PIL import Image
86
+ import gradio as gr
87
+ import subprocess
88
+ import shutil, os
89
+ from gradio_imageslider import ImageSlider
90
+
91
+ def generate(image):
92
+ image_in = cv2.imread(image)
93
+ image_out = colorizer.process(image_in)
94
+ cv2.imwrite('/content/DDColor/out.jpg', image_out)
95
+ image_in_pil = Image.fromarray(cv2.cvtColor(image_in, cv2.COLOR_BGR2RGB))
96
+ image_out_pil = Image.fromarray(cv2.cvtColor(image_out, cv2.COLOR_BGR2RGB))
97
+ return (image_in_pil, image_out_pil)
98
+
99
+ with gr.Blocks() as demo:
100
+ with gr.Row():
101
+ with gr.Column():
102
+ image = gr.Image(type='filepath')
103
+ button = gr.Button()
104
+ output_image = ImageSlider(show_label=False, type="filepath", interactive=False)
105
+ button.click(fn=generate, inputs=[image], outputs=[output_image])
106
+
107
+ demo.queue().launch(inline=False, debug=True)