mlnext.score.kl_divergence

mlnext.score.kl_divergence(mean: ndarray, log_var: ndarray, prior_mean: float = 0.0, prior_std: float = 1.0) ndarray[source]

Calculates the kl divergence kld(q||p) between a normal gaussian p (prior_mean, prior_std) and a normal distribution q parameterized by mean and log_var.

Parameters:
  • mean (np.ndarray) – Mean of q.

  • log_var (np.ndarray) – Log variance of q.

  • prior_mean (float) – Mean of the prior p. Defaults to 0.0.

  • prior_std (float) – Standard deviation of the prior p. Defaults to 1.0.

Returns:

Returns the kl divergence between two normal distributions.

Return type:

np.ndarray

Example

>>> import numpy as np
>>> from mlnext import kl_divergence
>>> kl_divergence(
>>>     mean=np.array([1, 0.8, 0.12]),
>>>     log_var=np.log(np.array([0.1, 0.2, 0.8])**2),
>>> )
array([2.30758509, 1.44943791, 0.05034355])