import torch from torch import nn, Tensor import torch.nn.functional as F from typing import Tuple, Any, Optional, Union from types import FunctionType from itertools import repeat from collections.abc import Iterable def _log_api_usage_once(obj: Any) -> None: """ Logs API usage(module and name) within an organization. In a large ecosystem, it's often useful to track the PyTorch and TorchVision APIs usage. This API provides the similar functionality to the logging module in the Python stdlib. It can be used for debugging purpose to log which methods are used and by default it is inactive, unless the user manually subscribes a logger via the `SetAPIUsageLogger method `_. Please note it is triggered only once for the same API call within a process. It does not collect any data from open-source users since it is no-op by default. For more information, please refer to * PyTorch note: https://pytorch.org/docs/stable/notes/large_scale_deployments.html#api-usage-logging; * Logging policy: https://github.com/pytorch/vision/issues/5052; Args: obj (class instance or method): an object to extract info from. """ module = obj.__module__ if not module.startswith("torchvision"): module = f"torchvision.internal.{module}" name = obj.__class__.__name__ if isinstance(obj, FunctionType): name = obj.__name__ torch._C._log_api_usage_once(f"{module}.{name}") def _make_ntuple(x: Any, n: int) -> Tuple[Any, ...]: """ Make n-tuple from input x. If x is an iterable, then we just convert it to tuple. Otherwise, we will make a tuple of length n, all with value of x. reference: https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/utils.py#L8 Args: x (Any): input value n (int): length of the resulting tuple """ if isinstance(x, Iterable): return tuple(x) return tuple(repeat(x, n)) def _init_weights(model: nn.Module) -> None: for m in model.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") if m.bias is not None: nn.init.constant_(m.bias, 0.) elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm, nn.LayerNorm)): nn.init.constant_(m.weight, 1.) if m.bias is not None: nn.init.constant_(m.bias, 0.) elif isinstance(m, nn.Linear): nn.init.normal_(m.weight, std=0.01) if m.bias is not None: nn.init.constant_(m.bias, 0.) def interpolate_pos_embed(pos_embed: Tensor, size: Optional[Union[int, Tuple[int, int]]] = None, scale_factor: Optional[float] = None) -> Tensor: assert len(pos_embed.shape) == 3, f"Positional embedding should be 3D tensor (C, H, W), but got {pos_embed.shape}." return F.interpolate( pos_embed.unsqueeze(0), size=size, scale_factor=scale_factor, mode="bicubic", align_corners=False, antialias=True, ).squeeze(0)