Skip to content
/ mpi4jax Public
forked from mpi4jax/mpi4jax

Zero-copy MPI communication of JAX arrays, for turbo-charged HPC applications in Python ⚡

License

Notifications You must be signed in to change notification settings

Silv3S/mpi4jax

 
 

Repository files navigation

mpi4jax

JOSS paper PyPI Version Conda Version Tests codecov Documentation Status

mpi4jax enables zero-copy, multi-host communication of JAX arrays, even from jitted code and from GPU memory.

But why?

The JAX framework has great performance for scientific computing workloads, but its multi-host capabilities are still limited.

With mpi4jax, you can scale your JAX-based simulations to entire CPU and GPU clusters (without ever leaving jax.jit).

In the spirit of differentiable programming, mpi4jax also supports differentiating through some MPI operations.

Installation

mpi4jax is available through pip and conda:

$ pip install mpi4jax                     # Pip
$ conda install -c conda-forge mpi4jax    # conda

If you use pip and don't have JAX installed already, you will also need to do:

$ pip install jaxlib

(or an equivalent GPU-enabled version, see the JAX installation instructions)

In case your MPI installation is not detected correctly, it can help to install mpi4py separately. When using a pre-installed mpi4py, you must use --no-build-isolation when installing mpi4jax:

# if mpi4py is already installed
$ pip install cython
$ pip install mpi4jax --no-build-isolation

Our documentation includes some more advanced installation examples.

Example usage

from mpi4py import MPI
import jax
import jax.numpy as jnp
import mpi4jax

comm = MPI.COMM_WORLD
rank = comm.Get_rank()

@jax.jit
def foo(arr):
   arr = arr + rank
   arr_sum, _ = mpi4jax.allreduce(arr, op=MPI.SUM, comm=comm)
   return arr_sum

a = jnp.zeros((3, 3))
result = foo(a)

if rank == 0:
   print(result)

Running this script on 4 processes gives:

$ mpirun -n 4 python example.py
[[6. 6. 6.]
 [6. 6. 6.]
 [6. 6. 6.]]

allreduce is just one example of the MPI primitives you can use. See all supported operations here.

Community guidelines

If you have a question or feature request, or want to report a bug, feel free to open an issue.

We welcome contributions of any kind through pull requests. For information on running our tests, debugging, and contribution guidelines please refer to the corresponding documentation page.

How to cite

If you use mpi4jax in your work, please consider citing the following article:

@article{mpi4jax,
  doi = {10.21105/joss.03419},
  url = {https://doi.org/10.21105/joss.03419},
  year = {2021},
  publisher = {The Open Journal},
  volume = {6},
  number = {65},
  pages = {3419},
  author = {Dion Häfner and Filippo Vicentini},
  title = {mpi4jax: Zero-copy MPI communication of JAX arrays},
  journal = {Journal of Open Source Software}
}

About

Zero-copy MPI communication of JAX arrays, for turbo-charged HPC applications in Python ⚡

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 81.2%
  • Cython 17.9%
  • Shell 0.9%