Skip to content

Commit 73b7087

Browse files
committed
added a test for incremental ATM simulations
1 parent af2e534 commit 73b7087

1 file changed

Lines changed: 42 additions & 12 deletions

File tree

tests/test_atm.py

Lines changed: 42 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -39,25 +39,55 @@ def _test_production_xtc(tmp_path):
3939
os.path.join(tmp_path, "QB_A08_A07_completed"),
4040
)
4141

42-
with open(
43-
os.path.join(tmp_path, "QB_A08_A07_completed", "QB_A08_A07_asyncre.yaml"), "r"
44-
) as f:
45-
lines = yaml.safe_load(f)
46-
lines["XTC_TRAJECTORY"] = True
47-
with open(
48-
os.path.join(tmp_path, "QB_A08_A07_completed", "QB_A08_A07_asyncre.yaml"), "w"
49-
) as f:
50-
yaml.dump(lines, f)
51-
52-
rbfe_production(
53-
os.path.join(tmp_path, "QB_A08_A07_completed", "QB_A08_A07_asyncre.yaml")
42+
configfile = os.path.join(
43+
tmp_path, "QB_A08_A07_completed", "QB_A08_A07_asyncre.yaml"
5444
)
45+
with open(configfile, "r") as f:
46+
config = yaml.safe_load(f)
47+
config["XTC_TRAJECTORY"] = True
48+
with open(configfile, "w") as f:
49+
yaml.dump(config, f)
50+
51+
rbfe_production(configfile)
5552
for i in range(4):
5653
assert os.path.exists(
5754
os.path.join(tmp_path, "QB_A08_A07_completed", f"r{i}", "QB_A08_A07.xtc")
5855
)
5956

6057

58+
def _test_production_incremental(tmp_path):
59+
from atm.rbfe_production import rbfe_production
60+
import yaml
61+
62+
shutil.copytree(
63+
os.path.join(curr_dir, "QB_A08_A07_completed"),
64+
os.path.join(tmp_path, "QB_A08_A07_completed"),
65+
)
66+
67+
configfile = os.path.join(
68+
tmp_path, "QB_A08_A07_completed", "QB_A08_A07_asyncre.yaml"
69+
)
70+
with open(configfile, "r") as f:
71+
config = yaml.safe_load(f)
72+
config["MAX_SAMPLES"] = "+2"
73+
with open(configfile, "w") as f:
74+
yaml.dump(config, f)
75+
76+
rbfe_production(configfile)
77+
for i in range(4):
78+
assert os.path.exists(
79+
os.path.join(tmp_path, "QB_A08_A07_completed", f"r{i}", "QB_A08_A07.dcd")
80+
)
81+
with open(
82+
os.path.join(tmp_path, "QB_A08_A07_completed", "starting_sample"), "r"
83+
) as f:
84+
starting_sample = int(f.read().strip())
85+
assert starting_sample == 1
86+
with open(os.path.join(tmp_path, "QB_A08_A07_completed", "progress"), "r") as f:
87+
progress = float(f.read().strip())
88+
assert progress == 0.5
89+
90+
6191
def _test_uwham_analysis(tmp_path):
6292
from atm.uwham import calculate_uwham
6393

0 commit comments

Comments
 (0)