面向PyTorch用户的JAX简易教程[2]: 如何训练一个神经网络

面向PyTorch用户的JAX简易教程[2]: 如何训练一个神经网络

背景#上一篇文章我们学习了JAX的基本知识,主要是几个关键词: NumPy API、transformations、XLA。这一篇来点实际的,看下如何训练一个神经网络,我们先回忆下用PyTorch训练神经网络,都需要哪几步:

实现网络模型实现数据读取流程使用优化器/调度器更新模型参数/学习率实现模型训练和验证流程下面我们就以在MNIST数据集上训练一个MLP为例,看下在JAX中如何实现上面的流程。

NumPy API实现网络模型#MNIST是一个10分类问题,每张图片大小是 28 * 28=784 ,我们设计一个简单的MLP网络,

一个四层MLP (包含输入层)import jax

from jax import numpy as jnp

from jax import grad, jit, vmap

# 创建 PRNGKey (PRNG State)

key = jax.random.PRNGKey(0)

## 创建模型参数, 去除输入层,实际上三层Linear,每层都包含一组(w, b),共三组参数

def random_layer_params(m, n, key, scale=1e-2):

"""

A helper function to randomly initialize weights and biases

for a dense neural network layer

"""

w_key, b_key = jax.random.split(key) # 显式更新PRNG state

return scale * jax.random.normal(w_key, (n, m)), scale * jax.random.normal(b_key, (n,))

def init_network_params(sizes, key):

"""Initialize all layers for a fully-connected neural network with sizes "sizes"

"""

keys = jax.random.split(key, len(sizes)) # split可以同时创建多个key

return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]

layer_sizes = [784, 512, 512, 10]

key, init_key = jax.random.split(key) # init_key used for initialization

params = init_network_params(layer_sizes, init_key)

print(len(params), len(params[0]), len(params[1]), len(params[2]))

# 3, 2, 2, 2

print(params[0][0].shape, params[0][1].shape)

# (512, 784), (512,)

# 创建网络,实际上就是写出forward

def relu(x):

return jnp.maximum(0, x)

# 注意下面的x只是一张图片,我们并不需要自己动手去实现batched_x

def model_forward(params, x):

# per-example predictions

for w, b in params[:-1]:

x = jnp.dot(w, x) + b

x = relu(x)

final_w, final_b = params[-1]

logits = jnp.dot(final_w, x) + final_b

return logits

# 模型forward已经完成了,下面测试下

key, test_key = jax.random.split(key)

random_flattened_image = jax.random.normal(test_key, (784, ))

preds = model_forward(params, random_flattened_image)

print(preds.shape)

# (10,)

我们知道,网络的输入都是batch数据,下面就用vmap来得到一个支持batch的model_forward:

# 创建一个随机batch数据, shape=(32, 784)

random_batched_flattened_images = jax.random.normal(jax.random.PRNGKey(1), (32, 784))

model_forward(params, random_batched_flattened_images) # error

# TypeError: Incompatible shapes for dot: got (512, 784) and (32, 784).

# 创建支持batch数据的model_forward, 使用vmap so easy

batched_forward = vmap(model_forward, in_axes=(None, 0), out_axes=0)

batched_preds = batched_forward(params, random_batched_flattened_images)

print(batched_preds.shape)

# (32, 10)

借助PyTorch实现数据读取流程#准确来说,JAX并不是为深度学习而设计的框架,它并不包含任何数据集处理相关的函数和类,但是借助NumPy NDArray作为桥梁,我们可以将PyTorch中的Dataset/DataLoader和JAX DeviceArray连接起来:

PyTorch预处理 --> numpy.ndarray --> jax.numpy.array方法很简单,在创建DataLoader 时使用自定义的collate_fn,返回numpy array而不是torch Tensor。

还有一点要注意,上一篇文章介绍JAX时,我们讲过,JAX中建议使用显式的随机数生成器状态(PRNG State),所以,我们最好不使用DataLoader自带的shuffle,而是自定义Sampler。

import numpy as np

from torch.utils.data import DataLoader

from torchvision.datasets import MNIST

from torch.utils.data import Sampler, SequentialSampler

class FlattenAndCast(object):

def __call__(self, pic):

return np.ravel(np.array(pic, dtype=jnp.float32))

# DataLoader返回numpy array,而不是torch Tensor

def numpy_collate(batch):

if isinstance(batch[0], np.ndarray):

return np.stack(batch)

elif isinstance(batch[0], (tuple,list)):

transposed = zip(*batch)

return [numpy_collate(samples) for samples in transposed]

else:

return np.array(batch)

class JAXRandomSampler(Sampler):

def __init__(self, data_source, rng_key):

self.data_source = data_source

self.rng_key = rng_key

def __len__(self):

return len(self.data_source)

def __iter__(self):

self.rng_key, current_rng = jax.random.split(self.rng_key)

return iter(jax.random.permutation(current_rng, jnp.arange(len(self))).tolist())

class NumpyLoader(DataLoader):

def __init__(self, dataset, rng_key=None, batch_size=1,

shuffle=False, **kwargs):

if shuffle:

sampler = JAXRandomSampler(dataset, rng_key)

else:

sampler = SequentialSampler(dataset)

super().__init__(dataset, batch_size, sampler=sampler, **kwargs)

# 借助于torchvision和NumpyLoader

mnist_dataset_train = MNIST('/tmp/mnist/', download=True, transform=FlattenAndCast())

key, loader_key = jax.random.split(key)

train_loader = NumpyLoader(mnist_dataset_train, loader_key, batch_size=32, shuffle=True,

num_workers=0, collate_fn=numpy_collate, drop_last=True)

mnist_dataset_test = MNIST('/tmp/mnist/', download=True, train=False, transform=FlattenAndCast())

eval_loader = NumpyLoader(mnist_dataset_test, batch_size=128, shuffle=False, num_workers=0,

collate_fn=numpy_collate, drop_last=False)

### Here we set num_workers=0

使用优化器更新模型参数#我们实现一个简单的SGD,

from jax.scipy.special import logsumexp

def loss(params, images, targets):

logits = batched_forward(params, images)

preds = logits - logsumexp(logits)

return -jnp.mean(preds * targets)

@jit

def sgd_update(params, x, y, lr):

grads = grad(loss)(params, x, y)

return [(w - lr * dw, b - lr * db)

for (w, b), (dw, db) in zip(params, grads)]

上面sgd_update这种写法没啥问题,但是考虑下,如果我们为模型再添加一层layer,含有三个参数:\( W_{2} \cdot (W_{1}\cdot x) + b \),此时

\[ params = [(w_{1}, b_{1}), (w_{2}, b_{2}), (w_{3}, b_{3}), (w_{4}, w_{5}, b_{4})] \]

sgd_update最后一行的列表解析式就不太好写了,涉及到if else了,比如这样

@jit

def sgd_update(params, x, y, lr):

grads = grad(loss)(params, x, y)

return [(param[0] - lr * grad[0], param[1] - lr * grad[1]) if len(param) == 2 else

(param[0] - lr * grad[0], param[1] - lr * grad[1], param[2] - lr * grad[2])

for param, grad in zip(params, grads)]

如果网络再复杂一点,比如Transformer,里面大大小的的layer十几个,这得多少个if else

PyTree#JAX有一个叫做"PyTree"的数据结构,并且内置了jax.tree_util.tree_*模块,里面有大量针对pytree结构的函数,可以优雅的进行参数管理。

特别要说明的是,PyTree不特指某一种数据类型,而是一种概念,是一类数据类型的统称,比如Python中的list、tuple、dict都属于pytree,这也比较容易理解,pytree表示“树”结构,线性的序列当然也属于树,而数值(int、float啥的)、ndarray、字符串以及我们自定义的类则不属于pytree,它们被称为leaf (叶子)。

PyTree类型可以嵌套PyTree以及leaf,比如list可以包含几个float,甚至可以包含其他的list,而leaf就是孤立的一个数字或者一个数组,不可以嵌套list、tuple这些。

看几个例子,

[1, "a", object()] # 这个list属于pytree,它含有3个leaf: 1, "a", object()

(1, (2, 3), ()) # 这个tuple属于pytree,含有三个leaf: 1, 2, 3

[1, {"k1": 2, "k2": (3, 4)}, 5] # 这个list含有5个leaf: 1, 2, 3, 4, 5

JAX也支持用户把自定义的类注册为PyTree,这部分后面会单独写篇文章讲一下,因为PyTree是JAX最核心的数据结构,上一篇文章讲过的jaxpr只接受PyTree作为输入,返回的也是PyTree。

我们再来看下jax.tree_util.tree_*模块中的函数,比如tree_map,这是一个为PyTree数据类型设计的map函数,

ptree = (1, (2, 3), (), [(2,3,4,), 5], {"key": 2})

jax.tree_util.tree_map(lambda x: x+ 2, ptree)

# (3, (4, 5), (), [(4, 5, 6), 7], {'key': 4})

哦豁,是不是有点好用?

下面我们就更新下sgd_update,

@jit

def sgd_update(params, x, y, lr):

grads = grad(loss)(params, x, y)

return jax.tree_util.tree_map(lambda p, g: p - lr * g, params, grads)

``

`请把优雅打在公屏上,栓Q`

# 训练流程和验证流程

ok,接下来就是把代码串起来,整一个训练流程和验证流程:

```python

def one_hot(x, k=10, dtype=jnp.float32):

"""Create a one-hot encoding of x of size k."""

return jnp.array(x[:, None] == jnp.arange(k), dtype)

def accuracy(params, loader):

total_acc = 0

total_num = 0

for x, y in loader:

predicted_class = jnp.argmax(batched_forward(params, x), axis=1)

total_num += len(x)

total_acc += jnp.sum(predicted_class == y)

return total_acc / total_num

lr = 0.01

n_classes = 10

for epoch in range(5):

for idx, (x, y) in enumerate(train_loader):

y = one_hot(y, n_classes)

params = sgd_update(params, x, y, lr)

lr = lr * 0.999 if lr > 1e-3 else 1e-3 # very simple lr scheduler

if idx % 20 == 0: # evaluation

train_acc = accuracy(params, train_loader)

eval_acc = accuracy(params, eval_loader)

print("Epoch {} - batch_idx {}, Training set acc {}, eval set accuracy {}".format(

epoch, idx, train_acc, eval_acc))

# Training logs

Epoch 0 - batch_idx 0, Training set acc 0.09814999997615814, eval set accuracy 0.09950000047683716

Epoch 0 - batch_idx 100, Training set acc 0.8302666544914246, eval set accuracy 0.8362999558448792

Epoch 0 - batch_idx 200, Training set acc 0.8892666697502136, eval set accuracy 0.8940999507904053

Epoch 0 - batch_idx 300, Training set acc 0.8997166752815247, eval set accuracy 0.9006999731063843

Epoch 0 - batch_idx 400, Training set acc 0.9085167050361633, eval set accuracy 0.9128999710083008

Epoch 0 - batch_idx 500, Training set acc 0.9076499938964844, eval set accuracy 0.911300003528595

Epoch 0 - batch_idx 600, Training set acc 0.9230999946594238, eval set accuracy 0.9253000020980835

Epoch 0 - batch_idx 700, Training set acc 0.9269000291824341, eval set accuracy 0.9298999905586243

Epoch 0 - batch_idx 800, Training set acc 0.9295666813850403, eval set accuracy 0.9334999918937683

Epoch 0 - batch_idx 900, Training set acc 0.9290666580200195, eval set accuracy 0.9296999573707581

Epoch 0 - batch_idx 1000, Training set acc 0.9342833161354065, eval set accuracy 0.9357999563217163

# 大概3个epoch后,acc能达到95%

以上就是使用JAX NumPy API来实现网络训练的流程,感觉一下子回到了上个世纪。

PyTorch一大特色就是API设计非常简洁优雅,用最少的类干最多的活,比如群众喜闻乐见的以nn.Module为核心进行模型创建,我们能不能模仿一个呢?

JAX && Flax && Optax#JAX NumPy API 在torch.nn.Module 面前显得太底层了,因此,衍生了不少基于JAX的深度学习框架 (Flax、Haiku、Equinox …),有点像当年TensorFlow1时代各种高阶API混战,不过没那么夸张,现在大家基本上接受了"JAX + Flax + Optax"的三件套:

jax.numpy提供array操作函数,类似于torch.*from flax import linen as nn,对齐torch.nn.*optax.*对齐torch.optim.*,同时也包含各种损失函数jit、grad、vmap、pmap等transformations做各种胶水操作下面我们就重构下刚才的训练流程。

首先是利用nn.Module创建模型,

Module类型是dataclass,我们将超参数(比如Dense layer的size)设置为field,别忘了添加类型标注 (type annotation)由于Module类型是dataclass,所以__init__方法我们用不了了,那就在setup方法中创建模型需要的Layer,也即是sub Module,相当于torch.nn.Module.__init__在__call__方法中 实现模型前向计算过程,相当于torch.nn.Module.forward还要注意flax的Module中是不包含模型参数的,必须显式的通过init方法来创建参数, 而init方法本质上就是调用了一次__call__,只不过需要额外加上一个PRNGKey参数,比如__call__方法的参数列表是(self, *args, **kwargs),那么init方法的参数列表就是(rngkey, *args, **kwargs)。

init方法需要创建模型参数, 也就是模型参数初始化,我们知道这一步需要用到随机数生成,比如some_params = jnp.random(key, (2, 2)),所以需要用到PRNGKey。

一旦模型和参数都创建好,在调用模型时也要注意,不是model(x),而是用apply方法,model.apply()。 apply方法默认调用__call__,只不过需要额外加上一个模型参数,比如__call__方法的参数列表是 (self, *args, **kwargs),那么apply方法的参数列表就是(params, *args, rngs=, mutable=, **kwargs),其中rngs用于那些需要随机性的layer,比如dropout,而mutable用于BatchNorm等包含状态变量的layer。

import jax

from jax import numpy as jnp

from jax import grad, jit, vmap

from flax import linen as nn

from typing import Sequence

# 创建 PRNGKey (PRNG State)

key = jax.random.PRNGKey(0)

class MLP(nn.Module):

layer_sizes: Sequence[int] = None # 类型标注信息 Sequence[int]

def setup(self):

# 创建Dense()时只设置了输出维度大小,输入维度大小需要Flax进行推测

self.layers = [nn.Dense(features=size) for size in self.layer_sizes[1:]]

def __call__(self, x):

for layer in self.layers[:-1]:

x = layer(x)

x = nn.relu(x)

return self.layers[-1](x)

layer_sizes = [784, 512, 512, 10]

# 创建模型

model = MLP(layer_sizes)

# 使用`init`和dummy_x来创建模型参数

# 注意,在创建Dense()时并没有指定输入维度大小,

# `init`方法本质上也是调用`__call__`,利用dummpy_x来进行推测参数的shape

key, init_key = jax.random.split(key) # init_key used for initialization

dummy_x = jax.random.uniform(init_key, (784, ))

key, init_key = jax.random.split(key)

# init_key

params = model.init(init_key, dummy_x)

# params

# 创建一个随机batch数据, shape=(32, 784)

# 模型同Pytorch Module一样,自动支持batch数据,所以也不需要手动vmap了

random_batched_flattened_images = jax.random.normal(jax.random.PRNGKey(1), (32, 784))

model.apply(params, random_batched_flattened_images).shape

# (32, 10)

接下来我们使用Optax来创建优化器和学习率调度器,

import optax

lr = 1e-3

# 学习率调度算法

lr_decay_fn = optax.linear_schedule(

init_value=lr,

end_value=1e-5,

transition_steps=200,

)

# 直接上Adam

optimizer = optax.adam(

learning_rate=lr_decay_fn,

)

先不管Adam好不好用,经历了手撸SGD的苦日子,用上了Optax,就是这么任性!

数据读取的代码原封不动即可。

TrainState#如果我们需要对模型进行checkpoint怎么办,在PyTorch中,需要保存Model、optimizer、lr_scheduer的state_dict,Flax将这些统称为训练阶段的状态,可以用TrainState类进行封装,方便checkpoint。

state = train_state.TrainState.create(apply_fn=model.apply, params=params, tx=optimizer)

此时模型的前向计算由model.apply换成state.apply_fn即可,梯度更新也很简单state.apply_gradients(grads=grads),接下来,就是训练流程,可以简化如下:

def train_step(state, x, y):

"""Computes gradients and loss for a single batch."""

def loss_fn(params):

logits = state.apply_fn(params, x)

one_hot = jax.nn.one_hot(y, 10)

loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))

return loss

grad_fn = value_and_grad(loss_fn)

loss, grads = grad_fn(state.params)

new_state = state.apply_gradients(grads=grads)

return new_state, loss

# donate_argnums用于buffer复用,这里指的是输入和输出的state buffer复用

jit_train_step = jit(train_step, donate_argnums=(0,))

@jax.jit

def apply_model(state, x):

"""Computes gradients and loss for a single batch."""

logits = state.apply_fn(state.params, x)

return jnp.argmax(logits, -1)

def eval_model(state, loader):

total_acc = 0.

total_num = 0.

for x, y in loader:

y_pred = apply_model(state, x)

total_num += len(x)

total_acc += jnp.sum(y_pred == y)

return total_acc / total_num

for epoch in range(5):

for idx, (x, y) in enumerate(train_loader):

state, loss = jit_train_step(state, x, y)

if idx % 20 == 0: # evaluation

train_acc = eval_model(state, train_loader)

eval_acc = eval_model(state, eval_loader)

print("Epoch {} - batch_idx {}, loss {}, Training set acc {}, eval set accuracy {}".format(

epoch, idx, loss, train_acc, eval_acc))

Dropout和BatchNorm#前面我们讲过,JAX中每次涉及到随机数生成,都需要显式传入PRNGKey,如果网络中有Dropout,应该怎么处理呢?和PyTorch的Module不同,Flax的Module是不存储模型权重参数的,所以每次调用apply时都需要传入模型参数,如果网络中存在BatchNorm,又如何处理统计量数据呢?

首先,我们将上面的MLP模型改造下,添加上Dropout和BatchNorm,

class MLP(nn.Module):

def setup(self):

self.layer1 = nn.Dense(features=512)

self.dropout1 = nn.Dropout(rate=0.3)

self.norm1 = nn.BatchNorm()

self.layer2 = nn.Dense(features=512)

self.dropout2 = nn.Dropout(rate=0.4)

self.norm2 = nn.BatchNorm()

self.layer3 = nn.Dense(features=10)

def __call__(self, x, train:bool = True):

"""train用于区分train_mode or eval_mode"""

x = nn.relu(self.layer1(x))

x = self.dropout1(x, deterministic=not train)

x = self.norm1(x, use_running_average=not train)

x = nn.relu(self.layer2(x))

x = self.dropout2(x, deterministic=not train)

x = self.norm2(x, use_running_average=not train)

x = self.layer3(x)

return x

init#此时,在模型参数初始化时,就要注意了:Dropout在模型验证和Inference阶段不需要随机的dropout,在训练阶段每次前向过程都涉及随机操作,所以调用init时需要单独为"dropout"指定一个PRNGKey。

# 创建模型

model = MLP()

# 使用`init`和dummy_x来创建模型参数

key, init_key = jax.random.split(key)

dummy_x = jax.random.uniform(init_key, (784, ))

key, init_key, drop_key = jax.random.split(key, 3) # 通过split得到3个key

# "dropout"这个名字是固定的

variables = model.init({"params": init_key, "dropout": drop_key}, dummy_x, train=True)

我们再看下此时的variables,多了"batch_stats",这就是BatchNorm中的统计量moving_mean和moving_var,

variables.keys()

# frozen_dict_keys(['params', 'batch_stats'])

variables['batch_stats'].keys()

# frozen_dict_keys(['norm1', 'norm2'])

variables['batch_stats']['norm1'].keys()

# frozen_dict_keys(['mean', 'var'])

apply#在调用apply进行前向计算时,也要注意,

通过rngs为Dropout传入PRNGKeymutable指定"batch_stats"是可变的,需要在前向计算过程中进行更新apply返回结果除了y还有更新后的"batch_stats":

key, drop_key = jax.random.split(key)

y, non_trainable_params = model.apply(variables, dummy_x, train=True, rngs={"dropout": drop_key},

mutable=['batch_stats'])

non_trainable_params.keys()

# frozen_dict_keys(['batch_stats'])

参数更新#此时variables中包含了"batch_stats",我们首先新建一个TrainState类来包含batch_stats,其次 batch_stats不属于模型权重,不应该参与到optimizer的参数更新,所以训练流程也要进行修改:

class CustomTrainState(train_state.TrainState):

batch_stats: flax.core.FrozenDict[str, Any]

state = CustomTrainState.create(

apply_fn=model.apply,

params=variables['params'],

tx=optimizer,

batch_stats=variables['batch_stats'],

)

def train_step(state, x, y, dropout_key):

"""Computes gradients and loss for a single batch."""

def loss_fn(params):

logits, new_state = state.apply_fn({"params": params, "batch_stats": state.batch_stats},

x, train=True, rngs={"dropout": dropout_key}, mutable=["batch_stats"])

one_hot = jax.nn.one_hot(y, 10)

loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))

return loss, new_state

grad_fn = value_and_grad(loss_fn, has_aux=True) # `value_and_grad`在进行grad同时返回loss

(loss, new_state), grads = grad_fn(state.params)

new_state = state.apply_gradients(grads=grads, batch_stats=new_state["batch_stats"])

return new_state, loss

jit_train_step = jit(train_step, donate_argnums=(0,)) # donate_argnums用于buffer复用,这里指的是输入和输出的state buffer复用

@jax.jit

def apply_model(state, x):

"""Computes gradients and loss for a single batch."""

logits = state.apply_fn({"params":state.params, "batch_stats": state.batch_stats},

x, train=False) # train设置为False,即为eval mode

return jnp.argmax(logits, -1)

for epoch in range(5):

for idx, (x, y) in enumerate(train_loader):

key, dropout_key = jax.random.split(key)

state, loss = jit_train_step(state, x, y, dropout_key)

if idx % 100 == 0: # evaluation

train_acc = eval_model(state, train_loader)

eval_acc = eval_model(state, eval_loader)

print("Epoch {} - batch_idx {}, loss {}, Training set acc {}, eval set accuracy {}".format(

epoch, idx, loss, train_acc, eval_acc))

# some logs

Epoch 0 - batch_idx 0, loss 2.559518337249756, Training set acc 0.3179420530796051, eval set accuracy 0.31070002913475037

Epoch 0 - batch_idx 100, loss 0.3981797695159912, Training set acc 0.9382011890411377, eval set accuracy 0.9367000460624695

Epoch 0 - batch_idx 200, loss 0.29799991846084595, Training set acc 0.9520065784454346, eval set accuracy 0.9492000341415405

Epoch 0 - batch_idx 300, loss 0.22030052542686462, Training set acc 0.9536759257316589, eval set accuracy 0.9513000249862671

Epoch 0 - batch_idx 400, loss 0.22531506419181824, Training set acc 0.9540432095527649, eval set accuracy 0.950700044631958

Epoch 1 - batch_idx 0, loss 0.2441655695438385, Training set acc 0.954594075679779, eval set accuracy 0.9508000612258911

Epoch 1 - batch_idx 100, loss 0.14692620933055878, Training set acc 0.9552618265151978, eval set accuracy 0.9508000612258911

源码#以上就是使用JAX + Flax + Optax训练神经网络的简单示例,上面的代码我已放到GitHub :)

jax-tutorials-for-pytorchers

参考资料#[1] JAX文档,JAX reference documentation

[2] Flax文档,Flax documentation

[3] Optax文档,https://optax.readthedocs.io/en/latest/

[4] Pytorch Dataloders for Jax, https://colab.research.google.com/github/kk1694/blog/blob/master/_notebooks/2021-05-03-Pytorch_Dataloaders_for_Jax.ipynb

[5] Flax的基本用法, https://colab.research.google.com/github/google/flax/blob/main/docs/notebooks/flax_basics.ipynb

[6] Flax文档中的Annotated MNIST,Flax documentation

[7] Dropout和BatchNorm的例子参考 Machine Learning with Flax - From Zero to Hero Machine Learning with Flax - From Zero to Hero 和 HuggingFace BERT Flax实现

相关推荐