| import torch | |
| def get_device(device = None): | |
| if device is None: | |
| # get cuda -> mps -> cpu | |
| if torch.cuda.is_available(): | |
| device = "cuda" | |
| elif torch.backends.mps.is_available(): | |
| if torch.backends.mps.is_built(): | |
| device = "mps" | |
| else: | |
| device = "cpu" | |
| else: | |
| device = "cpu" | |
| return device |