Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/pystac/item.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,13 @@ def __init__(

self._collection: Collection | str | None = collection

@classmethod
def try_from(cls, data: dict[str, Any] | Item) -> Item:
if isinstance(data, Item):
return data
else:
return cls(**data)

@override
@classmethod
def from_dict(
Expand Down
50 changes: 47 additions & 3 deletions src/pystac/item_collection.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
from __future__ import annotations

import copy
import warnings
from collections.abc import Iterator
from pathlib import Path
from typing import Any, Literal, TypedDict, override

from typing_extensions import deprecated

from .container import Container
from .errors import STACTypeError
from .item import Item
from .reader import DEFAULT_READER, Reader
from .utils import make_absolute_href
Expand All @@ -16,9 +21,24 @@ class T_ItemCollection(TypedDict):


class ItemCollection:
def __init__(self, items: list[Item], **kwargs: Any):
self.items: list[Item] = items
def __init__(self, items: list[Item], root: Container | None = None, **kwargs: Any):
self.items: list[Item] = [Item.try_from(item) for item in items]

if "root" in kwargs:
raise KeyError(
"root is not a valid key word argument to ``ItemCollection()``. "
"If root is required try instantiating the items directly and "
"providing root."
)

extra_fields = kwargs.pop("extra_fields", {})
self.extra_fields: dict[str, Any] = kwargs
if extra_fields:
warnings.warn(
"Pass extra_fields entries as kwargs, "
"instead of in extra_fields dictionary."
)
self.extra_fields.update(extra_fields)

def __len__(self) -> int:
return len(self.items)
Expand All @@ -29,6 +49,16 @@ def __getitem__(self, index: int) -> Item:
def __iter__(self) -> Iterator[Item]:
return iter(self.items)

def __contains__(self, __x: Item) -> bool:
Comment thread
jsignell marked this conversation as resolved.
return __x in self.items

def __add__(self, other: Any) -> ItemCollection:
if not isinstance(other, ItemCollection):
return NotImplemented
Comment thread
gadomski marked this conversation as resolved.

combined = [*self.items, *other.items]
return ItemCollection(items=combined)

@override
def __repr__(self) -> str:
return f"ItemCollection({self.items})"
Expand All @@ -41,6 +71,14 @@ def to_dict(self) -> dict[str, Any]:
**data,
}

def clone(self) -> ItemCollection:
return copy.deepcopy(self)

@deprecated("Try `ItemCollection.from_dict` and handle any exceptions instead")
@staticmethod
def is_item_collection(data: dict[str, Any]) -> bool:
return data.get("type", "") == "FeatureCollection"

@classmethod
def try_from(cls, data: dict[str, Any] | ItemCollection) -> ItemCollection:
if isinstance(data, ItemCollection):
Expand All @@ -53,14 +91,20 @@ def from_dict(
cls,
data: dict[str, Any],
preserve_dict: bool = True,
root: Container | None = None,
) -> ItemCollection:
if data.get("type", "") != "FeatureCollection":
raise STACTypeError(data, cls)

if preserve_dict:
data = copy.deepcopy(data)

items = data.get("features", [])
extra_fields = {k: v for k, v in data.items() if k not in ("features", "type")}

return cls(items=[Item.from_dict(item) for item in items], **extra_fields)
return cls(
items=[Item.from_dict(item, root=root) for item in items], **extra_fields
)

@classmethod
def from_file(
Expand Down
25 changes: 19 additions & 6 deletions tests/v1/test_item_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,17 @@
from typing import Any, cast

import pytest
from unittest.mock import Mock

import pystac
from pystac import Item, StacIO
from pystac.item_collection import ItemCollection
from pystac.reader import DEFAULT_READER

from .utils import TestCases
from .utils.stac_io_mock import MockDefaultStacIO


pytestmark = pytest.mark.passing_v2

SIMPLE_ITEM = TestCases.get_path("data-files/examples/1.0.0-RC1/simple-item.json")
CORE_ITEM = TestCases.get_path("data-files/examples/1.0.0-RC1/core-item.json")
Expand All @@ -20,6 +24,16 @@
"data-files/item-collection/sample-item-collection.json"
)

class MockReader:
mock: Mock

def __init__(self) -> None:
self.mock = Mock()

def get_json(self, href: str) -> dict[str, Any]:
self.mock.read_json(href)
return DEFAULT_READER.get_json(href)


@pytest.fixture
def item_collection_dict() -> dict[str, Any]:
Expand Down Expand Up @@ -199,9 +213,8 @@ def test_from_dict_sets_root(item_collection_dict: dict[str, Any]) -> None:


def test_to_dict_does_not_read_root_link_of_items() -> None:
with MockDefaultStacIO() as mock_stac_io:
item_collection = pystac.ItemCollection.from_file(ITEM_COLLECTION)

item_collection.to_dict()
mock_reader = MockReader()
item_collection = ItemCollection.from_file(ITEM_COLLECTION, reader=mock_reader)
item_collection.to_dict()

assert mock_stac_io.mock.read_text.call_count == 1
assert mock_reader.mock.read_json.call_count == 1