Tutorial de MCMC con numpyro
Contenido
22. Tutorial de MCMC con numpyro
¶
A continuación veremos como:
Definir un modelo probabilístico generativo
Obtener una traza con métodos de MCMC
Verificar la convergencia de la cadena
Utilizar el posterior para hacer inferencia predictiva
en base a la librería numpyro
.
import numpy as np
import scipy.stats
import jax.numpy as jnp
import jax.random as random
import numpyro
import numpyro.distributions as dists
import holoviews as hv
hv.extension('bokeh')
hv.opts.defaults(hv.opts.Curve(width=500),
hv.opts.Scatter(width=500),
hv.opts.Image(width=500),
hv.opts.Histogram(width=500))
numpyro.set_host_device_count(2)
print(f"Numpyro version: {numpyro.__version__}")
Numpyro version: 0.10.1
22.1. Regresión logística bayesiana¶
A continuación se genera un set de 30 datos bidimensionales con etiqueta binaria en base a los cuales entrenaremos un regresor logístico bayesiano.
Para generar números pseudo-aleatorios utilizamos el módulo numpyro.distributions
:
true_coef = jnp.array([-0.5, -1, 2])
N = 30
key = random.PRNGKey(1)
x = dists.Normal(0, 1).sample(key, sample_shape=(N, 2))
key, key_ = random.split(key)
y = dists.Bernoulli(logits=true_coef[0] + true_coef[1]*x[:, 0] + true_coef[2]*x[:, 1]).sample(key)
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Que graficamente sería:
p1 = hv.Scatter((x[y==0, 0], x[y==0, 1]), kdims=['x[:, 1]'], vdims=['x[:, 0]'], label='y==0').opts(size=10, marker='d')
p2 = hv.Scatter((x[y==1, 0], x[y==1, 1]), label='y==1').opts(size=10, marker='o')
hv.Overlay([p1, p2]).opts(legend_position='top')
Especificación del modelo
Consideremos los siguientes supuestos
Las variables independientes son determinísticas
Tenemos tres coeficientes: un parámetro para cada variable independiente más la intercepta
Los coeficientes del modelo \(\theta\) tienen un prior Gaussiano con media cero y desviación estándar igual a 10
La variable dependiente es binaria con verosimilitud Bernoulli
Matemáticamente, la relación entre las variables independientes y dependientes en la regresión logística es:
donde:
En numpyro
definimos este modelo como una función que utiliza las siguientes primitivas:
numpyro.sample
para declarar una variable aleatoria. Espera un nombre y una distribución.numpyro.deterministic
para declarar una variable determinística. Espera un nombre y un valor.numpyro.plate
para declarar independencia condicional, por ejemplo contra el dataset. Espera un nombre y un tamaño (size
).
En este caso el modelo sería:
def sigmoid(z):
return 1./(1. + jnp.exp(-z))
def model(x, y=None):
theta_prior = dists.Normal(loc=jnp.zeros(3), scale=10*jnp.ones(3))
theta = numpyro.sample("theta", theta_prior.to_event(1))
with numpyro.plate('data', size=len(x)):
logitp = theta[0] + x[:, 0]*theta[1] + x[:, 1]*theta[2]
p = numpyro.deterministic('p', value=sigmoid(logitp))
numpyro.sample("y", dists.BernoulliLogits(logits=logitp), obs=y)
return p
Nota
La diferencia entre una variable aleatoria observada (como la verosimilitud) y una latente (como el prior) es que para definir la primera utilizamos el argumento obs
.
Podemos verificar que las dimensiones de las variables sean las correctas con numpyro.handlers
:
seeded_model = numpyro.handlers.seed(model, random.PRNGKey(1234))
exec_trace = numpyro.handlers.trace(seeded_model).get_trace(x, y)
print(numpyro.util.format_shapes(exec_trace))
Trace Shapes:
Param Sites:
Sample Sites:
theta dist | 3
value | 3
data plate 30 |
y dist 30 |
value 30 |
Si queremos obtener muestras de las variables del modelo debemos utilizar el handler seed
:
for i in range(3): # Three arbitrary seeds
seeded_model = numpyro.handlers.seed(model, random.PRNGKey(i))
print(i, seeded_model(x))
0 [9.9704093e-01 9.9998522e-01 9.9999821e-01 1.0000000e+00 9.9705923e-01
9.5718265e-01 9.9999094e-01 9.9996459e-01 7.9175276e-01 9.9999988e-01
1.0000000e+00 1.0000000e+00 6.9852918e-01 9.9579394e-01 1.0000000e+00
9.3691295e-01 9.9999714e-01 9.9542749e-01 1.0000000e+00 9.9135768e-01
3.0541473e-06 8.7679946e-01 4.9094800e-02 2.6528134e-03 9.9182957e-01
9.9999988e-01 1.0000000e+00 9.9977142e-01 5.5105127e-03 9.9999714e-01]
1 [9.9966180e-01 1.9972560e-01 3.1399537e-02 5.0219068e-11 9.9472636e-01
9.9999988e-01 9.9996793e-01 1.4277863e-05 1.0000000e+00 9.9963140e-01
9.9912614e-01 5.3104283e-14 1.0000000e+00 1.0000000e+00 9.5890732e-07
9.2436033e-01 2.1963162e-02 9.9999976e-01 2.9915304e-07 1.0000000e+00
1.0000000e+00 1.0000000e+00 1.0000000e+00 1.0000000e+00 1.0000000e+00
9.9929976e-01 4.5712731e-09 1.4080463e-01 1.0000000e+00 4.5782326e-05]
2 [6.65232725e-03 1.95944961e-02 1.08659985e-02 8.70697945e-03
3.49684916e-02 4.65345453e-04 9.44657404e-06 9.46368337e-01
1.45537069e-06 9.89858677e-07 4.39241221e-09 6.51585042e-01
6.80097628e-07 2.60510689e-07 1.52326291e-04 7.55531132e-01
2.02591997e-02 1.02681304e-04 2.45893560e-02 2.58306000e-06
5.17141161e-05 8.80170846e-05 1.26643281e-05 1.28239594e-06
4.82327266e-07 1.38845428e-06 5.08832920e-04 2.25024223e-01
3.78528088e-02 4.86093163e-01]
Aplicación y diagnóstico de MCMC
En numpyro
MCMC se aplica utilizando:
numpyro.infer.MCMC(sampler, # Algoritmo de muestreo
num_warmup, # Número de muestras iniciales a descartar
num_samples, # Largo de la traza (luego de descartar num_warmup)
num_chains=1, # Número de cadenas
...
)
El cual retorna un objeto con los siguientes métodos:
run()
: Ejecuta el algoritmo y llena la traza. Espera una llave para el PRNG y los argumentos de la funciónmodel
.print_summary()
: Retorna una tabla con los momentos estadísticos de los parámetros y algunos diagnósticos.get_sample()
: Retorna la traza completa, es decir las muestras del posterior.
El argumento sampler
de MCMC espera una instancia de MCMCKernel
Ver también
Revisé la documentación para verificar los algoritmos de propuestas que están implementados actualmente.
En este caso utilizaremos No-U Turn (NUTS), el cual es el estado del arte para el caso de parámetros continuos.
key, key_ = random.split(key)
sampler = numpyro.infer.MCMC(sampler=numpyro.infer.NUTS(model),
num_samples=1000, num_warmup=100, thinning=1,
num_chains=2, progress_bar=True, jit_model_args=True)
sampler.run(key_, x, y)
Podemos inspeccionar el resultado utilizando:
sampler.print_summary()
mean std median 5.0% 95.0% n_eff r_hat
theta[0] 0.34 0.62 0.29 -0.71 1.35 1135.22 1.00
theta[1] -0.69 0.70 -0.68 -1.84 0.43 1601.65 1.00
theta[2] 2.70 0.96 2.60 1.13 4.18 1339.10 1.00
Number of divergences: 0
El estadístico de Gelman Rubin (\(\hat r\)) compara la varianza entre múltiples cadenas contra la varianza dentro de una misma cadena. Un valor cercano a uno significa que las cadenas llegaron a una distribución estacionaria muy similar.
La tabla también muestra que:
No hubieron divergencias durante el muestreo.
El número efectivo de muestras es mayor al 50% (1000 en este caso)
Todo lo anterior es signo de buena convergencia. Inspeccionemos ahora las trazas.
posterior_samples = sampler.get_samples()
posterior_samples.keys()
dict_keys(['p', 'theta'])
trace_plots = []
dist_plots = []
for d in range(3):
trace_plots.append(hv.Curve((posterior_samples['theta'][:, d]),
'Step', 'Theta', label=f'theta {d}').opts(alpha=0.75))
dist_plots.append(hv.Distribution(posterior_samples['theta'][:, d], 'Theta').opts(width=200))
hv.Overlay(trace_plots).opts(legend_position='top') << hv.Overlay(dist_plots)
Si la cadena de Markov ha llegado a su distribución estacionaria la cadena debería verse como ruido aleatorio en torno a un cierto valor, lo cual se cumple en este caso.
Si la cadena mostrara una tendencia (creciente o decreciente) significa que se requiere más pasos.
Si la cadena tiene saltos discontinuos (divergencias) el modelo podría tener errores en su especificación.
Para hacer una inspección cuantitativa de las trazas se pueden verificiar su función de autocorrelación. La autocorrelación de una cadena de Markov que ha convergido debería ir rápidamente a cero.
Ver también
En el siguiente link se muestra como interpretar el resultado de la autocorrelación de una traza en MCMC: https://phuijse.github.io/MonteCarloBook/lectures/mcmc2.html
Extraer muestras de la distribución posterior predictiva
En la práctica, para nuestro regresor logístico, lo más interesante no es el posterior de los parámetros \(p(\theta|\mathcal{D})\) sino el posterior de las predicciones, es decir la distribución de y para una nueva observación x dado que conocemos el conjunto de entrenamiento:
donde lo más difícil de obtener es justamente \(p(\theta| \mathcal{D})\), que afortunadamente ya tenemos gracias al MCMC.
Para estimar un posterior predictivo en numpyro utilizamos:
predictive = numpyro.infer.Predictive(model, posterior_samples=posterior_samples)
El cual retorna un objeto que podemos evalaur en nuevos datos:
x_test = jnp.array([[-3, 1], [3, -1], [0.5, 0.5]])
posterior_predictive_samples = predictive(random.PRNGKey(1), x_test)
display(posterior_predictive_samples.keys(),
posterior_predictive_samples['y'].shape)
dict_keys(['p', 'y'])
(2000, 3)
Nota
Debido a que la traza tiene 2000 muestras del posterior, entonces obtendremos 2000 valores para las predicciones de cada elemento en x_test
.
Con la distribución de las predicciones podemos obtener la clase más probable y también su incertidumbre:
for i, sample in enumerate(posterior_predictive_samples['y'].T):
print(f"Example {x_test[i]}, N0: {len(sample) - sum(sample)}, N1: {sum(sample)}, Mean: {jnp.mean(sample):0.4f} Std: {jnp.std(sample):0.4f}")
Example [-3. 1.], N0: 78, N1: 1922, Mean: 0.9610 Std: 0.1936
Example [ 3. -1.], N0: 1890, N1: 110, Mean: 0.0550 Std: 0.2280
Example [0.5 0.5], N0: 454, N1: 1546, Mean: 0.7730 Std: 0.4189
Realizemos una predicción en un rango mayor:
x_test = jnp.linspace(-3, 3, num=100)
X_test1, X_test2 = jnp.meshgrid(x_test, x_test)
X_test = jnp.vstack((X_test1.ravel(), X_test2.ravel())).T
posterior_predictive_samples = predictive(random.PRNGKey(1), X_test)
posterior_predictive_samples['y'].shape
(2000, 10000)
Así es como se ve una de las predicciones anteriores. Los puntos verdes corresponden al set de entrenamiento.
p1 = hv.Scatter((x[y==0, 0], x[y==0, 1]), label='y==0').opts(size=10, color='g', marker='d')
p2 = hv.Scatter((x[y==1, 0], x[y==1, 1]), label='y==1').opts(size=10, color='g', marker='o')
pred = hv.Image((x_test, x_test, posterior_predictive_samples['y'][0].reshape(X_test1.shape)),
kdims=['x[:, 0]', 'x[:, 1]']).opts(cmap='RdBu', colorbar=True)
hv.Overlay([pred, p1, p2]).opts(legend_position='top')
Esta corresponde a la moda del posterior predictivo:
pred = hv.Image((x_test, x_test, (jnp.sum(posterior_predictive_samples['y'], axis=0) > 1000).reshape(X_test1.shape)),
kdims=['x[:, 0]', 'x[:, 1]']).opts(cmap='RdBu', colorbar=True)
hv.Overlay([pred, p1, p2]).opts(legend_position='top')
La predicción promedio:
pred = hv.Image((x_test, x_test, jnp.mean(posterior_predictive_samples['y'], axis=0).reshape(X_test1.shape)),
kdims=['x[:, 0]', 'x[:, 1]']).opts(cmap='RdBu', colorbar=True)
hv.Overlay([pred, p1, p2]).opts(legend_position='top')
Y la desviación estándar de las predicciones:
pred = hv.Image((x_test, x_test, jnp.std(posterior_predictive_samples['y'], axis=0).reshape(X_test1.shape)),
kdims=['x[:, 0]', 'x[:, 1]']).opts(cmap='Reds', colorbar=True)
hv.Overlay([pred, p1, p2]).opts(legend_position='top')
22.2. Mezcla de Gaussianas Bayesiana¶
Ver también
Este ejemplo fue implementado originalmente en PyMC3: http://nbviewer.jupyter.org/github/CamDavidsonPilon/Probabilistic-Programming-and-Bayesian-Methods-for-Hackers/blob/master/Chapter3_MCMC/Ch3_IntroMCMC_PyMC3.ipynb
Los datos que utilizaremos en este ejemplo se generan a continuación:
mu_true = jnp.array([-3., 2.])
std_true = jnp.array([2., 0.75])
pi_true = jnp.array([0.4, 0.6])
N = 200
components = []
for mu, std, pi in zip(mu_true, std_true, pi_true):
key, key_ = random.split(key)
components.append(dists.Normal(mu, std).sample(key_, sample_shape=(int(pi*N),)))
data = jnp.concatenate(components)
x_plot = jnp.linspace(jnp.amin(data)*1.1, jnp.amax(data)*1.1, num=1000)
bins, edges = jnp.histogram(data, bins=20, density=True)
p = [hv.Histogram((edges, bins)).opts(alpha=0.5)]
for mu, std, pi in zip(mu_true, std_true, pi_true):
p.append(hv.Curve((x_plot, pi*jnp.exp(dists.Normal(mu, std).log_prob(x_plot)))))
hv.Overlay(p)
A continuación se muestra como implementar un modelo de mezcla de Gaussianas en NumPyro. Cabe destacar que en este caso tenemos una combinación de variables continuas y discretas.
Nota
Este ejemplo requiere instalar adicionalmente la librería funsor
def gmm_model(x):
pi = numpyro.sample('pis', dists.Dirichlet(np.array([0.5, 0.5])))
centers = numpyro.sample("mus", dists.Normal(loc=np.array([-1, 1]),
scale=np.array([10, 10])).to_event(1))
sds = numpyro.sample("sds", dists.HalfCauchy(scale=np.array([5, 5])).to_event(1))
with numpyro.plate('data', size=len(x), dim=-1):
z = numpyro.sample('z', dists.Categorical(pi))
numpyro.sample("obs", dists.Normal(loc=centers[z], scale=sds[z]), obs=x)
key, key_ = random.split(key)
seeded_model = numpyro.handlers.seed(gmm_model, key_)
exec_trace = numpyro.handlers.trace(seeded_model).get_trace(data)
print(numpyro.util.format_shapes(exec_trace))
Trace Shapes:
Param Sites:
Sample Sites:
pis dist | 2
value | 2
mus dist | 2
value | 2
sds dist | 2
value | 2
data plate 200 |
z dist 200 |
value 200 |
obs dist 200 |
value 200 |
Muestreo del posterior con MCMC:
key, key_ = random.split(key)
sampler = numpyro.infer.MCMC(sampler=numpyro.infer.NUTS(gmm_model),
num_samples=1000, num_warmup=100, thinning=1,
num_chains=1, progress_bar=True,
jit_model_args=True)
# Debido al problema de label switching sólo se usa una cadena:
with numpyro.validation_enabled():
sampler.run(key_, data)
/tmp/ipykernel_19368/651341955.py:10: FutureWarning: Some algorithms will automatically enumerate the discrete latent site z of your model. In the future, enumerated sites need to be marked with `infer={'enumerate': 'parallel'}`.
sampler.run(key_, data)
sample: 100%|█| 1100/1100 [00:06<00:00, 173.69i
Diagnósticos:
sampler.print_summary()
mean std median 5.0% 95.0% n_eff r_hat
mus[0] 1.85 0.11 1.86 1.68 2.03 348.28 1.00
mus[1] -3.34 0.30 -3.38 -3.81 -2.86 426.40 1.00
pis[0] 0.66 0.04 0.66 0.58 0.72 410.48 1.00
pis[1] 0.34 0.04 0.34 0.28 0.42 410.48 1.00
sds[0] 0.95 0.10 0.94 0.79 1.10 324.60 1.00
sds[1] 1.57 0.25 1.54 1.18 2.00 337.61 1.00
Number of divergences: 0
posterior_samples = sampler.get_samples()
posterior_samples.keys()
dict_keys(['mus', 'pis', 'sds'])
p = []
for param in posterior_samples.keys():
plot_traces, plot_dists = [], []
for k in range(2):
plot_traces.append(hv.Curve((posterior_samples[param][:, k]), 'Steps', param).opts(height=200))
plot_dists.append(hv.Distribution(posterior_samples[param][:, k], param).opts(width=200))
p.append(hv.Overlay(plot_traces).opts(legend_position='top') << hv.Overlay(plot_dists))
hv.Layout(p).cols(1)
Visualización de las primeras 100 muestras del posterior:
p = []
for mu, pi, sds in zip(posterior_samples['mus'][:100, :],
posterior_samples['pis'][:100, :],
posterior_samples['sds'][:100, :]):
pdf = [pi[k]*jnp.exp(dists.Normal(loc=mu[k], scale=sds[k]).log_prob(x_plot)) for k in range(2)]
p.append(hv.Curve((x_plot, jnp.sum(jnp.stack(pdf), axis=0))).opts(color='k', alpha=0.05))
hv.Overlay(p)
Visualización del posterior predictivo de la variable latente \(z\):
predictive = numpyro.infer.Predictive(gmm_model, posterior_samples=posterior_samples, infer_discrete=True)
x_test = np.linspace(-8, 4, num=100)
posterior_predictive_samples = predictive(random.PRNGKey(1), x_test)
display(posterior_predictive_samples.keys(),
posterior_predictive_samples['z'].shape)
dict_keys(['obs', 'z'])
(1000, 100)
z_mean = jnp.mean(posterior_predictive_samples['z'], axis=0)
z_err = jnp.std(posterior_predictive_samples['z'], axis=0)/np.sqrt(len(x_test))
hv.Curve((x_test, z_mean), 'data', 'z') * hv.Spread((x_test, z_mean, z_err))