Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hamiltonian Monte Carlo with Energy Conserving Subsampling #905

Merged
merged 100 commits into from
Feb 11, 2021
Merged
Changes from 1 commit
Commits
Show all changes
100 commits
Select commit Hold shift + click to select a range
4321595
start
LysSanzMoreta Sep 8, 2020
b8001a9
start hmcecs two
LysSanzMoreta Sep 8, 2020
26219ce
structuring
LysSanzMoreta Sep 14, 2020
9a326db
small fix
LysSanzMoreta Sep 14, 2020
44d2fc1
ADDED: verlet, new log density
LysSanzMoreta Sep 16, 2020
4eeb1f0
FIXED: initialization model parameters
LysSanzMoreta Sep 18, 2020
ca9dece
FIXED: Arguments potential function
LysSanzMoreta Sep 21, 2020
f01c027
FIXED: Arguments mess
LysSanzMoreta Sep 22, 2020
cc2f1b0
FIXED? shapes error
LysSanzMoreta Sep 25, 2020
c3de253
Sampling working
LysSanzMoreta Sep 28, 2020
b17d53d
Seems to be working
LysSanzMoreta Sep 28, 2020
40be6c3
Added: Plotting and save samples to example
LysSanzMoreta Sep 29, 2020
f58dbf7
ADDED: Assertion errors
LysSanzMoreta Sep 29, 2020
3a25523
working on more than 1 chain
LysSanzMoreta Sep 30, 2020
a075857
ADDED: more plotting
LysSanzMoreta Sep 30, 2020
b40f662
ADDED: More tests and proxies
LysSanzMoreta Sep 30, 2020
54bca12
Small state fix
LysSanzMoreta Oct 1, 2020
765b3d6
Fixed : Proxies and init
LysSanzMoreta Oct 6, 2020
1fc6740
Working examples
LysSanzMoreta Oct 7, 2020
3993afd
Maybe working
LysSanzMoreta Oct 19, 2020
9dd7025
Started adding Block-Poisson
LysSanzMoreta Nov 2, 2020
bdcd352
small stuff
LysSanzMoreta Nov 2, 2020
742162c
Started adding poisson stuff
LysSanzMoreta Nov 2, 2020
81b8956
Working on documentation and poisson
LysSanzMoreta Nov 4, 2020
ffbaa9e
Added: Poisson stuff (missing initialization)
LysSanzMoreta Nov 6, 2020
1e83bb7
BlockPoissonRunning
LysSanzMoreta Nov 9, 2020
3f918ed
FIXED: potential estimator
LysSanzMoreta Nov 10, 2020
9fc0f3f
FINISHED: Block-poisson
LysSanzMoreta Nov 10, 2020
24e1c1b
MISSING: Postprocessing
LysSanzMoreta Nov 12, 2020
4243abe
FIXED: sign
LysSanzMoreta Nov 13, 2020
f996d18
More debugging
LysSanzMoreta Nov 13, 2020
1095f19
Fixed style.
OlaRonning Dec 9, 2020
0c18026
HMCECS working, fixed problems with SVI MAP and factored code.
OlaRonning Dec 14, 2020
d04f651
Merge remote-tracking branch 'origin/hmcecs' into hmcecs
OlaRonning Dec 15, 2020
b8f8830
Added MNIST BNN example using flax.
OlaRonning Dec 15, 2020
65531c2
Working potential with algebraic effect handlers.
OlaRonning Jan 7, 2021
216c2cf
Potential estimator integrated with ECS class.
OlaRonning Jan 7, 2021
d6e6700
ECS wrapper working on toy example.
OlaRonning Jan 8, 2021
c59b317
cleaned code.
OlaRonning Jan 8, 2021
dafaa6e
renamed hmcecs_utils to ecs_utils and added todos.
OlaRonning Jan 8, 2021
243e7bc
debugging taylor expansion.
OlaRonning Jan 12, 2021
c4252bb
Updated comments with reference and added test for num_blocks={} (the…
OlaRonning Jan 13, 2021
20b7350
Added pystan
OlaRonning Jan 13, 2021
4d7e4ed
Added components for variational proxy.
OlaRonning Jan 15, 2021
f867af1
Added variational_proxy, todo: fix estimator.
OlaRonning Jan 15, 2021
89c8ffe
Integrated variational proxy into ecs.
OlaRonning Jan 16, 2021
1403dfe
checkpoint: before redoing estimator.
OlaRonning Jan 18, 2021
1c6af82
Variational proxy running!
OlaRonning Jan 18, 2021
2a8cc23
Fixed minor bugs and example of hmcecs with variational proxy on logi…
OlaRonning Jan 18, 2021
7c41cee
merging
OlaRonning Jan 26, 2021
6f1f222
merged upstream
OlaRonning Jan 26, 2021
dd0426c
Refactored taylor_estimator into taylor_proxy and a difference estima…
OlaRonning Jan 26, 2021
60e0912
Sketched variational proxy in hmc_gibbs.
OlaRonning Jan 27, 2021
85957dc
Variational proxy running.
OlaRonning Jan 27, 2021
c01738f
Examples.
OlaRonning Jan 29, 2021
2151895
Moved estimate_likelihood
OlaRonning Jan 29, 2021
82ef761
Merge remote-tracking branch 'origin/feature/ecs' into feature/ecs
OlaRonning Jan 29, 2021
e46cb40
Added two moons
OlaRonning Jan 31, 2021
5b8af6e
Merge remote-tracking branch 'origin/feature/ecs' into feature/ecs
OlaRonning Jan 31, 2021
41bda0c
add gibbs_state and fix bugs
fehiepsi Feb 1, 2021
ab7888e
Integrated taylor proxy and updated API.
OlaRonning Feb 1, 2021
a9d2c0e
Bugs fixed and taylor working!
OlaRonning Feb 1, 2021
e4bf263
Updated variational proxy to new API.
OlaRonning Feb 2, 2021
508e96a
Variational proxy running on breast cancer!
OlaRonning Feb 2, 2021
3761905
Working regression
OlaRonning Feb 2, 2021
a695046
Fixed problems in variational; todo rethink dummy_sample ([] doesn't …
OlaRonning Feb 2, 2021
896cd19
add covtype example
fehiepsi Feb 2, 2021
4e9a192
Merge remote-tracking branch 'origin/feature/ecs' into feature/ecs
OlaRonning Feb 2, 2021
f5e8894
fix some bugs to substitute empty subsample indices and add some FIXME
fehiepsi Feb 2, 2021
b571228
FIXED ELBO computation and changed the weight scheme in variational p…
OlaRonning Feb 3, 2021
7d9cd11
fixed proxy_sum and added equations.
OlaRonning Feb 3, 2021
5dbac85
VECS working with AutoNormal on BreastCancer.
OlaRonning Feb 3, 2021
bb783f8
Using Likelihood as weight.
OlaRonning Feb 3, 2021
7644a08
factored out VECS
OlaRonning Feb 4, 2021
a997acb
Added simple test case.
OlaRonning Feb 5, 2021
ad2f799
Merge branch 'master' of github.com:pyro-ppl/numpyro into feature/ecs
OlaRonning Feb 5, 2021
fb95035
Cleaned.
OlaRonning Feb 5, 2021
e1150ea
Removed old HMCECS logistic examples.
OlaRonning Feb 5, 2021
c0a1c4c
removed old autoguide
OlaRonning Feb 5, 2021
5583162
Fixed linting.
OlaRonning Feb 5, 2021
8e08d76
Merge branch 'feature/ecs' of github.com:aleatory-science/numpyro int…
OlaRonning Feb 5, 2021
48cd7ef
fixed lint.
OlaRonning Feb 5, 2021
ff28eb0
Remove Poisson, factored out pandas for loading HIGGs dataset, added …
OlaRonning Feb 7, 2021
5febf7e
Fixed _block_update refactor. Missing new test cases, 2 more TODOs.
OlaRonning Feb 7, 2021
2c26173
fixed isort
OlaRonning Feb 7, 2021
6359292
Fixed comments, some 3 TODOs left.
OlaRonning Feb 8, 2021
0b886d7
Conditioned gradient computation and moved to unconstraint sapce for …
OlaRonning Feb 8, 2021
469a1f2
Fixed test for HMCECS and bumped jaxlib version.
OlaRonning Feb 8, 2021
0fb8d01
Fixed test.
OlaRonning Feb 8, 2021
5a6f629
Fixed lint.
OlaRonning Feb 8, 2021
29f3708
Corrected taylor_proxy works in unconstraint space. Added docstring a…
OlaRonning Feb 9, 2021
2aef856
Flipped syntax for geq in setup.py
OlaRonning Feb 9, 2021
cc2669e
Made default device for covtype example cpu.
OlaRonning Feb 9, 2021
63323ec
Added taylor proxy test.
OlaRonning Feb 10, 2021
89a99a2
Added test for variance.
OlaRonning Feb 10, 2021
0131f3e
Fixed lint.
OlaRonning Feb 10, 2021
c697e91
Added all log_density computation to test_estimate_likelihood and ass…
OlaRonning Feb 10, 2021
400389f
Fixed typo and isort.
OlaRonning Feb 10, 2021
2e1dddb
isort not included in previous commit.
OlaRonning Feb 10, 2021
6474df3
Fixed shadowing log_prob.
OlaRonning Feb 11, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Bugs fixed and taylor working!
OlaRonning committed Feb 1, 2021
commit a9d2c0eca97b00019ac5e4feecc30e4eafc4f057
17 changes: 8 additions & 9 deletions examples/hmcecs/logistic_regression.py
Original file line number Diff line number Diff line change
@@ -78,7 +78,7 @@ def guide(feature, obs, subsample_size):
numpyro.sample('theta', dist.continuous.Normal(mean, .5))


def hmcecs_model( dataset, data, obs, subsample_size, proxy_name='taylor'):
def hmcecs_model(dataset, data, obs, subsample_size, proxy_name='taylor'):
model_args, model_kwargs = (data, obs, subsample_size), {}

svi_key, proxy_key, estimator_key, mcmc_key = random.split(random.PRNGKey(0), 4)
@@ -95,15 +95,14 @@ def hmcecs_model( dataset, data, obs, subsample_size, proxy_name='taylor'):
proxy_key, ref_key = random.split(proxy_key)
ref_params = _predictive(ref_key, guide, {}, (1,), return_sites='', parallel=True,
model_args=model_args, model_kwargs=model_kwargs)
proxy_fn = taylor_proxy(proxy_key, model, model_args, model_kwargs, ref_params)
ref_params.pop('mean')
proxy_fn = taylor_proxy(ref_params)

else:
proxy_fn = variational_proxy(proxy_key, model, model_args, model_kwargs, guide, params)
estimator = perturbed_method(estimator_key, model, model_args, model_kwargs, proxy_fn)
proxy_fn = variational_proxy(guide, params)

# Compute HMCECS

kernel = HMCECS(NUTS(model), proxy=estimator)
kernel = HMCECS(NUTS(model), proxy=proxy_fn)
mcmc = MCMC(kernel, 1000, 1000)
start = time()
mcmc.run(random.PRNGKey(3), data, obs, subsample_size, extra_fields=("hmc_state.accept_prob",
@@ -128,7 +127,7 @@ def hmc(dataset, data, obs):

if __name__ == '__main__':

load_data = {'higgs': higgs_data, 'breast': breast_cancer_data, 'copsac': copsac_data}
load_data = {'breast': breast_cancer_data, 'higgs': higgs_data, 'copsac': copsac_data}
subsample_sizes = {'higgs': 1300, 'copsac': 1000, 'breast': 75, }
data, obs = breast_cancer_data()

@@ -137,6 +136,6 @@ def hmc(dataset, data, obs):
if not os.path.exists(dir):
os.mkdir(dir)
data, obs = load_data[dataset]()
# hmcecs_model(dir, data, obs, subsample_sizes[dataset])
hmc(dir, data, obs)
hmcecs_model(dir, data, obs, subsample_sizes[dataset])
# hmc(dir, data, obs)
exit()
31 changes: 18 additions & 13 deletions numpyro/infer/hmc_gibbs.py
Original file line number Diff line number Diff line change
@@ -491,7 +491,9 @@ class HMCECS(HMCGibbs):

def __init__(self, inner_kernel, *, num_blocks=1, proxy=None, method='perturbed'):
super().__init__(inner_kernel, lambda *args: None, None)
assert method in ['perturbed']

assert method in {'perturbed'}
self.inner_kernel._model = _wrap_gibbs_state(self.inner_kernel._model)
self._num_blocks = num_blocks
self._proxy = proxy
self._method = method
@@ -519,11 +521,11 @@ def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs):
self._gibbs_sites = list(self._subsample_plate_sizes.keys())
if self._proxy is not None:
rng_key, proxy_key, method_key = random.split(rng_key, 3)
proxy_fn, gibbs_init, gibbs_update = self._proxy(rng_key,
self.model,
model_args,
model_kwargs,
num_blocks=self._num_blocks)
proxy_fn, gibbs_init, self._gibbs_update = self._proxy(rng_key,
self.model,
model_args,
model_kwargs,
num_blocks=self._num_blocks)
method = perturbed_method(method_key, self.model, model_args, model_kwargs, proxy_fn)
self.inner_kernel._model = estimate_likelihood(self.inner_kernel._model, method)

@@ -542,9 +544,9 @@ def sample(self, state, model_args, model_kwargs):
model_kwargs = {} if model_kwargs is None else model_kwargs.copy()
rng_key, rng_gibbs = random.split(state.rng_key)

def potential_fn(z_gibbs, z_hmc):
def potential_fn(z_gibbs, gibbs_state, z_hmc):
return self.inner_kernel._potential_fn_gen(
*model_args, _gibbs_sites=z_gibbs, **model_kwargs)(z_hmc)
*model_args, _gibbs_sites=z_gibbs, _gibbs_state=gibbs_state, **model_kwargs)(z_hmc)

z_gibbs = {k: v for k, v in state.z.items() if k not in state.hmc_state.z}
z_gibbs_new, gibbs_state_new = self._gibbs_update(rng_key, z_gibbs, state.gibbs_state)
@@ -559,9 +561,9 @@ def potential_fn(z_gibbs, z_hmc):

# TODO (very low priority): move this to the above cond, only compute grad when accepting
if self.inner_kernel._forward_mode_differentiation:
z_grad = jacfwd(partial(potential_fn, z_gibbs))(state.hmc_state.z)
z_grad = jacfwd(partial(potential_fn, z_gibbs, gibbs_state))(state.hmc_state.z)
else:
z_grad = grad(partial(potential_fn, z_gibbs))(state.hmc_state.z)
z_grad = grad(partial(potential_fn, z_gibbs, gibbs_state))(state.hmc_state.z)
hmc_state = state.hmc_state._replace(z_grad=z_grad, potential_energy=pe)

model_kwargs["_gibbs_sites"] = z_gibbs
@@ -585,7 +587,7 @@ def perturbed_method(rng_key, model, model_args, model_kwargs, proxy_fn):

def estimator(likelihoods, params, gibbs_state):
subsample_log_liks = defaultdict(float)
for (fn, value, name, subsample_dim, subsample_idx) in likelihoods.values():
for (fn, value, name, subsample_dim) in likelihoods.values():
subsample_log_liks[name] += _sum_all_except_at_dim(fn.log_prob(value), subsample_dim)

log_lik_sum = 0.
@@ -629,13 +631,16 @@ def log_likelihood(params_flat, subsample_indices=None):
substitute(substitute_fn=partial(_unconstrain_reparam, params)):
model(*model_args, **model_kwargs)

log_lik = defaultdict(float)
log_lik = {}
for site in tr.values():
if site["type"] == "sample" and site["is_observed"]:
for frame in site["cond_indep_stack"]:
if frame.name in subsample_plate_sizes:
if frame.name in log_lik:
log_lik[frame.name] += _sum_all_except_at_dim(
site["fn"].log_prob(site["value"]), frame.dim)
else:
log_lik[frame.name] = _sum_all_except_at_dim(
site["fn"].log_prob(site["value"]), frame.dim)
return log_lik

def log_likelihood_sum(params_flat, subsample_indices=None):