MoCo and BYOL
Sources:
- For MoCo:
- MoCo v1 2019 paper
- MoCo v2 2020 paper
- Paper explained — Momentum Contrast for Unsupervised Visual Representation Learning [MoCo] by Nazim Bendib
- For BYOL:
- BYOL 2020 paper
- Neural Networks Intuitions: 10. BYOL- Paper Explanation by Raghul Asokan
- Understanding self-supervised and contrastive learning with "Bootstrap Your Own Latent" (BYOL) by imbue
MoCo and BYOL are both famous contrastive learning and self-supervised learning frameworks. They introduce an interesting designs:
- Using a online network and a fixed target network to mitigate the collapse problem in contrastive learning.
- Use the same architecture for these two networks while only training the online network. The target network is updated via EMA (expenential moving average).
This design is heavily used in further research like DINO as is worth learning about.
You may need to read my post for SimCLR to get a deeper understanding of contrastive learning.
Motivation of MoCo (and BYOL)
The contrastive learning methods suffer from collapsed representation problems. Authors of SimCLR found that using a fixed randomly initialized network as a target network and train another network to learn representations can mitigate this problem.
To make it more clear:
- Take a fixed randomly initialized network and name it as target network (or momentum network).
- Take a trainable network and name it as online network.
- Pass an input image through target and online networks and extract target and predicted embeddings respectively.
- Minimize the distance between both embeddings — euclidean distance or cosine similarity loss.
Even though this approach does not result in a collapsed solution, it does not produce useful representations as we are relying on a random network for targets.
However, we can apply a trick. We know that the target network is not good at making representation, but it's resilient to to collapse, and that the online network is good at making representation, but its susceptible to collapse. As a result, we can
- Use the online network as our target network for some iterations.
- During these iterations, conduct the constrastive learning process as described above. But we only train the online network.
- After that, we use EMA to copy the parameters of the online network to the target netwoek.
- Repeat 1-4.
This is the idea of MOCO (and BYOL later).
MoCo
MoCo (stands for Momentum Contrast) derives from the above idea. In its implementation, it adopts the idea of dictionary learning and aims do build dynamic dictionaries. Contrastive learning can be thought of as training an encoder for a dictionary look-up task.
The "keys" (tokens) in the dictionary are sampled from data (e.g., images or patches) and are represented by an encoder network. Unsupervised learning trains encoders to perform dictionary look-up: an encoded “query” should be similar to its matching key and dissimilar to others. Learning is formulated as minimizing a contrastive loss.
- The dictionary is built as a queue: As the size of the dictionary is limited by memory, it still needs to sample the whole datasets as good as possible. An intuitive solution would be to use a queue as a dictionary that will repeatedly get updated during the training, such as in a given learning step, the most recent encoded keys are enqueued (pushed), and the oldest ones are dequeued.
- we use a query encoder as the online network and a key encoder model as the momentum network. They have the same architecture and each takes an input as an image (or a patch) and output its representation.
Contrastive loss
Consider an encoded query \(q\) and a set of encoded samples \(\left\{k_0, k_1, k_2, \ldots\right\}\) that are the keys of a dictionary. Assume that there is a single key (denoted as \(k_{+}\)) in the dictionary that \(q\) matches. The contrastive loss function is
\[ \mathcal{L}_q=-\log \frac{\exp \left(q \cdot k_{+} / \tau\right)}{\sum_{i=0}^K \exp \left(q \cdot k_i / \tau\right)} \] where:
- \(q\) is the query representation.
- \(k+\) is the corresponding key (the one generated from the other augmented view of the query ) we call it the positive sample.
- \(k-i\) are the keys that are stored in the dictionary, we call them the negative samples.
- \(t\) is the temperature and it indicates how strict the loss function is with the negative samples.
This loss function (like most contrastive loss functions) will try to minimize the distance between \(q\) and \(k+\), and maximize it between \(q\) and \(k\)-. In other words, it will pull the views that have been augmented from the same sample together and push away those that didn't come from the same samples.
Expenential moving average
The key encoder won't get updated using back propagation, as the gradients will flow all the way down to all of the dictionary representation which makes it quite expensive. So how do we update it?
A naive solution would be to copy the parameters of the key encoder from the query encoder (like in DQN), but this solution yields bad experimental results, because the parameters were rapidly changing which led to reducing the representations' consistency of the dictionary.
Another solution is to do momentum updates (or expenential moving average, EMA). The parameters of the encoder would update by: \[ \theta_k \leftarrow m \theta_k+(1-m) \theta_q . \] Here \(m \in[0,1)\) is a momentum coefficient. Only the parameters \(\theta_\) are updated by back propagation.
Training process
We first start by uniformly sampling a batch of images from the dataset upon which we create two randomly augmented sets of views: the query views and the keys views. Both views are then passed to the query encoder and key encoder respectively, that will output the query representation and the key representation.
BYOL
BYOL (stands for Bootstrap Your Own Latent) builds on the momentum network concept of MoCo, it has three differences:
- It is not formulated as a dictionary learning problem abd doesn't have dictionaries (queries, keys, deques, etc).
- It adds an MLP to the outine network as the prediction head, making the architecture asymmetric (since the target network doesn't have a prediction head).
- Recallnig that MoCo leverages a contrastive loss and thus requiring negative samples. BYOL uses a L2 error between the normalized prediction \(q(z)\) and target \(z'\), thus removing the need of negative samples.
Figure 5: BYOL's architecture. At the end of training, everything but \(f_\theta\) is discarded, and \(y_\theta\) is used as the image representation.
Training process
Given a set of images \(\mathcal{D}\), an image \(x \sim \mathcal{D}\) sampled uniformly from \(\mathcal{D}\), and two distributions of image augmentations \(\mathcal{T}\) and \(\mathcal{T}^{\prime}\), BYOL produces two augmented views \(v \triangleq t(x)\) and \(v^{\prime} \triangleq t^{\prime}(x)\) from \(x\) by applying respectively image augmentations \(t \sim \mathcal{T}\) and \(t^{\prime} \sim \mathcal{T}^{\prime}\).
From the first augmented view \(v\), the online network outputs a representation \(y_\theta \triangleq f_\theta(v)\) and a projection \(z_\theta \triangleq g_\theta(y)\). The target network outputs \(y_{\xi}^{\prime} \triangleq f_{\xi}\left(v^{\prime}\right)\) and the target projection \(z_{\xi}^{\prime} \triangleq g_{\xi}\left(y^{\prime}\right)\) from the second augmented view \(v^{\prime}\).
We then output a prediction \(q_\theta\left(z_\theta\right)\) of \(z_{\xi}^{\prime}\) and \(\ell_2\)-normalize both \(q_\theta\left(z_\theta\right)\) and \(z_{\xi}^{\prime}\) to \(\overline{q_\theta}\left(z_\theta\right) \triangleq q_\theta\left(z_\theta\right) /\left\|q_\theta\left(z_\theta\right)\right\|_2\) and \(\bar{z}_{\xi}^{\prime} \triangleq z_{\xi}^{\prime} /\left\|z_{\xi}^{\prime}\right\|_2\).
- Note that this predictor is only applied to the online branch, making the architecture asymmetric between the online and target pipeline. Finally we define the
We minimizes a similarity loss between \(\overline{q_\theta}\left(z_\theta\right)\) and \(\operatorname{sg}\left(\bar z_{\xi}^{\prime}\right)\), where \(\theta\) are the trained weights, \(\xi\) are an exponential moving average of \(\theta\) and sg means stop-gradient.
Update the target network with EMA, which has been discussed before, of the previous online networks. \[ \xi \leftarrow \tau \xi+(1-\tau) \theta . \]
Repeat steps 2–5.
In fact, the step 4 above is a little more complex. We first define the following mean squared error loss function: \[ \begin{equation} \label{eq_BYOL_2} \mathcal{L}_{\theta, \xi} \triangleq\left\|\overline{q_\theta}\left(z_\theta\right)-\bar{z}_{\xi}^{\prime}\right\|_2^2=2-2 \frac{\left\langle q_\theta\left(z_\theta\right), z_{\xi}^{\prime}\right\rangle}{\left\|q_\theta\left(z_\theta\right)\right\|_2 \cdot\left\|z_{\xi}^{\prime}\right\|_2} . \end{equation} \]
Then, we symmetrize the loss \(\mathcal{L}_{\theta, \xi}\) in \(\eqref{eq_BYOL_2}\) by separately feeding \(v^{\prime}\) to the online network and \(v\) to the target network to compute \(\widetilde{\mathcal{L}}_{\theta, \xi}\).
The final loss function is: \[ \mathcal{L}_{\theta, \xi}^{\text {BYOL }}=\mathcal{L}_{\theta, \xi}+\widetilde{\mathcal{L}}_{\theta, \xi} \] We only optimize \(\theta\) , but not \(\xi\), as depicted by the stop-gradient in Figure 5. BYOL's optimizing process is summarized a5 \[ \theta \leftarrow \operatorname{optimizer}\left(\theta, \nabla_\theta \mathcal{L}_{\theta, \xi}^{\text {BYoL }}, \eta\right) . \]