Quantizing Probability Distributions by minimizing MMD distance

import torch


def mmd(x_nd, y_nd, kernel):
    """
    Computes the Emperical Maximum Mean Discrepancy (MMD) between two sets of points.
    """