PyTorch Devices
Device for Training
You can specify the device for training with strings or with torch.device
:
device = "cuda"
# or
device = torch.device("cuda")
To Use Certain GPU
If you have multiple GPUs, you can specify which one to use with code or with environment variables:
device = torch.device("cuda:0")
# or
export CUDA_VISIBLE_DEVICES=0
Allowing Fallback Options
If you don’t have a GPU, you should write your devices to allow for fallback options:
device = torch.device(
"cuda" if torch.cuda.is_available()
else "mps" if torch.backends.mps.is_available()
else "cpu"
)
MPS Backend
Apple Silicon Mac’s unified memory architecture allows GPU to access memory directly. We can accelerate training using Apple’s Metal Performance Shaders (MPS).
First check availability. MPS must be available on your device and PyTorch must be built with MPS support:
torch.backends.mps.is_available()
torch.backends.mps.is_built()
Just like cuda
, you can specify the devices as mps
:
device = "mps"
# or
device = torch.device("mps")
References