![]() | A request that this article title be changed to JAX (Machine Learning framework) is under discussion . Please do not move this article until the discussion is closed. |
![]() JAX logo | |
![]() | |
Developer(s) | Google, Nvidia [1] |
---|---|
Preview release | v0.4.31 / 30 July 2024 |
Repository | jax on GitHub |
Written in | Python, C++ |
Operating system | Linux, macOS, Windows |
Platform | Python, NumPy |
Size | 9.0 MB |
Type | Machine learning |
License | Apache 2.0 |
Website | jax![]() |
JAX is a machine learning framework for transforming numerical functions developed by Google with some contributions from Nvidia. [2] [3] [4] It is described as bringing together a modified version of autograd (automatic obtaining of the gradient function through differentiation of a function) and OpenXLA's XLA (Accelerated Linear Algebra). It is designed to follow the structure and workflow of NumPy as closely as possible and works with various existing frameworks such as TensorFlow and PyTorch. [5] [6] The primary functions of JAX are: [2]
The below code demonstrates the grad function's automatic differentiation.
# importsfromjaximportgradimportjax.numpyasjnp# define the logistic functiondeflogistic(x):returnjnp.exp(x)/(jnp.exp(x)+1)# obtain the gradient function of the logistic functiongrad_logistic=grad(logistic)# evaluate the gradient of the logistic function at x = 1 grad_log_out=grad_logistic(1.0)print(grad_log_out)
The final line should outputː
0.19661194
The below code demonstrates the jit function's optimization through fusion.
# importsfromjaximportjitimportjax.numpyasjnp# define the cube functiondefcube(x):returnx*x*x# generate datax=jnp.ones((10000,10000))# create the jit version of the cube functionjit_cube=jit(cube)# apply the cube and jit_cube functions to the same data for speed comparisoncube(x)jit_cube(x)
The computation time for jit_cube
(line #17) should be noticeably shorter than that for cube
(line #16). Increasing the values on line #7, will further exacerbate the difference.
The below code demonstrates the vmap function's vectorization.
# importsfromjaximportvmappartialimportjax.numpyasjnp# define functiondefgrads(self,inputs):in_grad_partial=jax.partial(self._net_grads,self._net_params)grad_vmap=jax.vmap(in_grad_partial)rich_grads=grad_vmap(inputs)flat_grads=np.asarray(self._flatten_batch(rich_grads))assertflat_grads.ndim==2andflat_grads.shape[0]==inputs.shape[0]returnflat_grads
The GIF on the right of this section illustrates the notion of vectorized addition.
The below code demonstrates the pmap function's parallelization for matrix multiplication.
# import pmap and random from JAX; import JAX NumPyfromjaximportpmap,randomimportjax.numpyasjnp# generate 2 random matrices of dimensions 5000 x 6000, one per devicerandom_keys=random.split(random.PRNGKey(0),2)matrices=pmap(lambdakey:random.normal(key,(5000,6000)))(random_keys)# without data transfer, in parallel, perform a local matrix multiplication on each CPU/GPUoutputs=pmap(lambdax:jnp.dot(x,x.T))(matrices)# without data transfer, in parallel, obtain the mean for both matrices on each CPU/GPU separatelymeans=pmap(jnp.mean)(outputs)print(means)
The final line should print the valuesː
[1.1566595 1.1805978]
In linear algebra, the outer product of two coordinate vectors is the matrix whose entries are all products of an element in the first vector with an element in the second vector. If the two coordinate vectors have dimensions n and m, then their outer product is an n × m matrix. More generally, given two tensors, their outer product is a tensor. The outer product of tensors is also referred to as their tensor product, and can be used to define the tensor algebra.
NumPy is a library for the Python programming language, adding support for large, multi-dimensional arrays and matrices, along with a large collection of high-level mathematical functions to operate on these arrays. The predecessor of NumPy, Numeric, was originally created by Jim Hugunin with contributions from several other developers. In 2005, Travis Oliphant created NumPy by incorporating features of the competing Numarray into Numeric, with extensive modifications. NumPy is open-source software and has many contributors. NumPy is fiscally sponsored by NumFOCUS.
In Itô calculus, the Euler–Maruyama method is a method for the approximate numerical solution of a stochastic differential equation (SDE). It is an extension of the Euler method for ordinary differential equations to stochastic differential equations named after Leonhard Euler and Gisiro Maruyama. The same generalization cannot be done for any arbitrary deterministic method.
In computing, CUDA is a proprietary parallel computing platform and application programming interface (API) that allows software to use certain types of graphics processing units (GPUs) for accelerated general-purpose processing, an approach called general-purpose computing on GPUs. CUDA was created by Nvidia in 2006. When it was first introduced, the name was an acronym for Compute Unified Device Architecture, but Nvidia later dropped the common use of the acronym and now rarely expands it.
Locally Optimal Block Preconditioned Conjugate Gradient (LOBPCG) is a matrix-free method for finding the largest eigenvalues and the corresponding eigenvectors of a symmetric generalized eigenvalue problem
In computer programming, ellipsis notation is used to denote ranges, an unspecified number of arguments, or a parent directory. Most programming languages require the ellipsis to be written as a series of periods; a single (Unicode) ellipsis character cannot be used.
scikit-learn is a free and open-source machine learning library for the Python programming language. It features various classification, regression and clustering algorithms including support-vector machines, random forests, gradient boosting, k-means and DBSCAN, and is designed to interoperate with the Python numerical and scientific libraries NumPy and SciPy. Scikit-learn is a NumFOCUS fiscally sponsored project.
Theano is a Python library and optimizing compiler for manipulating and evaluating mathematical expressions, especially matrix-valued ones. In Theano, computations are expressed using a NumPy-esque syntax and compiled to run efficiently on either CPU or GPU architectures.
Numba is an open-source JIT compiler that translates a subset of Python and NumPy into fast machine code using LLVM, via the llvmlite Python package. It offers a range of options for parallelising Python code for CPUs and GPUs, often with only minor code changes.
TensorFlow is a software library for machine learning and artificial intelligence. It can be used across a range of tasks, but is used mainly for training and inference of neural networks. It is one of the most popular deep learning frameworks, alongside others such as PyTorch and PaddlePaddle. It is free and open-source software released under the Apache License 2.0.
PyMC is a probabilistic programming language written in Python. It can be used for Bayesian statistical modeling and probabilistic machine learning.
PyTorch is a machine learning library based on the Torch library, used for applications such as computer vision and natural language processing, originally developed by Meta AI and now part of the Linux Foundation umbrella. It is one of the most popular deep learning frameworks, alongside others such as TensorFlow and PaddlePaddle, offering free and open-source software released under the modified BSD license. Although the Python interface is more polished and the primary focus of development, PyTorch also has a C++ interface.
Differentiable programming is a programming paradigm in which a numeric computer program can be differentiated throughout via automatic differentiation. This allows for gradient-based optimization of parameters in the program, often via gradient descent, as well as other learning approaches that are based on higher order derivative information. Differentiable programming has found use in a wide variety of areas, particularly scientific computing and machine learning. One of the early proposals to adopt such a framework in a systematic fashion to improve upon learning algorithms was made by the Advanced Concepts Team at the European Space Agency in early 2016.
Cirq is an open-source framework for noisy intermediate scale quantum (NISQ) computers.
Dask is an open-source Python library for parallel computing. Dask scales Python code from multi-core local machines to large distributed clusters in the cloud. Dask provides a familiar user interface by mirroring the APIs of other libraries in the PyData ecosystem including: Pandas, scikit-learn and NumPy. It also exposes low-level APIs that help programmers run custom algorithms in parallel.
QuTiP, short for the Quantum Toolbox in Python, is an open-source computational physics software library for simulating quantum systems, particularly open quantum systems. QuTiP allows simulation of Hamiltonians with arbitrary time-dependence, allowing simulation of situations of interest in quantum optics, ion trapping, superconducting circuits and quantum nanomechanical resonators. The library includes extensive visualization facilities for content under simulations.
CuPy is an open source library for GPU-accelerated computing with Python programming language, providing support for multi-dimensional arrays, sparse matrices, and a variety of numerical algorithms implemented on top of them. CuPy shares the same API set as NumPy and SciPy, allowing it to be a drop-in replacement to run NumPy/SciPy code on GPU. CuPy supports Nvidia CUDA GPU platform, and AMD ROCm GPU platform starting in v9.0.
XLA is an open-source compiler for machine learning developed by the OpenXLA project. XLA is designed to improve the performance of machine learning models by optimizing the computation graphs at a lower level, making it particularly useful for large-scale computations and high-performance machine learning models. Key features of XLA include:
{{cite journal}}
: CS1 maint: date and year (link)