Metrics usage
This example shows how to use metrics to measure the quality of the generated samples. The metrics take two samples (real and generated) and estimate discrepancies between their distributions. The smaller metrics values the better.
Generate two samples
Let's generate two samples from multivariate normal distributions with the same covariance matrices, but with some distance between their means.
import matplotlib.pyplot as plt
import numpy as np
from probaforms import metrics
def gen_two_samples(dist, N):
sigma = np.array([[1, 0.7], [0.7, 1]])
mu_x = np.array([0, 0])
mu_y = mu_x + dist/np.sqrt(2)
X = np.random.multivariate_normal(mu_x, sigma, N)
Y = np.random.multivariate_normal(mu_y, sigma, N)
return X, Y
# generate two samples with a size of 1000
dist = 2
X, Y = gen_two_samples(dist, N=1000)
def plot_samples(X, Y, dist=0):
plt.figure(figsize=(6, 4))
plt.scatter(X[:, 0], X[:, 1], label='X', alpha=0.5, color='C0')
plt.scatter(Y[:, 0], Y[:, 1], label='Y', alpha=0.5, color='C1')
plt.title("Distance = %.f" % (dist))
plt.legend()
plt.tight_layout()
plt.show()
plot_samples(X, Y, dist)
Compute discrepancies between the samples
def get_metrics(X, Y):
mu, sigma = metrics.frechet_distance(X, Y)
print(r"Frechet Distance = %.6f +- %.6f" % (mu, sigma))
mu, sigma = metrics.kolmogorov_smirnov_1d(X, Y)
print(r"Kolmogorov-Smirnov = %.6f +- %.6f" % (mu, sigma))
mu, sigma = metrics.cramer_von_mises_1d(X, Y)
print(r"Cramer-von Mises = %.6f +- %.6f" % (mu, sigma))
mu, sigma = metrics.anderson_darling_1d(X, Y)
print(r"Anderson-Darling = %.6f +- %.6f" % (mu, sigma))
mu, sigma = metrics.roc_auc_score_1d(X, Y)
print(r"ROC AUC = %.6f +- %.6f" % (mu, sigma))
mu, sigma = metrics.kullback_leibler_1d_kde(X, Y)
print(r"Kullback-Leibler KDE = %.6f +- %.6f" % (mu, sigma))
mu, sigma = metrics.jensen_shannon_1d_kde(X, Y)
print(r"Jensen-Shannon KDE = %.6f +- %.6f" % (mu, sigma))
mu, sigma = metrics.maximum_mean_discrepancy(X, Y)
print(r"Maximum Mean Discrepancy = %.6f +- %.6f" % (mu, sigma))
get_metrics(X, Y)
Additional experiments with other distances
dist = 10.
X, Y = gen_two_samples(dist, N=1000)
plot_samples(X, Y, dist)
get_metrics(X, Y)
dist = 0.
X, Y = gen_two_samples(dist, N=1000)
plot_samples(X, Y, dist)
get_metrics(X, Y)