The Wasserstein RBM demo code is provided (1) to reproduce the main results of the NIPS 2016 paper:
G Montavon, KR Müller, M Cuturi. Wasserstein Training of Restricted Boltzmann Machines
Advances in Neural Information Processing Systems (NIPS), 2016
and (2) to serve as a basic implementation for further research and investigation.
This RBM implentation is written in Python and requires the
matplotlib libraries. Because it operates in batch mode, training is slower than standard RBM implementations that minimize the usual KL divergence. For large-scale applications, other implementations should therefore be considered.
The Wasserstein RBM minimizes the smooth Wasserstein distance between the empirical data distribution and the RBM's generated data. It performs a number of iterations of gradient descent, where the Wasserstein distance gradient is computed at each iteration.
The demo code applies to downscaled MNIST digits, and generates digits that are less noisy but also less diverse than those learned by a standard RBM. Training the model takes a few hours.
|Data||Standard RBM||Wasserstein RBM
The demo code consists of the following four files
data.pyloads the dataset.
modules.pyimplements the Wasserstein distance computation and the RBM.
train.pytrains a Wasserstein RBM and stores it.
test.pyvisualizes the final state of the persistent chains of the trained Wasserstein RBM.
and can be downloaded as a zip archive:
The code can be easily adapted to learn from any distribution of binary vectors, although some parameter tuning might be needed, e.g. for the learning rate, regularization strengths, and the Wasserstein smoothing parameter.