World Models
Sources:
- World Models 2018 paper
- Training doc
- Mastering Diverse Domains through World Models, aka Dreamer V3.
- A Generalist Agent.
- Learning General World Models in a Handful of Reward-Free Deployments.
- Learning General World Models in a Handful of Reward-Free Deployments.
Interesting world models
- Mastering Diverse Domains through World Models, aka Dreamer V3.
- A Generalist Agent increase the generality through implementing an algorithm that genetes multi-modal world models. It converts multi-modal inputs to flat sequence of tokens separately(tokenization). Then uses a transformer sequence model to deal with them uniformly.
- Learning General World Models in a Handful of Reward-Free Deployments, aka CASCADE presents a reward-free problem setting instead of task-specific reward functions to gather diverse, highly informative data. Thus improving generality.
- The reason is that traditional RL methods collect only a few transitions with a single exploration policy, they likely produce a homogenous dataset when deployed at scale, which does not optimally improve the model.
- inspired by Bayesian Active Learning
- DayDreamer: World Models for Physical Robot Learning:
Agent model
Specifically, at each timestamp \(t\), the action \(a_{t+1}\) taken by an agent depends on \(v_t\) and \(h_t\), where
- \(z_t\) is the embedding of the sensory data at time \(t\).
- \(h_t\) is the hidden state of the world model at time \(t\).
Our agent consists of three components:
Vision (V), Memory (M), and Controller (C).
VAE (V) Model
MDN-RNN (M) Model
Controller (C) Model
C is a simple single layer linear model that maps ztzt and htht directly to action atat at each time step: \[ a_t = W_c[z_t \quad h_t] + b_c \]
In this linear model, WcWc and bcbc are the weight matrix and bias vector that maps the concatenated input vector [ztht][ztht] to the output action vector atat.3
Agent inferrence
The following flow diagram illustrates how V, M, and C interacts with the environment:
Below is the pseudocode for how our agent model is used in the OpenAI Gym [28] environment. Running this function on a given controller
C will return the cumulative reward during a rollout of the environment.
1 | def rollout(controller): |
The preceding figure is the architecture of the world model presented in the paper Recurrent World Models Facilitate Policy Evolution.
It has three components:
- Vision Model \(V\). At time stamp \(t\), it Encodes the input visual information to embeddings \(z_t\).
- Implementaion: Variational Autoencoder (VAE).
- Memory RNN \(M\). At time stamp \(t\), it outputs the probability distribution \(P(z_{t+1}|a_t, z_t, h_t)\) of the next latent vector \(z_{t+1}\) given the current and past information.
- Implementaion: RNN.
- Actually, it'll output \(h_{t+1}\) rather than \(z_{t+1}\) since the former suffice to represent the latter.
- A very simple Controller \(C\). At time stamp \(t\), it maps \(z_t\) and \(h_t\) directly to action \(a_t\).
- Implementaion: linear layer nn.
World model = \(V + M\).
World Models这篇paper还做了如下的细节处理:
- 我们知道现实世界并不是确定性的, 随时会有意外发生, 所以world model并不总是能做出正确的预测. 例如, 我开车的前2h都没有见过路障, 但某个时刻突然迎面而来一只兔子, 这是前所未有的信息. 为了让world model能够模拟出现实世界的随机性, 该paper引入了一个temperature变量来调整world model的预测的随机值. 同时world model的输出也不是确定性的未来, 而是未来的一个概率分布.
- agent有可能利用world model的规则漏洞(毕竟它太简单了), 导致over fit. 但我没看懂该paper对该问题的处理.
DreamerV3的工作是证明了world model具有很强的generality, 原来的world model只应用来2D的自动导航和射击游戏里, 现在可以用在2D和3D中. 并且DreamerV3还指出了world model具有良好的可扩展性(scaling properties).
不过我不能理解DreamerV3的arch, 那个actor-critic是干嘛的?
Agent training
- Collect 10,000 rollouts from a random policy.
- Train VAE (V) use (the iamge frames of) this dataset.
- Now, use the pretrained VAE to train the MDN-RNN (M) to model \(P(z_{t+1}|a_t, z_t, h_t)\).
- In this experiment, the world model (V and M) has no knowledge about the actual reward signals from the environment. Its task is simply to compress and predict the sequence of image frames observed.
- Evolve Controller (C) to maximize the expected cumulative reward of a rollout.
Model | Parameter Count |
---|---|
VAE | 4,348,547 |
MDN-RNN | 422,368 |
Controller | 867 |
Interesting things
An interesting connection to the neuroscience literature is the work on hippocampal replay that examines how the brain replays recent experiences when an animal rests or sleeps. Replaying recent experiences plays an important role in memory consolidation [68] — where hippocampus-dependent memories become independent of the hippocampus over a period of time [67]. As Foster [68] puts it, replay is “less like dreaming and more like thought”.
DreamerV3
the world model, the critic, and the actor—that are trained concurrently from replayed experience without sharing gradients,
DreamerV2 achieves comparable or higher performance on most games except for Video Pinball. We hypothesize that the reconstruction loss of the world model does not encourage learning a meaningful latent representation because the most important object in the game, the ball, occupies only a single pixel.
DreamerV3:
Gato
Tokenization
- Text is encoded via SentencePiece (Kudo & Richardson, 2018) with 32000 subwords into the integer range [0, 32000).
- Images are first transformed into sequences of non-overlapping 16 × 16 patches in raster order, as done in ViT (Dosovitskiy et al., 2020). Each pixel in the image patches is then normalized between and divided by the square-root of the patch size (i.e.
- Discrete values, e.g. Atari button presses, are flattened into sequences of integers in row-major order. The tokenized result is a sequence of integers within the range of [0, 1024).
- Continuous values, e.g. proprioceptive inputs or joint torques, are first flattened into sequences of floating point values in row-major order. The values are mu-law encoded to the range [−1, 1] if not already there (see Figure 14 for details), then discretized to 1024 uniform bins. The discrete integers are then shifted to the range of [32000, 33024).
Training
Given a sequence of tokens \(s_{1: L}\) and parameters \(\theta\), we model the data using the chain rule of probability: \[ \log p_\theta\left(s_1, \ldots, s_L\right)=\sum_{l=1}^L \log p_\theta\left(s_l \mid s_1, \ldots, s_{l-1}\right), \]
Let \(b\) index a training batch of sequences \(\mathcal{B}\). We define a masking function \(m\) such that \(m(b, l)=1\) if the token at index \(l\) is either from text or from the logged action of an agent, and 0 otherwise. The training loss for a batch \(\mathcal{B}\) can then be written as \[ \mathcal{L}(\theta, \mathcal{B})=-\sum_{b=1}^{|\mathcal{B}|} \sum_{l=1}^L m(b, l) \log p_\theta\left(s_l^{(b)} \mid s_1^{(b)}, \ldots, s_{l-1}^{(b)}\right) \]
As described above, Gato's network architecture has two main components:
- the parameterized embedding function which transforms tokens to token embeddings, and
- the sequence model which outputs a distribution over the next discrete token. While any general sequence model can work for next token prediction, we chose a transformer (Vaswani et al., 2017) for simplicity and scalability.
Because distinct tasks within a domain can share identical embodiments, observation formats and action specifications, the model sometimes needs further context to disambiguate tasks. Rather than providing e.g. one-hot task identifiers, we instead take inspiration from (Sanh et al., 2022; Wei et al., 2021; Brown et al., 2020) and use prompt conditioning. During training, for 25% of the sequences in each batch, a prompt sequence is prepended, coming from an episode generated by the same source agent on the same task. Half of the prompt sequences are from the end of the episode, acting as a form of goal conditioning for many domains; and the other half are uniformly sampled from the episode. During evaluation, the agent can be prompted using a successful demonstration of the desired task, which we do by default in all control results that we present here.
Running
CASCADE
Problem Statement
Reinforcement learning (RL) considers training an agent to solve a Markov Decision Process (MDP), represented as a tuple \(\mathcal{M}=\{\mathcal{S}, \mathcal{A}, P, R, \rho, \gamma\}\), where
- \(s \in \mathcal{S}\) and \(a \in \mathcal{A}\) are the set of states and actions respectively,
- \(P\left(s^{\prime} \mid s, a\right)\) is a probability distribution over next states given a previous state and action,
- \(R\left(s, a, s^{\prime}\right) \rightarrow r\) is a reward function mapping a transition to a scalar reward, \(\rho\) is an initial state distribution and
- \(\gamma\) is a discount factor.
A policy \(\pi\) acting in the environment produces a trajectory \(\tau=\left\{s_1, a_1, \ldots, s_H, a_H\right\}\) for an episode with horizon \(H\).
Since actions in the trajectory are sampled from a policy, we can then define the RL problem as finding a policy \(\pi\) that maximizes expected returns in the environment, i.e. \[ \pi^{\star}=\arg \max _\pi \mathbb{E}_{\tau \sim \pi}[R(\tau)] . \]
We seek to learn policies that can transfer to any MDP within a family of MDPs. This can be formalized as a Contextual MDP [51], where observations, dynamics and rewards can vary given a context. In this paper we consider settings where only the reward varies, thus, if the test-time context is unknown at training time we must collect data that sufficiently covers the space of possible reward functions.
Finally, to facilitate scalability, we operate in the deployment efficient paradigm [67], whereby policy learning and exploration are completely separate, and during a given deployment, we gather a large quantity of data without further policy retraining (c.f. online approaches like DER [112], which take multiple gradient steps per exploration timestep in the real environment). Taken together, we consider the reward-free deployment efficiency problem.
This differs from previous work as follows: 1) unlike previous deployment efficiency work, our exploration is task agnostic; 2) unlike previous reward-free RL work, we cannot update our exploration policy \(\pi_{\mathrm{EXP}}\) during deployment. Thus, the focus of our work is on how to train \(\pi_{\mathrm{EXP}}\) offline such that it gathers heterogeneous and informative data which facilitate zero-shot transfer to unknown tasks.
In this paper we make use of model-based RL (MBRL), where the goal is to learn a model of the environment (or world model) and then use it to subsequently train policies to solve downstream tasks. To do this, the world model needs to approximate both \(P\) and \(R\).
Typically, the model will be a neural network, parameterized by \(\psi\), hence we denote the approximate dynamics and reward functions as \(P_\psi\) and \(R_\psi\), which produces a new "imaginary" \(\operatorname{MDP}, \mathcal{M}_\psi=\left(\mathcal{S}, \mathcal{A}, P_\psi, R_\psi, \rho\right)\).
We focus on Dyna-style MBRL [104], whereby we train a policy \(\left(\pi_\theta\right.\) parameterized by \(\theta\) ) with model-free RL solely using "imagined" transitions inside \(\mathcal{M}_\psi\). Furthermore, we can train the policy on a single GPU with parallelized rollouts since the simulator is a neural network [54]. The general form of all methods in this paper is shown in Algorithm 1, with the key difference being step 5: We aim to update \(\pi_{\mathrm{EXP}}\) in the new imaginary MDP \(\mathcal{M}_\psi\) such that it continues to collect a large, diverse quantity of reward-free data. Note that \(\pi_{\mathrm{EXP}}\) need not be a single policy, but could also refer to a collection of policies that we can deploy (either in parallel or in series), such that \(\pi \in \pi_{\mathrm{EXP}}\).
Experiments
Problems
Reinforcement learning has enabled computers to solve individual tasks through interaction, such as surpassing humans in the games of Go and Dota.
Conclusions
DreamerV3:
World models carry the potential for substantial transfer between tasks. Therefore, we see training larger models to solve multiple tasks across overlapping domains as a promising direction for future investigations.
behave well across a wide range of domains with fixed hyperparameters. These domains include continuous and discrete actions, visual and low-dimensional inputs, 2D(Atari games ) and 3D(DMLab and Minecraft) worlds.
with fixed hyperparameters, outperforming specialized algorithms.
We observe favorable scaling properties of DreamerV3,