diff --git a/src/pystac/item.py b/src/pystac/item.py index 107556048..7107de576 100644 --- a/src/pystac/item.py +++ b/src/pystac/item.py @@ -89,6 +89,13 @@ def __init__( self._collection: Collection | str | None = collection + @classmethod + def try_from(cls, data: dict[str, Any] | Item) -> Item: + if isinstance(data, Item): + return data + else: + return cls(**data) + @override @classmethod def from_dict( diff --git a/src/pystac/item_collection.py b/src/pystac/item_collection.py index f47e5c10e..1ec2e4e8d 100644 --- a/src/pystac/item_collection.py +++ b/src/pystac/item_collection.py @@ -1,10 +1,15 @@ from __future__ import annotations import copy +import warnings from collections.abc import Iterator from pathlib import Path from typing import Any, Literal, TypedDict, override +from typing_extensions import deprecated + +from .container import Container +from .errors import STACTypeError from .item import Item from .reader import DEFAULT_READER, Reader from .utils import make_absolute_href @@ -16,9 +21,24 @@ class T_ItemCollection(TypedDict): class ItemCollection: - def __init__(self, items: list[Item], **kwargs: Any): - self.items: list[Item] = items + def __init__(self, items: list[Item], root: Container | None = None, **kwargs: Any): + self.items: list[Item] = [Item.try_from(item) for item in items] + + if "root" in kwargs: + raise KeyError( + "root is not a valid key word argument to ``ItemCollection()``. " + "If root is required try instantiating the items directly and " + "providing root." + ) + + extra_fields = kwargs.pop("extra_fields", {}) self.extra_fields: dict[str, Any] = kwargs + if extra_fields: + warnings.warn( + "Pass extra_fields entries as kwargs, " + "instead of in extra_fields dictionary." + ) + self.extra_fields.update(extra_fields) def __len__(self) -> int: return len(self.items) @@ -29,6 +49,16 @@ def __getitem__(self, index: int) -> Item: def __iter__(self) -> Iterator[Item]: return iter(self.items) + def __contains__(self, __x: Item) -> bool: + return __x in self.items + + def __add__(self, other: Any) -> ItemCollection: + if not isinstance(other, ItemCollection): + return NotImplemented + + combined = [*self.items, *other.items] + return ItemCollection(items=combined) + @override def __repr__(self) -> str: return f"ItemCollection({self.items})" @@ -41,6 +71,14 @@ def to_dict(self) -> dict[str, Any]: **data, } + def clone(self) -> ItemCollection: + return copy.deepcopy(self) + + @deprecated("Try `ItemCollection.from_dict` and handle any exceptions instead") + @staticmethod + def is_item_collection(data: dict[str, Any]) -> bool: + return data.get("type", "") == "FeatureCollection" + @classmethod def try_from(cls, data: dict[str, Any] | ItemCollection) -> ItemCollection: if isinstance(data, ItemCollection): @@ -53,14 +91,20 @@ def from_dict( cls, data: dict[str, Any], preserve_dict: bool = True, + root: Container | None = None, ) -> ItemCollection: + if data.get("type", "") != "FeatureCollection": + raise STACTypeError(data, cls) + if preserve_dict: data = copy.deepcopy(data) items = data.get("features", []) extra_fields = {k: v for k, v in data.items() if k not in ("features", "type")} - return cls(items=[Item.from_dict(item) for item in items], **extra_fields) + return cls( + items=[Item.from_dict(item, root=root) for item in items], **extra_fields + ) @classmethod def from_file( diff --git a/tests/v1/test_item_collection.py b/tests/v1/test_item_collection.py index 28af1c9e6..4b4c73bcb 100644 --- a/tests/v1/test_item_collection.py +++ b/tests/v1/test_item_collection.py @@ -4,13 +4,17 @@ from typing import Any, cast import pytest +from unittest.mock import Mock import pystac from pystac import Item, StacIO from pystac.item_collection import ItemCollection +from pystac.reader import DEFAULT_READER from .utils import TestCases -from .utils.stac_io_mock import MockDefaultStacIO + + +pytestmark = pytest.mark.passing_v2 SIMPLE_ITEM = TestCases.get_path("data-files/examples/1.0.0-RC1/simple-item.json") CORE_ITEM = TestCases.get_path("data-files/examples/1.0.0-RC1/core-item.json") @@ -20,6 +24,16 @@ "data-files/item-collection/sample-item-collection.json" ) +class MockReader: + mock: Mock + + def __init__(self) -> None: + self.mock = Mock() + + def get_json(self, href: str) -> dict[str, Any]: + self.mock.read_json(href) + return DEFAULT_READER.get_json(href) + @pytest.fixture def item_collection_dict() -> dict[str, Any]: @@ -199,9 +213,8 @@ def test_from_dict_sets_root(item_collection_dict: dict[str, Any]) -> None: def test_to_dict_does_not_read_root_link_of_items() -> None: - with MockDefaultStacIO() as mock_stac_io: - item_collection = pystac.ItemCollection.from_file(ITEM_COLLECTION) - - item_collection.to_dict() + mock_reader = MockReader() + item_collection = ItemCollection.from_file(ITEM_COLLECTION, reader=mock_reader) + item_collection.to_dict() - assert mock_stac_io.mock.read_text.call_count == 1 + assert mock_reader.mock.read_json.call_count == 1