You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The keras.distribution.distribution_lib follows JAX/TF distribution APIs with concepts such as DeviceMesh and TensorLayout, in which the user specifies:
DeviceMesh: An n-dimensional "grid" of the target devices
TensorLayout: How the array (tensor) is sharded (or replicated) across the DeviceMesh
Once the TensorLayout is applied to a tensor, the compiler (in this case XLA) understands the tensor topology (how the tensor is sharded/replicated across the devices) and executes/optimizes computations accordingly.
In PyTorch
This work a bit differently in PyTorch Distributed (PTD). The two main PTD offerings are DistributedDataParallel (DDP) and FullyShardedDataParallel (FSDP). Both are wrapper nn.Modules that wrap an underlying (presumably non-distributed) model (an nn.Module) and implement the distributed strategy as per their respective names. Both DDP and FSDP follow a programming paradigm akin to MPI therefore do not fit into the keras distribution APIs.
The forward() and backwards() functions of DDP and FSDP are implemented in such a way that they deal with:
Sharding (in the case of FSDP)
Injection of distributed collectives (e.g. all_reduce, all_gather, etc) at the appropriate places as-per the specified distribution strategy: data-parallel or fully-sharded-data-parallel.
Binning/partitioning of the tensors (both model and data) and (async-)scheduling of collectives/operators to optimize for performance.
NOTE: PyTorch does have a DTensor and DeviceMesh API but they are relatively new and experimental and are currently "hidden" under the torch.distributed._tensor package. Its worthwhile noting that FullyShardedDataParallel offers an experimental "hybrid" strategy (see _HYBRID_SHARD_ZERO2) that accepts a 2-D DeviceMesh and runs with FSDP across one dimension and DDP across the other.
torch_xla
torch_xla is a separate package that users can install in-addition to torch. Its main function today is to allow users to run PyTorch on TPUs, although technically CPU and GPU are also supported through XLA. Under the hood, torch_xla uses LazyTensor to trace the forward and backward graph. It then hands the graph to XLA for compilation. Therefore torch_xla APIs (specifically the SPMD API, see example here) are a better match for Keras' distribution APIs.
Important: this section discusses JIT in the context of NON-distributed (e.g. single worker).
Currently JIT for the pytorch backend is turned off in keras.backend.torch.trainer#make_train_function due to torch.compile() failing. torch.compile is a new API as of torch-2.0 that uses torchdynamo to trace the graph (using the torch.fx IR) which it then hands over a "backend-compiler". PyTorch ships with torchinductor as the builtin compiler, and users can implement/add their own backends. For instance, openxla (uses torch_xla under the hood) is one of the experimental dynamo backends.
Given that torch_xla SPMD APIs will be used to implement the distribution_lib for the torch backend in keras.distribution, it makes sense to use torch_xla (e.g. LazyTensor) directly to enable JIT for keras.backend.torch.trainer.
This requires keras pytorch trainer to:
Create a torch_xla.core.xla_model.xla_device() before the training-loop
Move the the data and model to the xla device
Call xm.mark_step() (or xm.optimizer_step() which calls xm.mark_step() internally) at the end of each step.
Gotchas
Below are a few gotchas to investigate prior to committing to using torch_xla for JIT (distribution API is fine).
For the torch backend, Keras converts the batch tensor returned by the dataloader and converts it into numpy. Need to see whether torch_xla requires this to be converted back to a torch.Tensor and placed on the xla device. See:
a. TorchDataLoaderAdapter#get_numpy_iterator
b. EpochIterator
c. torch.trainer)
Note: torch tensors can operate with numpy arrays but not vice-versa (e.g as long as the numpy array is the RHS operand). What I'm not sure is how the lazy tensor will kick in if one of the operands is a numpy array)
Performance on GPU. Presumably users will expect better performance when JIT is turned on. This may or may not be true if we use torch_xla on GPU, however it is mostly true for inductor.
The text was updated successfully, but these errors were encountered:
There's two use-cases for
torch_xla
for the pytorch backend in Keras, namely:Distribution API
In Keras
The
keras.distribution.distribution_lib
follows JAX/TF distribution APIs with concepts such asDeviceMesh
andTensorLayout
, in which the user specifies:DeviceMesh
: Ann
-dimensional "grid" of the target devicesTensorLayout
: How the array (tensor) is sharded (or replicated) across theDeviceMesh
Once the
TensorLayout
is applied to a tensor, the compiler (in this case XLA) understands the tensor topology (how the tensor is sharded/replicated across the devices) and executes/optimizes computations accordingly.In PyTorch
This work a bit differently in PyTorch Distributed (PTD). The two main PTD offerings are
DistributedDataParallel
(DDP) andFullyShardedDataParallel
(FSDP). Both are wrappernn.Module
s that wrap an underlying (presumably non-distributed) model (annn.Module
) and implement the distributed strategy as per their respective names. Both DDP and FSDP follow a programming paradigm akin to MPI therefore do not fit into the keras distribution APIs.The
forward()
andbackwards()
functions ofDDP
andFSDP
are implemented in such a way that they deal with:all_reduce
,all_gather
, etc) at the appropriate places as-per the specified distribution strategy: data-parallel or fully-sharded-data-parallel.torch_xla
torch_xla
is a separate package that users can install in-addition totorch
. Its main function today is to allow users to run PyTorch on TPUs, although technically CPU and GPU are also supported through XLA. Under the hood,torch_xla
uses LazyTensor to trace the forward and backward graph. It then hands the graph to XLA for compilation. Thereforetorch_xla
APIs (specifically the SPMD API, see example here) are a better match for Keras' distribution APIs.Re-enable JIT
Currently JIT for the pytorch backend is turned off in
keras.backend.torch.trainer#make_train_function
due totorch.compile()
failing.torch.compile
is a new API as of torch-2.0 that uses torchdynamo to trace the graph (using thetorch.fx
IR) which it then hands over a "backend-compiler". PyTorch ships withtorchinductor
as the builtin compiler, and users can implement/add their own backends. For instance,openxla
(usestorch_xla
under the hood) is one of the experimental dynamo backends.Given that
torch_xla
SPMD APIs will be used to implement thedistribution_lib
for the torch backend inkeras.distribution
, it makes sense to usetorch_xla
(e.g. LazyTensor) directly to enable JIT forkeras.backend.torch.trainer
.This requires keras pytorch trainer to:
torch_xla.core.xla_model.xla_device()
before the training-loopxm.mark_step()
(orxm.optimizer_step()
which callsxm.mark_step()
internally) at the end of each step.Gotchas
Below are a few gotchas to investigate prior to committing to using
torch_xla
for JIT (distribution API is fine).For the torch backend, Keras converts the batch tensor returned by the dataloader and converts it into numpy. Need to see whether
torch_xla
requires this to be converted back to atorch.Tensor
and placed on thexla
device. See:a.
TorchDataLoaderAdapter#get_numpy_iterator
b.
EpochIterator
c.
torch.trainer
)Performance on GPU. Presumably users will expect better performance when JIT is turned on. This may or may not be true if we use torch_xla on GPU, however it is mostly true for inductor.
The text was updated successfully, but these errors were encountered: