Variational Predictive Routing with Nested Subjective Timescales

Hierarchical generative model for sequential data

*Equal contribution

Paper Presentation Cite


Abstract

Discovery and learning of an underlying spatiotemporal hierarchy in sequential data is an important topic for machine learning. Despite this, little work has been done to explore hierarchical generative models that can flexibly adapt their layerwise representations in response to datasets with different temporal dynamics. Here, we present Variational Predictive Routing (VPR) – a neural probabilistic inference system that organizes latent representations of video features in a temporal hierarchy, based on their rates of change, thus modeling continuous data as a hierarchical renewal process. By employing an event detection mechanism that relies solely on the system’s latent representations (without the need of a separate model), VPR is able to dynamically adjust its internal state following changes in the observed features, promoting an optimal organisation of representations across the levels of the model’s latent hierarchy.


Architecture

Building blocks of VPR

VPR consists of computational blocks that are stacked together in a hierarchy. Figure on the right shows the insides of a single block with key variables and channels of information flow. Each block in layer \(n\) consists of three deterministic variables \((x^n_t, c^n_t, d^n_t)\) that represent the three channels of communication between the blocks: bottom-up (encoding), top-down (decoding), and temporal (transitioning), respectively. These variables are used to parametrise a random variable \(s^n_t\) that contains the learned representations in a given level of the hierarchy. Figure below shows an example of a three-level VPR model unrolled over five timesteps, demonstrating how communications between the blocks are realised over time.


Model

Example of a three-level VPR model unrolled over five timesteps.


Event boundary detection: enforcing temporal hierarchy

You may have noticed that the unrolled model updates its states at different times and with different rates - how does the model decide on that?

At the heart of VPR is a mechanism that is used for detecting predictable and unpredictable changes in the observable features over time, and thus determine the boundaries of events. At every level of the model, the event detection is used for controlling the structure of the unrolled model over time by allowing or disallowing propagation of bottom-up information to the higher levels of the hierarchy.

Event detection system serves two primary functions in our model:

  1. Detecting layerwise events in a sequence of observations is utilised for triggering an update on a block’s state by inferring its new posterior state. Matching block updates with detectable changes in layerwise features (event boundaries) prompts VPR to represent spatiotemporal features in levels that most closely mimic their rate of change over time. Similarly, learning to transition between states only when they signify a change in the features of the data means that VPR learns to make time-agnostic (or jumpy) transitions – from one event boundary to another.
  2. Second, the detection mechanism is used for blocking bottom-up communication and thus stopping the propagation of new information to the deeper levels of the hierarchy. This is meant to encourage the model to better organise its spatiotemporal representations by enforcing a temporal hierarchy onto its generative process.

Model-GIF

Short demonstration of how VPR works using the Moving Ball dataset. Notice how bottom-up signal propagates up only when the ball’s colour changes (i.e. an event boundary has been detected).


Experiments and results

Datasets: 3DSD and Moving Ball. 3D Shapes Dynamic (3DSD) includes three colour features that change with different periodicity. Moving Ball includes a moving ball that changes colour upon a wall bounce or at a random time.



My image

Layerwise rollouts using VPR. GT denotes ground-truth sequence, L1 rollouts made using level 1, and so on. To produce layerwise rollouts, the model predicts the next state \(s^n_{\tau+1}\) in the relevant level \(n\) and decodes under fixed states in all other levels. The produced rollouts illustrate model’s ability to learn disentangled representations and produce accurate and feature-specific jumpy rollouts. For Moving Ball, VPR learns to represent the ball’s position and colour in the two separate levels (L1 and L2, respectively). Similar behaviour can be observed using a three-level VPR and the 3D Shapes dataset. VPR learns to represent the three temporal factors of variation across the three levels of its latent hierarchy and to produce jumpy rollouts that correctly predict the changes in the corresponding features.


My image

Random samples taken from the different levels of VPR. The model is able to generate diverse images with respect to the spatiotemporal features represented in the sampled level, while keeping all other features fixed.


Cite us!

@inproceedings{VPR2022,
  title={Variational Predictive Routing with Nested Subjective Timescales},
  author={Alexey Zakharov and Qinghai Guo and Zafeirios Fountas},
  booktitle={International Conference on Learning Representations},
  year={2022},
  url={https://openreview.net/forum?id=JxFgJbZ-wft}
}