from dataclasses import dataclass
from pathlib import Path
from experimaestro import Config, Param
from experimaestro.compat import cached_property
import torch
from experimaestro.taskglobals import Env as TaskEnv
import torch.distributed as dist
import torch.multiprocessing as mp
import tempfile
from xpmir.context import Context
from xpmir.utils.utils import easylog
logger = easylog()
@dataclass
class DeviceInformation:
device: torch.device
"""The device"""
main: bool
"""Flag for the main process (all other are slaves)"""
count: int = 1
"""Number of processes"""
rank: int = 0
"""When using distributed processing, this is the rank of the process"""
class ComputationContext(Context):
device_information: DeviceInformation
@dataclass
class DistributedDeviceInformation(DeviceInformation):
pass
[docs]class Device(Config):
"""Device to use, as well as specific option (e.g. parallelism)"""
@cached_property
def value(self):
import torch
return torch.device("cpu")
n_processes = 1
"""Number of processes"""
def execute(self, callback, *args, **kwargs):
callback(DeviceInformation(self.value, True), *args, **kwargs)
def mp_launcher(rank, path, world_size, callback, taskenv, args, kwargs):
logger.info("Started process for rank %d [%s]", rank, path)
TaskEnv._instance = taskenv
taskenv.slave = rank == 0
logger.info("Initializing process group [%d]", rank)
dist.init_process_group(
"gloo", init_method=f"file://{path}", rank=rank, world_size=world_size
)
logger.info("Calling callback [%d]", rank)
device = torch.device(f"cuda:{rank}")
callback(
DistributedDeviceInformation(
device=device, main=rank == 0, rank=rank, count=world_size
),
*args,
**kwargs,
)
# Cleanup
dist.destroy_process_group()
[docs]class CudaDevice(Device):
"""CUDA device"""
gpu_determ: Param[bool] = False
"""Sets the deterministic"""
cpu_fallback: Param[bool] = False
"""Fallback to CPU if no GPU is available"""
distributed: Param[bool] = False
"""Flag for using DistributedDataParallel
When the number of GPUs is greater than one, use
torch.nn.parallel.DistributedDataParallel when `distributed` is `True` and
the number of GPUs greater than 1. When False, use `torch.nn.DataParallel`
"""
@cached_property
def value(self):
"""Called by experimaestro to substitute object at run time"""
if not torch.cuda.is_available():
if not self.cpu_fallback:
# Not accepting fallbacks
raise AssertionError("No GPU available")
logger.error("No GPU available. Falling back on CPU.")
return torch.device("cpu")
# Set the deterministic flag
torch.backends.cudnn.deterministic = self.gpu_determ
if self.gpu_determ:
logger.debug("using GPU (deterministic)")
else:
logger.debug("using GPU (non-deterministic)")
return torch.device("cuda")
@cached_property
def n_processes(self):
"""Number of processes"""
if self.distributed:
return torch.cuda.device_count()
return 1
def execute(self, callback, *args, **kwargs):
# Setup distributed computation
# Seehttps://pytorch.org/tutorials/intermediate/ddp_tutorial.html
n_gpus = torch.cuda.device_count()
if n_gpus == 1 or not self.distributed:
callback(DeviceInformation(self.value, True), *args, **kwargs)
else:
with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as directory:
logger.info("Setting up distributed CUDA computing (%d GPUs)", n_gpus)
return mp.start_processes(
mp_launcher,
args=(
str((Path(directory) / "link").absolute()),
n_gpus,
callback,
TaskEnv.instance(),
args,
kwargs,
),
nprocs=n_gpus,
join=True,
start_method=mp.get_start_method(),
)
# Default device is the CPU
DEFAULT_DEVICE = Device()