mlnext.score.kl_divergence#

mlnext.score.kl_divergence(mean: ndarray, log_var: ndarray, prior_mean: float = 0.0, prior_std: float = 1.0) ndarray[source]#

Calculates the kl divergence kld(q||p) between a normal gaussian p (prior_mean, prior_std) and a normal distribution q parameterized by mean and log_var.

Parameters:
  • mean (np.ndarray) – Mean of q.

  • log_var (np.ndarray) – Log variance of q.

  • prior_mean (float) – Mean of the prior p. Defaults to 0.0.

  • prior_std (float) – Standard deviation of the prior p. Defaults to 1.0.

Returns:

Returns the kl divergence between two normal distributions.

Return type:

np.ndarray