waytan22's picture
Upload folder using huggingface_hub
e730386 verified
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@File : utils.py
@Time : 2023/8/8 下午4:26
@Author : waytan
@Contact : [email protected]
@License : (C)Copyright 2023, Tencent
@Desc : utils
"""
from contextlib import contextmanager
import math
import os
import tempfile
import typing as tp
import json
import subprocess
import torch
from torch.nn import functional as F
from torch.utils.data import Subset
def unfold(a, kernel_size, stride):
"""Given input of size [*OT, T], output Tensor of size [*OT, F, K]
with K the kernel size, by extracting frames with the given stride.
This will pad the input so that `F = ceil(T / K)`.
see https://github.com/pytorch/pytorch/issues/60466
"""
*shape, length = a.shape
n_frames = math.ceil(length / stride)
tgt_length = (n_frames - 1) * stride + kernel_size
a = F.pad(a, (0, tgt_length - length))
strides = list(a.stride())
assert strides[-1] == 1, 'data should be contiguous'
strides = strides[:-1] + [stride, 1]
return a.as_strided([*shape, n_frames, kernel_size], strides)
def center_trim(tensor: torch.Tensor, reference: tp.Union[torch.Tensor, int]):
"""
Center trim `tensor` with respect to `reference`, along the last dimension.
`reference` can also be a number, representing the length to trim to.
If the size difference != 0 mod 2, the extra sample is removed on the right side.
"""
ref_size: int
if isinstance(reference, torch.Tensor):
ref_size = reference.size(-1)
else:
ref_size = reference
delta = tensor.size(-1) - ref_size
if delta < 0:
raise ValueError("tensor must be larger than reference. " f"Delta is {delta}.")
if delta:
tensor = tensor[..., delta // 2:-(delta - delta // 2)]
return tensor
def pull_metric(history: tp.List[dict], name: str):
out = []
for metrics in history:
metric = metrics
for part in name.split("."):
metric = metric[part]
out.append(metric)
return out
def sizeof_fmt(num: float, suffix: str = 'B'):
"""
Given `num` bytes, return human readable size.
Taken from https://stackoverflow.com/a/1094933
"""
for unit in ['', 'Ki', 'Mi', 'Gi', 'Ti', 'Pi', 'Ei', 'Zi']:
if abs(num) < 1024.0:
return "%3.1f%s%s" % (num, unit, suffix)
num /= 1024.0
return "%.1f%s%s" % (num, 'Yi', suffix)
@contextmanager
def temp_filenames(count: int, delete=True):
names = []
try:
for _ in range(count):
names.append(tempfile.NamedTemporaryFile(delete=False).name)
yield names
finally:
if delete:
for name in names:
os.unlink(name)
def random_subset(dataset, max_samples: int, seed: int = 42):
if max_samples >= len(dataset):
return dataset
generator = torch.Generator().manual_seed(seed)
perm = torch.randperm(len(dataset), generator=generator)
return Subset(dataset, perm[:max_samples].tolist())
class DummyPoolExecutor:
class DummyResult:
def __init__(self, func, *args, **kwargs):
self.func = func
self.args = args
self.kwargs = kwargs
def result(self):
return self.func(*self.args, **self.kwargs)
def __init__(self, workers=0):
pass
def submit(self, func, *args, **kwargs):
return DummyPoolExecutor.DummyResult(func, *args, **kwargs)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, exc_tb):
return