Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remembering previous partitions #263

Closed
zschutzman opened this issue Jan 11, 2019 · 6 comments
Closed

Remembering previous partitions #263

zschutzman opened this issue Jan 11, 2019 · 6 comments

Comments

@zschutzman
Copy link
Contributor

It would be useful for the chain to be able to "jump back" to a previous partition and continue walking without this requiring a user to hack together a solution for this. The immediate application is to implement the restarting mechanism we use in the MCMC for optimization demo where the proposal function is, in pseudocode,

remember the best thing you've seen so far
if num_steps %150 == 0:
    next = best_so_far
else: 
   next = proposal(current)

On the toy example we have there, doing 10 runs of 150 finds plans which are much more extreme (with respect to the definition of "best") than a single run of 1500 steps. We're interested in trying this on real data, so I've cooked up a solution which is to create a Proposal class which can remember best_so_far and has a __call__ method which makes it look like a proposal() function, in that an instance of the class can be passed a single argument: the current partition, and abstracts away from the chain all the other stuff. Here is a stub from my current implementation:

class RestartProposal:
    """
    A callable class wrapper around a proposal function which can restart from 
    a previously visited state.
    """
    def __init__(
        self, proposal, freq=1000, gap=False, criterion = None, compar = ">"
    ):
        """
        :param prop: A function handle; the underlying proposal function to perform the ordinary walk steps of the chain
        :param freq: An integer; the frequency with which the chain should go back to the saved state
        :param assignment: A boolean; if gap == True, then restart only after making freq steps
        	without saving a new state. If gap == False, restart every freq steps.
        :param criterion: A function handle; assigns a score to a partition
        :param compar: One of ">" or "<".  If ">", saves a new assignment if 
        	criterion(new)>criterion(old), if "<", saves if criterion(new)<criterion(old)
        """

        ...
        ...
        
    def __call__(self, partition):
    	"""
			Callable class mimicks a function from the perspective of chain.
			First, evaluates the partition to see if it should be checkpointed.
			If time to restart, returns the assignment from self.checkpoint,
			else calls self.proposal
    	"""
    	self._eval_partition(partition)

    	if self.counter == self.freq:
    		self.counter = 0
    		return self.proposal(self.checkpoint)
    	else:
    		self.counter += 1
    		return self.proposal(partition)

The advantage here is that an instance of the class can be passed to the MarkovChain constructor as if it were a proposal function and the chain can run as usual. I'm wondering whether or not we think there's a better way to go about doing this. I'm imagining use cases where we might want to remember several different states, e.g. informing the next step based on some function of the previous k steps, and it would be good to have a level of abstraction here that lets us do that kind of thing.

@maxhully
Copy link
Contributor

On the theory side of things, in order for the MarkovChain to be a Markov chain, the proposal shouldn't depend on anything but the current state of the chain. So that makes me think that this would be better accomplished by substituting the MarkovChain function, or wrapping it with a new function/callable.

For the specific burst run use case, that approach might look like this:

def burst_run(chain, number_of_bursts, criterion):
    best_so_far = chain.initial_state
    for burst in range(number_of_bursts):
        for state in chain:
            if criterion(state) > best_so_far:
                best_so_far = state
            yield state
        yield best_so_far

More generally, implementing some of these basic moves that users might want to make, beyond the basic Markov chain, seems like a great idea to me. Especially if we can make them straightforward to compose and combine.

@zschutzman
Copy link
Contributor Author

From the theoretical side, the burst run proposal is a Markov chain on the state space (PLANS x PLANS x Z), since the next step is history-independent once I tell you where you are, what the thing you're holding onto is, and what value your counter is set to.

From the design side, the mechanics for a burst run doesn't feel too different from something like a M-H modification where you crank up the temperature if you think you might be stuck or something that orients you in metagraph space so you favor walking away from places you've already been.

In both of these cases, it definitely doesn't make sense to hijack the proposal function, since both of these are better suited to modifications of the acceptance function. I'm not sure what the right design choice is here. It seems tough to strike the balance between making it so the chain can collect and access this information and not having so many little jigsaw pieces that need to be fit together in just the right way or nothing will work.

@maxhully
Copy link
Contributor

One small change we could make that would help in this area would be to have the proposal function return the next Partition instead of the dictionary of flips. That would let you use whatever object you want as the states in the chain. For instance, it could be a tuple with the current plan, the plan you're holding onto, and the counter.

@zschutzman
Copy link
Contributor Author

I think that's a good idea, both in terms of making this work as well as modularity for future ideas.

Is the way to do this to move the state.merge call inside the proposal function and then return the resulting object to then be validated by the chain?

@maxhully
Copy link
Contributor

I implemented that change in #266 .

@pjrule
Copy link
Contributor

pjrule commented Feb 7, 2023

This seems to be related in spirit to short bursts, which is implemented in #361.

@pjrule pjrule closed this as completed Feb 7, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants