Parallism in Deep Learning
Source:
- Exploring Parallel Strategies with Jax
- Training Deep Networks with Data Parallelism in Jax
Tensor sharding
Tensor sharding: In jax, we can split tensor x
by multiple sub-tensors and place each on its own device. But before we do that, we need to create a sharding
object, which is basically a device placement configuration:
1 | from jax.sharding import PositionalSharding |
We can split our tensor x
in multiple ways. For example, we can split it column-wise along embedding dimension:
1 | G = jax.local_device_count() |
If we print sharded_x.devices()
it will give us a list of all devices, which is not very informative since it tells us nothing about our tensor sharding. Luckily, we have visualize_array_sharding
function from jax.debug
which gives us a pretty visual idea on how x
is sharded:
Notation
Symbol | Meaning |
---|---|
\(G\) | The number of all avaiable workers, or devices. |
\(x \in \mathbb R^{1 \times d}\) | The single data point with dimension number \(d\). |
\(\mathbf{x} \in \mathbb R^{B \times d}\) | The batched input data with batch size is \(B\). |
\(\mathbf{x}_{\text{partitioned}} \in \mathbb R^{S \times d}\) | The partitioned batched input data, also called microbatch, with batch size \(S=\frac{B}{G}\). |
\(\mathbf{x}_{\text{partitioned}_i} \in \mathbb R^{S \times d}\) | The \(i\)th specific microbatch. |
\({z} \in \mathbb R^{1 \times d}\) | The singl embedding with dimension number \(d\). |
\(\mathbf{z} \in \mathbb R^{B \times d}\) | The batched embedding with batch size is \(B\). |
\(\mathbf{W} \in \mathbb R^{d \times h}\) | The weight matrix of a linear layer with hidden demenstion \(h\). |
\(\mathbf{B} \in \mathbb R^{1 \times h}\) | The bias matrix of a linear layer with hidden demenstion \(h\). |
\(\sigma(.)\) | The activation function, can be ReLU or anything else. |
Data parallelism
In Data Parallelism (DP), we copy the model across each worker but partition the training data \(\mathbf{x} \in \mathbb R^{B \times d}\). Each worker has its independent microbatch, denoted as \(\mathbf{x}_{\text{partitioned}_i}\), and a complete replica of the model.
During backward propagation:
- Each worker computes its loss and gradient using its microbatch \(\mathbf{x}_{\text{partitioned}_i}\).
- Each worker sends its loss and gradient to to a central server. The central server will calculate the mean of loss and grad and send them back to each worker.
- Each device then updates its using the average gradients.
NOTE: Since all workers initially have a complete replica of the model parameters, and workers always use the same gradient to update their parameters, the copy of the parameters will always be the same accross the workers.
1 | from typing import NamedTuple, Tuple |
- By using
pmap
we get the same vectorized functionality asvmap
but on multiple devices. - Here, the
functools.partial
decorator wraps theupdate
function with apmap
withaxis_name='num_devices'
as an input argument topmap
. This means that theupdate
function will be applied in parallel across all devices. Thepmean
function is used to average the gradients across all devices. Thepmean
function is similar tonp.mean
but it also takes anaxis_name
argument. This argument is used to specify which axis to average across. In this case, we average across thenum_devices
axis but this name is just a placeholder, we can change it to any string (e.g.'i'
or'data'
and it will work the same).
Problem: update
itself contains pmean
, which should require multiple device info. When the pmap
ed update
is run, the original update
itself will run. Therefore, it should fail due to pmean
. Why this doesn't happen?
Cost
The drawback is that during each BP, we need to do a "synchronization", which requires all the \(G\) replacated params. Suppose our model size is very big, say 10B, and we have \(G=8\), then we need to transfering 80B during every BP.
Data Parallel is also called Distributed Data Parallel (DDP) and it is different from PyTorch definition of data paralelism (also check out HuggingFace explanation). PyTorch version of DP helps to overcome slow intra-node connectivity by minimizing the amount of synchronized data and delegating a lot of data/gradient processing to one leading GPU. This, in turn, results in under-utilization of other devices.
Tensor parallelism
The idea of sharding, the way it was applied to data tensors, can be used in a similar way with respect to the model weights. We can divide each tensor \(\mathbf{W}\) into chunks distributed across multiple devices, so instead of having the whole tensor reside on a single device, each shard of the tensor resides on its own accelerator.
Each part gets processed separately in parallel on different devices and after processing the results are synced at the end of the step.
Say we have a linear layer with parameters \(\mathbf{W} \in \mathbb R^{d \times h}\) and \(\mathbf{B} \in \mathbb R^{1 \times h}\), we can represent matrices as concatenation of \(G\) sub-matrices along rows \[ \begin{aligned} \mathbf{W} & = \left( \begin{matrix} \color{teal}{\mathbf{W}_1} & \color{salmon}{\mathbf{W}_2} & \cdots & \color{orange}{\mathbf{W}_G} \end{matrix} \right) \\ \mathbf{B} &= \left( \begin{matrix} \color{teal}{\mathbf{B}_1} & \color{salmon}{\mathbf{B}_2} & \cdots & \color{orange}{\mathbf{B}_G} \end{matrix} \right) \end{aligned} \] with \(\mathbf{W}_k \in \mathbb{R}^{d \times \frac{h}{G}}\) and \(\mathbf{B}_k \in \mathbb{R}^{1 \times \frac{h}{G}}\) for \(k=1, \ldots, G\).
With input data \(x\) replicated over \(G\) devices we can perform sub-matrices computations in parallel: \[ x \mathbf{W} + \mathbf{B} = \begin{pmatrix} \color{teal}{x \mathbf{W}_1 + \mathbf{B}_1} & \color{salmon}{x \mathbf{W}_2 + \mathbf{B}_2} & \cdots & \color{orange}{x \mathbf{W}_G + \mathbf{B}_G} \end{pmatrix} . \] Using this method to compute \(x \mathbf{W} + \mathbf{B}\), we then apply the activation function \(\sigma(.)\), getting the embedding \[ z=\sigma \left(x \mathbf{W} + \mathbf{B} \right) \quad \in \mathbb R^{1 \times d} . \]
Cost
Since each worker computes its graidents and update its paramters separately. There is no need to synchronize gradients.
However, each worker has to processes the same training data \(\mathbf{x} \in \mathbb R^{B \times d}\), making the total memory cost (across all \(G\) wokers) very big.