1 of 21

Unsupervised State Representation Learning in Atari

Ankesh Anand*, Evan Racah*, Sherjil Ozair*,

Yoshua Bengio, Marc-Alexandre Côté, R Devon Hjelm

2 of 21

Key Points

  • A state representation learning algorithm to learn high-level concepts in a scene
    • without labels or rewards
    • without modelling pixels directly.
  • The Atari Annotated RAM Interface (AtariARI): A benchmark to systematically evaluate state representations.

3 of 21

State Representation Learning

Goal: Encode high-dimensional obs. to a latent space that captures underlying generative factors of an environment

  • Allow agents to learn to act in environments with fewer interactions.

  • Effectively transfer knowledge across different tasks in the environment

4 of 21

Supervised -> Self-supervised / Unsupervised Learning

  • Can’t rely on direct supervision given the high dimensional nature of problems.
    • The marginal cost of acquiring labels in RL is much higher.

  • The underlying data has a much richer structure than what sparse labels / rewards could provide.
    • Sparse signals -> sample inefficiency.

  • Leads to task-specific policies, rather than knowledge that could be repurposed.

  • Human learning is largely unsupervised.
    • The Scientist in the Crib: What Early Learning Tells Us About the Mind

(Alison Gopnik, Andrew N. Meltzoff and Patricia K. Kuhl, 1999)

    • The Development of Embodied Cognition: Six Lessons from Babies

(Linda Smith and Michael Gasser, 2005)

5 of 21

Illustrative Example

Representation Learning in humans doesn’t seem to be operating in the pixel space.

From memory

From reference

One dollar bill

Epstein et. al (2016)

6 of 21

Contrastive Unsupervised Representation Learning

  • Goal: Learn an embedding function such that:

    • and are similar data points [Positive].
    • is a random data point (and thus presumably dissimilar to ) [Negative].

  • If we use multiple negative samples, we get a lower bound on Mutual Information.

  • To maximize MI, we can compute gradients of this lower bound on MI w.r.t a parametric encoder .

Arora et. al (CURL) ICML’19�Poole et. al (MI Bounds) ICML’19

Hjelm et. al (Deep InfoMax) ICLR’19�Van Den Oord et. al (CPC) 2018

7 of 21

Nature of RL environments

  • Temporal Structure in Data.
    • Not i.i.d sampled.

  • Local consistency. Objects don’t move drastically over single time-steps.

  • Prior work has argued that without auxiliary variables (such as time), recovering underlying latent variables is generally not possible.
    • Hyvarinen et. al (2014)
    • Locatello et. al (2019)

Can we exploit the inherent temporal structure to learn representations?

8 of 21

The contrastive task

  • is randomly sampled from the episode.

  • In practice, we use multiple negative samples.

  • Standard CNN architecture from Minh et. al (2014) [DQN]

Temporal InfoMax

9 of 21

Temporal InfoMax is not enough

  • The encoder can “cheat” and focus on just one factor of variation that’s easy to predict (like the clock).

  • We incorporate a spatial prior to incentivize the encoder to focus on all factors of variation.

Ozair et. al (2019)

10 of 21

Spatio-Temporal DeepInfoMax (ST-DIM)

  • Maximize the temporal MI spatially across each local feature map.

  • Each feature map has a receptive field corresponding to the size of the full of the observation.

11 of 21

Evaluating Representations

  • Evaluating representations is hard.

  • Performance on a single downstream task (e.g. control with a single reward function, next-frame prediction)
    • Might not measure all useful things a representation should capture

  • More principled approach: measure ability of a representation at capturing all high-level factors.

12 of 21

Atari Annotated Ram Interface (AARI)

  • Interface for evaluating state representations in Atari games
  • Gym wrapper exposes 308 total semantic labels, for 22 Atari games
  • Evaluation using Linear Probing.

13 of 21

Categorization of State Variables

Agent Localization

Small Object Localization

Other Localization

Score/Clock/Lives/Display

Miscellaneous

facing direction

brick existence (binary)

14 of 21

State Variable Breakdown

  • 22 Total Games

  • 308 Total State Variables

15 of 21

Evaluation Using Probing

  • We focus on measuring “explicitness”
    • to what extent true state can be recovered from learned representation using linear transformation

Alain et. al (2017)

  1. Freeze encoder’s weights,
  2. Train a linear classifier on top of each representation

16 of 21

Training Details

  • Unsupervised training for 100K frames before probing
  • Probe train-val-test -> 50K frames
  • Two data collection modes:
    • Random policy
    • Expert PPO policy
  • Probing-Prune state variables that don’t vary in value very much (entropy pruning)
  • Different linear classifier for each state variable
  • F1 Score to account for label imbalance

17 of 21

Baselines

  • Majority Classifier
  • Random-CNN
  • VAE
  • Pixel-Prediction
  • CPC (Contrastive Predictive Coding)
  • Fully Supervised (Upper Bound)

18 of 21

Results

19 of 21

Categorical Breakdown

  • ST-DIM
    • excels at capturing small-objects (tough for reconstruction based methods).
    • robust to easy-to-exploit features.

20 of 21

Easy-to-Exploit Features

  • Contrastive losses can fail to capture all salient factors

  • Especially if one factor is very easy/predictable
    • E.g. clock in Boxing

  • Good litmus test: can the representation capture more than just the clock?

21 of 21

Future Directions

  • Learning “Abstract” world models.

  • Sample efficient downstream RL models.

  • Do contrastive representations generalize better?

  • What can we get with large scale unsupervised training?