import torch def mmd(x_nd, y_nd, kernel): """ Computes the Emperical Maximum Mean Discrepancy (MMD) between two sets of points. """