Wasserstein RBM - demo code

The Wasserstein RBM demo code is provided (1) to reproduce the main results of the NIPS 2016 paper:

G Montavon, KR Müller, M Cuturi. Wasserstein Training of Restricted Boltzmann Machines
Advances in Neural Information Processing Systems (NIPS), 2016

and (2) to serve as a basic implementation for further research and investigation.

This RBM implentation is written in Python and requires the numpy, scipy and matplotlib libraries. Because it operates in batch mode, training is slower than standard RBM implementations that minimize the usual KL divergence. For large-scale applications, other implementations should therefore be considered.

The Wasserstein RBM minimizes the smooth Wasserstein distance between the empirical data distribution and the RBM's generated data. It performs a number of iterations of gradient descent, where the Wasserstein distance gradient is computed at each iteration.

The demo code applies to downscaled MNIST digits, and generates digits that are less noisy but also less diverse than those learned by a standard RBM. Training the model takes a few hours.

Data
Standard RBM
Wasserstein RBM

The demo code consists of the following four files

and can be downloaded as a zip archive:

The code can be easily adapted to learn from any distribution of binary vectors, although some parameter tuning might be needed, e.g. for the learning rate, regularization strengths, and the Wasserstein smoothing parameter.