Google JAX

Last updated
JAX
Developer(s) Google
Stable release
0.4.24 [1]   OOjs UI icon edit-ltr-progressive.svg / 6 February 2024;3 months ago (6 February 2024)
Repository github.com/google/jax
Written in Python, C++
Operating system Linux, macOS, Windows
Platform Python, NumPy
Size 9.0 MB
Type Machine learning
License Apache 2.0
Website jax.readthedocs.io/en/latest/   OOjs UI icon edit-ltr-progressive.svg

Google JAX is a machine learning framework for transforming numerical functions, to be used in Python. [2] [3] [4] It is described as bringing together a modified version of autograd [5] (automatic obtaining of the gradient function through differentiation of a function) and TensorFlow's XLA (Accelerated Linear Algebra). It is designed to follow the structure and workflow of NumPy as closely as possible and works with various existing frameworks such as TensorFlow and PyTorch. [6] [7] The primary functions of JAX are: [2]

Contents

  1. grad: automatic differentiation
  2. jit: compilation
  3. vmap: auto-vectorization
  4. pmap: SPMD programming

grad

The code below demonstrates the grad function's automatic differentiation.

# importsfromjaximportgradimportjax.numpyasjnp# define the logistic functiondeflogistic(x):returnjnp.exp(x)/(jnp.exp(x)+1)# obtain the gradient function of the logistic functiongrad_logistic=grad(logistic)# evaluate the gradient of the logistic function at x = 1 grad_log_out=grad_logistic(1.0)print(grad_log_out)

The final line should outputː

0.19661194

jit

The code below demonstrates the jit function's optimization through fusion.

# importsfromjaximportjitimportjax.numpyasjnp# define the cube functiondefcube(x):returnx*x*x# generate datax=jnp.ones((10000,10000))# create the jit version of the cube functionjit_cube=jit(cube)# apply the cube and jit_cube functions to the same data for speed comparisoncube(x)jit_cube(x)

The computation time for jit_cube (line no. 17) should be noticeably shorter than that for cube (line no. 16). Increasing the values on line no. 10, will increase the difference.

vmap

The code below demonstrates the vmap function's vectorization.

# importsfromfunctoolsimportpartialfromjaximportvmapimportjax.numpyasjnp# define functiondefgrads(self,inputs):in_grad_partial=partial(self._net_grads,self._net_params)grad_vmap=vmap(in_grad_partial)rich_grads=grad_vmap(inputs)flat_grads=np.asarray(self._flatten_batch(rich_grads))assertflat_grads.ndim==2andflat_grads.shape[0]==inputs.shape[0]returnflat_grads

The GIF on the right of this section illustrates the notion of vectorized addition.

Illustration video of vectorized addition Vectorized-addition.gif
Illustration video of vectorized addition

pmap

The code below demonstrates the pmap function's parallelization for matrix multiplication.

# import pmap and random from JAX; import JAX NumPyfromjaximportpmap,randomimportjax.numpyasjnp# generate 2 random matrices of dimensions 5000 x 6000, one per devicerandom_keys=random.split(random.PRNGKey(0),2)matrices=pmap(lambdakey:random.normal(key,(5000,6000)))(random_keys)# without data transfer, in parallel, perform a local matrix multiplication on each CPU/GPUoutputs=pmap(lambdax:jnp.dot(x,x.T))(matrices)# without data transfer, in parallel, obtain the mean for both matrices on each CPU/GPU separatelymeans=pmap(jnp.mean)(outputs)print(means)

The final line should print the valuesː

[1.1566595 1.1805978]

Libraries using JAX

Several python libraries use JAX as a backend, including:

Some R libraries use JAX as a backend as well, including:

See also

Related Research Articles

<span class="mw-page-title-main">Sigmoid function</span> Mathematical function having a characteristic S-shaped curve or sigmoid curve

A sigmoid function is any mathematical function whose graph has a characteristic S-shaped or sigmoid curve.

<span class="mw-page-title-main">NumPy</span> Python library for numerical programming

NumPy is a library for the Python programming language, adding support for large, multi-dimensional arrays and matrices, along with a large collection of high-level mathematical functions to operate on these arrays. The predecessor of NumPy, Numeric, was originally created by Jim Hugunin with contributions from several other developers. In 2005, Travis Oliphant created NumPy by incorporating features of the competing Numarray into Numeric, with extensive modifications. NumPy is open-source software and has many contributors. NumPy is a NumFOCUS fiscally sponsored project.

Stochastic gradient descent is an iterative method for optimizing an objective function with suitable smoothness properties. It can be regarded as a stochastic approximation of gradient descent optimization, since it replaces the actual gradient by an estimate thereof. Especially in high-dimensional optimization problems this reduces the very high computational burden, achieving faster iterations in exchange for a lower convergence rate.

<span class="mw-page-title-main">CUDA</span> Parallel computing platform and programming model

Compute Unified Device Architecture (CUDA) is a proprietary parallel computing platform and application programming interface (API) that allows software to use certain types of graphics processing units (GPUs) for accelerated general-purpose processing, an approach called general-purpose computing on GPUs (GPGPU). CUDA API and its runtime: The CUDA API is an extension of the C programming language that adds the ability to specify thread-level parallelism in C and also to specify GPU device specific operations (like moving data between the CPU and the GPU). CUDA is a software layer that gives direct access to the GPU's virtual instruction set and parallel computational elements for the execution of compute kernels. In addition to drivers and runtime kernels, the CUDA platform includes compilers, libraries and developer tools to help programmers accelerate their applications.

Locally Optimal Block Preconditioned Conjugate Gradient (LOBPCG) is a matrix-free method for finding the largest eigenvalues and the corresponding eigenvectors of a symmetric generalized eigenvalue problem

<span class="mw-page-title-main">Hadamard product (matrices)</span> Matrix operation

In mathematics, the Hadamard product is a binary operation that takes in two matrices of the same dimensions and returns a matrix of the multiplied corresponding elements. This operation can be thought as a "naive matrix multiplication" and is different from the matrix product. It is attributed to, and named after, either French mathematician Jacques Hadamard or German mathematician Issai Schur.

<span class="mw-page-title-main">Numba</span> Open-source JIT compiler

Numba is an open-source JIT compiler that translates a subset of Python and NumPy into fast machine code using LLVM, via the llvmlite Python package. It offers a range of options for parallelising Python code for CPUs and GPUs, often with only minor code changes.

<span class="mw-page-title-main">Torch (machine learning)</span> Deep learning software

Torch is an open-source machine learning library, a scientific computing framework, and a scripting language based on Lua. It provides LuaJIT interfaces to deep learning algorithms implemented in C. It was created by the Idiap Research Institute at EPFL. Torch development moved in 2017 to PyTorch, a port of the library to Python.

Eclipse Deeplearning4j is a programming library written in Java for the Java virtual machine (JVM). It is a framework with wide support for deep learning algorithms. Deeplearning4j includes implementations of the restricted Boltzmann machine, deep belief net, deep autoencoder, stacked denoising autoencoder and recursive neural tensor network, word2vec, doc2vec, and GloVe. These algorithms all include distributed parallel versions that integrate with Apache Hadoop and Spark.

<span class="mw-page-title-main">TensorFlow</span> Machine learning software library

TensorFlow is a free and open-source software library for machine learning and artificial intelligence. It can be used across a range of tasks but has a particular focus on training and inference of deep neural networks.

<span class="mw-page-title-main">Keras</span> Neural network library

Keras is an open-source library that provides a Python interface for artificial neural networks. Keras was first independent software, then integrated into TensorFlow library, and later supporting more. "Keras 3 is a full rewrite of Keras [can be used] as a low-level cross-framework language to develop custom components such as layers, models, or metrics that can be used in native workflows in JAX, TensorFlow, or PyTorch — with one codebase." Keras 3 will be the default Keras version for TensorFlow 2.16 onwards, but Keras 2 can still be used.

Chainer is an open source deep learning framework written purely in Python on top of NumPy and CuPy Python libraries. The development is led by Japanese venture company Preferred Networks in partnership with IBM, Intel, Microsoft, and Nvidia.

PyTorch is a machine learning library based on the Torch library, used for applications such as computer vision and natural language processing, originally developed by Meta AI and now part of the Linux Foundation umbrella. It is recognized as one of the two most popular machine learning libraries alongside TensorFlow, offering free and open-source software released under the modified BSD license. Although the Python interface is more polished and the primary focus of development, PyTorch also has a C++ interface.

The GEKKO Python package solves large-scale mixed-integer and differential algebraic equations with nonlinear programming solvers. Modes of operation include machine learning, data reconciliation, real-time optimization, dynamic simulation, and nonlinear model predictive control. In addition, the package solves Linear programming (LP), Quadratic programming (QP), Quadratically constrained quadratic program (QCQP), Nonlinear programming (NLP), Mixed integer programming (MIP), and Mixed integer linear programming (MILP). GEKKO is available in Python and installed with pip from PyPI of the Python Software Foundation.

Cirq is an open-source framework for noisy intermediate scale quantum (NISQ) computers.

<span class="mw-page-title-main">QuTiP</span> Simulation software for quantum systems

QuTiP, short for the Quantum Toolbox in Python, is an open-source computational physics software library for simulating quantum systems, particularly open quantum systems. QuTiP allows simulation of Hamiltonians with arbitrary time-dependence, allowing simulation of situations of interest in quantum optics, ion trapping, superconducting circuits and quantum nanomechanical resonators. The library includes extensive visualization facilities for content under simulations.

<span class="mw-page-title-main">Tensor network</span> Mathematical wave functions

Tensor networks or tensor network states are a class of variational wave functions used in the study of many-body quantum systems. Tensor networks extend one-dimensional matrix product states to higher dimensions while preserving some of their useful mathematical properties.

CuPy is an open source library for GPU-accelerated computing with Python programming language, providing support for multi-dimensional arrays, sparse matrices, and a variety of numerical algorithms implemented on top of them. CuPy shares the same API set as NumPy and SciPy, allowing it to be a drop-in replacement to run NumPy/SciPy code on GPU. CuPy supports Nvidia CUDA GPU platform, and AMD ROCm GPU platform starting in v9.0.

Accelerated Linear Algebra (XLA) is an advanced optimization framework within TensorFlow, a popular machine learning library developed by Google. XLA is designed to improve the performance of TensorFlow models by optimizing the computation graph at a lower level, making it particularly useful for large-scale computations and high-performance machine learning models. Key features of TensorFlow XLA include:

References

  1. Error: Unable to display the reference properly. See the documentation for details.
  2. 1 2 Bradbury, James; Frostig, Roy; Hawkins, Peter; Johnson, Matthew James; Leary, Chris; MacLaurin, Dougal; Necula, George; Paszke, Adam; Vanderplas, Jake; Wanderman-Milne, Skye; Zhang, Qiao (2022-06-18), "JAX: Autograd and XLA", Astrophysics Source Code Library, Google, Bibcode:2021ascl.soft11002B, archived from the original on 2022-06-18, retrieved 2022-06-18
  3. Frostig, Roy; Johnson, Matthew James; Leary, Chris (2018-02-02). "Compiling machine learning programs via high-level tracing" (PDF). MLsys: 1–3. Archived (PDF) from the original on 2022-06-21.
  4. "Using JAX to accelerate our research". www.deepmind.com. 4 December 2020. Archived from the original on 2022-06-18. Retrieved 2022-06-18.
  5. HIPS/autograd, Formerly: Harvard Intelligent Probabilistic Systems Group -- Now at Princeton, 2024-03-27, retrieved 2024-03-28
  6. Lynley, Matthew. "Google is quietly replacing the backbone of its AI product strategy after its last big push for dominance got overshadowed by Meta". Business Insider. Archived from the original on 2022-06-21. Retrieved 2022-06-21.
  7. "Why is Google's JAX so popular?". Analytics India Magazine. 2022-04-25. Archived from the original on 2022-06-18. Retrieved 2022-06-18.
  8. Flax: A neural network library and ecosystem for JAX designed for flexibility, Google, 2022-07-29, retrieved 2022-07-29
  9. Flax: A neural network library and ecosystem for JAX designed for flexibility, Google, 2022-07-29, retrieved 2022-07-29
  10. Kidger, Patrick (2022-07-29), Equinox , retrieved 2022-07-29
  11. Optax, DeepMind, 2022-07-28, retrieved 2022-07-29
  12. RLax, DeepMind, 2022-07-29, retrieved 2022-07-29
  13. Jraph - A library for graph neural networks in jax., DeepMind, 2023-08-08, retrieved 2023-08-08
  14. "typing — Support for type hints". Python documentation. Retrieved 2023-08-08.
  15. jaxtyping, Google, 2023-08-08, retrieved 2023-08-08
  16. Jerzak, Connor (2023-10-01), fastrerandomize , retrieved 2023-10-03