Google JAX

Last updated
JAX
Developer(s) Google
Preview release
v0.4.31 / 30 July 2024;3 months ago (2024-07-30)
Repository github.com/google/jax
Written in Python, C++
Operating system Linux, macOS, Windows
Platform Python, NumPy
Size 9.0 MB
Type Machine learning
License Apache 2.0
Website jax.readthedocs.io/en/latest/   OOjs UI icon edit-ltr-progressive.svg

Google JAX is a machine learning framework for transforming numerical functions. [1] [2] [3] It is described as bringing together a modified version of autograd (automatic obtaining of the gradient function through differentiation of a function) and TensorFlow's XLA (Accelerated Linear Algebra). It is designed to follow the structure and workflow of NumPy as closely as possible and works with various existing frameworks such as TensorFlow and PyTorch. [4] [5] The primary functions of JAX are: [1]

Contents

  1. grad: automatic differentiation
  2. jit: compilation
  3. vmap: auto-vectorization
  4. pmap: Single program, multiple data (SPMD) programming

grad

The below code demonstrates the grad function's automatic differentiation.

# importsfromjaximportgradimportjax.numpyasjnp# define the logistic functiondeflogistic(x):returnjnp.exp(x)/(jnp.exp(x)+1)# obtain the gradient function of the logistic functiongrad_logistic=grad(logistic)# evaluate the gradient of the logistic function at x = 1 grad_log_out=grad_logistic(1.0)print(grad_log_out)

The final line should outputː

0.19661194

jit

The below code demonstrates the jit function's optimization through fusion.

# importsfromjaximportjitimportjax.numpyasjnp# define the cube functiondefcube(x):returnx*x*x# generate datax=jnp.ones((10000,10000))# create the jit version of the cube functionjit_cube=jit(cube)# apply the cube and jit_cube functions to the same data for speed comparisoncube(x)jit_cube(x)

The computation time for jit_cube (line #17) should be noticeably shorter than that for cube (line #16). Increasing the values on line #7, will further exacerbate the difference.

vmap

The below code demonstrates the vmap function's vectorization.

# importsfromjaximportvmappartialimportjax.numpyasjnp# define functiondefgrads(self,inputs):in_grad_partial=jax.partial(self._net_grads,self._net_params)grad_vmap=jax.vmap(in_grad_partial)rich_grads=grad_vmap(inputs)flat_grads=np.asarray(self._flatten_batch(rich_grads))assertflat_grads.ndim==2andflat_grads.shape[0]==inputs.shape[0]returnflat_grads

The GIF on the right of this section illustrates the notion of vectorized addition.

Illustration video of vectorized addition Vectorized-addition.gif
Illustration video of vectorized addition

pmap

The below code demonstrates the pmap function's parallelization for matrix multiplication.

# import pmap and random from JAX; import JAX NumPyfromjaximportpmap,randomimportjax.numpyasjnp# generate 2 random matrices of dimensions 5000 x 6000, one per devicerandom_keys=random.split(random.PRNGKey(0),2)matrices=pmap(lambdakey:random.normal(key,(5000,6000)))(random_keys)# without data transfer, in parallel, perform a local matrix multiplication on each CPU/GPUoutputs=pmap(lambdax:jnp.dot(x,x.T))(matrices)# without data transfer, in parallel, obtain the mean for both matrices on each CPU/GPU separatelymeans=pmap(jnp.mean)(outputs)print(means)

The final line should print the valuesː

[1.1566595 1.1805978]

See also

Related Research Articles

<span class="mw-page-title-main">NumPy</span> Python library for numerical programming

NumPy is a library for the Python programming language, adding support for large, multi-dimensional arrays and matrices, along with a large collection of high-level mathematical functions to operate on these arrays. The predecessor of NumPy, Numeric, was originally created by Jim Hugunin with contributions from several other developers. In 2005, Travis Oliphant created NumPy by incorporating features of the competing Numarray into Numeric, with extensive modifications. NumPy is open-source software and has many contributors. NumPy is a NumFOCUS fiscally sponsored project.

In Itô calculus, the Euler–Maruyama method is a method for the approximate numerical solution of a stochastic differential equation (SDE). It is an extension of the Euler method for ordinary differential equations to stochastic differential equations named after Leonhard Euler and Gisiro Maruyama. The same generalization cannot be done for any arbitrary deterministic method.

<span class="mw-page-title-main">CUDA</span> Parallel computing platform and programming model

In computing, CUDA is a proprietary parallel computing platform and application programming interface (API) that allows software to use certain types of graphics processing units (GPUs) for accelerated general-purpose processing, an approach called general-purpose computing on GPUs (GPGPU). CUDA API and its runtime: The CUDA API is an extension of the C programming language that adds the ability to specify thread-level parallelism in C and also to specify GPU device specific operations. CUDA is a software layer that gives direct access to the GPU's virtual instruction set and parallel computational elements for the execution of compute kernels. In addition to drivers and runtime kernels, the CUDA platform includes compilers, libraries and developer tools to help programmers accelerate their applications.

Locally Optimal Block Preconditioned Conjugate Gradient (LOBPCG) is a matrix-free method for finding the largest eigenvalues and the corresponding eigenvectors of a symmetric generalized eigenvalue problem

scikit-learn Python library for machine learning

scikit-learn is a free and open-source machine learning library for the Python programming language. It features various classification, regression and clustering algorithms including support-vector machines, random forests, gradient boosting, k-means and DBSCAN, and is designed to interoperate with the Python numerical and scientific libraries NumPy and SciPy. Scikit-learn is a NumFOCUS fiscally sponsored project.

Theano is a Python library and optimizing compiler for manipulating and evaluating mathematical expressions, especially matrix-valued ones. In Theano, computations are expressed using a NumPy-esque syntax and compiled to run efficiently on either CPU or GPU architectures.

<span class="mw-page-title-main">Numba</span> Open-source JIT compiler

Numba is an open-source JIT compiler that translates a subset of Python and NumPy into fast machine code using LLVM, via the llvmlite Python package. It offers a range of options for parallelising Python code for CPUs and GPUs, often with only minor code changes.

<span class="mw-page-title-main">TensorFlow</span> Machine learning software library

TensorFlow is a software library for machine learning and artificial intelligence. It can be used across a range of tasks, but is used mainly for training and inference of neural networks. It is one of the most popular deep learning frameworks, alongside others such as PyTorch and PaddlePaddle. It is free and open-source software released under the Apache License 2.0.

<span class="mw-page-title-main">Keras</span> Neural network library

Keras is an open-source library that provides a Python interface for artificial neural networks. Keras was first independent software, then integrated into the TensorFlow library, and later supporting more. "Keras 3 is a full rewrite of Keras [and can be used] as a low-level cross-framework language to develop custom components such as layers, models, or metrics that can be used in native workflows in JAX, TensorFlow, or PyTorch — with one codebase." Keras 3 will be the default Keras version for TensorFlow 2.16 onwards, but Keras 2 can still be used.

PyMC is a probabilistic programming language written in Python. It can be used for Bayesian statistical modeling and probabilistic machine learning.

PyTorch is a machine learning library based on the Torch library, used for applications such as computer vision and natural language processing, originally developed by Meta AI and now part of the Linux Foundation umbrella. It is one of the most popular deep learning frameworks, alongside others such as TensorFlow and PaddlePaddle, offering free and open-source software released under the modified BSD license. Although the Python interface is more polished and the primary focus of development, PyTorch also has a C++ interface.

Differentiable programming is a programming paradigm in which a numeric computer program can be differentiated throughout via automatic differentiation. This allows for gradient-based optimization of parameters in the program, often via gradient descent, as well as other learning approaches that are based on higher order derivative information. Differentiable programming has found use in a wide variety of areas, particularly scientific computing and machine learning. One of the early proposals to adopt such a framework in a systematic fashion to improve upon learning algorithms was made by the Advanced Concepts Team at the European Space Agency in early 2016.

Cirq is an open-source framework for noisy intermediate scale quantum (NISQ) computers.

<span class="mw-page-title-main">Dask (software)</span> Python library for parallel computing

Dask is an open-source Python library for parallel computing. Dask scales Python code from multi-core local machines to large distributed clusters in the cloud. Dask provides a familiar user interface by mirroring the APIs of other libraries in the PyData ecosystem including: Pandas, scikit-learn and NumPy. It also exposes low-level APIs that help programmers run custom algorithms in parallel.

Isolation Forest is an algorithm for data anomaly detection using binary trees. It was developed by Fei Tony Liu in 2008. It has a linear time complexity and a low memory use, which works well for high-volume data. It is based on the assumption that because anomalies are few and different from other data, they can be isolated using few partitions. Like decision tree algorithms, it does not perform density estimation. Unlike decision tree algorithms, it uses only path length to output an anomaly score, and does not use leaf node statistics of class distribution or target value.

<span class="mw-page-title-main">QuTiP</span> Simulation software for quantum systems

QuTiP, short for the Quantum Toolbox in Python, is an open-source computational physics software library for simulating quantum systems, particularly open quantum systems. QuTiP allows simulation of Hamiltonians with arbitrary time-dependence, allowing simulation of situations of interest in quantum optics, ion trapping, superconducting circuits and quantum nanomechanical resonators. The library includes extensive visualization facilities for content under simulations.

CuPy is an open source library for GPU-accelerated computing with Python programming language, providing support for multi-dimensional arrays, sparse matrices, and a variety of numerical algorithms implemented on top of them. CuPy shares the same API set as NumPy and SciPy, allowing it to be a drop-in replacement to run NumPy/SciPy code on GPU. CuPy supports Nvidia CUDA GPU platform, and AMD ROCm GPU platform starting in v9.0.

Accelerated Linear Algebra (XLA) is an advanced optimization framework within TensorFlow, a popular machine learning library developed by Google. XLA is designed to improve the performance of TensorFlow models by optimizing the computation graph at a lower level, making it particularly useful for large-scale computations and high-performance machine learning models. Key features of TensorFlow XLA include:

References

  1. 1 2 Bradbury, James; Frostig, Roy; Hawkins, Peter; Johnson, Matthew James; Leary, Chris; MacLaurin, Dougal; Necula, George; Paszke, Adam; Vanderplas, Jake; Wanderman-Milne, Skye; Zhang, Qiao (2022-06-18), "JAX: Autograd and XLA", Astrophysics Source Code Library, Google, Bibcode:2021ascl.soft11002B, archived from the original on 2022-06-18, retrieved 2022-06-18
  2. Frostig, Roy; Johnson, Matthew James; Leary, Chris (2018-02-02). "Compiling machine learning programs via high-level tracing" (PDF). MLsys: 1–3. Archived (PDF) from the original on 2022-06-21.{{cite journal}}: CS1 maint: date and year (link)
  3. "Using JAX to accelerate our research". www.deepmind.com. Archived from the original on 2022-06-18. Retrieved 2022-06-18.
  4. Lynley, Matthew. "Google is quietly replacing the backbone of its AI product strategy after its last big push for dominance got overshadowed by Meta". Business Insider. Archived from the original on 2022-06-21. Retrieved 2022-06-21.
  5. "Why is Google's JAX so popular?". Analytics India Magazine. 2022-04-25. Archived from the original on 2022-06-18. Retrieved 2022-06-18.