Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pytorch backend] Integrate torch_xla for distribution APIs and jit #18511

Open
kiukchung opened this issue Sep 27, 2023 · 3 comments
Open

[pytorch backend] Integrate torch_xla for distribution APIs and jit #18511

kiukchung opened this issue Sep 27, 2023 · 3 comments
Assignees

Comments

@kiukchung
Copy link
Contributor

kiukchung commented Sep 27, 2023

There's two use-cases for torch_xla for the pytorch backend in Keras, namely:

  1. Implement the distribution API
  2. re-enable JIT in trainer

Distribution API

In Keras

The keras.distribution.distribution_lib follows JAX/TF distribution APIs with concepts such as DeviceMesh and TensorLayout, in which the user specifies:

  1. DeviceMesh: An n-dimensional "grid" of the target devices
  2. TensorLayout: How the array (tensor) is sharded (or replicated) across the DeviceMesh

Once the TensorLayout is applied to a tensor, the compiler (in this case XLA) understands the tensor topology (how the tensor is sharded/replicated across the devices) and executes/optimizes computations accordingly.

In PyTorch

This work a bit differently in PyTorch Distributed (PTD). The two main PTD offerings are DistributedDataParallel (DDP) and FullyShardedDataParallel (FSDP). Both are wrapper nn.Modules that wrap an underlying (presumably non-distributed) model (an nn.Module) and implement the distributed strategy as per their respective names. Both DDP and FSDP follow a programming paradigm akin to MPI therefore do not fit into the keras distribution APIs.

The forward() and backwards() functions of DDP and FSDP are implemented in such a way that they deal with:

  1. Sharding (in the case of FSDP)
  2. Injection of distributed collectives (e.g. all_reduce, all_gather, etc) at the appropriate places as-per the specified distribution strategy: data-parallel or fully-sharded-data-parallel.
  3. Binning/partitioning of the tensors (both model and data) and (async-)scheduling of collectives/operators to optimize for performance.

NOTE: PyTorch does have a DTensor and DeviceMesh API but they are relatively new and experimental and are currently "hidden" under the torch.distributed._tensor package. Its worthwhile noting that FullyShardedDataParallel offers an experimental "hybrid" strategy (see _HYBRID_SHARD_ZERO2) that accepts a 2-D DeviceMesh and runs with FSDP across one dimension and DDP across the other.

torch_xla

torch_xla is a separate package that users can install in-addition to torch. Its main function today is to allow users to run PyTorch on TPUs, although technically CPU and GPU are also supported through XLA. Under the hood, torch_xla uses LazyTensor to trace the forward and backward graph. It then hands the graph to XLA for compilation. Therefore torch_xla APIs (specifically the SPMD API, see example here) are a better match for Keras' distribution APIs.

Note: torch_xla also works with PTD's native DistributedDataParallel and has its own torch_xla.distributed.fsdp.XlaFullyShardedDataParallel which is API-similar to PTD's native FullyShardedDataParallel.

Re-enable JIT

Important: this section discusses JIT in the context of NON-distributed (e.g. single worker).

Currently JIT for the pytorch backend is turned off in keras.backend.torch.trainer#make_train_function due to torch.compile() failing. torch.compile is a new API as of torch-2.0 that uses torchdynamo to trace the graph (using the torch.fx IR) which it then hands over a "backend-compiler". PyTorch ships with torchinductor as the builtin compiler, and users can implement/add their own backends. For instance, openxla (uses torch_xla under the hood) is one of the experimental dynamo backends.

Given that torch_xla SPMD APIs will be used to implement the distribution_lib for the torch backend in keras.distribution, it makes sense to use torch_xla (e.g. LazyTensor) directly to enable JIT for keras.backend.torch.trainer.

This requires keras pytorch trainer to:

  1. Create a torch_xla.core.xla_model.xla_device() before the training-loop
  2. Move the the data and model to the xla device
  3. Call xm.mark_step() (or xm.optimizer_step() which calls xm.mark_step() internally) at the end of each step.

Gotchas

Below are a few gotchas to investigate prior to committing to using torch_xla for JIT (distribution API is fine).

  1. For the torch backend, Keras converts the batch tensor returned by the dataloader and converts it into numpy. Need to see whether torch_xla requires this to be converted back to a torch.Tensor and placed on the xla device. See:
    a. TorchDataLoaderAdapter#get_numpy_iterator
    b. EpochIterator
    c. torch.trainer)

    Note: torch tensors can operate with numpy arrays but not vice-versa (e.g as long as the numpy array is the RHS operand). What I'm not sure is how the lazy tensor will kick in if one of the operands is a numpy array)

  2. Performance on GPU. Presumably users will expect better performance when JIT is turned on. This may or may not be true if we use torch_xla on GPU, however it is mostly true for inductor.

@kiukchung
Copy link
Contributor Author

Possible duplicate of #18510.

@kiukchung
Copy link
Contributor Author

cc) @fchollet

@qlzh727
Copy link
Member

qlzh727 commented Sep 27, 2023

Also cc @yeounoh from pytorch XLA team.

@sachinprasadhs sachinprasadhs added backend:torch keras-team-review-pending Pending review by a Keras team member. labels Sep 27, 2023
@grasskin grasskin removed the keras-team-review-pending Pending review by a Keras team member. label Sep 28, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
4 participants