MoCo and BYOL
Sources:
- For MoCo:
- For BYOL:
MoCo and BYOL are both famous contrastive learning and self-supervised learning frameworks. They introduce an interesting designs:
- Using a online network and a fixed target network to mitigate the collapse problem in contrastive learning.
- Use the same architecture for these two networks while only training the online network. The target network is updated via EMA (expenential moving average).
This design is heavily used in further research like DINO as is worth learning about.
You may need to read my post for SimCLR to get a deeper understanding of contrastive learning.
Motivation of MoCo (and BYOL)
The contrastive learning methods suffer from collapsed representation problems. Authors of SimCLR found that using a fixed randomly initialized network as a target network and train another network to learn representations can mitigate this problem.
To make it more clear:
- Take a fixed randomly initialized network and name it as target network (or momentum network).
- Take a trainable network and name it as online network.
- Pass an input image through target and online networks and extract target and predicted embeddings respectively.
- Minimize the distance between both embeddings — euclidean distance or cosine similarity loss.
Even though this approach does not result in a collapsed solution, it does not produce useful representations as we are relying on a random network for targets.
However, we can apply a trick. We know that the target network is not good at making representation, but it's resilient to to collapse, and that the online network is good at making representation, but its susceptible to collapse. As a result, we can
- Use the online network as our target network for some iterations.
- During these iterations, conduct the constrastive learning process as described above. But we only train the online network.
- After that, we use EMA to copy the parameters of the online network to the target netwoek.
- Repeat 1-4.
This is the idea of MOCO (and BYOL later).
MoCo
MoCo (stands for Momentum Contrast) derives from the above idea. In its implementation, it adopts the idea of dictionary learning and aims do build dynamic dictionaries. Contrastive learning can be thought of as training an encoder for a dictionary look-up task.
The "keys" (tokens) in the dictionary are sampled from data (e.g., images or patches) and are represented by an encoder network. Unsupervised learning trains encoders to perform dictionary look-up: an encoded “query” should be similar to its matching key and dissimilar to others. Learning is formulated as minimizing a contrastive loss.
- The dictionary is built as a queue: As the size of the dictionary is limited by memory, it still needs to sample the whole datasets as good as possible. An intuitive solution would be to use a queue as a dictionary that will repeatedly get updated during the training, such as in a given learning step, the most recent encoded keys are enqueued (pushed), and the oldest ones are dequeued.
- we use a query encoder as the online network and a key encoder model as the momentum network. They have the same architecture and each takes an input as an image (or a patch) and output its representation.
Contrastive loss
Consider an encoded query
is the query representation. is the corresponding key (the one generated from the other augmented view of the query ) we call it the positive sample. are the keys that are stored in the dictionary, we call them the negative samples. is the temperature and it indicates how strict the loss function is with the negative samples.
This loss function (like most contrastive loss functions) will try to minimize the distance between
Expenential moving average
The key encoder won't get updated using back propagation, as the gradients will flow all the way down to all of the dictionary representation which makes it quite expensive. So how do we update it?
A naive solution would be to copy the parameters of the key encoder from the query encoder (like in DQN), but this solution yields bad experimental results, because the parameters were rapidly changing which led to reducing the representations' consistency of the dictionary.
Another solution is to do momentum updates (or expenential moving average, EMA). The parameters of the encoder would update by:
Training process
We first start by uniformly sampling a batch of images from the dataset upon which we create two randomly augmented sets of views: the query views and the keys views. Both views are then passed to the query encoder and key encoder respectively, that will output the query representation and the key representation.
BYOL
BYOL (stands for Bootstrap Your Own Latent) builds on the momentum network concept of MoCo, it has three differences:
- It is not formulated as a dictionary learning problem abd doesn't have dictionaries (queries, keys, deques, etc).
- It adds an MLP to the outine network as the prediction head, making the architecture asymmetric (since the target network doesn't have a prediction head).
- I think the idea of this projection is just an engineering choice.
- Recallnig that MoCo leverages a contrastive loss and thus requiring negative samples. BYOL uses a L2 error between the normalized prediction
and target , thus removing the need of negative samples.
Figure 5: BYOL's architecture. At the end of training, everything but
Overall, BYOL uses a student-teacher architecute proposed by MOCO v1, but not using negative samples. It just pushes two positive embeddings closer.
The idea of BYOL is adopted in DINO.
Training process
Given a set of images
, an image sampled uniformly from , and two distributions of image augmentations and , BYOL produces two augmented views and from by applying respectively image augmentations and .From the first augmented view
, the online network outputs a representation and a projection . The target network outputs and the target projection from the second augmented view .We then output a prediction
of and -normalize both and to and .- Note that this predictor is only applied to the online branch, making the architecture asymmetric between the online and target pipeline. Finally we define the
We minimizes a similarity loss between
and , where are the trained weights, are an exponential moving average of and sg means stop-gradient.Update the target network with EMA, which has been discussed before, of the previous online networks.
Repeat steps 2–5.
In fact, the step 4 above is a little more complex. We first define the following mean squared error loss function:
Then, we symmetrize the loss
The final loss function is: