Skip to content

Commit 3aa8b80

Browse files
committed
fix: support torch 2.x (closes #167)
**`setup.py`** - Loosen torch pin from ==1.13.* to >=1.13,<3 - Loosen torchvision pin from ==0.14.* to >=0.14,<1 **`src/stepcount/models.py`** - Use keyword arg map_location= in torch.load() for torch 2.x compat
1 parent eba52d7 commit 3aa8b80

2 files changed

Lines changed: 3 additions & 3 deletions

File tree

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,8 @@ def get_string(string, rel_path="src/stepcount/__init__.py"):
5959
"scikit-learn==1.1.1",
6060
"imbalanced-learn==0.9.1",
6161
"hmmlearn==0.3.*",
62-
"torch==1.13.*",
63-
"torchvision==0.14.*",
62+
"torch>=1.13,<3",
63+
"torchvision>=0.14,<1",
6464
"transforms3d==0.4.*",
6565
"numba==0.58.*",
6666
"matplotlib==3.7.*",

src/stepcount/models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -441,7 +441,7 @@ def fit(self, X, Y, groups=None):
441441
model.to(self.device)
442442

443443
sslmodel.train(model, train_loader, val_loader, self.device, class_weights, weights_path=self.weights_path)
444-
model.load_state_dict(torch.load(self.weights_path, self.device))
444+
model.load_state_dict(torch.load(self.weights_path, map_location=self.device))
445445

446446
if self.verbose:
447447
print('Training HMM')

0 commit comments

Comments
 (0)