#!/usr/bin/env python # -*- coding: utf-8 -*- """ @File : utils.py @Time : 2023/8/8 下午4:26 @Author : waytan @Contact : waytan@tencent.com @License : (C)Copyright 2023, Tencent @Desc : utils """ from contextlib import contextmanager import math import os import tempfile import typing as tp import json import subprocess import torch from torch.nn import functional as F from torch.utils.data import Subset def unfold(a, kernel_size, stride): """Given input of size [*OT, T], output Tensor of size [*OT, F, K] with K the kernel size, by extracting frames with the given stride. This will pad the input so that `F = ceil(T / K)`. see https://github.com/pytorch/pytorch/issues/60466 """ *shape, length = a.shape n_frames = math.ceil(length / stride) tgt_length = (n_frames - 1) * stride + kernel_size a = F.pad(a, (0, tgt_length - length)) strides = list(a.stride()) assert strides[-1] == 1, 'data should be contiguous' strides = strides[:-1] + [stride, 1] return a.as_strided([*shape, n_frames, kernel_size], strides) def center_trim(tensor: torch.Tensor, reference: tp.Union[torch.Tensor, int]): """ Center trim `tensor` with respect to `reference`, along the last dimension. `reference` can also be a number, representing the length to trim to. If the size difference != 0 mod 2, the extra sample is removed on the right side. """ ref_size: int if isinstance(reference, torch.Tensor): ref_size = reference.size(-1) else: ref_size = reference delta = tensor.size(-1) - ref_size if delta < 0: raise ValueError("tensor must be larger than reference. " f"Delta is {delta}.") if delta: tensor = tensor[..., delta // 2:-(delta - delta // 2)] return tensor def pull_metric(history: tp.List[dict], name: str): out = [] for metrics in history: metric = metrics for part in name.split("."): metric = metric[part] out.append(metric) return out def sizeof_fmt(num: float, suffix: str = 'B'): """ Given `num` bytes, return human readable size. Taken from https://stackoverflow.com/a/1094933 """ for unit in ['', 'Ki', 'Mi', 'Gi', 'Ti', 'Pi', 'Ei', 'Zi']: if abs(num) < 1024.0: return "%3.1f%s%s" % (num, unit, suffix) num /= 1024.0 return "%.1f%s%s" % (num, 'Yi', suffix) @contextmanager def temp_filenames(count: int, delete=True): names = [] try: for _ in range(count): names.append(tempfile.NamedTemporaryFile(delete=False).name) yield names finally: if delete: for name in names: os.unlink(name) def random_subset(dataset, max_samples: int, seed: int = 42): if max_samples >= len(dataset): return dataset generator = torch.Generator().manual_seed(seed) perm = torch.randperm(len(dataset), generator=generator) return Subset(dataset, perm[:max_samples].tolist()) class DummyPoolExecutor: class DummyResult: def __init__(self, func, *args, **kwargs): self.func = func self.args = args self.kwargs = kwargs def result(self): return self.func(*self.args, **self.kwargs) def __init__(self, workers=0): pass def submit(self, func, *args, **kwargs): return DummyPoolExecutor.DummyResult(func, *args, **kwargs) def __enter__(self): return self def __exit__(self, exc_type, exc_value, exc_tb): return