Spaces:
Build error
Build error
Upload folder using huggingface_hub (#1)
Browse files- Upload folder using huggingface_hub (cb2df8ba8d0f21c17856586ba796d689100b836c)
This view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +2 -0
- LICENSE +201 -0
- MODEL_ZOO.md +13 -0
- VERSION +1 -0
- assets/anime_landscapes.png +3 -0
- assets/network_arch.jpg +0 -0
- assets/teaser.png +3 -0
- assets/test_images/Abandoned Boy Holding a Stuffed Toy Animal. London 1945.jpg +0 -0
- assets/test_images/Acrobats Balance On Top Of The Empire State Building, 1934.jpg +0 -0
- assets/test_images/Ansel Adams _ Moore Photography.jpeg +0 -0
- assets/test_images/Audrey Hepburn.jpg +0 -0
- assets/test_images/Broadway at the United States Hotel Saratoga Springs, N.Y. ca 1900-1915.jpg +0 -0
- assets/test_images/Buffalo Bank Buffalo, New York, circa 1908. Erie County Savings Bank, Niagara Street.jpg +0 -0
- assets/test_images/Detroit circa 1915.jpg +0 -0
- Crafting a Future.jpeg +0 -0
- assets/test_images/February 1936. Nipomo, Calif. Destitute pea pickers living in tent in migrant camp. Mother of seven children. Age 32.jpg +0 -0
- assets/test_images/Helen Keller meeting Charlie Chaplin in 1919.jpg +0 -0
- assets/test_images/Louis Armstrong practicing in his dressing room, ca 1946.jpg +0 -0
- assets/test_images/New York Riverfront December 15, 1931.jpg +0 -0
- assets/test_images/colorized-historical-photos-vintage-photography-39.jpg +0 -0
- basicsr/__init__.py +12 -0
- basicsr/__pycache__/__init__.cpython-310.pyc +0 -0
- basicsr/__pycache__/train.cpython-310.pyc +0 -0
- basicsr/archs/__init__.py +25 -0
- basicsr/archs/__pycache__/__init__.cpython-310.pyc +0 -0
- basicsr/archs/__pycache__/ddcolor_arch.cpython-310.pyc +0 -0
- basicsr/archs/__pycache__/discriminator_arch.cpython-310.pyc +0 -0
- basicsr/archs/__pycache__/vgg_arch.cpython-310.pyc +0 -0
- basicsr/archs/ddcolor_arch.py +385 -0
- basicsr/archs/ddcolor_arch_utils/__int__.py +0 -0
- basicsr/archs/ddcolor_arch_utils/__pycache__/convnext.cpython-310.pyc +0 -0
- basicsr/archs/ddcolor_arch_utils/__pycache__/convnext.cpython-38.pyc +0 -0
- basicsr/archs/ddcolor_arch_utils/__pycache__/position_encoding.cpython-310.pyc +0 -0
- basicsr/archs/ddcolor_arch_utils/__pycache__/position_encoding.cpython-38.pyc +0 -0
- basicsr/archs/ddcolor_arch_utils/__pycache__/transformer.cpython-310.pyc +0 -0
- basicsr/archs/ddcolor_arch_utils/__pycache__/transformer.cpython-38.pyc +0 -0
- basicsr/archs/ddcolor_arch_utils/__pycache__/transformer_utils.cpython-310.pyc +0 -0
- basicsr/archs/ddcolor_arch_utils/__pycache__/transformer_utils.cpython-38.pyc +0 -0
- basicsr/archs/ddcolor_arch_utils/__pycache__/unet.cpython-310.pyc +0 -0
- basicsr/archs/ddcolor_arch_utils/__pycache__/unet.cpython-38.pyc +0 -0
- basicsr/archs/ddcolor_arch_utils/convnext.py +155 -0
- basicsr/archs/ddcolor_arch_utils/position_encoding.py +52 -0
- basicsr/archs/ddcolor_arch_utils/transformer.py +368 -0
- basicsr/archs/ddcolor_arch_utils/transformer_utils.py +192 -0
- basicsr/archs/ddcolor_arch_utils/unet.py +208 -0
- basicsr/archs/ddcolor_arch_utils/util.py +63 -0
- basicsr/archs/discriminator_arch.py +28 -0
- basicsr/archs/vgg_arch.py +165 -0
- basicsr/data/__init__.py +101 -0
- basicsr/data/__pycache__/__init__.cpython-310.pyc +0 -0
.gitattributes
CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
assets/anime_landscapes.png filter=lfs diff=lfs merge=lfs -text
|
37 |
+
assets/teaser.png filter=lfs diff=lfs merge=lfs -text
|
LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
MODEL_ZOO.md
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## DDColor Model Zoo
|
2 |
+
|
3 |
+
| Model | Description | Note |
|
4 |
+
| ---------------------- | :------------------ | :-----|
|
5 |
+
| [ddcolor_paper.pth](https://huggingface.co/piddnad/DDColor-models/resolve/main/ddcolor_paper.pth) | DDColor-L trained on ImageNet | paper model, use it only if you want to reproduce some of the images in the paper.
|
6 |
+
| [ddcolor_modelscope.pth](https://huggingface.co/piddnad/DDColor-models/resolve/main/ddcolor_modelscope.pth) (***default***) | DDColor-L trained on ImageNet | We trained this model using the same data cleaning scheme as [BigColor](https://github.com/KIMGEONUNG/BigColor/issues/2#issuecomment-1196287574), so it can get the best qualitative results with little degrading FID performance. Use this model by default if you want to test images outside the ImageNet. It can also be easily downloaded through ModelScope [in this way](README.md#inference-with-modelscope-library).
|
7 |
+
| [ddcolor_artistic.pth](https://huggingface.co/piddnad/DDColor-models/resolve/main/ddcolor_artistic.pth) | DDColor-L trained on ImageNet + private data | We trained this model with an extended dataset containing many high-quality artistic images. Also, we didn't use colorfulness loss during training, so there may be fewer unreasonable color artifacts. Use this model if you want to try different colorization results.
|
8 |
+
| [ddcolor_paper_tiny.pth](https://huggingface.co/piddnad/DDColor-models/resolve/main/ddcolor_paper_tiny.pth) | DDColor-T trained on ImageNet | The most lightweight version of ddcolor model, using the same training scheme as ddcolor_paper.
|
9 |
+
|
10 |
+
## Discussions
|
11 |
+
|
12 |
+
* About Colorfulness Loss (CL): CL can encourage more "colorful" results and help improve CF scores, however, it sometimes leads to the generation of unpleasant color blocks (eg. red color artifacts). If something goes wrong, I personally recommend trying to remove it during training.
|
13 |
+
|
VERSION
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
1.3.4.6
|
assets/anime_landscapes.png
ADDED
![]() |
Git LFS Details
|
assets/network_arch.jpg
ADDED
![]() |
assets/teaser.png
ADDED
![]() |
Git LFS Details
|
assets/test_images/Abandoned Boy Holding a Stuffed Toy Animal. London 1945.jpg
ADDED
![]() |
assets/test_images/Acrobats Balance On Top Of The Empire State Building, 1934.jpg
ADDED
![]() |
assets/test_images/Ansel Adams _ Moore Photography.jpeg
ADDED
![]() |
assets/test_images/Audrey Hepburn.jpg
ADDED
![]() |
assets/test_images/Broadway at the United States Hotel Saratoga Springs, N.Y. ca 1900-1915.jpg
ADDED
![]() |
assets/test_images/Buffalo Bank Buffalo, New York, circa 1908. Erie County Savings Bank, Niagara Street.jpg
ADDED
![]() |
assets/test_images/Detroit circa 1915.jpg
ADDED
![]() |
Crafting a Future.jpeg
RENAMED
File without changes
|
assets/test_images/February 1936. Nipomo, Calif. Destitute pea pickers living in tent in migrant camp. Mother of seven children. Age 32.jpg
ADDED
![]() |
assets/test_images/Helen Keller meeting Charlie Chaplin in 1919.jpg
ADDED
![]() |
assets/test_images/Louis Armstrong practicing in his dressing room, ca 1946.jpg
ADDED
![]() |
assets/test_images/New York Riverfront December 15, 1931.jpg
ADDED
![]() |
assets/test_images/colorized-historical-photos-vintage-photography-39.jpg
ADDED
![]() |
basicsr/__init__.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# https://github.com/xinntao/BasicSR
|
2 |
+
# flake8: noqa
|
3 |
+
from .archs import *
|
4 |
+
from .data import *
|
5 |
+
from .losses import *
|
6 |
+
from .metrics import *
|
7 |
+
from .models import *
|
8 |
+
# from .ops import *
|
9 |
+
# from .test import *
|
10 |
+
from .train import *
|
11 |
+
from .utils import *
|
12 |
+
# from .version import __gitsha__, __version__
|
basicsr/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (256 Bytes). View file
|
|
basicsr/__pycache__/train.cpython-310.pyc
ADDED
Binary file (6.55 kB). View file
|
|
basicsr/archs/__init__.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
from copy import deepcopy
|
3 |
+
from os import path as osp
|
4 |
+
|
5 |
+
from basicsr.utils import get_root_logger, scandir
|
6 |
+
from basicsr.utils.registry import ARCH_REGISTRY
|
7 |
+
|
8 |
+
__all__ = ['build_network']
|
9 |
+
|
10 |
+
# automatically scan and import arch modules for registry
|
11 |
+
# scan all the files under the 'archs' folder and collect files ending with
|
12 |
+
# '_arch.py'
|
13 |
+
arch_folder = osp.dirname(osp.abspath(__file__))
|
14 |
+
arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')]
|
15 |
+
# import all the arch modules
|
16 |
+
_arch_modules = [importlib.import_module(f'basicsr.archs.{file_name}') for file_name in arch_filenames]
|
17 |
+
|
18 |
+
|
19 |
+
def build_network(opt):
|
20 |
+
opt = deepcopy(opt)
|
21 |
+
network_type = opt.pop('type')
|
22 |
+
net = ARCH_REGISTRY.get(network_type)(**opt)
|
23 |
+
logger = get_root_logger()
|
24 |
+
logger.info(f'Network [{net.__class__.__name__}] is created.')
|
25 |
+
return net
|
basicsr/archs/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (1.13 kB). View file
|
|
basicsr/archs/__pycache__/ddcolor_arch.cpython-310.pyc
ADDED
Binary file (10.5 kB). View file
|
|
basicsr/archs/__pycache__/discriminator_arch.cpython-310.pyc
ADDED
Binary file (1.36 kB). View file
|
|
basicsr/archs/__pycache__/vgg_arch.cpython-310.pyc
ADDED
Binary file (4.87 kB). View file
|
|
basicsr/archs/ddcolor_arch.py
ADDED
@@ -0,0 +1,385 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from basicsr.archs.ddcolor_arch_utils.unet import Hook, CustomPixelShuffle_ICNR, UnetBlockWide, NormType, custom_conv_layer
|
5 |
+
from basicsr.archs.ddcolor_arch_utils.convnext import ConvNeXt
|
6 |
+
from basicsr.archs.ddcolor_arch_utils.transformer_utils import SelfAttentionLayer, CrossAttentionLayer, FFNLayer, MLP
|
7 |
+
from basicsr.archs.ddcolor_arch_utils.position_encoding import PositionEmbeddingSine
|
8 |
+
from basicsr.archs.ddcolor_arch_utils.transformer import Transformer
|
9 |
+
from basicsr.utils.registry import ARCH_REGISTRY
|
10 |
+
|
11 |
+
|
12 |
+
@ARCH_REGISTRY.register()
|
13 |
+
class DDColor(nn.Module):
|
14 |
+
|
15 |
+
def __init__(self,
|
16 |
+
encoder_name='convnext-l',
|
17 |
+
decoder_name='MultiScaleColorDecoder',
|
18 |
+
num_input_channels=3,
|
19 |
+
input_size=(256, 256),
|
20 |
+
nf=512,
|
21 |
+
num_output_channels=3,
|
22 |
+
last_norm='Weight',
|
23 |
+
do_normalize=False,
|
24 |
+
num_queries=256,
|
25 |
+
num_scales=3,
|
26 |
+
dec_layers=9,
|
27 |
+
encoder_from_pretrain=False):
|
28 |
+
super().__init__()
|
29 |
+
|
30 |
+
self.encoder = Encoder(encoder_name, ['norm0', 'norm1', 'norm2', 'norm3'], from_pretrain=encoder_from_pretrain)
|
31 |
+
self.encoder.eval()
|
32 |
+
test_input = torch.randn(1, num_input_channels, *input_size)
|
33 |
+
self.encoder(test_input)
|
34 |
+
|
35 |
+
self.decoder = Decoder(
|
36 |
+
self.encoder.hooks,
|
37 |
+
nf=nf,
|
38 |
+
last_norm=last_norm,
|
39 |
+
num_queries=num_queries,
|
40 |
+
num_scales=num_scales,
|
41 |
+
dec_layers=dec_layers,
|
42 |
+
decoder_name=decoder_name
|
43 |
+
)
|
44 |
+
self.refine_net = nn.Sequential(custom_conv_layer(num_queries + 3, num_output_channels, ks=1, use_activ=False, norm_type=NormType.Spectral))
|
45 |
+
|
46 |
+
self.do_normalize = do_normalize
|
47 |
+
self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
|
48 |
+
self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
|
49 |
+
|
50 |
+
def normalize(self, img):
|
51 |
+
return (img - self.mean) / self.std
|
52 |
+
|
53 |
+
def denormalize(self, img):
|
54 |
+
return img * self.std + self.mean
|
55 |
+
|
56 |
+
def forward(self, x):
|
57 |
+
if x.shape[1] == 3:
|
58 |
+
x = self.normalize(x)
|
59 |
+
|
60 |
+
self.encoder(x)
|
61 |
+
out_feat = self.decoder()
|
62 |
+
coarse_input = torch.cat([out_feat, x], dim=1)
|
63 |
+
out = self.refine_net(coarse_input)
|
64 |
+
|
65 |
+
if self.do_normalize:
|
66 |
+
out = self.denormalize(out)
|
67 |
+
return out
|
68 |
+
|
69 |
+
|
70 |
+
class Decoder(nn.Module):
|
71 |
+
|
72 |
+
def __init__(self,
|
73 |
+
hooks,
|
74 |
+
nf=512,
|
75 |
+
blur=True,
|
76 |
+
last_norm='Weight',
|
77 |
+
num_queries=256,
|
78 |
+
num_scales=3,
|
79 |
+
dec_layers=9,
|
80 |
+
decoder_name='MultiScaleColorDecoder'):
|
81 |
+
super().__init__()
|
82 |
+
self.hooks = hooks
|
83 |
+
self.nf = nf
|
84 |
+
self.blur = blur
|
85 |
+
self.last_norm = getattr(NormType, last_norm)
|
86 |
+
self.decoder_name = decoder_name
|
87 |
+
|
88 |
+
self.layers = self.make_layers()
|
89 |
+
embed_dim = nf // 2
|
90 |
+
|
91 |
+
self.last_shuf = CustomPixelShuffle_ICNR(embed_dim, embed_dim, blur=self.blur, norm_type=self.last_norm, scale=4)
|
92 |
+
|
93 |
+
if self.decoder_name == 'MultiScaleColorDecoder':
|
94 |
+
self.color_decoder = MultiScaleColorDecoder(
|
95 |
+
in_channels=[512, 512, 256],
|
96 |
+
num_queries=num_queries,
|
97 |
+
num_scales=num_scales,
|
98 |
+
dec_layers=dec_layers,
|
99 |
+
)
|
100 |
+
else:
|
101 |
+
self.color_decoder = SingleColorDecoder(
|
102 |
+
in_channels=hooks[-1].feature.shape[1],
|
103 |
+
num_queries=num_queries,
|
104 |
+
)
|
105 |
+
|
106 |
+
|
107 |
+
def forward(self):
|
108 |
+
encode_feat = self.hooks[-1].feature
|
109 |
+
out0 = self.layers[0](encode_feat)
|
110 |
+
out1 = self.layers[1](out0)
|
111 |
+
out2 = self.layers[2](out1)
|
112 |
+
out3 = self.last_shuf(out2)
|
113 |
+
|
114 |
+
if self.decoder_name == 'MultiScaleColorDecoder':
|
115 |
+
out = self.color_decoder([out0, out1, out2], out3)
|
116 |
+
else:
|
117 |
+
out = self.color_decoder(out3, encode_feat)
|
118 |
+
|
119 |
+
return out
|
120 |
+
|
121 |
+
def make_layers(self):
|
122 |
+
decoder_layers = []
|
123 |
+
|
124 |
+
e_in_c = self.hooks[-1].feature.shape[1]
|
125 |
+
in_c = e_in_c
|
126 |
+
|
127 |
+
out_c = self.nf
|
128 |
+
setup_hooks = self.hooks[-2::-1]
|
129 |
+
for layer_index, hook in enumerate(setup_hooks):
|
130 |
+
feature_c = hook.feature.shape[1]
|
131 |
+
if layer_index == len(setup_hooks) - 1:
|
132 |
+
out_c = out_c // 2
|
133 |
+
decoder_layers.append(
|
134 |
+
UnetBlockWide(
|
135 |
+
in_c, feature_c, out_c, hook, blur=self.blur, self_attention=False, norm_type=NormType.Spectral))
|
136 |
+
in_c = out_c
|
137 |
+
return nn.Sequential(*decoder_layers)
|
138 |
+
|
139 |
+
|
140 |
+
class Encoder(nn.Module):
|
141 |
+
|
142 |
+
def __init__(self, encoder_name, hook_names, from_pretrain, **kwargs):
|
143 |
+
super().__init__()
|
144 |
+
|
145 |
+
if encoder_name == 'convnext-t' or encoder_name == 'convnext':
|
146 |
+
self.arch = ConvNeXt()
|
147 |
+
elif encoder_name == 'convnext-s':
|
148 |
+
self.arch = ConvNeXt(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768])
|
149 |
+
elif encoder_name == 'convnext-b':
|
150 |
+
self.arch = ConvNeXt(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024])
|
151 |
+
elif encoder_name == 'convnext-l':
|
152 |
+
self.arch = ConvNeXt(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536])
|
153 |
+
else:
|
154 |
+
raise NotImplementedError
|
155 |
+
|
156 |
+
self.encoder_name = encoder_name
|
157 |
+
self.hook_names = hook_names
|
158 |
+
self.hooks = self.setup_hooks()
|
159 |
+
|
160 |
+
if from_pretrain:
|
161 |
+
self.load_pretrain_model()
|
162 |
+
|
163 |
+
def setup_hooks(self):
|
164 |
+
hooks = [Hook(self.arch._modules[name]) for name in self.hook_names]
|
165 |
+
return hooks
|
166 |
+
|
167 |
+
def forward(self, x):
|
168 |
+
return self.arch(x)
|
169 |
+
|
170 |
+
def load_pretrain_model(self):
|
171 |
+
if self.encoder_name == 'convnext-t' or self.encoder_name == 'convnext':
|
172 |
+
self.load('pretrain/convnext_tiny_22k_224.pth')
|
173 |
+
elif self.encoder_name == 'convnext-s':
|
174 |
+
self.load('pretrain/convnext_small_22k_224.pth')
|
175 |
+
elif self.encoder_name == 'convnext-b':
|
176 |
+
self.load('pretrain/convnext_base_22k_224.pth')
|
177 |
+
elif self.encoder_name == 'convnext-l':
|
178 |
+
self.load('pretrain/convnext_large_22k_224.pth')
|
179 |
+
else:
|
180 |
+
raise NotImplementedError
|
181 |
+
print('Loaded pretrained convnext model.')
|
182 |
+
|
183 |
+
def load(self, path):
|
184 |
+
from basicsr.utils import get_root_logger
|
185 |
+
logger = get_root_logger()
|
186 |
+
if not path:
|
187 |
+
logger.info("No checkpoint found. Initializing model from scratch")
|
188 |
+
return
|
189 |
+
logger.info("[Encoder] Loading from {} ...".format(path))
|
190 |
+
checkpoint = torch.load(path, map_location=torch.device("cpu"))
|
191 |
+
checkpoint_state_dict = checkpoint['model'] if 'model' in checkpoint.keys() else checkpoint
|
192 |
+
incompatible = self.arch.load_state_dict(checkpoint_state_dict, strict=False)
|
193 |
+
|
194 |
+
if incompatible.missing_keys:
|
195 |
+
msg = "Some model parameters or buffers are not found in the checkpoint:\n"
|
196 |
+
msg += str(incompatible.missing_keys)
|
197 |
+
logger.warning(msg)
|
198 |
+
if incompatible.unexpected_keys:
|
199 |
+
msg = "The checkpoint state_dict contains keys that are not used by the model:\n"
|
200 |
+
msg += str(incompatible.unexpected_keys)
|
201 |
+
logger.warning(msg)
|
202 |
+
|
203 |
+
|
204 |
+
class MultiScaleColorDecoder(nn.Module):
|
205 |
+
|
206 |
+
def __init__(
|
207 |
+
self,
|
208 |
+
in_channels,
|
209 |
+
hidden_dim=256,
|
210 |
+
num_queries=100,
|
211 |
+
nheads=8,
|
212 |
+
dim_feedforward=2048,
|
213 |
+
dec_layers=9,
|
214 |
+
pre_norm=False,
|
215 |
+
color_embed_dim=256,
|
216 |
+
enforce_input_project=True,
|
217 |
+
num_scales=3
|
218 |
+
):
|
219 |
+
super().__init__()
|
220 |
+
|
221 |
+
# positional encoding
|
222 |
+
N_steps = hidden_dim // 2
|
223 |
+
self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True)
|
224 |
+
|
225 |
+
# define Transformer decoder here
|
226 |
+
self.num_heads = nheads
|
227 |
+
self.num_layers = dec_layers
|
228 |
+
self.transformer_self_attention_layers = nn.ModuleList()
|
229 |
+
self.transformer_cross_attention_layers = nn.ModuleList()
|
230 |
+
self.transformer_ffn_layers = nn.ModuleList()
|
231 |
+
|
232 |
+
for _ in range(self.num_layers):
|
233 |
+
self.transformer_self_attention_layers.append(
|
234 |
+
SelfAttentionLayer(
|
235 |
+
d_model=hidden_dim,
|
236 |
+
nhead=nheads,
|
237 |
+
dropout=0.0,
|
238 |
+
normalize_before=pre_norm,
|
239 |
+
)
|
240 |
+
)
|
241 |
+
self.transformer_cross_attention_layers.append(
|
242 |
+
CrossAttentionLayer(
|
243 |
+
d_model=hidden_dim,
|
244 |
+
nhead=nheads,
|
245 |
+
dropout=0.0,
|
246 |
+
normalize_before=pre_norm,
|
247 |
+
)
|
248 |
+
)
|
249 |
+
self.transformer_ffn_layers.append(
|
250 |
+
FFNLayer(
|
251 |
+
d_model=hidden_dim,
|
252 |
+
dim_feedforward=dim_feedforward,
|
253 |
+
dropout=0.0,
|
254 |
+
normalize_before=pre_norm,
|
255 |
+
)
|
256 |
+
)
|
257 |
+
|
258 |
+
self.decoder_norm = nn.LayerNorm(hidden_dim)
|
259 |
+
|
260 |
+
self.num_queries = num_queries
|
261 |
+
# learnable color query features
|
262 |
+
self.query_feat = nn.Embedding(num_queries, hidden_dim)
|
263 |
+
# learnable color query p.e.
|
264 |
+
self.query_embed = nn.Embedding(num_queries, hidden_dim)
|
265 |
+
|
266 |
+
# level embedding
|
267 |
+
self.num_feature_levels = num_scales
|
268 |
+
self.level_embed = nn.Embedding(self.num_feature_levels, hidden_dim)
|
269 |
+
|
270 |
+
# input projections
|
271 |
+
self.input_proj = nn.ModuleList()
|
272 |
+
for i in range(self.num_feature_levels):
|
273 |
+
if in_channels[i] != hidden_dim or enforce_input_project:
|
274 |
+
self.input_proj.append(nn.Conv2d(in_channels[i], hidden_dim, kernel_size=1))
|
275 |
+
nn.init.kaiming_uniform_(self.input_proj[-1].weight, a=1)
|
276 |
+
if self.input_proj[-1].bias is not None:
|
277 |
+
nn.init.constant_(self.input_proj[-1].bias, 0)
|
278 |
+
else:
|
279 |
+
self.input_proj.append(nn.Sequential())
|
280 |
+
|
281 |
+
# output FFNs
|
282 |
+
self.color_embed = MLP(hidden_dim, hidden_dim, color_embed_dim, 3)
|
283 |
+
|
284 |
+
def forward(self, x, img_features):
|
285 |
+
# x is a list of multi-scale feature
|
286 |
+
assert len(x) == self.num_feature_levels
|
287 |
+
src = []
|
288 |
+
pos = []
|
289 |
+
|
290 |
+
for i in range(self.num_feature_levels):
|
291 |
+
pos.append(self.pe_layer(x[i], None).flatten(2))
|
292 |
+
src.append(self.input_proj[i](x[i]).flatten(2) + self.level_embed.weight[i][None, :, None])
|
293 |
+
|
294 |
+
# flatten NxCxHxW to HWxNxC
|
295 |
+
pos[-1] = pos[-1].permute(2, 0, 1)
|
296 |
+
src[-1] = src[-1].permute(2, 0, 1)
|
297 |
+
|
298 |
+
_, bs, _ = src[0].shape
|
299 |
+
|
300 |
+
# QxNxC
|
301 |
+
query_embed = self.query_embed.weight.unsqueeze(1).repeat(1, bs, 1)
|
302 |
+
output = self.query_feat.weight.unsqueeze(1).repeat(1, bs, 1)
|
303 |
+
|
304 |
+
for i in range(self.num_layers):
|
305 |
+
level_index = i % self.num_feature_levels
|
306 |
+
# attention: cross-attention first
|
307 |
+
output = self.transformer_cross_attention_layers[i](
|
308 |
+
output, src[level_index],
|
309 |
+
memory_mask=None,
|
310 |
+
memory_key_padding_mask=None,
|
311 |
+
pos=pos[level_index], query_pos=query_embed
|
312 |
+
)
|
313 |
+
output = self.transformer_self_attention_layers[i](
|
314 |
+
output, tgt_mask=None,
|
315 |
+
tgt_key_padding_mask=None,
|
316 |
+
query_pos=query_embed
|
317 |
+
)
|
318 |
+
# FFN
|
319 |
+
output = self.transformer_ffn_layers[i](
|
320 |
+
output
|
321 |
+
)
|
322 |
+
|
323 |
+
decoder_output = self.decoder_norm(output)
|
324 |
+
decoder_output = decoder_output.transpose(0, 1) # [N, bs, C] -> [bs, N, C]
|
325 |
+
color_embed = self.color_embed(decoder_output)
|
326 |
+
out = torch.einsum("bqc,bchw->bqhw", color_embed, img_features)
|
327 |
+
|
328 |
+
return out
|
329 |
+
|
330 |
+
|
331 |
+
class SingleColorDecoder(nn.Module):
|
332 |
+
|
333 |
+
def __init__(
|
334 |
+
self,
|
335 |
+
in_channels=768,
|
336 |
+
hidden_dim=256,
|
337 |
+
num_queries=256, # 100
|
338 |
+
nheads=8,
|
339 |
+
dropout=0.1,
|
340 |
+
dim_feedforward=2048,
|
341 |
+
enc_layers=0,
|
342 |
+
dec_layers=6,
|
343 |
+
pre_norm=False,
|
344 |
+
deep_supervision=True,
|
345 |
+
enforce_input_project=True,
|
346 |
+
):
|
347 |
+
|
348 |
+
super().__init__()
|
349 |
+
|
350 |
+
N_steps = hidden_dim // 2
|
351 |
+
self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True)
|
352 |
+
|
353 |
+
transformer = Transformer(
|
354 |
+
d_model=hidden_dim,
|
355 |
+
dropout=dropout,
|
356 |
+
nhead=nheads,
|
357 |
+
dim_feedforward=dim_feedforward,
|
358 |
+
num_encoder_layers=enc_layers,
|
359 |
+
num_decoder_layers=dec_layers,
|
360 |
+
normalize_before=pre_norm,
|
361 |
+
return_intermediate_dec=deep_supervision,
|
362 |
+
)
|
363 |
+
self.num_queries = num_queries
|
364 |
+
self.transformer = transformer
|
365 |
+
|
366 |
+
self.query_embed = nn.Embedding(num_queries, hidden_dim)
|
367 |
+
|
368 |
+
if in_channels != hidden_dim or enforce_input_project:
|
369 |
+
self.input_proj = nn.Conv2d(in_channels, hidden_dim, kernel_size=1)
|
370 |
+
nn.init.kaiming_uniform_(self.input_proj.weight, a=1)
|
371 |
+
if self.input_proj.bias is not None:
|
372 |
+
nn.init.constant_(self.input_proj.bias, 0)
|
373 |
+
else:
|
374 |
+
self.input_proj = nn.Sequential()
|
375 |
+
|
376 |
+
|
377 |
+
def forward(self, img_features, encode_feat):
|
378 |
+
pos = self.pe_layer(encode_feat)
|
379 |
+
src = encode_feat
|
380 |
+
mask = None
|
381 |
+
hs, memory = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos)
|
382 |
+
color_embed = hs[-1]
|
383 |
+
color_preds = torch.einsum('bqc,bchw->bqhw', color_embed, img_features)
|
384 |
+
return color_preds
|
385 |
+
|
basicsr/archs/ddcolor_arch_utils/__int__.py
ADDED
File without changes
|
basicsr/archs/ddcolor_arch_utils/__pycache__/convnext.cpython-310.pyc
ADDED
Binary file (6.08 kB). View file
|
|
basicsr/archs/ddcolor_arch_utils/__pycache__/convnext.cpython-38.pyc
ADDED
Binary file (6.2 kB). View file
|
|
basicsr/archs/ddcolor_arch_utils/__pycache__/position_encoding.cpython-310.pyc
ADDED
Binary file (2.03 kB). View file
|
|
basicsr/archs/ddcolor_arch_utils/__pycache__/position_encoding.cpython-38.pyc
ADDED
Binary file (2.03 kB). View file
|
|
basicsr/archs/ddcolor_arch_utils/__pycache__/transformer.cpython-310.pyc
ADDED
Binary file (8.96 kB). View file
|
|
basicsr/archs/ddcolor_arch_utils/__pycache__/transformer.cpython-38.pyc
ADDED
Binary file (8.81 kB). View file
|
|
basicsr/archs/ddcolor_arch_utils/__pycache__/transformer_utils.cpython-310.pyc
ADDED
Binary file (6.4 kB). View file
|
|
basicsr/archs/ddcolor_arch_utils/__pycache__/transformer_utils.cpython-38.pyc
ADDED
Binary file (6.57 kB). View file
|
|
basicsr/archs/ddcolor_arch_utils/__pycache__/unet.cpython-310.pyc
ADDED
Binary file (7.4 kB). View file
|
|
basicsr/archs/ddcolor_arch_utils/__pycache__/unet.cpython-38.pyc
ADDED
Binary file (7.37 kB). View file
|
|
basicsr/archs/ddcolor_arch_utils/convnext.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
|
3 |
+
# All rights reserved.
|
4 |
+
|
5 |
+
# This source code is licensed under the license found in the
|
6 |
+
# LICENSE file in the root directory of this source tree.
|
7 |
+
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.nn.functional as F
|
12 |
+
from timm.models.layers import trunc_normal_, DropPath
|
13 |
+
|
14 |
+
class Block(nn.Module):
|
15 |
+
r""" ConvNeXt Block. There are two equivalent implementations:
|
16 |
+
(1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
|
17 |
+
(2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
|
18 |
+
We use (2) as we find it slightly faster in PyTorch
|
19 |
+
|
20 |
+
Args:
|
21 |
+
dim (int): Number of input channels.
|
22 |
+
drop_path (float): Stochastic depth rate. Default: 0.0
|
23 |
+
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
|
24 |
+
"""
|
25 |
+
def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6):
|
26 |
+
super().__init__()
|
27 |
+
self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
|
28 |
+
self.norm = LayerNorm(dim, eps=1e-6)
|
29 |
+
self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
|
30 |
+
self.act = nn.GELU()
|
31 |
+
self.pwconv2 = nn.Linear(4 * dim, dim)
|
32 |
+
self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)),
|
33 |
+
requires_grad=True) if layer_scale_init_value > 0 else None
|
34 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
35 |
+
|
36 |
+
def forward(self, x):
|
37 |
+
input = x
|
38 |
+
x = self.dwconv(x)
|
39 |
+
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
|
40 |
+
x = self.norm(x)
|
41 |
+
x = self.pwconv1(x)
|
42 |
+
x = self.act(x)
|
43 |
+
x = self.pwconv2(x)
|
44 |
+
if self.gamma is not None:
|
45 |
+
x = self.gamma * x
|
46 |
+
x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
|
47 |
+
|
48 |
+
x = input + self.drop_path(x)
|
49 |
+
return x
|
50 |
+
|
51 |
+
class ConvNeXt(nn.Module):
|
52 |
+
r""" ConvNeXt
|
53 |
+
A PyTorch impl of : `A ConvNet for the 2020s` -
|
54 |
+
https://arxiv.org/pdf/2201.03545.pdf
|
55 |
+
Args:
|
56 |
+
in_chans (int): Number of input image channels. Default: 3
|
57 |
+
num_classes (int): Number of classes for classification head. Default: 1000
|
58 |
+
depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
|
59 |
+
dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
|
60 |
+
drop_path_rate (float): Stochastic depth rate. Default: 0.
|
61 |
+
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
|
62 |
+
head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
|
63 |
+
"""
|
64 |
+
def __init__(self, in_chans=3, num_classes=1000,
|
65 |
+
depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], drop_path_rate=0.,
|
66 |
+
layer_scale_init_value=1e-6, head_init_scale=1.,
|
67 |
+
):
|
68 |
+
super().__init__()
|
69 |
+
|
70 |
+
self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers
|
71 |
+
stem = nn.Sequential(
|
72 |
+
nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
|
73 |
+
LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
|
74 |
+
)
|
75 |
+
self.downsample_layers.append(stem)
|
76 |
+
for i in range(3):
|
77 |
+
downsample_layer = nn.Sequential(
|
78 |
+
LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
|
79 |
+
nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2),
|
80 |
+
)
|
81 |
+
self.downsample_layers.append(downsample_layer)
|
82 |
+
|
83 |
+
self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks
|
84 |
+
dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
|
85 |
+
cur = 0
|
86 |
+
for i in range(4):
|
87 |
+
stage = nn.Sequential(
|
88 |
+
*[Block(dim=dims[i], drop_path=dp_rates[cur + j],
|
89 |
+
layer_scale_init_value=layer_scale_init_value) for j in range(depths[i])]
|
90 |
+
)
|
91 |
+
self.stages.append(stage)
|
92 |
+
cur += depths[i]
|
93 |
+
|
94 |
+
# add norm layers for each output
|
95 |
+
out_indices = (0, 1, 2, 3)
|
96 |
+
for i in out_indices:
|
97 |
+
layer = LayerNorm(dims[i], eps=1e-6, data_format="channels_first")
|
98 |
+
# layer = nn.Identity()
|
99 |
+
layer_name = f'norm{i}'
|
100 |
+
self.add_module(layer_name, layer)
|
101 |
+
|
102 |
+
self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer
|
103 |
+
# self.head_cls = nn.Linear(dims[-1], 4)
|
104 |
+
|
105 |
+
self.apply(self._init_weights)
|
106 |
+
# self.head_cls.weight.data.mul_(head_init_scale)
|
107 |
+
# self.head_cls.bias.data.mul_(head_init_scale)
|
108 |
+
|
109 |
+
def _init_weights(self, m):
|
110 |
+
if isinstance(m, (nn.Conv2d, nn.Linear)):
|
111 |
+
trunc_normal_(m.weight, std=.02)
|
112 |
+
nn.init.constant_(m.bias, 0)
|
113 |
+
|
114 |
+
def forward_features(self, x):
|
115 |
+
for i in range(4):
|
116 |
+
x = self.downsample_layers[i](x)
|
117 |
+
x = self.stages[i](x)
|
118 |
+
|
119 |
+
# add extra norm
|
120 |
+
norm_layer = getattr(self, f'norm{i}')
|
121 |
+
# x = norm_layer(x)
|
122 |
+
norm_layer(x)
|
123 |
+
|
124 |
+
return self.norm(x.mean([-2, -1])) # global average pooling, (N, C, H, W) -> (N, C)
|
125 |
+
|
126 |
+
def forward(self, x):
|
127 |
+
x = self.forward_features(x)
|
128 |
+
# x = self.head_cls(x)
|
129 |
+
return x
|
130 |
+
|
131 |
+
class LayerNorm(nn.Module):
|
132 |
+
r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
|
133 |
+
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
|
134 |
+
shape (batch_size, height, width, channels) while channels_first corresponds to inputs
|
135 |
+
with shape (batch_size, channels, height, width).
|
136 |
+
"""
|
137 |
+
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
|
138 |
+
super().__init__()
|
139 |
+
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
140 |
+
self.bias = nn.Parameter(torch.zeros(normalized_shape))
|
141 |
+
self.eps = eps
|
142 |
+
self.data_format = data_format
|
143 |
+
if self.data_format not in ["channels_last", "channels_first"]:
|
144 |
+
raise NotImplementedError
|
145 |
+
self.normalized_shape = (normalized_shape, )
|
146 |
+
|
147 |
+
def forward(self, x):
|
148 |
+
if self.data_format == "channels_last": # B H W C
|
149 |
+
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
150 |
+
elif self.data_format == "channels_first": # B C H W
|
151 |
+
u = x.mean(1, keepdim=True)
|
152 |
+
s = (x - u).pow(2).mean(1, keepdim=True)
|
153 |
+
x = (x - u) / torch.sqrt(s + self.eps)
|
154 |
+
x = self.weight[:, None, None] * x + self.bias[:, None, None]
|
155 |
+
return x
|
basicsr/archs/ddcolor_arch_utils/position_encoding.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# Modified from: https://github.com/facebookresearch/detr/blob/master/models/position_encoding.py
|
3 |
+
"""
|
4 |
+
Various positional encodings for the transformer.
|
5 |
+
"""
|
6 |
+
import math
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from torch import nn
|
10 |
+
|
11 |
+
|
12 |
+
class PositionEmbeddingSine(nn.Module):
|
13 |
+
"""
|
14 |
+
This is a more standard version of the position embedding, very similar to the one
|
15 |
+
used by the Attention is all you need paper, generalized to work on images.
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
|
19 |
+
super().__init__()
|
20 |
+
self.num_pos_feats = num_pos_feats
|
21 |
+
self.temperature = temperature
|
22 |
+
self.normalize = normalize
|
23 |
+
if scale is not None and normalize is False:
|
24 |
+
raise ValueError("normalize should be True if scale is passed")
|
25 |
+
if scale is None:
|
26 |
+
scale = 2 * math.pi
|
27 |
+
self.scale = scale
|
28 |
+
|
29 |
+
def forward(self, x, mask=None):
|
30 |
+
if mask is None:
|
31 |
+
mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool)
|
32 |
+
not_mask = ~mask
|
33 |
+
y_embed = not_mask.cumsum(1, dtype=torch.float32)
|
34 |
+
x_embed = not_mask.cumsum(2, dtype=torch.float32)
|
35 |
+
if self.normalize:
|
36 |
+
eps = 1e-6
|
37 |
+
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
38 |
+
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
39 |
+
|
40 |
+
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
41 |
+
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
|
42 |
+
|
43 |
+
pos_x = x_embed[:, :, :, None] / dim_t
|
44 |
+
pos_y = y_embed[:, :, :, None] / dim_t
|
45 |
+
pos_x = torch.stack(
|
46 |
+
(pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
|
47 |
+
).flatten(3)
|
48 |
+
pos_y = torch.stack(
|
49 |
+
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
|
50 |
+
).flatten(3)
|
51 |
+
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
52 |
+
return pos
|
basicsr/archs/ddcolor_arch_utils/transformer.py
ADDED
@@ -0,0 +1,368 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# Modified from: https://github.com/facebookresearch/detr/blob/master/models/transformer.py
|
3 |
+
"""
|
4 |
+
Transformer class.
|
5 |
+
Copy-paste from torch.nn.Transformer with modifications:
|
6 |
+
* positional encodings are passed in MHattention
|
7 |
+
* extra LN at the end of encoder is removed
|
8 |
+
* decoder returns a stack of activations from all decoding layers
|
9 |
+
"""
|
10 |
+
import copy
|
11 |
+
from typing import List, Optional
|
12 |
+
|
13 |
+
import torch
|
14 |
+
import torch.nn.functional as F
|
15 |
+
from torch import Tensor, nn
|
16 |
+
|
17 |
+
|
18 |
+
class Transformer(nn.Module):
|
19 |
+
def __init__(
|
20 |
+
self,
|
21 |
+
d_model=512,
|
22 |
+
nhead=8,
|
23 |
+
num_encoder_layers=6,
|
24 |
+
num_decoder_layers=6,
|
25 |
+
dim_feedforward=2048,
|
26 |
+
dropout=0.1,
|
27 |
+
activation="relu",
|
28 |
+
normalize_before=False,
|
29 |
+
return_intermediate_dec=False,
|
30 |
+
):
|
31 |
+
super().__init__()
|
32 |
+
|
33 |
+
encoder_layer = TransformerEncoderLayer(
|
34 |
+
d_model, nhead, dim_feedforward, dropout, activation, normalize_before
|
35 |
+
)
|
36 |
+
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
|
37 |
+
self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
|
38 |
+
|
39 |
+
decoder_layer = TransformerDecoderLayer(
|
40 |
+
d_model, nhead, dim_feedforward, dropout, activation, normalize_before
|
41 |
+
)
|
42 |
+
decoder_norm = nn.LayerNorm(d_model)
|
43 |
+
self.decoder = TransformerDecoder(
|
44 |
+
decoder_layer,
|
45 |
+
num_decoder_layers,
|
46 |
+
decoder_norm,
|
47 |
+
return_intermediate=return_intermediate_dec,
|
48 |
+
)
|
49 |
+
|
50 |
+
self._reset_parameters()
|
51 |
+
|
52 |
+
self.d_model = d_model
|
53 |
+
self.nhead = nhead
|
54 |
+
|
55 |
+
def _reset_parameters(self):
|
56 |
+
for p in self.parameters():
|
57 |
+
if p.dim() > 1:
|
58 |
+
nn.init.xavier_uniform_(p)
|
59 |
+
|
60 |
+
def forward(self, src, mask, query_embed, pos_embed):
|
61 |
+
# flatten NxCxHxW to HWxNxC
|
62 |
+
bs, c, h, w = src.shape
|
63 |
+
src = src.flatten(2).permute(2, 0, 1)
|
64 |
+
pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
|
65 |
+
query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
|
66 |
+
if mask is not None:
|
67 |
+
mask = mask.flatten(1)
|
68 |
+
|
69 |
+
tgt = torch.zeros_like(query_embed)
|
70 |
+
memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
|
71 |
+
hs = self.decoder(
|
72 |
+
tgt, memory, memory_key_padding_mask=mask, pos=pos_embed, query_pos=query_embed
|
73 |
+
)
|
74 |
+
return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w)
|
75 |
+
|
76 |
+
|
77 |
+
class TransformerEncoder(nn.Module):
|
78 |
+
def __init__(self, encoder_layer, num_layers, norm=None):
|
79 |
+
super().__init__()
|
80 |
+
self.layers = _get_clones(encoder_layer, num_layers)
|
81 |
+
self.num_layers = num_layers
|
82 |
+
self.norm = norm
|
83 |
+
|
84 |
+
def forward(
|
85 |
+
self,
|
86 |
+
src,
|
87 |
+
mask: Optional[Tensor] = None,
|
88 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
89 |
+
pos: Optional[Tensor] = None,
|
90 |
+
):
|
91 |
+
output = src
|
92 |
+
|
93 |
+
for layer in self.layers:
|
94 |
+
output = layer(
|
95 |
+
output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, pos=pos
|
96 |
+
)
|
97 |
+
|
98 |
+
if self.norm is not None:
|
99 |
+
output = self.norm(output)
|
100 |
+
|
101 |
+
return output
|
102 |
+
|
103 |
+
|
104 |
+
class TransformerDecoder(nn.Module):
|
105 |
+
def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
|
106 |
+
super().__init__()
|
107 |
+
self.layers = _get_clones(decoder_layer, num_layers)
|
108 |
+
self.num_layers = num_layers
|
109 |
+
self.norm = norm
|
110 |
+
self.return_intermediate = return_intermediate
|
111 |
+
|
112 |
+
def forward(
|
113 |
+
self,
|
114 |
+
tgt,
|
115 |
+
memory,
|
116 |
+
tgt_mask: Optional[Tensor] = None,
|
117 |
+
memory_mask: Optional[Tensor] = None,
|
118 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
119 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
120 |
+
pos: Optional[Tensor] = None,
|
121 |
+
query_pos: Optional[Tensor] = None,
|
122 |
+
):
|
123 |
+
output = tgt
|
124 |
+
|
125 |
+
intermediate = []
|
126 |
+
|
127 |
+
for layer in self.layers:
|
128 |
+
output = layer(
|
129 |
+
output,
|
130 |
+
memory,
|
131 |
+
tgt_mask=tgt_mask,
|
132 |
+
memory_mask=memory_mask,
|
133 |
+
tgt_key_padding_mask=tgt_key_padding_mask,
|
134 |
+
memory_key_padding_mask=memory_key_padding_mask,
|
135 |
+
pos=pos,
|
136 |
+
query_pos=query_pos,
|
137 |
+
)
|
138 |
+
if self.return_intermediate:
|
139 |
+
intermediate.append(self.norm(output))
|
140 |
+
|
141 |
+
if self.norm is not None:
|
142 |
+
output = self.norm(output)
|
143 |
+
if self.return_intermediate:
|
144 |
+
intermediate.pop()
|
145 |
+
intermediate.append(output)
|
146 |
+
|
147 |
+
if self.return_intermediate:
|
148 |
+
return torch.stack(intermediate)
|
149 |
+
|
150 |
+
return output.unsqueeze(0)
|
151 |
+
|
152 |
+
|
153 |
+
class TransformerEncoderLayer(nn.Module):
|
154 |
+
def __init__(
|
155 |
+
self,
|
156 |
+
d_model,
|
157 |
+
nhead,
|
158 |
+
dim_feedforward=2048,
|
159 |
+
dropout=0.1,
|
160 |
+
activation="relu",
|
161 |
+
normalize_before=False,
|
162 |
+
):
|
163 |
+
super().__init__()
|
164 |
+
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
165 |
+
# Implementation of Feedforward model
|
166 |
+
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
167 |
+
self.dropout = nn.Dropout(dropout)
|
168 |
+
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
169 |
+
|
170 |
+
self.norm1 = nn.LayerNorm(d_model)
|
171 |
+
self.norm2 = nn.LayerNorm(d_model)
|
172 |
+
self.dropout1 = nn.Dropout(dropout)
|
173 |
+
self.dropout2 = nn.Dropout(dropout)
|
174 |
+
|
175 |
+
self.activation = _get_activation_fn(activation)
|
176 |
+
self.normalize_before = normalize_before
|
177 |
+
|
178 |
+
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
179 |
+
return tensor if pos is None else tensor + pos
|
180 |
+
|
181 |
+
def forward_post(
|
182 |
+
self,
|
183 |
+
src,
|
184 |
+
src_mask: Optional[Tensor] = None,
|
185 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
186 |
+
pos: Optional[Tensor] = None,
|
187 |
+
):
|
188 |
+
q = k = self.with_pos_embed(src, pos)
|
189 |
+
src2 = self.self_attn(
|
190 |
+
q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
|
191 |
+
)[0]
|
192 |
+
src = src + self.dropout1(src2)
|
193 |
+
src = self.norm1(src)
|
194 |
+
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
|
195 |
+
src = src + self.dropout2(src2)
|
196 |
+
src = self.norm2(src)
|
197 |
+
return src
|
198 |
+
|
199 |
+
def forward_pre(
|
200 |
+
self,
|
201 |
+
src,
|
202 |
+
src_mask: Optional[Tensor] = None,
|
203 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
204 |
+
pos: Optional[Tensor] = None,
|
205 |
+
):
|
206 |
+
src2 = self.norm1(src)
|
207 |
+
q = k = self.with_pos_embed(src2, pos)
|
208 |
+
src2 = self.self_attn(
|
209 |
+
q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
|
210 |
+
)[0]
|
211 |
+
src = src + self.dropout1(src2)
|
212 |
+
src2 = self.norm2(src)
|
213 |
+
src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
|
214 |
+
src = src + self.dropout2(src2)
|
215 |
+
return src
|
216 |
+
|
217 |
+
def forward(
|
218 |
+
self,
|
219 |
+
src,
|
220 |
+
src_mask: Optional[Tensor] = None,
|
221 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
222 |
+
pos: Optional[Tensor] = None,
|
223 |
+
):
|
224 |
+
if self.normalize_before:
|
225 |
+
return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
|
226 |
+
return self.forward_post(src, src_mask, src_key_padding_mask, pos)
|
227 |
+
|
228 |
+
|
229 |
+
class TransformerDecoderLayer(nn.Module):
|
230 |
+
def __init__(
|
231 |
+
self,
|
232 |
+
d_model,
|
233 |
+
nhead,
|
234 |
+
dim_feedforward=2048,
|
235 |
+
dropout=0.1,
|
236 |
+
activation="relu",
|
237 |
+
normalize_before=False,
|
238 |
+
):
|
239 |
+
super().__init__()
|
240 |
+
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
241 |
+
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
242 |
+
# Implementation of Feedforward model
|
243 |
+
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
244 |
+
self.dropout = nn.Dropout(dropout)
|
245 |
+
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
246 |
+
|
247 |
+
self.norm1 = nn.LayerNorm(d_model)
|
248 |
+
self.norm2 = nn.LayerNorm(d_model)
|
249 |
+
self.norm3 = nn.LayerNorm(d_model)
|
250 |
+
self.dropout1 = nn.Dropout(dropout)
|
251 |
+
self.dropout2 = nn.Dropout(dropout)
|
252 |
+
self.dropout3 = nn.Dropout(dropout)
|
253 |
+
|
254 |
+
self.activation = _get_activation_fn(activation)
|
255 |
+
self.normalize_before = normalize_before
|
256 |
+
|
257 |
+
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
258 |
+
return tensor if pos is None else tensor + pos
|
259 |
+
|
260 |
+
def forward_post(
|
261 |
+
self,
|
262 |
+
tgt,
|
263 |
+
memory,
|
264 |
+
tgt_mask: Optional[Tensor] = None,
|
265 |
+
memory_mask: Optional[Tensor] = None,
|
266 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
267 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
268 |
+
pos: Optional[Tensor] = None,
|
269 |
+
query_pos: Optional[Tensor] = None,
|
270 |
+
):
|
271 |
+
q = k = self.with_pos_embed(tgt, query_pos)
|
272 |
+
tgt2 = self.self_attn(
|
273 |
+
q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
|
274 |
+
)[0]
|
275 |
+
tgt = tgt + self.dropout1(tgt2)
|
276 |
+
tgt = self.norm1(tgt)
|
277 |
+
tgt2 = self.multihead_attn(
|
278 |
+
query=self.with_pos_embed(tgt, query_pos),
|
279 |
+
key=self.with_pos_embed(memory, pos),
|
280 |
+
value=memory,
|
281 |
+
attn_mask=memory_mask,
|
282 |
+
key_padding_mask=memory_key_padding_mask,
|
283 |
+
)[0]
|
284 |
+
tgt = tgt + self.dropout2(tgt2)
|
285 |
+
tgt = self.norm2(tgt)
|
286 |
+
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
|
287 |
+
tgt = tgt + self.dropout3(tgt2)
|
288 |
+
tgt = self.norm3(tgt)
|
289 |
+
return tgt
|
290 |
+
|
291 |
+
def forward_pre(
|
292 |
+
self,
|
293 |
+
tgt,
|
294 |
+
memory,
|
295 |
+
tgt_mask: Optional[Tensor] = None,
|
296 |
+
memory_mask: Optional[Tensor] = None,
|
297 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
298 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
299 |
+
pos: Optional[Tensor] = None,
|
300 |
+
query_pos: Optional[Tensor] = None,
|
301 |
+
):
|
302 |
+
tgt2 = self.norm1(tgt)
|
303 |
+
q = k = self.with_pos_embed(tgt2, query_pos)
|
304 |
+
tgt2 = self.self_attn(
|
305 |
+
q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
|
306 |
+
)[0]
|
307 |
+
tgt = tgt + self.dropout1(tgt2)
|
308 |
+
tgt2 = self.norm2(tgt)
|
309 |
+
tgt2 = self.multihead_attn(
|
310 |
+
query=self.with_pos_embed(tgt2, query_pos),
|
311 |
+
key=self.with_pos_embed(memory, pos),
|
312 |
+
value=memory,
|
313 |
+
attn_mask=memory_mask,
|
314 |
+
key_padding_mask=memory_key_padding_mask,
|
315 |
+
)[0]
|
316 |
+
tgt = tgt + self.dropout2(tgt2)
|
317 |
+
tgt2 = self.norm3(tgt)
|
318 |
+
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
|
319 |
+
tgt = tgt + self.dropout3(tgt2)
|
320 |
+
return tgt
|
321 |
+
|
322 |
+
def forward(
|
323 |
+
self,
|
324 |
+
tgt,
|
325 |
+
memory,
|
326 |
+
tgt_mask: Optional[Tensor] = None,
|
327 |
+
memory_mask: Optional[Tensor] = None,
|
328 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
329 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
330 |
+
pos: Optional[Tensor] = None,
|
331 |
+
query_pos: Optional[Tensor] = None,
|
332 |
+
):
|
333 |
+
if self.normalize_before:
|
334 |
+
return self.forward_pre(
|
335 |
+
tgt,
|
336 |
+
memory,
|
337 |
+
tgt_mask,
|
338 |
+
memory_mask,
|
339 |
+
tgt_key_padding_mask,
|
340 |
+
memory_key_padding_mask,
|
341 |
+
pos,
|
342 |
+
query_pos,
|
343 |
+
)
|
344 |
+
return self.forward_post(
|
345 |
+
tgt,
|
346 |
+
memory,
|
347 |
+
tgt_mask,
|
348 |
+
memory_mask,
|
349 |
+
tgt_key_padding_mask,
|
350 |
+
memory_key_padding_mask,
|
351 |
+
pos,
|
352 |
+
query_pos,
|
353 |
+
)
|
354 |
+
|
355 |
+
|
356 |
+
def _get_clones(module, N):
|
357 |
+
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
|
358 |
+
|
359 |
+
|
360 |
+
def _get_activation_fn(activation):
|
361 |
+
"""Return an activation function given a string"""
|
362 |
+
if activation == "relu":
|
363 |
+
return F.relu
|
364 |
+
if activation == "gelu":
|
365 |
+
return F.gelu
|
366 |
+
if activation == "glu":
|
367 |
+
return F.glu
|
368 |
+
raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
|
basicsr/archs/ddcolor_arch_utils/transformer_utils.py
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
from torch import nn, Tensor
|
3 |
+
from torch.nn import functional as F
|
4 |
+
|
5 |
+
class SelfAttentionLayer(nn.Module):
|
6 |
+
|
7 |
+
def __init__(self, d_model, nhead, dropout=0.0,
|
8 |
+
activation="relu", normalize_before=False):
|
9 |
+
super().__init__()
|
10 |
+
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
11 |
+
|
12 |
+
self.norm = nn.LayerNorm(d_model)
|
13 |
+
self.dropout = nn.Dropout(dropout)
|
14 |
+
|
15 |
+
self.activation = _get_activation_fn(activation)
|
16 |
+
self.normalize_before = normalize_before
|
17 |
+
|
18 |
+
self._reset_parameters()
|
19 |
+
|
20 |
+
def _reset_parameters(self):
|
21 |
+
for p in self.parameters():
|
22 |
+
if p.dim() > 1:
|
23 |
+
nn.init.xavier_uniform_(p)
|
24 |
+
|
25 |
+
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
26 |
+
return tensor if pos is None else tensor + pos
|
27 |
+
|
28 |
+
def forward_post(self, tgt,
|
29 |
+
tgt_mask: Optional[Tensor] = None,
|
30 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
31 |
+
query_pos: Optional[Tensor] = None):
|
32 |
+
q = k = self.with_pos_embed(tgt, query_pos)
|
33 |
+
tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,
|
34 |
+
key_padding_mask=tgt_key_padding_mask)[0]
|
35 |
+
tgt = tgt + self.dropout(tgt2)
|
36 |
+
tgt = self.norm(tgt)
|
37 |
+
|
38 |
+
return tgt
|
39 |
+
|
40 |
+
def forward_pre(self, tgt,
|
41 |
+
tgt_mask: Optional[Tensor] = None,
|
42 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
43 |
+
query_pos: Optional[Tensor] = None):
|
44 |
+
tgt2 = self.norm(tgt)
|
45 |
+
q = k = self.with_pos_embed(tgt2, query_pos)
|
46 |
+
tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
|
47 |
+
key_padding_mask=tgt_key_padding_mask)[0]
|
48 |
+
tgt = tgt + self.dropout(tgt2)
|
49 |
+
|
50 |
+
return tgt
|
51 |
+
|
52 |
+
def forward(self, tgt,
|
53 |
+
tgt_mask: Optional[Tensor] = None,
|
54 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
55 |
+
query_pos: Optional[Tensor] = None):
|
56 |
+
if self.normalize_before:
|
57 |
+
return self.forward_pre(tgt, tgt_mask,
|
58 |
+
tgt_key_padding_mask, query_pos)
|
59 |
+
return self.forward_post(tgt, tgt_mask,
|
60 |
+
tgt_key_padding_mask, query_pos)
|
61 |
+
|
62 |
+
|
63 |
+
class CrossAttentionLayer(nn.Module):
|
64 |
+
|
65 |
+
def __init__(self, d_model, nhead, dropout=0.0,
|
66 |
+
activation="relu", normalize_before=False):
|
67 |
+
super().__init__()
|
68 |
+
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
69 |
+
|
70 |
+
self.norm = nn.LayerNorm(d_model)
|
71 |
+
self.dropout = nn.Dropout(dropout)
|
72 |
+
|
73 |
+
self.activation = _get_activation_fn(activation)
|
74 |
+
self.normalize_before = normalize_before
|
75 |
+
|
76 |
+
self._reset_parameters()
|
77 |
+
|
78 |
+
def _reset_parameters(self):
|
79 |
+
for p in self.parameters():
|
80 |
+
if p.dim() > 1:
|
81 |
+
nn.init.xavier_uniform_(p)
|
82 |
+
|
83 |
+
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
84 |
+
return tensor if pos is None else tensor + pos
|
85 |
+
|
86 |
+
def forward_post(self, tgt, memory,
|
87 |
+
memory_mask: Optional[Tensor] = None,
|
88 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
89 |
+
pos: Optional[Tensor] = None,
|
90 |
+
query_pos: Optional[Tensor] = None):
|
91 |
+
tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
|
92 |
+
key=self.with_pos_embed(memory, pos),
|
93 |
+
value=memory, attn_mask=memory_mask,
|
94 |
+
key_padding_mask=memory_key_padding_mask)[0]
|
95 |
+
tgt = tgt + self.dropout(tgt2)
|
96 |
+
tgt = self.norm(tgt)
|
97 |
+
|
98 |
+
return tgt
|
99 |
+
|
100 |
+
def forward_pre(self, tgt, memory,
|
101 |
+
memory_mask: Optional[Tensor] = None,
|
102 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
103 |
+
pos: Optional[Tensor] = None,
|
104 |
+
query_pos: Optional[Tensor] = None):
|
105 |
+
tgt2 = self.norm(tgt)
|
106 |
+
tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
|
107 |
+
key=self.with_pos_embed(memory, pos),
|
108 |
+
value=memory, attn_mask=memory_mask,
|
109 |
+
key_padding_mask=memory_key_padding_mask)[0]
|
110 |
+
tgt = tgt + self.dropout(tgt2)
|
111 |
+
|
112 |
+
return tgt
|
113 |
+
|
114 |
+
def forward(self, tgt, memory,
|
115 |
+
memory_mask: Optional[Tensor] = None,
|
116 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
117 |
+
pos: Optional[Tensor] = None,
|
118 |
+
query_pos: Optional[Tensor] = None):
|
119 |
+
if self.normalize_before:
|
120 |
+
return self.forward_pre(tgt, memory, memory_mask,
|
121 |
+
memory_key_padding_mask, pos, query_pos)
|
122 |
+
return self.forward_post(tgt, memory, memory_mask,
|
123 |
+
memory_key_padding_mask, pos, query_pos)
|
124 |
+
|
125 |
+
|
126 |
+
class FFNLayer(nn.Module):
|
127 |
+
|
128 |
+
def __init__(self, d_model, dim_feedforward=2048, dropout=0.0,
|
129 |
+
activation="relu", normalize_before=False):
|
130 |
+
super().__init__()
|
131 |
+
# Implementation of Feedforward model
|
132 |
+
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
133 |
+
self.dropout = nn.Dropout(dropout)
|
134 |
+
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
135 |
+
|
136 |
+
self.norm = nn.LayerNorm(d_model)
|
137 |
+
|
138 |
+
self.activation = _get_activation_fn(activation)
|
139 |
+
self.normalize_before = normalize_before
|
140 |
+
|
141 |
+
self._reset_parameters()
|
142 |
+
|
143 |
+
def _reset_parameters(self):
|
144 |
+
for p in self.parameters():
|
145 |
+
if p.dim() > 1:
|
146 |
+
nn.init.xavier_uniform_(p)
|
147 |
+
|
148 |
+
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
149 |
+
return tensor if pos is None else tensor + pos
|
150 |
+
|
151 |
+
def forward_post(self, tgt):
|
152 |
+
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
|
153 |
+
tgt = tgt + self.dropout(tgt2)
|
154 |
+
tgt = self.norm(tgt)
|
155 |
+
return tgt
|
156 |
+
|
157 |
+
def forward_pre(self, tgt):
|
158 |
+
tgt2 = self.norm(tgt)
|
159 |
+
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
|
160 |
+
tgt = tgt + self.dropout(tgt2)
|
161 |
+
return tgt
|
162 |
+
|
163 |
+
def forward(self, tgt):
|
164 |
+
if self.normalize_before:
|
165 |
+
return self.forward_pre(tgt)
|
166 |
+
return self.forward_post(tgt)
|
167 |
+
|
168 |
+
|
169 |
+
def _get_activation_fn(activation):
|
170 |
+
"""Return an activation function given a string"""
|
171 |
+
if activation == "relu":
|
172 |
+
return F.relu
|
173 |
+
if activation == "gelu":
|
174 |
+
return F.gelu
|
175 |
+
if activation == "glu":
|
176 |
+
return F.glu
|
177 |
+
raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
|
178 |
+
|
179 |
+
|
180 |
+
class MLP(nn.Module):
|
181 |
+
""" Very simple multi-layer perceptron (also called FFN)"""
|
182 |
+
|
183 |
+
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
|
184 |
+
super().__init__()
|
185 |
+
self.num_layers = num_layers
|
186 |
+
h = [hidden_dim] * (num_layers - 1)
|
187 |
+
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
|
188 |
+
|
189 |
+
def forward(self, x):
|
190 |
+
for i, layer in enumerate(self.layers):
|
191 |
+
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
192 |
+
return x
|
basicsr/archs/ddcolor_arch_utils/unet.py
ADDED
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from enum import Enum
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from torch.nn import functional as F
|
5 |
+
import collections
|
6 |
+
|
7 |
+
|
8 |
+
NormType = Enum('NormType', 'Batch BatchZero Weight Spectral')
|
9 |
+
|
10 |
+
|
11 |
+
class Hook:
|
12 |
+
feature = None
|
13 |
+
|
14 |
+
def __init__(self, module):
|
15 |
+
self.hook = module.register_forward_hook(self.hook_fn)
|
16 |
+
|
17 |
+
def hook_fn(self, module, input, output):
|
18 |
+
if isinstance(output, torch.Tensor):
|
19 |
+
self.feature = output
|
20 |
+
elif isinstance(output, collections.OrderedDict):
|
21 |
+
self.feature = output['out']
|
22 |
+
|
23 |
+
def remove(self):
|
24 |
+
self.hook.remove()
|
25 |
+
|
26 |
+
|
27 |
+
class SelfAttention(nn.Module):
|
28 |
+
"Self attention layer for nd."
|
29 |
+
|
30 |
+
def __init__(self, n_channels: int):
|
31 |
+
super().__init__()
|
32 |
+
self.query = conv1d(n_channels, n_channels // 8)
|
33 |
+
self.key = conv1d(n_channels, n_channels // 8)
|
34 |
+
self.value = conv1d(n_channels, n_channels)
|
35 |
+
self.gamma = nn.Parameter(torch.tensor([0.]))
|
36 |
+
|
37 |
+
def forward(self, x):
|
38 |
+
#Notation from https://arxiv.org/pdf/1805.08318.pdf
|
39 |
+
size = x.size()
|
40 |
+
x = x.view(*size[:2], -1)
|
41 |
+
f, g, h = self.query(x), self.key(x), self.value(x)
|
42 |
+
beta = F.softmax(torch.bmm(f.permute(0, 2, 1).contiguous(), g), dim=1)
|
43 |
+
o = self.gamma * torch.bmm(h, beta) + x
|
44 |
+
return o.view(*size).contiguous()
|
45 |
+
|
46 |
+
|
47 |
+
def batchnorm_2d(nf: int, norm_type: NormType = NormType.Batch):
|
48 |
+
"A batchnorm2d layer with `nf` features initialized depending on `norm_type`."
|
49 |
+
bn = nn.BatchNorm2d(nf)
|
50 |
+
with torch.no_grad():
|
51 |
+
bn.bias.fill_(1e-3)
|
52 |
+
bn.weight.fill_(0. if norm_type == NormType.BatchZero else 1.)
|
53 |
+
return bn
|
54 |
+
|
55 |
+
|
56 |
+
def init_default(m: nn.Module, func=nn.init.kaiming_normal_) -> None:
|
57 |
+
"Initialize `m` weights with `func` and set `bias` to 0."
|
58 |
+
if func:
|
59 |
+
if hasattr(m, 'weight'): func(m.weight)
|
60 |
+
if hasattr(m, 'bias') and hasattr(m.bias, 'data'): m.bias.data.fill_(0.)
|
61 |
+
return m
|
62 |
+
|
63 |
+
|
64 |
+
def icnr(x, scale=2, init=nn.init.kaiming_normal_):
|
65 |
+
"ICNR init of `x`, with `scale` and `init` function."
|
66 |
+
ni, nf, h, w = x.shape
|
67 |
+
ni2 = int(ni / (scale**2))
|
68 |
+
k = init(torch.zeros([ni2, nf, h, w])).transpose(0, 1)
|
69 |
+
k = k.contiguous().view(ni2, nf, -1)
|
70 |
+
k = k.repeat(1, 1, scale**2)
|
71 |
+
k = k.contiguous().view([nf, ni, h, w]).transpose(0, 1)
|
72 |
+
x.data.copy_(k)
|
73 |
+
|
74 |
+
|
75 |
+
def conv1d(ni: int, no: int, ks: int = 1, stride: int = 1, padding: int = 0, bias: bool = False):
|
76 |
+
"Create and initialize a `nn.Conv1d` layer with spectral normalization."
|
77 |
+
conv = nn.Conv1d(ni, no, ks, stride=stride, padding=padding, bias=bias)
|
78 |
+
nn.init.kaiming_normal_(conv.weight)
|
79 |
+
if bias: conv.bias.data.zero_()
|
80 |
+
return nn.utils.spectral_norm(conv)
|
81 |
+
|
82 |
+
|
83 |
+
def custom_conv_layer(
|
84 |
+
ni: int,
|
85 |
+
nf: int,
|
86 |
+
ks: int = 3,
|
87 |
+
stride: int = 1,
|
88 |
+
padding: int = None,
|
89 |
+
bias: bool = None,
|
90 |
+
is_1d: bool = False,
|
91 |
+
norm_type=NormType.Batch,
|
92 |
+
use_activ: bool = True,
|
93 |
+
transpose: bool = False,
|
94 |
+
init=nn.init.kaiming_normal_,
|
95 |
+
self_attention: bool = False,
|
96 |
+
extra_bn: bool = False,
|
97 |
+
):
|
98 |
+
"Create a sequence of convolutional (`ni` to `nf`), ReLU (if `use_activ`) and batchnorm (if `bn`) layers."
|
99 |
+
if padding is None:
|
100 |
+
padding = (ks - 1) // 2 if not transpose else 0
|
101 |
+
bn = norm_type in (NormType.Batch, NormType.BatchZero) or extra_bn == True
|
102 |
+
if bias is None:
|
103 |
+
bias = not bn
|
104 |
+
conv_func = nn.ConvTranspose2d if transpose else nn.Conv1d if is_1d else nn.Conv2d
|
105 |
+
conv = init_default(
|
106 |
+
conv_func(ni, nf, kernel_size=ks, bias=bias, stride=stride, padding=padding),
|
107 |
+
init,
|
108 |
+
)
|
109 |
+
|
110 |
+
if norm_type == NormType.Weight:
|
111 |
+
conv = nn.utils.weight_norm(conv)
|
112 |
+
elif norm_type == NormType.Spectral:
|
113 |
+
conv = nn.utils.spectral_norm(conv)
|
114 |
+
layers = [conv]
|
115 |
+
if use_activ:
|
116 |
+
layers.append(nn.ReLU(True))
|
117 |
+
if bn:
|
118 |
+
layers.append((nn.BatchNorm1d if is_1d else nn.BatchNorm2d)(nf))
|
119 |
+
if self_attention:
|
120 |
+
layers.append(SelfAttention(nf))
|
121 |
+
return nn.Sequential(*layers)
|
122 |
+
|
123 |
+
|
124 |
+
def conv_layer(ni: int,
|
125 |
+
nf: int,
|
126 |
+
ks: int = 3,
|
127 |
+
stride: int = 1,
|
128 |
+
padding: int = None,
|
129 |
+
bias: bool = None,
|
130 |
+
is_1d: bool = False,
|
131 |
+
norm_type=NormType.Batch,
|
132 |
+
use_activ: bool = True,
|
133 |
+
transpose: bool = False,
|
134 |
+
init=nn.init.kaiming_normal_,
|
135 |
+
self_attention: bool = False):
|
136 |
+
"Create a sequence of convolutional (`ni` to `nf`), ReLU (if `use_activ`) and batchnorm (if `bn`) layers."
|
137 |
+
if padding is None: padding = (ks - 1) // 2 if not transpose else 0
|
138 |
+
bn = norm_type in (NormType.Batch, NormType.BatchZero)
|
139 |
+
if bias is None: bias = not bn
|
140 |
+
conv_func = nn.ConvTranspose2d if transpose else nn.Conv1d if is_1d else nn.Conv2d
|
141 |
+
conv = init_default(conv_func(ni, nf, kernel_size=ks, bias=bias, stride=stride, padding=padding), init)
|
142 |
+
if norm_type == NormType.Weight: conv = nn.utils.weight_norm(conv)
|
143 |
+
elif norm_type == NormType.Spectral: conv = nn.utils.spectral_norm(conv)
|
144 |
+
layers = [conv]
|
145 |
+
if use_activ: layers.append(nn.ReLU(True))
|
146 |
+
if bn: layers.append((nn.BatchNorm1d if is_1d else nn.BatchNorm2d)(nf))
|
147 |
+
if self_attention: layers.append(SelfAttention(nf))
|
148 |
+
return nn.Sequential(*layers)
|
149 |
+
|
150 |
+
|
151 |
+
def _conv(ni: int, nf: int, ks: int = 3, stride: int = 1, **kwargs):
|
152 |
+
return conv_layer(ni, nf, ks=ks, stride=stride, norm_type=NormType.Spectral, **kwargs)
|
153 |
+
|
154 |
+
|
155 |
+
class CustomPixelShuffle_ICNR(nn.Module):
|
156 |
+
"Upsample by `scale` from `ni` filters to `nf` (default `ni`), using `nn.PixelShuffle`, `icnr` init, and `weight_norm`."
|
157 |
+
|
158 |
+
def __init__(self,
|
159 |
+
ni: int,
|
160 |
+
nf: int = None,
|
161 |
+
scale: int = 2,
|
162 |
+
blur: bool = True,
|
163 |
+
norm_type=NormType.Spectral,
|
164 |
+
extra_bn=False):
|
165 |
+
super().__init__()
|
166 |
+
self.conv = custom_conv_layer(
|
167 |
+
ni, nf * (scale**2), ks=1, use_activ=False, norm_type=norm_type, extra_bn=extra_bn)
|
168 |
+
icnr(self.conv[0].weight)
|
169 |
+
self.shuf = nn.PixelShuffle(scale)
|
170 |
+
self.do_blur = blur
|
171 |
+
# Blurring over (h*w) kernel
|
172 |
+
# "Super-Resolution using Convolutional Neural Networks without Any Checkerboard Artifacts"
|
173 |
+
# - https://arxiv.org/abs/1806.02658
|
174 |
+
self.pad = nn.ReplicationPad2d((1, 0, 1, 0))
|
175 |
+
self.blur = nn.AvgPool2d(2, stride=1)
|
176 |
+
self.relu = nn.ReLU(True)
|
177 |
+
|
178 |
+
def forward(self, x):
|
179 |
+
x = self.shuf(self.relu(self.conv(x)))
|
180 |
+
return self.blur(self.pad(x)) if self.do_blur else x
|
181 |
+
|
182 |
+
|
183 |
+
class UnetBlockWide(nn.Module):
|
184 |
+
"A quasi-UNet block, using `PixelShuffle_ICNR upsampling`."
|
185 |
+
|
186 |
+
def __init__(self,
|
187 |
+
up_in_c: int,
|
188 |
+
x_in_c: int,
|
189 |
+
n_out: int,
|
190 |
+
hook,
|
191 |
+
blur: bool = False,
|
192 |
+
self_attention: bool = False,
|
193 |
+
norm_type=NormType.Spectral):
|
194 |
+
super().__init__()
|
195 |
+
|
196 |
+
self.hook = hook
|
197 |
+
up_out = n_out
|
198 |
+
self.shuf = CustomPixelShuffle_ICNR(up_in_c, up_out, blur=blur, norm_type=norm_type, extra_bn=True)
|
199 |
+
self.bn = batchnorm_2d(x_in_c)
|
200 |
+
ni = up_out + x_in_c
|
201 |
+
self.conv = custom_conv_layer(ni, n_out, norm_type=norm_type, self_attention=self_attention, extra_bn=True)
|
202 |
+
self.relu = nn.ReLU()
|
203 |
+
|
204 |
+
def forward(self, up_in):
|
205 |
+
s = self.hook.feature
|
206 |
+
up_out = self.shuf(up_in)
|
207 |
+
cat_x = self.relu(torch.cat([up_out, self.bn(s)], dim=1))
|
208 |
+
return self.conv(cat_x)
|
basicsr/archs/ddcolor_arch_utils/util.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from skimage import color
|
4 |
+
|
5 |
+
|
6 |
+
def rgb2lab(img_rgb):
|
7 |
+
img_lab = color.rgb2lab(img_rgb)
|
8 |
+
return img_lab[:, :, :1], img_lab[:, :, 1:]
|
9 |
+
|
10 |
+
|
11 |
+
def tensor_lab2rgb(labs, illuminant="D65", observer="2"):
|
12 |
+
"""
|
13 |
+
Args:
|
14 |
+
lab : (B, C, H, W)
|
15 |
+
Returns:
|
16 |
+
tuple : (B, C, H, W)
|
17 |
+
"""
|
18 |
+
illuminants = \
|
19 |
+
{"A": {'2': (1.098466069456375, 1, 0.3558228003436005),
|
20 |
+
'10': (1.111420406956693, 1, 0.3519978321919493)},
|
21 |
+
"D50": {'2': (0.9642119944211994, 1, 0.8251882845188288),
|
22 |
+
'10': (0.9672062750333777, 1, 0.8142801513128616)},
|
23 |
+
"D55": {'2': (0.956797052643698, 1, 0.9214805860173273),
|
24 |
+
'10': (0.9579665682254781, 1, 0.9092525159847462)},
|
25 |
+
"D65": {'2': (0.95047, 1., 1.08883), # This was: `lab_ref_white`
|
26 |
+
'10': (0.94809667673716, 1, 1.0730513595166162)},
|
27 |
+
"D75": {'2': (0.9497220898840717, 1, 1.226393520724154),
|
28 |
+
'10': (0.9441713925645873, 1, 1.2064272211720228)},
|
29 |
+
"E": {'2': (1.0, 1.0, 1.0),
|
30 |
+
'10': (1.0, 1.0, 1.0)}}
|
31 |
+
xyz_from_rgb = np.array([[0.412453, 0.357580, 0.180423], [0.212671, 0.715160, 0.072169],
|
32 |
+
[0.019334, 0.119193, 0.950227]])
|
33 |
+
|
34 |
+
rgb_from_xyz = np.array([[3.240481340, -0.96925495, 0.055646640], [-1.53715152, 1.875990000, -0.20404134],
|
35 |
+
[-0.49853633, 0.041555930, 1.057311070]])
|
36 |
+
B, C, H, W = labs.shape
|
37 |
+
arrs = labs.permute((0, 2, 3, 1)).contiguous() # (B, 3, H, W) -> (B, H, W, 3)
|
38 |
+
L, a, b = arrs[:, :, :, 0:1], arrs[:, :, :, 1:2], arrs[:, :, :, 2:]
|
39 |
+
y = (L + 16.) / 116.
|
40 |
+
x = (a / 500.) + y
|
41 |
+
z = y - (b / 200.)
|
42 |
+
invalid = z.data < 0
|
43 |
+
z[invalid] = 0
|
44 |
+
xyz = torch.cat([x, y, z], dim=3)
|
45 |
+
mask = xyz.data > 0.2068966
|
46 |
+
mask_xyz = xyz.clone()
|
47 |
+
mask_xyz[mask] = torch.pow(xyz[mask], 3.0)
|
48 |
+
mask_xyz[~mask] = (xyz[~mask] - 16.0 / 116.) / 7.787
|
49 |
+
xyz_ref_white = illuminants[illuminant][observer]
|
50 |
+
for i in range(C):
|
51 |
+
mask_xyz[:, :, :, i] = mask_xyz[:, :, :, i] * xyz_ref_white[i]
|
52 |
+
|
53 |
+
rgb_trans = torch.mm(mask_xyz.view(-1, 3), torch.from_numpy(rgb_from_xyz).type_as(xyz)).view(B, H, W, C)
|
54 |
+
rgb = rgb_trans.permute((0, 3, 1, 2)).contiguous()
|
55 |
+
mask = rgb.data > 0.0031308
|
56 |
+
mask_rgb = rgb.clone()
|
57 |
+
mask_rgb[mask] = 1.055 * torch.pow(rgb[mask], 1 / 2.4) - 0.055
|
58 |
+
mask_rgb[~mask] = rgb[~mask] * 12.92
|
59 |
+
neg_mask = mask_rgb.data < 0
|
60 |
+
large_mask = mask_rgb.data > 1
|
61 |
+
mask_rgb[neg_mask] = 0
|
62 |
+
mask_rgb[large_mask] = 1
|
63 |
+
return mask_rgb
|
basicsr/archs/discriminator_arch.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torchvision import models
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
from basicsr.archs.ddcolor_arch_utils.unet import _conv
|
7 |
+
from basicsr.utils.registry import ARCH_REGISTRY
|
8 |
+
|
9 |
+
|
10 |
+
@ARCH_REGISTRY.register()
|
11 |
+
class DynamicUNetDiscriminator(nn.Module):
|
12 |
+
|
13 |
+
def __init__(self, n_channels: int = 3, nf: int = 256, n_blocks: int = 3):
|
14 |
+
super().__init__()
|
15 |
+
layers = [_conv(n_channels, nf, ks=4, stride=2)]
|
16 |
+
for i in range(n_blocks):
|
17 |
+
layers += [
|
18 |
+
_conv(nf, nf, ks=3, stride=1),
|
19 |
+
_conv(nf, nf * 2, ks=4, stride=2, self_attention=(i == 0)),
|
20 |
+
]
|
21 |
+
nf *= 2
|
22 |
+
layers += [_conv(nf, nf, ks=3, stride=1), _conv(nf, 1, ks=4, bias=False, padding=0, use_activ=False)]
|
23 |
+
self.layers = nn.Sequential(*layers)
|
24 |
+
|
25 |
+
def forward(self, x):
|
26 |
+
out = self.layers(x)
|
27 |
+
out = out.view(out.size(0), -1)
|
28 |
+
return out
|
basicsr/archs/vgg_arch.py
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
from collections import OrderedDict
|
4 |
+
from torch import nn as nn
|
5 |
+
from torchvision.models import vgg as vgg
|
6 |
+
|
7 |
+
from basicsr.utils.registry import ARCH_REGISTRY
|
8 |
+
|
9 |
+
VGG_PRETRAIN_PATH = {
|
10 |
+
'vgg19': './pretrain/vgg19-dcbb9e9d.pth',
|
11 |
+
'vgg16_bn': './pretrain/vgg16_bn-6c64b313.pth'
|
12 |
+
}
|
13 |
+
|
14 |
+
NAMES = {
|
15 |
+
'vgg11': [
|
16 |
+
'conv1_1', 'relu1_1', 'pool1', 'conv2_1', 'relu2_1', 'pool2', 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2',
|
17 |
+
'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2',
|
18 |
+
'pool5'
|
19 |
+
],
|
20 |
+
'vgg13': [
|
21 |
+
'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
|
22 |
+
'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4',
|
23 |
+
'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'pool5'
|
24 |
+
],
|
25 |
+
'vgg16': [
|
26 |
+
'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
|
27 |
+
'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2',
|
28 |
+
'relu4_2', 'conv4_3', 'relu4_3', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3',
|
29 |
+
'pool5'
|
30 |
+
],
|
31 |
+
'vgg19': [
|
32 |
+
'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
|
33 |
+
'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4', 'pool3', 'conv4_1',
|
34 |
+
'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'conv4_4', 'relu4_4', 'pool4', 'conv5_1', 'relu5_1',
|
35 |
+
'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4', 'pool5'
|
36 |
+
]
|
37 |
+
}
|
38 |
+
|
39 |
+
|
40 |
+
def insert_bn(names):
|
41 |
+
"""Insert bn layer after each conv.
|
42 |
+
|
43 |
+
Args:
|
44 |
+
names (list): The list of layer names.
|
45 |
+
|
46 |
+
Returns:
|
47 |
+
list: The list of layer names with bn layers.
|
48 |
+
"""
|
49 |
+
names_bn = []
|
50 |
+
for name in names:
|
51 |
+
names_bn.append(name)
|
52 |
+
if 'conv' in name:
|
53 |
+
position = name.replace('conv', '')
|
54 |
+
names_bn.append('bn' + position)
|
55 |
+
return names_bn
|
56 |
+
|
57 |
+
|
58 |
+
@ARCH_REGISTRY.register()
|
59 |
+
class VGGFeatureExtractor(nn.Module):
|
60 |
+
"""VGG network for feature extraction.
|
61 |
+
|
62 |
+
In this implementation, we allow users to choose whether use normalization
|
63 |
+
in the input feature and the type of vgg network. Note that the pretrained
|
64 |
+
path must fit the vgg type.
|
65 |
+
|
66 |
+
Args:
|
67 |
+
layer_name_list (list[str]): Forward function returns the corresponding
|
68 |
+
features according to the layer_name_list.
|
69 |
+
Example: {'relu1_1', 'relu2_1', 'relu3_1'}.
|
70 |
+
vgg_type (str): Set the type of vgg network. Default: 'vgg19'.
|
71 |
+
use_input_norm (bool): If True, normalize the input image. Importantly,
|
72 |
+
the input feature must in the range [0, 1]. Default: True.
|
73 |
+
range_norm (bool): If True, norm images with range [-1, 1] to [0, 1].
|
74 |
+
Default: False.
|
75 |
+
requires_grad (bool): If true, the parameters of VGG network will be
|
76 |
+
optimized. Default: False.
|
77 |
+
remove_pooling (bool): If true, the max pooling operations in VGG net
|
78 |
+
will be removed. Default: False.
|
79 |
+
pooling_stride (int): The stride of max pooling operation. Default: 2.
|
80 |
+
"""
|
81 |
+
|
82 |
+
def __init__(self,
|
83 |
+
layer_name_list,
|
84 |
+
vgg_type='vgg19',
|
85 |
+
use_input_norm=True,
|
86 |
+
range_norm=False,
|
87 |
+
requires_grad=False,
|
88 |
+
remove_pooling=False,
|
89 |
+
pooling_stride=2):
|
90 |
+
super(VGGFeatureExtractor, self).__init__()
|
91 |
+
|
92 |
+
self.layer_name_list = layer_name_list
|
93 |
+
self.use_input_norm = use_input_norm
|
94 |
+
self.range_norm = range_norm
|
95 |
+
|
96 |
+
self.names = NAMES[vgg_type.replace('_bn', '')]
|
97 |
+
if 'bn' in vgg_type:
|
98 |
+
self.names = insert_bn(self.names)
|
99 |
+
|
100 |
+
# only borrow layers that will be used to avoid unused params
|
101 |
+
max_idx = 0
|
102 |
+
for v in layer_name_list:
|
103 |
+
idx = self.names.index(v)
|
104 |
+
if idx > max_idx:
|
105 |
+
max_idx = idx
|
106 |
+
|
107 |
+
if os.path.exists(VGG_PRETRAIN_PATH[vgg_type]):
|
108 |
+
vgg_net = getattr(vgg, vgg_type)(pretrained=False)
|
109 |
+
state_dict = torch.load(VGG_PRETRAIN_PATH[vgg_type], map_location=lambda storage, loc: storage)
|
110 |
+
vgg_net.load_state_dict(state_dict)
|
111 |
+
else:
|
112 |
+
vgg_net = getattr(vgg, vgg_type)(pretrained=True)
|
113 |
+
|
114 |
+
features = vgg_net.features[:max_idx + 1]
|
115 |
+
|
116 |
+
modified_net = OrderedDict()
|
117 |
+
for k, v in zip(self.names, features):
|
118 |
+
if 'pool' in k:
|
119 |
+
# if remove_pooling is true, pooling operation will be removed
|
120 |
+
if remove_pooling:
|
121 |
+
continue
|
122 |
+
else:
|
123 |
+
# in some cases, we may want to change the default stride
|
124 |
+
modified_net[k] = nn.MaxPool2d(kernel_size=2, stride=pooling_stride)
|
125 |
+
else:
|
126 |
+
modified_net[k] = v
|
127 |
+
|
128 |
+
self.vgg_net = nn.Sequential(modified_net)
|
129 |
+
|
130 |
+
if not requires_grad:
|
131 |
+
self.vgg_net.eval()
|
132 |
+
for param in self.parameters():
|
133 |
+
param.requires_grad = False
|
134 |
+
else:
|
135 |
+
self.vgg_net.train()
|
136 |
+
for param in self.parameters():
|
137 |
+
param.requires_grad = True
|
138 |
+
|
139 |
+
if self.use_input_norm:
|
140 |
+
# the mean is for image with range [0, 1]
|
141 |
+
self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
|
142 |
+
# the std is for image with range [0, 1]
|
143 |
+
self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
|
144 |
+
|
145 |
+
def forward(self, x):
|
146 |
+
"""Forward function.
|
147 |
+
|
148 |
+
Args:
|
149 |
+
x (Tensor): Input tensor with shape (n, c, h, w).
|
150 |
+
|
151 |
+
Returns:
|
152 |
+
Tensor: Forward results.
|
153 |
+
"""
|
154 |
+
if self.range_norm:
|
155 |
+
x = (x + 1) / 2
|
156 |
+
if self.use_input_norm:
|
157 |
+
x = (x - self.mean) / self.std
|
158 |
+
|
159 |
+
output = {}
|
160 |
+
for key, layer in self.vgg_net._modules.items():
|
161 |
+
x = layer(x)
|
162 |
+
if key in self.layer_name_list:
|
163 |
+
output[key] = x.clone()
|
164 |
+
|
165 |
+
return output
|
basicsr/data/__init__.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
import numpy as np
|
3 |
+
import random
|
4 |
+
import torch
|
5 |
+
import torch.utils.data
|
6 |
+
from copy import deepcopy
|
7 |
+
from functools import partial
|
8 |
+
from os import path as osp
|
9 |
+
|
10 |
+
from basicsr.data.prefetch_dataloader import PrefetchDataLoader
|
11 |
+
from basicsr.utils import get_root_logger, scandir
|
12 |
+
from basicsr.utils.dist_util import get_dist_info
|
13 |
+
from basicsr.utils.registry import DATASET_REGISTRY
|
14 |
+
|
15 |
+
__all__ = ['build_dataset', 'build_dataloader']
|
16 |
+
|
17 |
+
# automatically scan and import dataset modules for registry
|
18 |
+
# scan all the files under the data folder with '_dataset' in file names
|
19 |
+
data_folder = osp.dirname(osp.abspath(__file__))
|
20 |
+
dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')]
|
21 |
+
# import all the dataset modules
|
22 |
+
_dataset_modules = [importlib.import_module(f'basicsr.data.{file_name}') for file_name in dataset_filenames]
|
23 |
+
|
24 |
+
|
25 |
+
def build_dataset(dataset_opt):
|
26 |
+
"""Build dataset from options.
|
27 |
+
|
28 |
+
Args:
|
29 |
+
dataset_opt (dict): Configuration for dataset. It must contain:
|
30 |
+
name (str): Dataset name.
|
31 |
+
type (str): Dataset type.
|
32 |
+
"""
|
33 |
+
dataset_opt = deepcopy(dataset_opt)
|
34 |
+
dataset = DATASET_REGISTRY.get(dataset_opt['type'])(dataset_opt)
|
35 |
+
logger = get_root_logger()
|
36 |
+
logger.info(f'Dataset [{dataset.__class__.__name__}] - {dataset_opt["name"]} ' 'is built.')
|
37 |
+
return dataset
|
38 |
+
|
39 |
+
|
40 |
+
def build_dataloader(dataset, dataset_opt, num_gpu=1, dist=False, sampler=None, seed=None):
|
41 |
+
"""Build dataloader.
|
42 |
+
|
43 |
+
Args:
|
44 |
+
dataset (torch.utils.data.Dataset): Dataset.
|
45 |
+
dataset_opt (dict): Dataset options. It contains the following keys:
|
46 |
+
phase (str): 'train' or 'val'.
|
47 |
+
num_worker_per_gpu (int): Number of workers for each GPU.
|
48 |
+
batch_size_per_gpu (int): Training batch size for each GPU.
|
49 |
+
num_gpu (int): Number of GPUs. Used only in the train phase.
|
50 |
+
Default: 1.
|
51 |
+
dist (bool): Whether in distributed training. Used only in the train
|
52 |
+
phase. Default: False.
|
53 |
+
sampler (torch.utils.data.sampler): Data sampler. Default: None.
|
54 |
+
seed (int | None): Seed. Default: None
|
55 |
+
"""
|
56 |
+
phase = dataset_opt['phase']
|
57 |
+
rank, _ = get_dist_info()
|
58 |
+
if phase == 'train':
|
59 |
+
if dist: # distributed training
|
60 |
+
batch_size = dataset_opt['batch_size_per_gpu']
|
61 |
+
num_workers = dataset_opt['num_worker_per_gpu']
|
62 |
+
else: # non-distributed training
|
63 |
+
multiplier = 1 if num_gpu == 0 else num_gpu
|
64 |
+
batch_size = dataset_opt['batch_size_per_gpu'] * multiplier
|
65 |
+
num_workers = dataset_opt['num_worker_per_gpu'] * multiplier
|
66 |
+
dataloader_args = dict(
|
67 |
+
dataset=dataset,
|
68 |
+
batch_size=batch_size,
|
69 |
+
shuffle=False,
|
70 |
+
num_workers=num_workers,
|
71 |
+
sampler=sampler,
|
72 |
+
drop_last=True)
|
73 |
+
if sampler is None:
|
74 |
+
dataloader_args['shuffle'] = True
|
75 |
+
dataloader_args['worker_init_fn'] = partial(
|
76 |
+
worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) if seed is not None else None
|
77 |
+
elif phase in ['val', 'test']: # validation
|
78 |
+
dataloader_args = dict(dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
|
79 |
+
else:
|
80 |
+
raise ValueError(f'Wrong dataset phase: {phase}. ' "Supported ones are 'train', 'val' and 'test'.")
|
81 |
+
|
82 |
+
dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False)
|
83 |
+
dataloader_args['persistent_workers'] = dataset_opt.get('persistent_workers', False)
|
84 |
+
|
85 |
+
prefetch_mode = dataset_opt.get('prefetch_mode')
|
86 |
+
if prefetch_mode == 'cpu': # CPUPrefetcher
|
87 |
+
num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1)
|
88 |
+
logger = get_root_logger()
|
89 |
+
logger.info(f'Use {prefetch_mode} prefetch dataloader: num_prefetch_queue = {num_prefetch_queue}')
|
90 |
+
return PrefetchDataLoader(num_prefetch_queue=num_prefetch_queue, **dataloader_args)
|
91 |
+
else:
|
92 |
+
# prefetch_mode=None: Normal dataloader
|
93 |
+
# prefetch_mode='cuda': dataloader for CUDAPrefetcher
|
94 |
+
return torch.utils.data.DataLoader(**dataloader_args)
|
95 |
+
|
96 |
+
|
97 |
+
def worker_init_fn(worker_id, num_workers, rank, seed):
|
98 |
+
# Set the worker seed to num_workers * rank + worker_id + seed
|
99 |
+
worker_seed = num_workers * rank + worker_id + seed
|
100 |
+
np.random.seed(worker_seed)
|
101 |
+
random.seed(worker_seed)
|
basicsr/data/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (3.56 kB). View file
|
|