Score Matching, Toy Datasets, and JAX
Posted on Mon 31 October 2022
Score matching, toy datasets, and Jax¶
Score matching is a technique to learn (the gradient of) data distributions, with applications in generative modeling and beyond. This post explores score matching on toy datasets.
The method was introduced in 2005 but really came back into the spotlight in 2019 when Song and Hermon used the technique to generate pictures that soon rivaled Generative Adversarial Networks. Score matching is also a core component of diffusion models.
The goal is to approximate a data distribution pd with a model distribution pm. In practice, one usually introduces an unnormalized parametric model ˜pθ and partition function Zθ=∫˜pθ(x)dx (normalization term) such that pm(x)=˜pθ(x)Zθ. Unfortunately, fitting the parameters of pm is intractable in large dimensional spaces because of the partition function that requires integrating over the entire dataset for every set of parameters during the training process.
Methods that circumvent this difficulty include Kernel density estimation which makes strong assumptions on the form of the density estimator, or Variational Autoencoders which also make a normality assumption on the distribution of the latent variables.
Note that instead of learning the data distribution itself, we can try to learn its gradient. Score matching does exactly this: it introduces a so-called "score function" sm (with a set of parameters θ, typically the weights in a neural network) which models the gradient of (the log of) the data distribution pd.
sm(x;θ)≜∇xlogpm(x;θ)=∇xlog˜pm(x;θ)where ˜pm is the unnormalized density ˜pm(x,θ)=Zθpm(x,θ).
Note that the presence of the log doesn't steer us away from the initial goal since
∇xlogp(x)=1p(x)∇xp(x),i.e., the gradient of the log of the density is the gradient of the density, weighted by the inverse density.
Optimizing the score function is equivalent to minimizing the Fisher divergence between our model sm and the real score function sd that we're trying to learn.
L(θ)≜12Epd[‖sm(x;θ)−sd(x)‖22],Hyvärinen (2005) showed that the Fisher divergence L(θ) is equal to a loss function J defined as
J(θ)≜Epd[tr(∇xsm(x;θ))+12‖sm(x;θ)‖22],up to an additive constant. In other words, for samples from the distribution pd, the score function should have minimal gradient trace (i.e., be constant around sample coordinates) and norm (i.e., be as small as possible). This loss function is still not tractable in large dimensional spaces, but we can compute it in smaller dimensions. For larger dimensions, tricks (like Hutchinson's) are needed to evaluate the gradient trace (more on this later).
I'll be using the JAX deep learning framework, the latest cool kid on the neural networks' block. Compared to PyTorch and TensorFlow, JAX is interesting because it has a functional spin and an interface very similar to Numpy. But let's get back to score matching.
!pip install --upgrade --quiet pip "jax[cpu]"
Illustrative example: the ring dataset¶
To give a concrete example of score matching let's use a synthetic data distribution centered around a ring, with radial normal noise and uniform angular distribution.
import numpy as np
from scipy.stats import norm
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
# Matplotlib parameters
plt.rcParams["figure.figsize"] = (10, 10)
plt.rcParams['image.cmap'] = 'RdBu'
bright_colormap = ListedColormap(['#FF0000', '#0000FF'])
class RingDistribution:
"""Two-dimensional probability distribution centered around a ring, with
radial normal noise and uniform angular distribution.
"""
def __init__(self, radius, std):
self.mean = radius
self.std = std
def sample(self, n_samples, supervised=False, seed=None):
"""Return an array of samples from the distribution."""
np.random.seed(seed)
r = np.random.normal(loc=self.mean, scale=self.std, size=n_samples)
theta = np.random.uniform(low=-np.pi, high=np.pi, size=n_samples)
x1 = r*np.cos(theta)
x2 = r*np.sin(theta)
x = np.array([x1, x2]).T
return x
def pdf(self, x):
"""Probability density function."""
r = np.sqrt((x**2).sum(axis=1))
return norm.pdf(r, loc=self.mean, scale=self.std)
def scores(self, x):
"""Gradient of the log of the PDF."""
r = np.sqrt((x**2).sum(axis=1))
scores = (x * (RADIUS/r - 1).reshape(-1, 1))
return scores
RADIUS = 1
STD = 0.1
distribution = RingDistribution(radius=RADIUS, std=STD)
X = distribution.sample(n_samples=10000)
For this toy distribution, the score function can be computed analytically, it reads:
s(x)=∇xlogp(x)=∇xlog(√12πexp(−(r−R)22))=(Rr−1)xThe fact that the score function can be computed analytically allows us to visualize the score function that we're trying to learn, as a vector field.
# Create a grid of points for plots.
size = RADIUS+5*STD
step = 0.1
x_grid, y_grid = np.meshgrid(
np.arange(-size, size+step, step),
np.arange(-size, size+step, step)
)
grid_points = np.c_[x_grid.ravel(), y_grid.ravel()]
f, ax = plt.subplots()
ax.scatter(X[:, 0], X[:, 1], edgecolors='k')
scores = distribution.scores(grid_points)
scores[int(len(scores)/2)] = [0., 0.] # Center singularity
ax.quiver(x_grid, y_grid, scores[:, 0], scores[:, 1])
ax.set_title("Samples (dots) and true score function (vectors)");
As expected, the score function is rotation-invariant and points to the ring, with smaller scores near the data distribution, and larger ones far from it.
The score is larger when moving away from the data distribution because of the log term in the score definition: since sm(x)=∇xlog(p(x))=(∇xp(x))/p(x), the score becomes large in regions of low probability. This property is useful for generative modeling, since the score points to the data distribution more forcefully in regions of the feature space that are "out of distribution".
Learning the score function with JAX¶
To try to learn the score function introduced above, we can use a deep learning model, the loss function J(θ)≜Epd[tr(∇xsm(x;θ))+12‖sm(x;θ)‖22], introduced above, and our good old gradient descent will do the rest.
In the following we define a few util functions in JAX then initialize a Multilayer Perceptron (MLP) to model the score function. The output of the MLP has the same dimension as the input since the output is a vector in the original feature space.
import jax
from jax import numpy as jnp
from jax import random
from jax import jacfwd
random_num_gen = jax.random.PRNGKey(0)
def dense_init(in_features, out_features, random_num_gen,
kernel_init=jax.nn.initializers.lecun_normal(),
bias_init=jax.nn.initializers.zeros):
"""Initialize the weights of a dense layer."""
k1, k2 = random.split(random_num_gen)
kernel = kernel_init(k1, (in_features, out_features))
bias = bias_init(k2, (out_features,))
return kernel, bias
def init_network_params(sizes, key):
"""Initialize all layers for a fully-connected neural network with sizes "sizes"."""
keys = random.split(key, len(sizes))
return [dense_init(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]
# Forward pass through the neural network
def compute_score(params, x):
"""MLP for score computation, the input and output have the same dimensions."""
activations = x.T
for w, b in params[:-1]:
outputs = jnp.dot(w.T, activations) + b
activations = jnp.tanh(outputs)
final_w, final_b = params[-1]
return jnp.dot(final_w.T, activations.T) + final_b
@jax.jit
def batched_compute_score(params, x_batched):
"""Compute the score for a batch."""
f = lambda x: compute_score(params, x)
return jax.vmap(f)(x_batched)
def exact_matching_loss(params, inputs):
"""Compute the exact score matching loss function written in
the form presented by Hyvärinen (2005)."""
f = lambda x: compute_score(params, x)
score = f(inputs)
Jacobian = jacfwd(f)(inputs)
return jnp.trace(Jacobian) + jnp.linalg.norm(score)/2 + jnp.linalg.norm(score)
@jax.jit
def batched_loss_computation(params, x_batched):
"""Compute the loss for a batch."""
f = lambda x: exact_matching_loss(params, x)
return jnp.mean(jax.vmap(f)(x_batched))
Model training¶
Let's write a function to generate batches, then perform the usual gradient descent against the loss funciton defined above.
import numpy.random as npr
from jax.example_libraries.optimizers import adam
batch_size = 512
num_train = X.shape[0]
num_complete_batches, leftover = divmod(num_train, batch_size)
num_batches = num_complete_batches + bool(leftover)
np_random_num_gen = npr.RandomState(0)
def data_stream():
perm = np_random_num_gen.permutation(num_train)
for i in range(num_batches):
batch_idx = perm[i * batch_size:(i + 1) * batch_size]
yield X[batch_idx]
batches = data_stream()
@jax.jit
def update(params, x, opt_state):
""" Compute the gradient for a batch and update the parameters."""
value, grads = jax.value_and_grad(batched_loss_computation)(params, x)
opt_state = opt_update(0, grads, opt_state)
return get_params(opt_state), opt_state, value
# Model architecture
layers_sizes = [2, 64, 128, 64, 2]
params = init_network_params(layers_sizes, random_num_gen)
# Defining an optimizer in Jax
opt_init, opt_update, get_params = adam(step_size=1e-4)
opt_state = opt_init(params)
num_epochs = 1000
print(f"Initial loss: {batched_loss_computation(params, X):.3f}")
for epoch in range(num_epochs):
batches = data_stream()
for i, x_batch in enumerate(batches):
params, opt_state, loss = update(params, x_batch, opt_state)
if epoch % 100 == 0:
full_loss = batched_loss_computation(params, X)
print(f"End of epoch {epoch}: loss={full_loss:.3f}")
Note that the loss can become negative because it is only equal to the Fisher divergence (which is positive) up to an additive term.
Let's look at the score function learned by the model.
f, ax = plt.subplots()
ax.scatter(X[:, 0], X[:, 1], cmap=bright_colormap, edgecolors='k')
scores = np.array(batched_compute_score(params, grid_points))
ax.quiver(x_grid, y_grid, scores[:, 0], scores[:, 1]);
Not perfect, but not too bad. If you were to generate samples by:
- starting from a random location of the feature space
- following the learned score function,
then you would more or less sample from the original distribution.
In practice, one can add a bit of noise to the process that I just described to avoid being stuck in stationary points of the score function. The corresponding sampling process is called Langevin sampling.
Scaling with Sliced Score Matching¶
When trying the above learning process in higher dimensional spaces (e.g., for images), one faces a challenge posed by the trace term tr(∇xsm(x;θ)) in the score matching loss function. A naive approach to computing the gradients of this term for the gradient descent requires as many backward passes as there are dimensions, which can be a lot. Fortunately there are several ways out of this problem. Sliced Score Matching is one of them.
Sliced score matching uses the Hutchinson trick to compute the trace approximately with a stochastic sampling technique. The idea is to sample random vectors v with zero mean unit and unit variance and use the identity tr(A)=Epv[vTAv].
For each datapoint xi we draw M random projection vectors vij, such that the (tractable) estimated loss function reads:
J(θ)≃1NN∑i=1[1MM∑j=1vTij(∇xsm(xi;θ))vij+12‖sm(xi;θ)‖22].More on this in a future post!