Nikhil Mudhalwadkar commited on
Commit
0fb096b
·
1 Parent(s): 219bed6
.gitattributes CHANGED
@@ -9,13 +9,9 @@
9
  *.lfs.* filter=lfs diff=lfs merge=lfs -text
10
  *.model filter=lfs diff=lfs merge=lfs -text
11
  *.msgpack filter=lfs diff=lfs merge=lfs -text
12
- *.npy filter=lfs diff=lfs merge=lfs -text
13
- *.npz filter=lfs diff=lfs merge=lfs -text
14
  *.onnx filter=lfs diff=lfs merge=lfs -text
15
  *.ot filter=lfs diff=lfs merge=lfs -text
16
  *.parquet filter=lfs diff=lfs merge=lfs -text
17
- *.pickle filter=lfs diff=lfs merge=lfs -text
18
- *.pkl filter=lfs diff=lfs merge=lfs -text
19
  *.pb filter=lfs diff=lfs merge=lfs -text
20
  *.pt filter=lfs diff=lfs merge=lfs -text
21
  *.pth filter=lfs diff=lfs merge=lfs -text
@@ -29,3 +25,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
29
  *.zip filter=lfs diff=lfs merge=lfs -text
30
  *.zstandard filter=lfs diff=lfs merge=lfs -text
31
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
9
  *.lfs.* filter=lfs diff=lfs merge=lfs -text
10
  *.model filter=lfs diff=lfs merge=lfs -text
11
  *.msgpack filter=lfs diff=lfs merge=lfs -text
 
 
12
  *.onnx filter=lfs diff=lfs merge=lfs -text
13
  *.ot filter=lfs diff=lfs merge=lfs -text
14
  *.parquet filter=lfs diff=lfs merge=lfs -text
 
 
15
  *.pb filter=lfs diff=lfs merge=lfs -text
16
  *.pt filter=lfs diff=lfs merge=lfs -text
17
  *.pth filter=lfs diff=lfs merge=lfs -text
 
25
  *.zip filter=lfs diff=lfs merge=lfs -text
26
  *.zstandard filter=lfs diff=lfs merge=lfs -text
27
  *tfevents* filter=lfs diff=lfs merge=lfs -text
28
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
29
+ model/pix2pix_lightning_model/version_0/checkpoints/epoch=9-step=17780.ckpt filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,13 +1,13 @@
1
  ---
2
  title: Sketch2ColourDemo
3
- emoji: 🏢
4
- colorFrom: yellow
5
- colorTo: gray
6
  sdk: gradio
7
- sdk_version: 3.1.1
8
  app_file: app.py
9
  pinned: false
10
- license: afl-3.0
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
  title: Sketch2ColourDemo
3
+ emoji: 📈
4
+ colorFrom: red
5
+ colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 3.0.24
8
  app_file: app.py
9
  pinned: false
10
+ license: eupl-1.1
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union, List
2
+
3
+ import gradio as gr
4
+ import matplotlib
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from pytorch_lightning.utilities.types import EPOCH_OUTPUT
9
+
10
+ matplotlib.use('Agg')
11
+ import numpy as np
12
+ from PIL import Image
13
+ import albumentations as A
14
+ import albumentations.pytorch as al_pytorch
15
+ import torchvision
16
+ from pl_bolts.models.gans import Pix2Pix
17
+
18
+ """ Class """
19
+
20
+
21
+ class OverpoweredPix2Pix(Pix2Pix):
22
+
23
+ def validation_step(self, batch, batch_idx):
24
+ """ Validation step """
25
+ real, condition = batch
26
+ with torch.no_grad():
27
+ loss = self._disc_step(real, condition)
28
+ self.log("val_PatchGAN_loss", loss)
29
+
30
+ loss = self._gen_step(real, condition)
31
+ self.log("val_generator_loss", loss)
32
+
33
+ return {
34
+ 'sketch': real,
35
+ 'colour': condition
36
+ }
37
+
38
+ def validation_epoch_end(self, outputs: Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]]) -> None:
39
+ sketch = outputs[0]['sketch']
40
+ colour = outputs[0]['colour']
41
+ with torch.no_grad():
42
+ gen_coloured = self.gen(sketch)
43
+ grid_image = torchvision.utils.make_grid(
44
+ [
45
+ sketch[0], colour[0], gen_coloured[0],
46
+ ],
47
+ normalize=True
48
+ )
49
+ self.logger.experiment.add_image(f'Image Grid {str(self.current_epoch)}', grid_image, self.current_epoch)
50
+
51
+
52
+ """ Load the model """
53
+ model_checkpoint_path = "model/lightning_bolts_model/epoch=99-step=89000.ckpt"
54
+ # model_checkpoint_path = "model/pix2pix_lightning_model/version_0/checkpoints/epoch=199-step=355600.ckpt"
55
+ # model_checkpoint_path = "model/pix2pix_lightning_model/gen.pth"
56
+
57
+ model = OverpoweredPix2Pix.load_from_checkpoint(
58
+ model_checkpoint_path
59
+ )
60
+
61
+ model_chk = torch.load(
62
+ model_checkpoint_path, map_location=torch.device('cpu')
63
+ )
64
+ # model = gen().load_state_dict(model_chk)
65
+
66
+ model.eval()
67
+
68
+
69
+ def greet(name):
70
+ return "Hello " + name + "!!"
71
+
72
+
73
+ def predict(img: Image):
74
+ # transform img
75
+ image = np.asarray(img)
76
+ # image = image[:, image.shape[1] // 2:, :]
77
+ # use on inference
78
+ inference_transform = A.Compose([
79
+ A.Resize(width=256, height=256),
80
+ A.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5], max_pixel_value=255.0),
81
+ al_pytorch.ToTensorV2(),
82
+ ])
83
+ # inverse_transform = A.Compose([
84
+ # A.Normalize(
85
+ # mean=[0.485, 0.456, 0.406],
86
+ # std=[0.229, 0.224, 0.225]
87
+ # ),
88
+ # ])
89
+ inference_img = inference_transform(
90
+ image=image
91
+ )['image'].unsqueeze(0)
92
+ with torch.no_grad():
93
+ result = model.gen(inference_img)
94
+ # torchvision.utils.save_image(inference_img, "inference_image.png", normalize=True)
95
+ torchvision.utils.save_image(result, "inference_image.png", normalize=True)
96
+
97
+ """
98
+ result_grid = torchvision.utils.make_grid(
99
+ [result[0]],
100
+ normalize=True
101
+ )
102
+ # plt.imsave("coloured_grid.png", (result_grid.permute(1,2,0).detach().numpy()*255).astype(int))
103
+ torchvision.utils.save_image(
104
+ result_grid, "coloured_image.png", normalize=True
105
+ )
106
+ """
107
+ return "inference_image.png" # 'coloured_image.png',
108
+
109
+
110
+ iface = gr.Interface(
111
+ fn=predict,
112
+ inputs=gr.inputs.Image(type="pil"),
113
+ #inputs="sketchpad",
114
+ examples=[
115
+ "examples/thesis_test.png",
116
+ "examples/thesis_test2.png",
117
+ "examples/thesis1.png",
118
+ "examples/thesis4.png",
119
+ "examples/thesis5.png",
120
+ "examples/thesis6.png",
121
+ # "examples/1000000.png"
122
+ ],
123
+ outputs=gr.outputs.Image(type="pil",),
124
+ #outputs=[
125
+ # "image",
126
+ # # "image"
127
+ #],
128
+ title="Colour your sketches!",
129
+ description=" Upload a sketch and the conditional gan will colour it for you!",
130
+ article="WIP repo lives here - https://github.com/nmud19/thesisGAN "
131
+ )
132
+ iface.launch()
examples/__init__.py ADDED
File without changes
examples/thesis1.png ADDED
examples/thesis4.png ADDED
examples/thesis5.png ADDED
examples/thesis6.png ADDED
examples/thesis_test.png ADDED
examples/thesis_test2.png ADDED
model/lightning_bolts_model/epoch=99-step=89000.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ba59fcb2905e98d61f12bb858069893e27a4ff83042f2f6a78cbb60ae28fd947
3
+ size 686275171