Skip to content

dirmeier/ebm

Repository files navigation

ebm

ci

About

This repository implements an energy-based model (EBM) trained using noise-contrastive estimaton. The implementations are done using Flax NNX.

To sample from the learned distribution, a conventional Langevin sampler is used. For training, the noise distribution is constructed from an exponential moving average of the original model’s weights and not a simple distribution like a Gaussian. This is not particularly efficient during training, since one has to generate samples from the EBM for each training step. But in this way we can create a high-quality noise distribution such that the classification task is not too easy, and the model actually learns the structure of the data. To make the noise distribution not match the distribution of the EBM exactly, we sample only using few steps (using the Langevin sampler).

Example

You can find a minimal experiment in experiments/eight_gaussians where a generative model is trained that can generate samples from a Gaussian mixture with eight components (eight Gaussians). To run the example, first download the latest release and install all dependencies via:

wget -qO- https://github.com/dirmeier/ebm/archive/refs/tags/<TAG>.tar.gz | tar zxvf -
uv sync --all-groups

To train a model and make visualizations, call:

cd experiments/eight_gaussians
python main.py

Below are the synthetic samples from an EBM trained using the hyperparameters defined in experiments/eight_gaussians/config.py.

To make sure that the noise model is not "too good", we also show the noise model samples after each training step. It is clear that the data distribution is well covered, but the model is a "weaker" version of the EBM.

Installation

To install the latest GitHub , just call the following on the command line:

pip install git+https://github.com/dirmeier/ebm@<TAG>

Author

Simon Dirmeier simd23 @ pm dot me

About

An energy-based model implementation in Flax/NNX

Topics

Resources

License

Stars

Watchers

Forks

Languages