Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow (auto) saving and loading a RunManager (e.g., for external monitoring) #156

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open
69 changes: 63 additions & 6 deletions adaptive_scheduler/_server_support/run_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable

import cloudpickle
import pandas as pd

from adaptive_scheduler.utils import (
Expand All @@ -17,6 +18,7 @@
_at_least_adaptive_version,
_remove_or_move_files,
_time_between,
fname_to_learner,
fname_to_learner_fname,
load_dataframes,
load_parallel,
Expand Down Expand Up @@ -159,7 +161,7 @@ class RunManager(BaseManager):

"""

def __init__(
def __init__( # noqa: PLR0915
self,
scheduler: BaseScheduler,
learners: list[adaptive.BaseLearner],
Expand Down Expand Up @@ -187,6 +189,7 @@ def __init__(
max_log_lines: int = 500,
max_fails_per_job: int = 50,
max_simultaneous_jobs: int = 100,
store_fname: str | Path | None = None,
initializers: list[Callable[[], None]] | None = None,
) -> None:
super().__init__()
Expand All @@ -213,6 +216,7 @@ def __init__(
self.max_log_lines = max_log_lines
self.max_fails_per_job = max_fails_per_job
self.max_simultaneous_jobs = max_simultaneous_jobs
self.store_fname = store_fname
self.initializers = initializers
# Track job start times, (job_name, start_time) -> request_time
self._job_start_time_dict: dict[tuple[str, str], str] = {}
Expand All @@ -232,11 +236,7 @@ def __init__(
self.start_time: float | None = None
self.end_time: float | None = None
self._start_one_by_one_task: (
tuple[
asyncio.Future,
list[asyncio.Task],
]
| None
tuple[asyncio.Future, list[asyncio.Task]] | None
) = None

# Set on init
Expand Down Expand Up @@ -318,6 +318,8 @@ def start(self, wait_for: RunManager | None = None) -> RunManager: # type: igno
self._start_one_by_one_task = start_one_by_one(wait_for, self)
else:
super().start()
if self.store_fname is not None:
self.save()
return self

async def _manage(self) -> None:
Expand Down Expand Up @@ -485,6 +487,61 @@ def load_dataframes(self) -> pd.DataFrame:
raise ValueError(msg)
return load_dataframes(self.fnames, format=self.dataframe_format) # type: ignore[return-value]

def save(
self,
store_fname: str | Path | None = None,
*,
overwrite: bool = True,
) -> None:
"""Store the `RunManager` to a file.

Parameters
----------
store_fname : str or Path or None
The filename to store the `RunManager` to, if None, use the
`store_fname` attribute.
overwrite : bool, default: False
If True, overwrite the file if it already exists.

"""
if store_fname is None:
store_fname = self.store_fname
if store_fname is None:
msg = "No `store_fname` given and no `store_fname` attribute is set."
raise ValueError(msg)
store_fname = Path(store_fname)
keys = self.__dict__.keys() - {
"ioloop", # set in super().start()
"task", # set in super().start()
"learners", # we can load them from the filenames
# below are created in __init__
"job_names",
"database_manager",
"job_manager",
"kill_manager",
}
to_save = {k: self.__dict__[k] for k in keys if not k.startswith("_")}
if store_fname.exists() and not overwrite:
msg = f"{store_fname} already exists."
raise FileExistsError(msg)
with store_fname.open("wb") as f:
cloudpickle.dump(to_save, f)

@classmethod
def load(cls: type[RunManager], store_fname: str | Path) -> RunManager:
"""Load a `RunManager` from a file."""
store_fname = Path(store_fname)
with store_fname.open("rb") as f:
to_load = cloudpickle.load(f)
to_load["learners"] = [fname_to_learner(fn) for fn in to_load["fnames"]]
to_load["overwrite_db"] = False
start_time = to_load.pop("start_time")
end_time = to_load.pop("end_time")
rm = cls(**to_load)
rm.start_time = start_time
rm.end_time = end_time
return rm

def remove_existing_data(
self,
*,
Expand Down
20 changes: 20 additions & 0 deletions tests/test_run_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,26 @@ def test_run_manager_cleanup(
assert not rm.move_old_logs_to.exists()


@pytest.mark.asyncio()
async def test_run_manager_save_load(
mock_scheduler: MockScheduler,
learners: list[adaptive.Learner1D]
| list[adaptive.BalancingLearner]
| list[adaptive.SequenceLearner],
fnames: list[str] | list[Path],
tmp_path: Path,
) -> None:
"""Test the cleanup method of RunManager."""
with temporary_working_directory(tmp_path):
rm = RunManager(mock_scheduler, learners, fnames)
rm.start()
await asyncio.sleep(0.1)
fn = "run_manager.cloudpickle"
rm.save(fn)
rm2 = RunManager.load(fn)
assert rm2.fnames == rm.fnames


def test_run_manager_parse_log_files(
mock_scheduler: MockScheduler,
learners: list[adaptive.Learner1D]
Expand Down
Loading