Skip to content

zamba.models.utils

configure_accelerator_and_devices_from_gpus(gpus)

Derive accelerator and number of devices for pl.Trainer from user-specified number of gpus.

Source code in zamba/models/utils.py
80
81
82
83
84
85
86
87
88
def configure_accelerator_and_devices_from_gpus(gpus):
    """Derive accelerator and number of devices for pl.Trainer from user-specified number of gpus."""
    if gpus > 0:
        accelerator = "gpu"
        devices = gpus
    else:
        accelerator = "cpu"
        devices = "auto"
    return accelerator, devices