Google JAX

Last updated
JAX
Developer(s) Google
Stable release
0.4.24 [1]   OOjs UI icon edit-ltr-progressive.svg / 6 February 2024;2 months ago (6 February 2024)
Repository github.com/google/jax
Written in Python, C++
Operating system Linux, macOS, Windows
Platform Python, NumPy
Size 9.0 MB
Type Machine learning
License Apache 2.0
Website jax.readthedocs.io/en/latest/   OOjs UI icon edit-ltr-progressive.svg

Google JAX is a machine learning framework for transforming numerical functions, to be used in Python. [2] [3] [4] It is described as bringing together a modified version of autograd [5] (automatic obtaining of the gradient function through differentiation of a function) and TensorFlow's XLA (Accelerated Linear Algebra). It is designed to follow the structure and workflow of NumPy as closely as possible and works with various existing frameworks such as TensorFlow and PyTorch. [6] [7] The primary functions of JAX are: [2]

Contents

  1. grad: automatic differentiation
  2. jit: compilation
  3. vmap: auto-vectorization
  4. pmap: SPMD programming

grad

The code below demonstrates the grad function's automatic differentiation.

# importsfromjaximportgradimportjax.numpyasjnp# define the logistic functiondeflogistic(x):returnjnp.exp(x)/(jnp.exp(x)+1)# obtain the gradient function of the logistic functiongrad_logistic=grad(logistic)# evaluate the gradient of the logistic function at x = 1 grad_log_out=grad_logistic(1.0)print(grad_log_out)

The final line should outputː

0.19661194

jit

The code below demonstrates the jit function's optimization through fusion.

# importsfromjaximportjitimportjax.numpyasjnp# define the cube functiondefcube(x):returnx*x*x# generate datax=jnp.ones((10000,10000))# create the jit version of the cube functionjit_cube=jit(cube)# apply the cube and jit_cube functions to the same data for speed comparisoncube(x)jit_cube(x)

The computation time for jit_cube (line no. 17) should be noticeably shorter than that for cube (line no. 16). Increasing the values on line no. 10, will increase the difference.

vmap

The code below demonstrates the vmap function's vectorization.

# importsfromfunctoolsimportpartialfromjaximportvmapimportjax.numpyasjnp# define functiondefgrads(self,inputs):in_grad_partial=partial(self._net_grads,self._net_params)grad_vmap=vmap(in_grad_partial)rich_grads=grad_vmap(inputs)flat_grads=np.asarray(self._flatten_batch(rich_grads))assertflat_grads.ndim==2andflat_grads.shape[0]==inputs.shape[0]returnflat_grads

The GIF on the right of this section illustrates the notion of vectorized addition.

Illustration video of vectorized addition Vectorized-addition.gif
Illustration video of vectorized addition

pmap

The code below demonstrates the pmap function's parallelization for matrix multiplication.

# import pmap and random from JAX; import JAX NumPyfromjaximportpmap,randomimportjax.numpyasjnp# generate 2 random matrices of dimensions 5000 x 6000, one per devicerandom_keys=random.split(random.PRNGKey(0),2)matrices=pmap(lambdakey:random.normal(key,(5000,6000)))(random_keys)# without data transfer, in parallel, perform a local matrix multiplication on each CPU/GPUoutputs=pmap(lambdax:jnp.dot(x,x.T))(matrices)# without data transfer, in parallel, obtain the mean for both matrices on each CPU/GPU separatelymeans=pmap(jnp.mean)(outputs)print(means)

The final line should print the valuesː

[1.1566595 1.1805978]

Libraries using JAX

Several python libraries use JAX as a backend, including:

Some R libraries use JAX as a backend as well, including:

See also

Related Research Articles

<span class="mw-page-title-main">NumPy</span> Python library for numerical programming

NumPy is a library for the Python programming language, adding support for large, multi-dimensional arrays and matrices, along with a large collection of high-level mathematical functions to operate on these arrays. The predecessor of NumPy, Numeric, was originally created by Jim Hugunin with contributions from several other developers. In 2005, Travis Oliphant created NumPy by incorporating features of the competing Numarray into Numeric, with extensive modifications. NumPy is open-source software and has many contributors. NumPy is a NumFOCUS fiscally sponsored project.

In numerical linear algebra, the Arnoldi iteration is an eigenvalue algorithm and an important example of an iterative method. Arnoldi finds an approximation to the eigenvalues and eigenvectors of general matrices by constructing an orthonormal basis of the Krylov subspace, which makes it particularly useful when dealing with large sparse matrices.

Stochastic gradient descent is an iterative method for optimizing an objective function with suitable smoothness properties. It can be regarded as a stochastic approximation of gradient descent optimization, since it replaces the actual gradient by an estimate thereof. Especially in high-dimensional optimization problems this reduces the very high computational burden, achieving faster iterations in exchange for a lower convergence rate.

In probability theory and mathematical physics, a random matrix is a matrix-valued random variable—that is, a matrix in which some or all elements are random variables. Many important properties of physical systems can be represented mathematically as matrix problems. For example, the thermal conductivity of a lattice can be computed from the dynamical matrix of the particle-particle interactions within the lattice.

<span class="mw-page-title-main">CUDA</span> Parallel computing platform and programming model

Compute Unified Device Architecture (CUDA) is a parallel computing platform and application programming interface (API) that allows software to use certain types of graphics processing units (GPUs) for accelerated general-purpose processing, an approach called general-purpose computing on GPUs (GPGPU). CUDA API and its runtime: The CUDA API is an extension of the C programming language that adds the ability to specify thread-level parallelism in C and also to specify GPU device specific operations (like moving data between the CPU and the GPU). CUDA is a software layer that gives direct access to the GPU's virtual instruction set and parallel computational elements for the execution of compute kernels. In addition to drivers and runtime kernels, the CUDA platform includes compilers, libraries and developer tools to help programmers accelerate their applications.

<span class="mw-page-title-main">Echo state network</span> Type of reservoir computer

An echo state network (ESN) is a type of reservoir computer that uses a recurrent neural network with a sparsely connected hidden layer. The connectivity and weights of hidden neurons are fixed and randomly assigned. The weights of output neurons can be learned so that the network can produce or reproduce specific temporal patterns. The main interest of this network is that although its behavior is non-linear, the only weights that are modified during training are for the synapses that connect the hidden neurons to output neurons. Thus, the error function is quadratic with respect to the parameter vector and can be differentiated easily to a linear system.

<span class="mw-page-title-main">SymPy</span> Python library for symbolic computation

SymPy is an open-source Python library for symbolic computation. It provides computer algebra capabilities either as a standalone application, as a library to other applications, or live on the web as SymPy Live or SymPy Gamma. SymPy is simple to install and to inspect because it is written entirely in Python with few dependencies. This ease of access combined with a simple and extensible code base in a well known language make SymPy a computer algebra system with a relatively low barrier to entry.

Locally Optimal Block Preconditioned Conjugate Gradient (LOBPCG) is a matrix-free method for finding the largest eigenvalues and the corresponding eigenvectors of a symmetric generalized eigenvalue problem

<span class="mw-page-title-main">Hadamard product (matrices)</span> Matrix operation

In mathematics, the Hadamard product is a binary operation that takes in two matrices of the same dimensions and returns a matrix of the multiplied corresponding elements. This operation can be thought as a "naive matrix multiplication" and is different from the matrix product. It is attributed to, and named after, either French mathematician Jacques Hadamard or German mathematician Issai Schur.

<span class="mw-page-title-main">Numba</span> Open-source JIT compiler

Numba is an open-source JIT compiler that translates a subset of Python and NumPy into fast machine code using LLVM, via the llvmlite Python package. It offers a range of options for parallelising Python code for CPUs and GPUs, often with only minor code changes.

<span class="mw-page-title-main">Torch (machine learning)</span> Deep learning software

Torch is an open-source machine learning library, a scientific computing framework, and a scripting language based on Lua. It provides LuaJIT interfaces to deep learning algorithms implemented in C. It was created by the Idiap Research Institute at EPFL. Torch development moved in 2017 to PyTorch, a port of the library to Python.

<span class="mw-page-title-main">TensorFlow</span> Machine learning software library

TensorFlow is a free and open-source software library for machine learning and artificial intelligence. It can be used across a range of tasks but has a particular focus on training and inference of deep neural networks.

<span class="mw-page-title-main">Keras</span> Neural network library

Keras is an open-source library that provides a Python interface for artificial neural networks. Keras acts as an interface for the TensorFlow library.

PyTorch is a machine learning library based on the Torch library, used for applications such as computer vision and natural language processing, originally developed by Meta AI and now part of the Linux Foundation umbrella. It is recognized as one of the two most popular machine learning libraries alongside TensorFlow, offering free and open-source software released under the modified BSD license. Although the Python interface is more polished and the primary focus of development, PyTorch also has a C++ interface.

The GEKKO Python package solves large-scale mixed-integer and differential algebraic equations with nonlinear programming solvers. Modes of operation include machine learning, data reconciliation, real-time optimization, dynamic simulation, and nonlinear model predictive control. In addition, the package solves Linear programming (LP), Quadratic programming (QP), Quadratically constrained quadratic program (QCQP), Nonlinear programming (NLP), Mixed integer programming (MIP), and Mixed integer linear programming (MILP). GEKKO is available in Python and installed with pip from PyPI of the Python Software Foundation.

<span class="mw-page-title-main">Transformer (deep learning architecture)</span> Machine learning algorithm used for natural-language processing

A transformer is a deep learning architecture developed by Google and based on the multi-head attention mechanism, proposed in a 2017 paper "Attention Is All You Need". Text is converted to numerical representations called tokens, and each token is converted into a vector via looking up from a word embedding table. At each layer, each token is then contextualized within the scope of the context window with other (unmasked) tokens via a parallel multi-head attention mechanism allowing the signal for key tokens to be amplified and less important tokens to be diminished. The transformer paper, published in 2017, is based on the softmax-based attention mechanism proposed by Bahdanau et. al. in 2014 for machine translation, and the Fast Weight Controller, similar to a transformer, proposed in 1992.

CuPy is an open source library for GPU-accelerated computing with Python programming language, providing support for multi-dimensional arrays, sparse matrices, and a variety of numerical algorithms implemented on top of them. CuPy shares the same API set as NumPy and SciPy, allowing it to be a drop-in replacement to run NumPy/SciPy code on GPU. CuPy supports Nvidia CUDA GPU platform, and AMD ROCm GPU platform starting in v9.0.

Tensor informally refers in machine learning to two different concepts that organize and represent data. Data may be organized in a multidimensional array (M-way array) that is informally referred to as a "data tensor"; however in the strict mathematical sense, a tensor is a multilinear mapping over a set of domain vector spaces to a range vector space. Observations, such as images, movies, volumes, sounds, and relationships among words and concepts, stored in an M-way array ("data tensor") may be analyzed either by artificial neural networks or tensor methods.

Accelerated Linear Algebra (XLA) is an advanced optimization framework within TensorFlow, a popular machine learning library developed by Google. XLA is designed to improve the performance of TensorFlow models by optimizing the computation graph at a lower level, making it particularly useful for large-scale computations and high-performance machine learning models. Key features of TensorFlow XLA include:

References

  1. Error: Unable to display the reference properly. See the documentation for details.
  2. 1 2 Bradbury, James; Frostig, Roy; Hawkins, Peter; Johnson, Matthew James; Leary, Chris; MacLaurin, Dougal; Necula, George; Paszke, Adam; Vanderplas, Jake; Wanderman-Milne, Skye; Zhang, Qiao (2022-06-18), "JAX: Autograd and XLA", Astrophysics Source Code Library, Google, Bibcode:2021ascl.soft11002B, archived from the original on 2022-06-18, retrieved 2022-06-18
  3. Frostig, Roy; Johnson, Matthew James; Leary, Chris (2018-02-02). "Compiling machine learning programs via high-level tracing" (PDF). MLsys: 1–3. Archived (PDF) from the original on 2022-06-21.
  4. "Using JAX to accelerate our research". www.deepmind.com. Archived from the original on 2022-06-18. Retrieved 2022-06-18.
  5. HIPS/autograd, Formerly: Harvard Intelligent Probabilistic Systems Group -- Now at Princeton, 2024-03-27, retrieved 2024-03-28
  6. Lynley, Matthew. "Google is quietly replacing the backbone of its AI product strategy after its last big push for dominance got overshadowed by Meta". Business Insider. Archived from the original on 2022-06-21. Retrieved 2022-06-21.
  7. "Why is Google's JAX so popular?". Analytics India Magazine. 2022-04-25. Archived from the original on 2022-06-18. Retrieved 2022-06-18.
  8. Flax: A neural network library and ecosystem for JAX designed for flexibility, Google, 2022-07-29, retrieved 2022-07-29
  9. Flax: A neural network library and ecosystem for JAX designed for flexibility, Google, 2022-07-29, retrieved 2022-07-29
  10. Kidger, Patrick (2022-07-29), Equinox , retrieved 2022-07-29
  11. Optax, DeepMind, 2022-07-28, retrieved 2022-07-29
  12. RLax, DeepMind, 2022-07-29, retrieved 2022-07-29
  13. Jraph - A library for graph neural networks in jax., DeepMind, 2023-08-08, retrieved 2023-08-08
  14. "typing — Support for type hints". Python documentation. Retrieved 2023-08-08.
  15. jaxtyping, Google, 2023-08-08, retrieved 2023-08-08
  16. Jerzak, Connor (2023-10-01), fastrerandomize , retrieved 2023-10-03