File size: 446 Bytes
dbaa71b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import torch


def is_gpu_available() -> bool:
    return torch.cuda.is_available()


def get_device_id(device: str) -> int:
    if device == "cpu":
        return -1
    elif device == "auto":
        return 0 if is_gpu_available() else -1
    elif device.startswith("cuda:"):
        device_no = device.replace("cuda:", "")
        if device_no.isnumeric():
            return int(device_no)

    raise Exception(f"Invalid device: '{device}'")