Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,181 Bytes
a7dedf9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
import torch
from torch import nn, Tensor
import torch.nn.functional as F
from typing import Tuple, Any, Optional, Union
from types import FunctionType
from itertools import repeat
from collections.abc import Iterable
def _log_api_usage_once(obj: Any) -> None:
"""
Logs API usage(module and name) within an organization.
In a large ecosystem, it's often useful to track the PyTorch and
TorchVision APIs usage. This API provides the similar functionality to the
logging module in the Python stdlib. It can be used for debugging purpose
to log which methods are used and by default it is inactive, unless the user
manually subscribes a logger via the `SetAPIUsageLogger method <https://github.com/pytorch/pytorch/blob/eb3b9fe719b21fae13c7a7cf3253f970290a573e/c10/util/Logging.cpp#L114>`_.
Please note it is triggered only once for the same API call within a process.
It does not collect any data from open-source users since it is no-op by default.
For more information, please refer to
* PyTorch note: https://pytorch.org/docs/stable/notes/large_scale_deployments.html#api-usage-logging;
* Logging policy: https://github.com/pytorch/vision/issues/5052;
Args:
obj (class instance or method): an object to extract info from.
"""
module = obj.__module__
if not module.startswith("torchvision"):
module = f"torchvision.internal.{module}"
name = obj.__class__.__name__
if isinstance(obj, FunctionType):
name = obj.__name__
torch._C._log_api_usage_once(f"{module}.{name}")
def _make_ntuple(x: Any, n: int) -> Tuple[Any, ...]:
"""
Make n-tuple from input x. If x is an iterable, then we just convert it to tuple.
Otherwise, we will make a tuple of length n, all with value of x.
reference: https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/utils.py#L8
Args:
x (Any): input value
n (int): length of the resulting tuple
"""
if isinstance(x, Iterable):
return tuple(x)
return tuple(repeat(x, n))
def _init_weights(model: nn.Module) -> None:
for m in model.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
if m.bias is not None:
nn.init.constant_(m.bias, 0.)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm, nn.LayerNorm)):
nn.init.constant_(m.weight, 1.)
if m.bias is not None:
nn.init.constant_(m.bias, 0.)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, std=0.01)
if m.bias is not None:
nn.init.constant_(m.bias, 0.)
def interpolate_pos_embed(pos_embed: Tensor, size: Optional[Union[int, Tuple[int, int]]] = None, scale_factor: Optional[float] = None) -> Tensor:
assert len(pos_embed.shape) == 3, f"Positional embedding should be 3D tensor (C, H, W), but got {pos_embed.shape}."
return F.interpolate(
pos_embed.unsqueeze(0),
size=size,
scale_factor=scale_factor,
mode="bicubic",
align_corners=False,
antialias=True,
).squeeze(0)
|