Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow multiple pattern matches in chain with PyTorchFileRecorder #1269

Merged
merged 1 commit into from
Feb 7, 2024

Conversation

laggui
Copy link
Member

@laggui laggui commented Feb 6, 2024

Pull Request Template

Checklist

  • Confirmed that run-checks all script has been executed.

Changes

Changed the record remap function to allow multiple pattern matches in chain when using the PyTorchFileRecorder

For example, I initially though that providing multiple patterns like this would remap the keys on each match:

let load_args = LoadArgs::new("resnet18-f37072fd.pth".into())
  // Map *.downsample.0.* -> *.downsample.conv.*
  .with_key_remap("(.+)\\.downsample\\.0\\.(.+)", "$1.downsample.conv.$2")
  // Map *.downsample.1.* -> *.downsample.bn.*
  .with_key_remap("(.+)\\.downsample\\.1\\.(.+)", "$1.downsample.bn.$2")
  // Map layer[i].[j].* -> layer[i].blocks.[j].*
  .with_key_remap("layer[1-4]\\.([0-9])\\.(.+)", "layer$1.blocks.$2.$3");

let record = PyTorchFileRecorder::<FullPrecisionSettings>::new()
    .load(load_args, &device)
    .map_err(|err| format!("Failed to load weights.\nError: {err}"))
    .unwrap();

but it does not. Instead it breaks at the first match.

Stumbled upon this when trying to import the pre-trained ResNet-18 weights from torchvision for my implementation here.

In the residual blocks, we sometimes have a downsample layer that requires more than one portion of the key to be remapped (hence the patterns in chain). For example, layer1.0.downsample.0.* should be remapped to layer1.blocks.0.downsample.conv.*, which can be decomposed with the patterns presented in the example above.

Copy link
Collaborator

@antimora antimora left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@antimora antimora merged commit d2bdc46 into main Feb 7, 2024
13 checks passed
@antimora antimora deleted the fix/pytorch-remap branch February 7, 2024 02:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants