Русская Википедия:Google JAX

Материал из Онлайн справочника
Перейти к навигацииПерейти к поиску

Шаблон:Infobox software Google JAX — фреймворк машинного обучения для преобразования числовых функций.[1][2][3] Он описывается как объединение измененной версии autograd (автоматическое получение градиентной функции через дифференцирование функции) и TensorFlow's XLA (Ускоренная линейная алгебра(Accelerated Linear Algebra)). Он спроектирован, чтобы максимально соответствовать структуре и рабочему процессу NumPy для работы с различными существующими фреймворками, такими как TensorFlow и PyTorch.[4][5] Основными функциями JAX являются:[1]

  1. grad: автоматическое дифференцирование
  2. jit: компиляция
  3. vmap: автоматическая векторизация
  4. pmap: SPMD программирование

grad

Шаблон:Main Код представленный ниже демонстрирует функцию автоматического дифференцирования пакета grad.

# imports
from jax import grad
import jax.numpy as jnp

# define the logistic function
def logistic(x):  
    return jnp.exp(x) / (jnp.exp(x) + 1)

# obtain the gradient function of the logistic function
grad_logistic = grad(logistic)

# evaluate the gradient of the logistic function at x = 1 
grad_log_out = grad_logistic(1.0)   
print(grad_log_out)

Код должен напечатать:

0.19661194

jit

Шаблон:Main Код представленный ниже демонстрирует функцию оптимизации через слияние пакета jit.

# imports
from jax import jit
import jax.numpy as jnp

# define the cube function
def cube(x):
    return x * x * x

# generate data
x = jnp.ones((10000, 10000))

# create the jit version of the cube function
jit_cube = jit(cube)

# apply the cube and jit_cube functions to the same data for speed comparison
cube(x)
jit_cube(x)

Вычислительное время для jit_cube (строка 17) должно быть заметно короче, чем для cube (строка 16). Увеличение значения в строке 7, будет увеличивать разницу.

vmap

Шаблон:Main Код представленный ниже демонстрирует функцию векторизации пакета vmap.

# imports
from functools import partial
from jax import vmap
import jax.numpy as jnp

# define function
def grads(self, inputs):
    in_grad_partial = partial(self._net_grads, self._net_params)
    grad_vmap = jax.vmap(in_grad_partial)
    rich_grads = grad_vmap(inputs)
    flat_grads = np.asarray(self._flatten_batch(rich_grads))
    assert flat_grads.ndim == 2 and flat_grads.shape[0] == inputs.shape[0]
    return flat_grads

Изображение в правой части раздела иллюстрирует идея векторизованного сложения.

Файл:Vectorized-addition.gif
Иллюстрационное видео векторизованного сложения

pmap

Шаблон:Main Код представленный ниже демонстрирует распараллеливание для умножения матриц пакета pmap.

# import pmap and random from JAX; import JAX NumPy
from jax import pmap, random
import jax.numpy as jnp

# generate 2 random matrices of dimensions 5000 x 6000, one per device
random_keys = random.split(random.PRNGKey(0), 2)
matrices = pmap(lambda key: random.normal(key, (5000, 6000)))(random_keys)

# without data transfer, in parallel, perform a local matrix multiplication on each CPU/GPU
outputs = pmap(lambda x: jnp.dot(x, x.T))(matrices)

# without data transfer, in parallel, obtain the mean for both matrices on each CPU/GPU separately
means = pmap(jnp.mean)(outputs)
print(means)

Последняя строка должна напечатать значенияː

[1.1566595 1.1805978]

Библиотеки использующие Jax

Несколько библиотек Python используют Jax в качестве бэкенда, включая:

  • Flax, высокоуровневая библиотека для нейронных сетей изначально разработанная Google Brain.[6]
  • Haiku, объектно-ориентированная библиотека для нейронных сетей разработанная DeepMind.[7]
  • Equinox, библиотека которая вращается вокруг идеи представления параметризованных функций (включая нейронные сети) как PyTrees. Она была создана Патриком Кидгером.[8]
  • Optax, библиотека для градиентной обработки и оптимизации разработанная DeepMind.[9]
  • RLax, библиотека для разработки агентов для обучения с подкреплением, разработанная DeepMind.[10]

См. также

Примечания

Шаблон:Примечания

Ссылки

Шаблон:Google