Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion src/fairchem/core/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,13 @@ def decorator(func):
@wraps(func)
def cls_method(self, *args, **kwargs):
f = func
if self.regress_forces and not getattr(self, "direct_forces", 0):
if hasattr(self, "regress_config"):
regress_forces = self.regress_config.forces
direct_forces = self.regress_config.direct_forces
else:
regress_forces = self.regress_forces
direct_forces = getattr(self, "direct_forces", 0)
if regress_forces and not direct_forces:
f = dec(func)
return f(self, *args, **kwargs)

Expand Down
11 changes: 3 additions & 8 deletions src/fairchem/core/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,13 +161,6 @@ def tasks(self) -> dict[str, Task]:
"""
return self._tasks

@property
def direct_forces(self) -> bool:
"""
Whether this model uses direct force prediction.
"""
return getattr(self.backbone, "direct_forces", False)

@property
def dataset_to_tasks(self) -> dict[str, list]:
"""
Expand All @@ -191,8 +184,10 @@ def _validate_task_compatibility(self, task: Task) -> None:
"""
derivative_properties = ("forces", "stress", "hessian")

backbone_regress_config = getattr(self.backbone, "regress_config", None)
if (
self.direct_forces
backbone_regress_config is not None
and backbone_regress_config.direct_forces
and task.inference_only
and task.property in derivative_properties
):
Expand Down
22 changes: 1 addition & 21 deletions src/fairchem/core/models/uma/escn_md.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,18 +512,6 @@ def __init__(
)
self.register_buffer("coefficient_index", coefficient_index, persistent=False)

@property # deprecate this
def direct_forces(self) -> bool:
return self.regress_config.direct_forces

@property # deprecate this
def regress_forces(self) -> bool:
return self.regress_config.forces

@property # deprecate this
def regress_stress(self) -> bool:
return self.regress_config.stress

def balance_channels(
self,
x_message_prime: torch.Tensor,
Expand Down Expand Up @@ -903,7 +891,7 @@ def get_default_untrained_tasks(
stress computation requires energy-conserving force computation.
"""
# Direct force models can't compute stress via autograd
if self.direct_forces:
if self.regress_config.direct_forces:
return []

tasks = []
Expand Down Expand Up @@ -1159,14 +1147,6 @@ def __init__(
backbone.force_block = None
self.regress_config = backbone.regress_config

@property
def regress_forces(self) -> bool:
return self.regress_config.forces

@property
def regress_stress(self) -> bool:
return self.regress_config.stress

@conditional_grad(torch.enable_grad())
def forward(
self, data: AtomicData, emb: dict[str, torch.Tensor]
Expand Down
16 changes: 0 additions & 16 deletions src/fairchem/core/models/uma/escn_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,14 +265,6 @@ def __init__(
self.merged_on_dataset = None
self.non_merged_dataset_names: list[str] = []

@property
def regress_forces(self) -> bool:
return self.regress_config.forces

@property
def regress_stress(self) -> bool:
return self.regress_config.stress

@staticmethod
def _build_expert_mapping(
dataset_names: list[str] | None,
Expand Down Expand Up @@ -424,14 +416,6 @@ def __init__(
# keep track if this head has been merged or not
self.merged_on_dataset = None

@property
def regress_forces(self) -> bool:
return self.regress_config.forces

@property
def regress_stress(self) -> bool:
return self.regress_config.stress

def merge_MOLE_model(self, data):
self.merged_on_dataset = data.dataset[0]
self.non_merged_dataset_names = [
Expand Down
12 changes: 6 additions & 6 deletions src/fairchem/core/units/mlip_unit/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,16 +196,12 @@ def __init__(
self.assert_on_nans = assert_on_nans
self._warned_upcast = False

if self.model.module.direct_forces:
if self.model.module.backbone.regress_config.direct_forces:
logging.warning(
"This is a direct-force model. Direct force predictions may lead to "
"discontinuities in the potential energy surface and energy conservation errors."
)

@property
def direct_forces(self) -> bool:
return self.model.module.direct_forces

@property
def dataset_to_tasks(self) -> dict[str, list]:
return self.model.module.dataset_to_tasks
Expand Down Expand Up @@ -478,7 +474,11 @@ def _run_inference(self, data: AtomicData, undo_refs: bool) -> dict:
"""
Execute model inference.
"""
inference_context = torch.no_grad() if self.direct_forces else nullcontext()
inference_context = (
torch.no_grad()
if self.model.module.backbone.regress_config.direct_forces
else nullcontext()
)
tf32_context = (
tf32_context_manager() if self.inference_settings.tf32 else nullcontext()
)
Expand Down
2 changes: 1 addition & 1 deletion tests/core/calculate/test_ase_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def test_calculator_setup(all_calculators):
datasets = list(calc.predictor.dataset_to_tasks.keys())

# all conservative UMA checkpoints should support E/F/S!
if not calc.predictor.direct_forces and (
if not calc.predictor.model.module.backbone.regress_config.direct_forces and (
len(datasets) > 1 or (calc.task_name != "omol" and calc.task_name != "odac")
):
print(len(datasets), calc.task_name)
Expand Down
7 changes: 0 additions & 7 deletions tests/core/models/test_inference_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,13 +166,6 @@ def test_dataset_to_tasks_raises_before_setup(self, mock_hydra_model):
with pytest.raises(RuntimeError, match="setup_tasks"):
_ = mock_hydra_model.dataset_to_tasks

def test_direct_forces_property(self, mock_hydra_model):
"""Test direct_forces property delegates to backbone."""
assert mock_hydra_model.direct_forces is False

mock_hydra_model.backbone.direct_forces = True
assert mock_hydra_model.direct_forces is True


class TestBackboneInterface:
"""Tests for backbone interface method implementations."""
Expand Down
Loading