Skip to content
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
167dd2b
Modify execute_DAG with unitresult caching
IAlibay Feb 16, 2026
618a384
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 16, 2026
6d5a652
Add parents
IAlibay Feb 16, 2026
edf8b6f
Merge branch 'restart_execute' of github.com:OpenFreeEnergy/gufe into…
IAlibay Feb 16, 2026
6b2916c
Add fix
IAlibay Feb 16, 2026
6c05ac5
Merge branch 'main' into restart_execute
atravitz Mar 5, 2026
16d472f
add test
atravitz Mar 6, 2026
f243bff
mock up a function to test dependency logic
atravitz Mar 10, 2026
2854d73
update tests
atravitz Mar 10, 2026
8fd14f5
cleaning up, adding comments
atravitz Mar 11, 2026
4fbb80b
back to constructing a dict now that I trust my logic
atravitz Mar 11, 2026
e4722e5
make private
atravitz Mar 11, 2026
b94e329
fix key check in results
atravitz Mar 11, 2026
ba337a1
add additional check
atravitz Mar 11, 2026
f4c09cc
corrupt instead of delete test file to cover JSONError catching
atravitz Mar 11, 2026
e925018
add news item
atravitz Mar 11, 2026
b7a49f6
Apply suggestions from code review
atravitz Mar 12, 2026
c1114ab
add suggestions from mmh
atravitz Mar 12, 2026
37499bc
Merge branch 'main' of github.com:OpenFreeEnergy/gufe into restart_ex…
atravitz Mar 13, 2026
03af9d6
Merge branch 'main' into restart_execute
atravitz Mar 19, 2026
2a2e27e
rename to meet the line limit
atravitz Mar 19, 2026
81ee236
rename args for clarity
atravitz Mar 20, 2026
8a01796
store by pu key instead of pur key
atravitz Mar 20, 2026
e9d454d
make names clearer
atravitz Mar 20, 2026
543981b
clean up redundant vars
atravitz Mar 20, 2026
d8cffd1
catch warning in test
atravitz Mar 20, 2026
d38d25a
unit -> pu
atravitz Mar 20, 2026
bed6a68
fail early
atravitz Mar 20, 2026
f83001d
cache_unitresults -> keep_cache
atravitz Mar 20, 2026
2627b08
use more specific error and add to docstring
atravitz Mar 20, 2026
45e05ed
mypy fix
atravitz Mar 20, 2026
5982414
update news item
atravitz Mar 20, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 68 additions & 9 deletions gufe/protocols/protocoldag.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@
# For details, see https://github.com/OpenFreeEnergy/gufe

import shutil
import warnings
from collections import defaultdict
from collections.abc import Iterable
from json import JSONDecodeError
from pathlib import Path
from typing import Any

Expand Down Expand Up @@ -371,15 +374,38 @@ def _from_dict(cls, dct: dict):
return cls(**dct)


def _get_valid_unit_results(
protocoldag: ProtocolDAG, unit_results: Iterable[ProtocolUnitResult]
) -> dict[GufeKey, ProtocolUnitResult]:
"""Given a ProtocolDAG and a set of unit_results, determine which protocol_units of the DAG can be skipped during execution."""

# handle results & optionally archiving
# Is source key stable enough?
# We probably don't want to resume if gufe stability has changed
result_pu_key_to_pur: dict[GufeKey, ProtocolUnitResult] = {ur.source_key: ur for ur in unit_results}

for unit in protocoldag.protocol_units: # protocol_units is in DAG-dependency-order
if unit.key in result_pu_key_to_pur: # units we want to skip during execution
pass
else: # units we want to run (or re-run) during execution
# if this unit needs to be run, then everything downstream of it needs to be re-run as well
for downstream_unit in nx.ancestors(protocoldag.graph, unit):
Comment thread
atravitz marked this conversation as resolved.
Outdated
result_pu_key_to_pur.pop(downstream_unit.key, "None")

return result_pu_key_to_pur


def execute_DAG(
protocoldag: ProtocolDAG,
*,
shared_basedir: Path,
scratch_basedir: Path,
unitresults_basedir: Path | None = None,
stderr_basedir: Path | None = None,
stdout_basedir: Path | None = None,
keep_shared: bool = False,
keep_scratch: bool = False,
keep_unitresults: bool = False,
raise_error: bool = True,
n_retries: int = 0,
) -> ProtocolDAGResult:
Expand All @@ -396,6 +422,11 @@ def execute_DAG(
class:``ProtocolUnit`` instances.
scratch_basedir : Path
Filesystem path to use for `ProtocolUnit` `scratch` space.
unitresults_basedir : Path | None = None
Filesystem path to use for `ProtocolUnitResult` archiving during
execution. If ``None`` (default), results will not be archived
and it will not be able to resume DAG execution from the last
successfully finished `ProtocolUnit`.
stderr_basedir : Path | None
Filesystem path to use for `ProtocolUnit` `stderr` archiving.
stdout_basedir : Path | None
Expand All @@ -406,6 +437,9 @@ def execute_DAG(
keep_scratch : bool
If True, don't remove scratch directories for a `ProtocolUnit` after
it is executed.
keep_unitresults : bool
If True, don't remove the unitresults directory which contains
the serialized `ProtocolUnitResult` for all executed `ProtocolUnit`/s.
raise_error : bool
If True, raise an exception if a ProtocolUnit fails, default True
if False, any exceptions will be stored as `ProtocolUnitFailure`
Expand All @@ -422,33 +456,50 @@ def execute_DAG(
if n_retries < 0:
raise ValueError("Must give positive number of retries")

# iterate in DAG order
results: dict[GufeKey, ProtocolUnitResult] = {}
all_cached_results: list[ProtocolUnitResult] = [] # store all unitresults found in the cache
if unitresults_basedir is not None:
unitresults_path = unitresults_basedir / f"unitresults_{str(protocoldag.key)}"
unitresults_path.mkdir(exist_ok=True, parents=True)

for file in unitresults_path.rglob("*.json"):
try:
unit_result = ProtocolUnitResult.from_json(file)
# TODO: any additional criteria to check here?
except JSONDecodeError as e:
warnings.warn(f"Unable to read file, skipping {file}: {e}")
else:
all_cached_results.append(unit_result)

# handle results & optionally archiving
results: dict[GufeKey, ProtocolUnitResult] = _get_valid_unit_results(protocoldag, all_cached_results)
all_results = [] # successes AND failures
shared_paths = []
for unit in protocoldag.protocol_units:
# translate each `ProtocolUnit` in input into corresponding
# `ProtocolUnitResult`
for unit in protocoldag.protocol_units: # protocol_units is in DAG-dependency-order
# If we already have results, skip execution
if unit.key in results:
all_results.append(results[unit.key])
continue
# translate each `ProtocolUnit` in input into corresponding `ProtocolUnitResult`
inputs = _pu_to_pur(unit.inputs, results)

attempt = 0
while attempt <= n_retries:
shared = shared_basedir / f"shared_{str(unit.key)}_attempt_{attempt}"
shared_paths.append(shared)
shared.mkdir()
shared.mkdir(exist_ok=True)

scratch = scratch_basedir / f"scratch_{str(unit.key)}_attempt_{attempt}"
scratch.mkdir()
scratch.mkdir(exist_ok=True)

stderr = None
if stderr_basedir:
stderr = stderr_basedir / f"stderr_{str(unit.key)}_attempt_{attempt}"
stderr.mkdir()
stderr.mkdir(exist_ok=True)

stdout = None
if stdout_basedir:
stdout = stdout_basedir / f"stdout_{str(unit.key)}_attempt_{attempt}"
stdout.mkdir()
stdout.mkdir(exist_ok=True)

context = Context(shared=shared, scratch=scratch, stderr=stderr, stdout=stdout)

Expand All @@ -468,6 +519,10 @@ def execute_DAG(
if result.ok():
# attach result to this `ProtocolUnit`
results[unit.key] = result

# Serialize results if requested
if unitresults_basedir is not None:
result.to_json(unitresults_path / f"{str(result.key)}.json")
break
attempt += 1

Expand All @@ -478,6 +533,9 @@ def execute_DAG(
for shared_path in shared_paths:
shutil.rmtree(shared_path)

if not keep_unitresults and unitresults_basedir is not None:
shutil.rmtree(unitresults_path)

return ProtocolDAGResult(
name=protocoldag.name,
protocol_units=protocoldag.protocol_units,
Expand Down Expand Up @@ -506,6 +564,7 @@ def _pu_to_pur(
replaced with its corresponding `ProtocolUnitResult`.

"""

if isinstance(inputs, dict):
return {key: _pu_to_pur(value, mapping) for key, value in inputs.items()}
elif isinstance(inputs, list):
Expand Down
153 changes: 149 additions & 4 deletions gufe/tests/test_protocoldag.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from openff.units import unit

import gufe
from gufe.protocols import execute_DAG
from gufe.protocols import execute_DAG, protocoldag


class WriterUnit(gufe.ProtocolUnit):
Expand Down Expand Up @@ -77,14 +77,18 @@ def writefile_dag():

@pytest.mark.parametrize("keep_shared", [False, True])
@pytest.mark.parametrize("keep_scratch", [False, True])
@pytest.mark.parametrize("keep_unitresults", [False, True])
@pytest.mark.parametrize("capture_stderr_stdout", [False, True])
def test_execute_dag(tmp_path, keep_shared, keep_scratch, writefile_dag, capture_stderr_stdout):
def test_execute_dag(tmp_path, keep_shared, keep_scratch, keep_unitresults, writefile_dag, capture_stderr_stdout):
shared = pathlib.Path(tmp_path / "shared")
shared.mkdir(parents=True)

scratch = pathlib.Path(tmp_path / "scratch")
scratch.mkdir(parents=True)

unit_results_cache = pathlib.Path(tmp_path / "unit_results_cache")
unit_results_cache.mkdir(parents=True)

stderr = None
stdout = None
if capture_stderr_stdout:
Expand All @@ -98,12 +102,13 @@ def test_execute_dag(tmp_path, keep_shared, keep_scratch, writefile_dag, capture
writefile_dag,
shared_basedir=shared,
scratch_basedir=scratch,
unitresults_basedir=unit_results_cache,
stderr_basedir=stderr,
stdout_basedir=stdout,
keep_shared=keep_shared,
keep_scratch=keep_scratch,
keep_unitresults=keep_unitresults,
)

# check outputs are as expected
# will have produced 4 files in scratch and shared directory
for pu in writefile_dag.protocol_units:
Expand All @@ -114,6 +119,8 @@ def test_execute_dag(tmp_path, keep_shared, keep_scratch, writefile_dag, capture
f"scratch_{str(pu.key)}_attempt_0",
f"unit_{identity}_scratch.txt",
)
# TODO: add result key.json
unit_result_file = os.path.join(unit_results_cache, f"unitresults_{str(writefile_dag.key)}")

if capture_stderr_stdout:
stderr_file = os.path.join(
Expand All @@ -126,6 +133,8 @@ def test_execute_dag(tmp_path, keep_shared, keep_scratch, writefile_dag, capture
f"stdout_{str(pu.key)}_attempt_0",
f"unit_{identity}_stdout",
)
# TODO: add result key.json
unit_result_file = os.path.join(unit_results_cache, f"unitresults_{str(writefile_dag.key)}")

# stderr and stdout are always removed since their
# contents are included in the unit results
Expand All @@ -140,7 +149,10 @@ def test_execute_dag(tmp_path, keep_shared, keep_scratch, writefile_dag, capture
assert os.path.exists(scratch_file)
else:
assert not os.path.exists(scratch_file)

if keep_unitresults:
assert os.path.exists(unit_result_file)
else:
assert not os.path.exists(unit_result_file)
# check that our shared and scratch basedirs are left behind
assert shared.exists()
assert scratch.exists()
Expand All @@ -166,3 +178,136 @@ def test_protocoldag_missing_dependency_unit():
protocol_units=dependent_units, # Missing setup_unit!
transformation_key=None,
)


def test_execute_DAG_cached_unitresults(tmp_path):
"""Test that execute_DAG will re-run based on unitresults_basedir where only a terminal node is missing results."""

# Create a setup unit that other units depend on
setup_unit = WriterUnit(identity=0, name="setup")

# Create units that depend on the setup unit
dependent_units = [WriterUnit(identity=i, setup=setup_unit, name=f"cycle_{i}") for i in range(1, 4)]

dep_dag = gufe.ProtocolDAG(
protocol_units=dependent_units + [setup_unit],
transformation_key=None,
)

# run all unit_results
shared = pathlib.Path(tmp_path / "shared")
shared.mkdir(parents=True)

scratch = pathlib.Path(tmp_path / "scratch")
scratch.mkdir(parents=True)

unit_results_dir = pathlib.Path(tmp_path / "unitresults_cache")
protocol_result = execute_DAG(
dep_dag,
shared_basedir=shared,
scratch_basedir=scratch,
unitresults_basedir=unit_results_dir,
stderr_basedir=None,
stdout_basedir=None,
keep_shared=False,
keep_scratch=False,
keep_unitresults=True,
)

for pur in protocol_result.protocol_unit_results:
assert os.path.exists(os.path.join(unit_results_dir, f"unitresults_{dep_dag.key}", f"{str(pur.key)}.json"))

# choose a terminal result so that only one node is rerun
pur_to_corrupt = protocol_result.terminal_protocol_unit_results[0]

with open(
os.path.join(unit_results_dir, f"unitresults_{dep_dag.key}", f"{str(pur_to_corrupt.key)}.json"), "a"
) as f:
f.write("string that will break JSON.")

protocol_result_rerun = execute_DAG(
dep_dag,
shared_basedir=shared,
scratch_basedir=scratch,
unitresults_basedir=unit_results_dir,
stderr_basedir=None,
stdout_basedir=None,
keep_shared=False,
keep_scratch=False,
keep_unitresults=True,
)

assert protocol_result.protocol_units == protocol_result_rerun.protocol_units
# if the cache isn't used, these would be identical

rerun_keys = {r.key for r in protocol_result_rerun.protocol_unit_results}
original_keys = {r.key for r in protocol_result.protocol_unit_results}

# Only one result should differ (the corrupted one)
assert len(rerun_keys.symmetric_difference(original_keys)) == 2
assert len(rerun_keys.intersection(original_keys)) == len(protocol_result.protocol_unit_results) - 1

assert protocol_result_rerun.graph.edges == protocol_result.graph.edges


def test_get_valid_unit_results(tmp_path):
"""
Create a graph of dependencies that looks like this:
A<-B, B<-C, B<-D, B<-E, D<-F, E<-F
or read top-down:
A
B
C D E
F
"""

unit_A = WriterUnit(identity="A", name="unit_A")
unit_B = WriterUnit(identity="B", name="unit_B", needs=[unit_A])
unit_C = WriterUnit(identity="C", name="unit_C", needs=[unit_B])
unit_D = WriterUnit(identity="D", name="unit_D", needs=[unit_B])
unit_E = WriterUnit(identity="E", name="unit_E", needs=[unit_B])
unit_F = WriterUnit(identity="F", name="unit_F", needs=[unit_D, unit_E])

all_protocol_units = {unit_A, unit_B, unit_C, unit_D, unit_E, unit_F}
# Create units that depend on the setup unit
dep_dag = gufe.ProtocolDAG(
protocol_units=all_protocol_units,
transformation_key=None,
)
shared = pathlib.Path(tmp_path / "shared")
shared.mkdir(parents=True)

scratch = pathlib.Path(tmp_path / "scratch")
scratch.mkdir(parents=True)

unit_results_dir = pathlib.Path(tmp_path / "unitresults_cache")
protocol_result = execute_DAG(
dep_dag,
shared_basedir=shared,
scratch_basedir=scratch,
unitresults_basedir=unit_results_dir,
stderr_basedir=None,
stdout_basedir=None,
keep_shared=False,
keep_scratch=False,
keep_unitresults=True,
)
all_cached_unit_results = protocol_result.protocol_unit_results

# cache is empty, so nothing should be skipped
units_to_skip = protocoldag._get_valid_unit_results(protocoldag=dep_dag, unit_results=[])
assert len(units_to_skip.keys()) == 0

# all results are available, so everything should be skipped
units_to_skip = protocoldag._get_valid_unit_results(protocoldag=dep_dag, unit_results=all_cached_unit_results)
assert set(units_to_skip.keys()) == {u.key for u in all_protocol_units}

# drop the top-most unit, so nothing should be skipped
unit_results_drop_A = [u for u in all_cached_unit_results if u.name != "unit_A"]
units_to_skip = protocoldag._get_valid_unit_results(protocoldag=dep_dag, unit_results=unit_results_drop_A)
assert len(units_to_skip.keys()) == 0

# drop terminal nodes, so everything *but* the terminal nodes can be skipped
unit_results_drop_C_F = [u for u in all_cached_unit_results if u.name not in ("unit_C", "unit_F")]
units_to_skip = protocoldag._get_valid_unit_results(protocoldag=dep_dag, unit_results=unit_results_drop_C_F)
assert set(units_to_skip.keys()) == {u.key for u in (unit_A, unit_B, unit_D, unit_E)}
23 changes: 23 additions & 0 deletions news/resuming.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
**Added:**

* ``gufe.protocol.protocoldag.execute_DAG`` now has the ability to resume DAG execution by passing in a path for results cacheing into ``unitresults_basedir`` and setting ``keep_unitresults=True``.
Comment thread
atravitz marked this conversation as resolved.
Outdated

**Changed:**

* <news item>

**Deprecated:**

* <news item>

**Removed:**

* <news item>

**Fixed:**

* <news item>

**Security:**

* <news item>
Loading