Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 25 additions & 3 deletions scispacy/file_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,6 @@ def url_to_filename(url: str, etag: Optional[str] = None) -> str:
If `etag` is specified, append its hash to the url's, delimited
by a period.
"""

last_part = url.split("/")[-1]
url_bytes = url.encode("utf-8")
url_hash = sha256(url_bytes)
filename = url_hash.hexdigest()
Expand All @@ -67,7 +65,13 @@ def url_to_filename(url: str, etag: Optional[str] = None) -> str:
etag_hash = sha256(etag_bytes)
filename += "." + etag_hash.hexdigest()

filename += "." + last_part
# Only keep the file extension to stay within filesystem NAME_MAX
# limits (e.g. 143 bytes on eCryptfs).
url_path = urlparse(url).path
_, ext = os.path.splitext(os.path.basename(url_path))
if ext:
filename += ext

return filename


Expand Down Expand Up @@ -106,6 +110,19 @@ def http_get(url: str, temp_file: IO) -> None:
pbar.close()


def _find_legacy_cache_path(
url: str, etag: Optional[str], cache_dir: str
) -> Optional[str]:
"""Check for a cached file using the old naming scheme (full trailing URL component)."""
last_part = os.path.basename(urlparse(url).path)
filename = sha256(url.encode("utf-8")).hexdigest()
if etag:
filename += "." + sha256(etag.encode("utf-8")).hexdigest()
filename += "." + last_part
path = os.path.join(cache_dir, filename)
return path if os.path.exists(path) else None


def get_from_cache(url: str, cache_dir: Optional[str] = None) -> str:
"""
Given a URL, look for the corresponding dataset in the local cache.
Expand All @@ -131,6 +148,11 @@ def get_from_cache(url: str, cache_dir: Optional[str] = None) -> str:
cache_path = os.path.join(cache_dir, filename)

if not os.path.exists(cache_path):
# Check for files cached under the old naming scheme, which appended
# the full trailing URL component instead of just the extension.
legacy_path = _find_legacy_cache_path(url, etag, cache_dir)
if legacy_path is not None:
return legacy_path
# Download to temporary file, then copy to cache dir once finished.
# Otherwise you get corrupt cache entries if the download gets interrupted.
with tempfile.NamedTemporaryFile() as temp_file: # type: IO
Expand Down
35 changes: 35 additions & 0 deletions tests/test_file_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,38 @@ def test_url_to_filename_with_etags_eliminates_quotes(self):
back_to_url, etag = filename_to_url(filename, cache_dir=self.TEST_DIR)
assert back_to_url == url
assert etag == "mytag"

def test_url_to_filename_stays_within_name_max(self):
# eCryptfs limits filenames to 143 bytes; make sure we stay under that
# even with a long URL and etag.
long_url = "https://s3-us-west-2.amazonaws.com/bucket/" + "a" * 300 + "/file.npz"
long_etag = "x" * 300
filename = url_to_filename(long_url, etag=long_etag)
assert len(filename) <= 143
assert filename.endswith(".npz")
# also without etag
filename_no_etag = url_to_filename(long_url)
assert len(filename_no_etag) <= 143

def test_url_to_filename_no_extension(self):
# URLs without a file extension should still produce a valid filename
filename = url_to_filename("https://example.com/data/somefile")
assert len(filename) == 64 # just the sha256 hex digest
assert "." not in filename

def test_legacy_cache_files_still_found(self):
from scispacy.file_cache import _find_legacy_cache_path

url = "https://example.com/data/model.bin"
etag = "some-etag"
# Create a file with the old naming scheme
old_filename = (
"b6794c9b5101703824700fe53156f28b7c5c2ef432467c1399f30142e7db9977"
".700ccb3dacaae313fbd70ea50e5646377634d6f144ea63acaf30d8e7ecf1cc4e"
".model.bin"
)
old_path = os.path.join(self.TEST_DIR, old_filename)
pathlib.Path(old_path).touch()

found = _find_legacy_cache_path(url, etag, self.TEST_DIR)
assert found == old_path