Skip to content

Commit 7604fec

Browse files
committed
Set pythonic default model dir
1 parent 8f3e96e commit 7604fec

2 files changed

Lines changed: 3 additions & 3 deletions

File tree

trackastra/model/model_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ def from_pretrained(
172172
Args:
173173
name: Name of pretrained model (e.g. "general_2d").
174174
device: Device to run model on ("cuda", "mps", "cpu", "automatic" or None).
175-
download_dir: Directory to download model to (defaults to ~/.cache/trackastra).
175+
download_dir: Directory to download model. Default handled by platformdirs.
176176
177177
Returns:
178178
Trackastra model instance.

trackastra/model/pretrained.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22
import shutil
33
import tempfile
44
import zipfile
5-
from importlib.resources import files
65
from pathlib import Path
76

87
import requests
8+
from platformdirs import user_data_dir
99
from tqdm import tqdm
1010

1111
logger = logging.getLogger(__name__)
@@ -60,7 +60,7 @@ def download(url: str, fname: Path):
6060
def download_pretrained(name: str, download_dir: Path | None = None):
6161
# TODO make safe, introduce versioning
6262
if download_dir is None:
63-
download_dir = files("trackastra").joinpath(".models")
63+
download_dir = Path(user_data_dir("trackastra")) / "models"
6464
else:
6565
download_dir = Path(download_dir)
6666

0 commit comments

Comments
 (0)