Disentangled Representations in Reinforcement Learning
Real-world environments are diverse and unpredictable, so Reinforcement Learning (RL) agents need to be robust to environment changes and adapt quickly. However, RL agents that learn from image observations often fail to generalise to unseen changes in the environment, such as colours and object positions, because they tend to overfit to variations seen in training. An open problem in RL is to learn a more robust representation of images that generalises well. Disentanglement is a promising direction towards robust representations in unsupervised learning and has the potential to improve RL generalisation. In this blog post we will discuss disentangled representations for RL and how they have been used in recent work to improve generalisation. We will cover the following:
- What is disentanglement?
- Why disentanglement in RL?
- Challenges with disentanglement in RL
- What has been done so far?
- Open problems
What is disentanglement?
Disentanglement is an approach for learning low-dimensional representations of images that separate distinct, informative factors of variation in an image into the unknown ground truth factors that generated the image. For example, the 3D Shapes dataset [1], pictured below, is commonly used to test (un)supervised techniques to disentanglement.
The images are generated from 6 factors of variations - floor colour, wall colour, object colour, scale, shape and orientation. These factors of variation, while fairly obvious to us as humans, are the unknown ground truth factors that are used to generate the images. The agent, who does not have knowledge of these ground truth factors of variation, must use the images to learn a disentangled representation that separates out these factors into distinct features in its lower-dimensional representation.
It is common to assume that the ground truth factors of variation are independent, e.g. object shape and colour vary independently to each other in the images. So disentanglement approaches usually learn a lower dimensional representation that captures the information in the image with independent features or groups of features. A representation feature refers to a dimension in the representation vector. Individual approaches vary in the other assumptions and access to privileged information such as some labels or image groupings.
Ultimately, the goal of disentanglement is to learn a representation that is more robust to changes in the images. If only one factor of variation changes, e.g. object colour in the example above, then only one corresponding (group of) features in the representation will change to something new. The agent still has a reliable representation for all other factors of variation, resulting in improved generalisation when this representation is used for downstream tasks (such as image classification). A disentangled representation can also offer improved explainability as semantically meaningful factors of variation are separated into distinct representation features.
Why disentanglement in RL?
Generalisation is especially problematic for RL agents that learn from image pixel observations. A change in just one environment variable, such as the background colour, can change many pixels in the image, which can then change the entire representation of the image that the agent learned in training. This means a policy based on this representation is no longer optimal (i.e. can no longer solve the task) after a change in the environment. For example, a change in lighting conditions can change the perceived colour of an object, but we do not want this to affect the agent’s ability to perform a task.
These visual changes can affect agent performance so significantly that re-training from scratch after a change in the environment is often a necessary choice in practice. In an attempt to avoid re-training from scratch, domain randomisation is a popular approach where the environment factors of variation are randomly selected, maximising the variations seen by the agent during training. But it is hard to know what the agent will need to generalise to in future to ensure it is included in the training set, thereby complicating the randomisation strategy. Moreover, domain randomisation techniques may not be possible when training RL agents in the real world.
A disentangled representation isolates changes in factors of variation to specific portions of the representation vector. For example, if an agent sees a new colour, then only that feature in the representation would be expected to change to something new. The rest of the representation will remain values that the agent is already familiar with and can still be relied upon for the policy. This makes learning disentangled representations a promising direction to improve generalisation in image-based RL.
Challenges with disentanglement in RL
The disentanglement literature is mainly focused on unsupervised and semi-supervised learning problems, such identifying objects in images in the 3D shapes example above. Many of the early approaches are extensions of the variational autoencoder (VAE) and assume independent and identically distributed (i.i.d.) factors of variation. But the factors of variation in RL images are not i.i.d. At the very least, they evolve over time as the agent learns and acts in the environment. In real-world scenarios, independence could also be broken by spurious correlations in the images, and causal relationships between objects (e.g. objects colliding, agents learning to reach a goal position). We will discuss these different types of correlation in more detail as we introduce existing approaches in the next section.
Locatello et al. (2019) [2] proved that it is theoretically impossible to disentangle the ground truth factors of variation from i.i.d. data alone, so many recent approaches use some form of supervision, such as labelling or pairs of images. In RL, it is difficult to get this supervision as the agent explores the environment, so labelling cannot be done in advance. But now we can turn the problem of non-i.i.d. data in RL into a solution to overcome the impossibility result, which is exactly the approach taken by some of the existing works we explain below.
What has been done so far?
We will outline three approaches that have been developed for disentangled representations in RL to improve generalisation to previously unseen changes in the environment. To the best of our knowledge, these are the only papers to focus on this topic so far. The final two of the three approaches we will discuss are our work.
New/update: Multi-view disentanglement has now been successfully applied to RL with multiple cameras in the new RLC 2024 paper Multi-view Disentanglement for Reinforcement Learning with Multiple Cameras.
Disentangled Representation Learning Agent (DARLA)
DisentAngled Representation Learning Agent (DARLA) [3] is a multi-stage RL agent that first "learns to see" by learning disentangled representations of images, then "learns to act" by learning a policy based on the disentangled representations.
In the learning to see phase of training, DARLA uses a \(\beta\)-VAE [4]. The \(\beta\)-VAE was developed for unsupervised learning to learn disentangled representations from i.i.d. images. It extends the VAE by introducing the \(\beta\) hyperparameter to control the weight of the independence constraint in the VAE loss. Increasing \(\beta\) encourages independence between features, thereby producing a disentangled representation, at the expense of image reconstruction (the other VAE loss term). We won't go into the details of the \(\beta\)-VAE here as there are many other papers and blog posts that cover this in detail.
DARLA training is done as a multi-stage process to meet the i.i.d. requirements for learning disentangled representations with the \(\beta\)-VAE. First, the \(\beta\)-VAE is trained using data from a pre-trained policy designed to collect i.i.d. images. Then a standard RL agent, equipped with the pre-trained \(\beta\)-VAE, learns a policy by acting the environment.
Temporal Disentanglement (TED)
More recently, we propose Temporal Disentanglement (TED) [5], which avoids the need to pre-train on an encoder on i.i.d. data like DARLA. Instead, TED uses the non-i.i.d. data that is available in RL as the image evolves over time to learn a disentangled representation. This approach allows online learning of the disentangled representation at the same time as learning the RL policy.
TED is a self-supervised auxiliary task that can be combined with existing RL algorithms. It is designed to encourage the image encoder to disentangle the temporal structure determined by observations at consecutive timesteps, such that a classifier can discriminate between temporal and non-temporal pairs of representations. The classifier is designed to compare each feature in the representation separately. This structure encourages disentanglement in representation because the encoder must learn independent features that can be classified without the knowledge of the other features.
The results of two of the experiments from the TED paper are shown below, with RAD [6] and SVEA [7] as different base RL algorithms using the TED auxiliary task. The RL agent is trained on a limited set of colours up to the vertical dotted line on the x-axis, then tested on an unseen set of colours with continued learning to assess adaptation.
The results show that TED (blue) has faster adaptation to unseen environment changes than the base RL algorithm (orange) and other baselines. In the cartpole experiment (left), TED also performs better than the privileged domain randomisation baseline (green) which assumes the test set of colours is known in advance and used during training. The code for TED along with instructions is available on github.
Conditional Mutual Information for Disentanglement (CMID)
Most disentanglement techniques assume independent factors of variation to learn independent features in the representation. These techniques are usually based on some way to minimise the mutual information between features in the representation to enforce independence. This is also an assumption made with the DARLA and TED approaches above. However, real-world environments can contain spurious correlations between features that are unknown or unintended. This can lead to agents encoding these misleading correlations into their latent representation, even when using existing disentanglement techniques, because correlated factors cannot be separated into distinct independent features in the representation. This prevents an agent from generalising if the correlation changes. For example, an autonomous driving agent trained in an environment where aggressive drivers often have green cars can encode this correlation that does not hold in the real world.
We propose Conditional Mutual Information for Disentanglement (CMID) [8] as a new approach that relaxes the assumption of independence between factors of variation to conditional independence to adjust for spurious correlations between features. Based on the causal graph of an MDP, we define a suitable conditioning set that renders the features in the representation conditionally independent given the history of representations and actions. CMID uses this conditioning set to define an auxiliary task to render the representation features conditionally independent. This is achieved by minimising the conditional mutual information between features in the representation given the conditioning set. Intuitively, this means that the other features in the representation do not contain any more information about the \(n^{th}\) feature given the conditioning set is already known.
The CMID experiments use spurious correlations between colour and the morphology of the control object on the DeepMind Control suite. An example is pictured below.
Both the control objects (cartpole A and B in this example) require a different optimal policy, but the control object is strongly correlated with the colour, where \(\rho\) in the diagram indicates the probability of being in that given scenario during training and testing. An agent that encodes the spurious correlation in its representation, relying on colour information to select its action, will fail when the correlation no longer holds. The results of a set of experiments from the CMID paper are shown below. The agent is trained under the train correlation, and at the vertical dotted line on the x-axis the correlation is reversed, and learning continues to assess adaptation.
The results show that CMID improves generalisation under correlation shift compared to the base RL algorithm (which is SVEA in this case). CMID experiences little or no reduction in performance when the correlation changes and even improves the performance of SVEA in the correlated training environment. The code for CMID along with instructions is available on github.
Open problems
Disentanglement is a hot topic in unsupervised learning, but only recently is progress being made to apply this work to RL, so there are still many open problems to address. Some open problems that might provide an interesting direction for future research are outlined below.- Both the TED and CMID work show that disentanglement in RL should consider correlations (across time in TED, or spurious correlations in CMID). However, there is no commonly used definition of disentanglement with correlated factors of variation and there is currently no metric to measure disentanglement of correlated data.
- Factors of variation in RL environments can have causal relationships between them, such as objects colliding. While CMID considers spurious correlations (i.e. no causal relationships), there is still an open problem to consider disentanglement when causal relationships exist between the factors of variation.
- While some progress is being made towards disentanglement in single-agent RL, multi-agent RL remains an open problem where the views of objects from multiple agents could provide useful information to aid the learning of disentangled representations.
References
- Chris Burgess and Hyunjik Kim (2018). "3D shapes dataset". https://github.com/deepmind/3dshapes-dataset/.
- Francesco Locatello, Stefan Bauer, Mario Lucic, Gunnar Ratsch, Sylvain Gelly, Bernhard Schölkopf, and Olivier Bachem (2019). "Challenging Common Assumptions in the Unsupervised Learning of Disentangled Representations". In Proceedings of the 36th International Conference on Machine Learning (ICML 2019).
- Irina Higgins, Arka Pal, Andrei Rusu, Loic Matthey, Christopher Burgess, Alexander Pritzel, Matthew Botvinick, Charles Blundell, and Alexander Lerchner (2017). "DARLA: Improving Zero-shot Transfer in Reinforcement Learning". In Proceedings of the 34th International Conference on Machine Learning (ICML 2017).
- Irina Higgins, Loïc Matthey, Arka Pal, Christopher P. Burgess, Xavier Glorot, Matthew M. Botvinick, Shakir Mohamed, and Alexander Lerchner (2017). "β-VAE: Learning Basic Visual Concepts with a Constrained Variational Framework". In Proceedings of the International Conference on Learning Representations 2017.
- Mhairi Dunion, Trevor McInroe, Kevin Sebastian Luck, Josiah P. Hanna, and Stefano V. Albrecht (2023). "Temporal Disentanglement of Representations for Improved Generalisation in Reinforcement Learning". In Proceedings of the International Conference on Learning Representations 2023.
- Michael Laskin, Kimin Lee, Adam Stooke, Lerrel Pinto, Pieter Abbeel, and Aravind Srinivas (2020). "Reinforcement Learning with Augmented Data". In Proceedings of the 34th Conference on Neural Information Processing Systems (NeurIPS 2020).
- Nicklas Hansen, Hao Su, and Xiaolong Wang (2021). "Stabilizing Deep Q-Learning with ConvNets and Vision Transformers under Data Augmentation". In Proceedings of the 35th Conference on Neural Information Processing Systems (NeurIPS 2021).
- Mhairi Dunion, Trevor McInroe, Kevin Sebastian Luck, Josiah P. Hanna, and Stefano V. Albrecht (2023). "Conditional Mutual Information for Disentangled Representations in Reinforcement Learning". In Proceedings of the 37th Conference on Neural Information Processing Systems (NeurIPS 2023).