Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Analytic sampling for conditional posterior instances trained with MDNs. #458

Merged
merged 22 commits into from
Jul 19, 2021

Conversation

jnsbck
Copy link
Contributor

@jnsbck jnsbck commented Mar 29, 2021

Hey,

I added code to analytically sample and condition DirectPosterior instances trained on MDNs. Wrapping a DirectPosterior instance that has been trained using an MDN in this way, replaces the .log_prob() and .sample() methods with analytical ones. This means it should be compatible with the rest of sbi.

Example

After training the posterior...

inference = SNPE(prior, density_estimator='mdn')
density_estimator = inference.append_simulations().train()
posterior = inference.build_posterior(density_estimator)
posterior.set_default_x(x=x_o)

...it can be wrapped, ...

analytic_posterior = MDNPosterior(posterior)

...and conditioned.

condition = torch.tensor([[theta_1, theta_2, 'nan']])
cond_posterior = analytic_posterior.conditionalise(condition)

Variables of interest are represented by 'nan's.

Then cond_posterior can be sampled and evaluated as before.

cond_posterior.sample((1,))
cond_posterior.log_prob(theta)

I am sure this could somehow be integrated into the DirectPosterior class directly as well, since the MDNPosterior inherits from it.
Let me know of any issues, I'd be happy to help integrating it into sbi. :)

@michaeldeistler michaeldeistler self-assigned this Mar 29, 2021
Copy link
Contributor

@michaeldeistler michaeldeistler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot! I made a few preliminary comments. Nothing for you to do yet. Instead, we should re-write the MDN class in pyknos. This will avoid code repetition of sample and log_prob methods. I will take care of re-writing pyknos and will let you know. This might take a few days though.

sbi/utils/conditional_density.py Outdated Show resolved Hide resolved
sbi/utils/conditional_density.py Outdated Show resolved Hide resolved
sbi/utils/conditional_density.py Outdated Show resolved Hide resolved
sbi/utils/conditional_density.py Outdated Show resolved Hide resolved
sbi/utils/conditional_density.py Outdated Show resolved Hide resolved
@michaeldeistler
Copy link
Contributor

You can track progress here

Copy link
Contributor

@michaeldeistler michaeldeistler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okey, the PR is merged. We can now use .sample_mog() and .log_prob_mog() and replace basically the entire probability evaluations and sample functions by just calling these methods. Let me know in case I am missing something, and thanks again for taking over this endeavor!

sbi/utils/conditional_density.py Outdated Show resolved Hide resolved
sbi/utils/conditional_density.py Outdated Show resolved Hide resolved
sbi/utils/conditional_density.py Outdated Show resolved Hide resolved
sbi/utils/conditional_density.py Outdated Show resolved Hide resolved
@jnsbck
Copy link
Contributor Author

jnsbck commented Mar 30, 2021

So much for "This might take a few days though." 😄
Sure. I'll keep you posted on the progress.

log_factor = torch.log(self.leakage_correction(x=self.default_x))
return torch.log(torch.sum(pdf, axis=1)) - log_factor
self.net.eval() # leakage correction requires eval mode
log_factor = torch.log(self.leakage_correction(x=self.default_x))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the leakage correction does not work for the conditioned MDNs, because it uses samples from self i.e. the full posterior rather then from the new, conditioned posterior.

sbi/utils/conditional_density.py Outdated Show resolved Hide resolved
Copy link
Contributor Author

@jnsbck jnsbck left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

extract_and_transform_mog + __conditionalise can now be used to get new sets of mog logits, means etc. for an arbitrary condition which makes them compatible with sample_mog and log_prob_mog.

Next step: Introduce extract_and_transform_mog + __conditionalise into DirectPosterior.
Then insert them into sample_conditional together with sample_mog and sample_posterior_within_prior as well as into
log_prob_conditional together with log_prob_mog.

Copy link
Contributor Author

@jnsbck jnsbck left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have integrated everything into DirectPosterior as discussed and according to my tests it works.
Also... I just commited a bunch of files for review on accident. The only thing important for merging is the direct_posterior.py.

Looking forward to your remarks. Hope this round of review is the last :)

@codecov-commenter
Copy link

codecov-commenter commented Apr 18, 2021

Codecov Report

Merging #458 (f989443) into main (9ed18ca) will decrease coverage by 0.99%.
The diff coverage is 12.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #458      +/-   ##
==========================================
- Coverage   67.96%   66.96%   -1.00%     
==========================================
  Files          56       56              
  Lines        4117     4190      +73     
==========================================
+ Hits         2798     2806       +8     
- Misses       1319     1384      +65     
Flag Coverage Δ
unittests 66.96% <12.00%> (-1.00%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
sbi/utils/__init__.py 100.00% <ø> (ø)
sbi/utils/conditional_density.py 61.98% <8.16%> (-36.63%) ⬇️
sbi/inference/posteriors/direct_posterior.py 69.40% <19.23%> (-11.51%) ⬇️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 9ed18ca...f989443. Read the comment docs.

Copy link
Contributor Author

@jnsbck jnsbck left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code should be ready for final review.

Copy link
Contributor

@michaeldeistler michaeldeistler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's starting to look good, thanks! The main thing we still need are unit-tests. I suggest we write something akin to this. In your case, you would compare your samples from your conditional to samples from the ground truth via c2st

sbi/inference/posteriors/direct_posterior.py Outdated Show resolved Hide resolved
sbi/inference/posteriors/direct_posterior.py Outdated Show resolved Hide resolved
sbi/inference/posteriors/direct_posterior.py Outdated Show resolved Hide resolved
sbi/inference/posteriors/direct_posterior.py Show resolved Hide resolved
sbi/inference/posteriors/direct_posterior.py Outdated Show resolved Hide resolved
sbi/inference/posteriors/direct_posterior.py Outdated Show resolved Hide resolved
sbi/inference/posteriors/direct_posterior.py Outdated Show resolved Hide resolved
sbi/inference/posteriors/direct_posterior.py Outdated Show resolved Hide resolved
sbi/inference/posteriors/direct_posterior.py Outdated Show resolved Hide resolved
Copy link
Contributor Author

@jnsbck jnsbck left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, I finally got round to implementing everything. I hope it still plays nicely with the new version of sbi that has come out in the meantime. I have also added something resembling a unit test, but I am not sure how to integrate it properly. Help would be greatly appreciated. I have just plugged it into an extra file for now. Looking forward to your comments.

Best,

Jonas

Copy link
Contributor

@michaeldeistler michaeldeistler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot! Super minor things only, we can merge it after this!

tests/testing_analytic_mdn_conditioning.py Outdated Show resolved Hide resolved
tests/testing_analytic_mdn_conditioning.py Outdated Show resolved Hide resolved
sbi/utils/conditional_density.py Outdated Show resolved Hide resolved
sbi/inference/posteriors/direct_posterior.py Show resolved Hide resolved
sbi/inference/posteriors/direct_posterior.py Outdated Show resolved Hide resolved
@jnsbck
Copy link
Contributor Author

jnsbck commented Jul 6, 2021

I have just addressed your latest comments. Hope that it can be merged now. :)

Copy link
Contributor

@michaeldeistler michaeldeistler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great and will be super useful, thanks a lot for all your efforts!

@michaeldeistler
Copy link
Contributor

Can be merged once merge conflicts are resolved.

@jnsbck
Copy link
Contributor Author

jnsbck commented Jul 18, 2021

I just resolved all conflicts, if it passes the tests, I guess you can merge it :)

@jnsbck jnsbck merged commit 00692fe into sbi-dev:main Jul 19, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants