Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
from torch import nn, Tensor | |
import torch.nn.functional as F | |
from typing import Tuple, Any, Optional, Union | |
from types import FunctionType | |
from itertools import repeat | |
from collections.abc import Iterable | |
def _log_api_usage_once(obj: Any) -> None: | |
""" | |
Logs API usage(module and name) within an organization. | |
In a large ecosystem, it's often useful to track the PyTorch and | |
TorchVision APIs usage. This API provides the similar functionality to the | |
logging module in the Python stdlib. It can be used for debugging purpose | |
to log which methods are used and by default it is inactive, unless the user | |
manually subscribes a logger via the `SetAPIUsageLogger method <https://github.com/pytorch/pytorch/blob/eb3b9fe719b21fae13c7a7cf3253f970290a573e/c10/util/Logging.cpp#L114>`_. | |
Please note it is triggered only once for the same API call within a process. | |
It does not collect any data from open-source users since it is no-op by default. | |
For more information, please refer to | |
* PyTorch note: https://pytorch.org/docs/stable/notes/large_scale_deployments.html#api-usage-logging; | |
* Logging policy: https://github.com/pytorch/vision/issues/5052; | |
Args: | |
obj (class instance or method): an object to extract info from. | |
""" | |
module = obj.__module__ | |
if not module.startswith("torchvision"): | |
module = f"torchvision.internal.{module}" | |
name = obj.__class__.__name__ | |
if isinstance(obj, FunctionType): | |
name = obj.__name__ | |
torch._C._log_api_usage_once(f"{module}.{name}") | |
def _make_ntuple(x: Any, n: int) -> Tuple[Any, ...]: | |
""" | |
Make n-tuple from input x. If x is an iterable, then we just convert it to tuple. | |
Otherwise, we will make a tuple of length n, all with value of x. | |
reference: https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/utils.py#L8 | |
Args: | |
x (Any): input value | |
n (int): length of the resulting tuple | |
""" | |
if isinstance(x, Iterable): | |
return tuple(x) | |
return tuple(repeat(x, n)) | |
def _init_weights(model: nn.Module) -> None: | |
for m in model.modules(): | |
if isinstance(m, nn.Conv2d): | |
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") | |
if m.bias is not None: | |
nn.init.constant_(m.bias, 0.) | |
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm, nn.LayerNorm)): | |
nn.init.constant_(m.weight, 1.) | |
if m.bias is not None: | |
nn.init.constant_(m.bias, 0.) | |
elif isinstance(m, nn.Linear): | |
nn.init.normal_(m.weight, std=0.01) | |
if m.bias is not None: | |
nn.init.constant_(m.bias, 0.) | |
def interpolate_pos_embed(pos_embed: Tensor, size: Optional[Union[int, Tuple[int, int]]] = None, scale_factor: Optional[float] = None) -> Tensor: | |
assert len(pos_embed.shape) == 3, f"Positional embedding should be 3D tensor (C, H, W), but got {pos_embed.shape}." | |
return F.interpolate( | |
pos_embed.unsqueeze(0), | |
size=size, | |
scale_factor=scale_factor, | |
mode="bicubic", | |
align_corners=False, | |
antialias=True, | |
).squeeze(0) | |