Parallism in Deep Learning

Source:

  1. Exploring Parallel Strategies with Jax
  2. Training Deep Networks with Data Parallelism in Jax

Tensor sharding

Tensor sharding: In jax, we can split tensor x by multiple sub-tensors and place each on its own device. But before we do that, we need to create a sharding object, which is basically a device placement configuration:

1
2
3
from jax.sharding import PositionalSharding

sharding = PositionalSharding(jax.devices())

We can split our tensor x in multiple ways. For example, we can split it column-wise along embedding dimension:

1
2
G = jax.local_device_count()
sharded_x = jax.device_put(x, sharding.reshape(1, G))

If we print sharded_x.devices() it will give us a list of all devices, which is not very informative since it tells us nothing about our tensor sharding. Luckily, we have visualize_array_sharding function from jax.debug which gives us a pretty visual idea on how x is sharded:

Notation

Symbol Meaning
\(G\) The number of all avaiable workers, or devices.
\(x \in \mathbb R^{1 \times d}\) The single data point with dimension number \(d\).
\(\mathbf{x} \in \mathbb R^{B \times d}\) The batched input data with batch size is \(B\).
\(\mathbf{x}_{\text{partitioned}} \in \mathbb R^{S \times d}\) The partitioned batched input data, also called microbatch, with batch size \(S=\frac{B}{G}\).
\(\mathbf{x}_{\text{partitioned}_i} \in \mathbb R^{S \times d}\) The \(i\)th specific microbatch.
\({z} \in \mathbb R^{1 \times d}\) The singl embedding with dimension number \(d\).
\(\mathbf{z} \in \mathbb R^{B \times d}\) The batched embedding with batch size is \(B\).
\(\mathbf{W} \in \mathbb R^{d \times h}\) The weight matrix of a linear layer with hidden demenstion \(h\).
\(\mathbf{B} \in \mathbb R^{1 \times h}\) The bias matrix of a linear layer with hidden demenstion \(h\).
\(\sigma(.)\) The activation function, can be ReLU or anything else.

Data parallelism

In Data Parallelism (DP), we copy the model across each worker but partition the training data \(\mathbf{x} \in \mathbb R^{B \times d}\). Each worker has its independent microbatch, denoted as \(\mathbf{x}_{\text{partitioned}_i}\), and a complete replica of the model.

During backward propagation:

  1. Each worker computes its loss and gradient using its microbatch \(\mathbf{x}_{\text{partitioned}_i}\).
  2. Each worker sends its loss and gradient to to a central server. The central server will calculate the mean of loss and grad and send them back to each worker.
  3. Each device then updates its using the average gradients.

NOTE: Since all workers initially have a complete replica of the model parameters, and workers always use the same gradient to update their parameters, the copy of the parameters will always be the same accross the workers.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from typing import NamedTuple, Tuple
import functools

# class for storing model parameters
class Params(NamedTuple):
weight: jnp.ndarray
bias: jnp.ndarray

# function for initializing model parameters
def init(rng) -> Params:
"""Returns the initial model params."""
weights_key, bias_key = jax.random.split(rng)
weight = jax.random.normal(weights_key, ())
bias = jax.random.normal(bias_key, ())
return Params(weight, bias)

# function for computing the MSE loss
def loss_fn(params: Params, xs: jnp.ndarray, ys: jnp.ndarray) -> jnp.ndarray:
"""Computes the least squares error of the model's predictions on x against y."""
pred = params.weight * xs + params.bias
return jnp.mean((pred - ys) ** 2)

# function for performing one SGD update step (fwd & bwd pass)
@functools.partial(jax.pmap, axis_name='num_devices')
def update(params: Params, xs: jnp.ndarray, ys: jnp.ndarray) -> Tuple[Params, jnp.ndarray]:
loss, grads = jax.value_and_grad(loss_fn)(params, xs, ys)
grads = jax.lax.pmean(grads, axis_name='num_devices')
loss = jax.lax.pmean(loss, axis_name='num_devices')
new_params = jax.tree_map(
lambda param, g: param - g * LEARNING_RATE, params, grads)

return new_params, loss
  • By using pmap we get the same vectorized functionality as vmap but on multiple devices.
  • Here, the functools.partial decorator wraps the update function with a pmap with axis_name='num_devices' as an input argument to pmap. This means that the update function will be applied in parallel across all devices. The pmean function is used to average the gradients across all devices. The pmean function is similar to np.mean but it also takes an axis_name argument. This argument is used to specify which axis to average across. In this case, we average across the num_devices axis but this name is just a placeholder, we can change it to any string (e.g. 'i' or 'data' and it will work the same).

Problem: update itself contains pmean, which should require multiple device info. When the pmaped update is run, the original update itself will run. Therefore, it should fail due to pmean. Why this doesn't happen?

Cost

The drawback is that during each BP, we need to do a "synchronization", which requires all the \(G\) replacated params. Suppose our model size is very big, say 10B, and we have \(G=8\), then we need to transfering 80B during every BP.

Data Parallel is also called Distributed Data Parallel (DDP) and it is different from PyTorch definition of data paralelism (also check out HuggingFace explanation). PyTorch version of DP helps to overcome slow intra-node connectivity by minimizing the amount of synchronized data and delegating a lot of data/gradient processing to one leading GPU. This, in turn, results in under-utilization of other devices.

Tensor parallelism

The idea of sharding, the way it was applied to data tensors, can be used in a similar way with respect to the model weights. We can divide each tensor \(\mathbf{W}\) into chunks distributed across multiple devices, so instead of having the whole tensor reside on a single device, each shard of the tensor resides on its own accelerator.

Each part gets processed separately in parallel on different devices and after processing the results are synced at the end of the step.

Say we have a linear layer with parameters \(\mathbf{W} \in \mathbb R^{d \times h}\) and \(\mathbf{B} \in \mathbb R^{1 \times h}\), we can represent matrices as concatenation of \(G\) sub-matrices along rows \[ \begin{aligned} \mathbf{W} & = \left( \begin{matrix} \color{teal}{\mathbf{W}_1} & \color{salmon}{\mathbf{W}_2} & \cdots & \color{orange}{\mathbf{W}_G} \end{matrix} \right) \\ \mathbf{B} &= \left( \begin{matrix} \color{teal}{\mathbf{B}_1} & \color{salmon}{\mathbf{B}_2} & \cdots & \color{orange}{\mathbf{B}_G} \end{matrix} \right) \end{aligned} \] with \(\mathbf{W}_k \in \mathbb{R}^{d \times \frac{h}{G}}\) and \(\mathbf{B}_k \in \mathbb{R}^{1 \times \frac{h}{G}}\) for \(k=1, \ldots, G\).

With input data \(x\) replicated over \(G\) devices we can perform sub-matrices computations in parallel: \[ x \mathbf{W} + \mathbf{B} = \begin{pmatrix} \color{teal}{x \mathbf{W}_1 + \mathbf{B}_1} & \color{salmon}{x \mathbf{W}_2 + \mathbf{B}_2} & \cdots & \color{orange}{x \mathbf{W}_G + \mathbf{B}_G} \end{pmatrix} . \] Using this method to compute \(x \mathbf{W} + \mathbf{B}\), we then apply the activation function \(\sigma(.)\), getting the embedding \[ z=\sigma \left(x \mathbf{W} + \mathbf{B} \right) \quad \in \mathbb R^{1 \times d} . \]

Cost

Since each worker computes its graidents and update its paramters separately. There is no need to synchronize gradients.

However, each worker has to processes the same training data \(\mathbf{x} \in \mathbb R^{B \times d}\), making the total memory cost (across all \(G\) wokers) very big.