-
Notifications
You must be signed in to change notification settings - Fork 17
Expand file tree
/
Copy pathchemeleon_fingerprint.py
More file actions
63 lines (57 loc) · 2.42 KB
/
chemeleon_fingerprint.py
File metadata and controls
63 lines (57 loc) · 2.42 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
# chemeleon_fingerprint.py
#
# this file contains the class CheMeleonFingerprint which can be instantiated
# and called to generate the CheMeleon learned embeddings for a list of SMILES
# strings and/or RDKit Mols. you may wish to simply copy or download this file directly for use,
# or adapt the code for your own purposes. No other files are required for it
# to work, though you must `pip install 'chemprop>=2.2.0'` for this to run.
#
# run `python chemeleon_fingerprint.py` for a quick usage demo, otherwise you
# should `import` the CheMeleonFingerprint class into your other code and use
# it there (following the example at the bottom of this file) to generate
# your learned fingerprints
from pathlib import Path
from urllib.request import urlretrieve
import numpy as np
import torch
from chemprop import featurizers, nn
from chemprop.data import BatchMolGraph
from chemprop.models import MPNN
from chemprop.nn import RegressionFFN
from rdkit.Chem import Mol, MolFromSmiles
class CheMeleonFingerprint:
def __init__(self, device: str | torch.device | None = None):
self.featurizer = featurizers.SimpleMoleculeMolGraphFeaturizer()
agg = nn.MeanAggregation()
ckpt_dir = Path().home() / ".chemprop"
ckpt_dir.mkdir(exist_ok=True)
mp_path = ckpt_dir / "chemeleon_mp.pt"
if not mp_path.exists():
urlretrieve(
r"https://zenodo.org/records/15460715/files/chemeleon_mp.pt",
mp_path,
)
chemeleon_mp = torch.load(mp_path, weights_only=True)
mp = nn.BondMessagePassing(**chemeleon_mp["hyper_parameters"])
mp.load_state_dict(chemeleon_mp["state_dict"])
self.model = MPNN(
message_passing=mp,
agg=agg,
predictor=RegressionFFN(input_dim=mp.output_dim), # not actually used
)
self.model.eval()
if device is not None:
self.model.to(device=device)
def __call__(self, molecules: list[str | Mol]) -> np.ndarray:
bmg = BatchMolGraph(
[
self.featurizer(MolFromSmiles(m) if isinstance(m, str) else m)
for m in molecules
]
)
bmg.to(device=self.model.device)
with torch.no_grad():
return self.model.fingerprint(bmg).numpy(force=True)
if __name__ == "__main__":
chemeleon_fingerprint = CheMeleonFingerprint()
chemeleon_fingerprint(["C", "CC", MolFromSmiles("CCC")])