Skip to content

Commit

Permalink
add trainer back after save (#1416)
Browse files Browse the repository at this point in the history
Co-authored-by: Ziqin Xiong <88057852+haha@users.noreply.github.com>
  • Loading branch information
Ziqin Xiong and ziqin8 authored Sep 11, 2023
1 parent 1c6094d commit 56d9910
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion neuralprophet/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,17 @@ def save(forecaster, path: str):
>>> save(forecaster, "test_save_model.np")
"""
# Remove the Lightning trainer since it does not serialise correcly with torch.save
for attr in ["trainer"]:
attrs_to_remove = ["trainer"]
removed_attrs = {}
for attr in attrs_to_remove:
removed_attrs[attr] = getattr(forecaster, attr)
setattr(forecaster, attr, None)
torch.save(forecaster, path)

# Restore the Lightning trainer
for attr in attrs_to_remove:
setattr(forecaster, attr, removed_attrs[attr])


def load(path: str):
"""retrieve a fitted model from a .np file that was saved by save.
Expand Down

0 comments on commit 56d9910

Please sign in to comment.