Skip to content

Commit e79a66b

Browse files
committed
Fixes issue in which dataloader did not accept subscripted generics as the output type
Note we only allow dict[str, ...], nothing else. Also updates tests to work with 3.8 -- we actually remove a few as 3.8 is past EOL now.
1 parent 2c8c3d7 commit e79a66b

File tree

3 files changed

+46
-10
lines changed

3 files changed

+46
-10
lines changed

hamilton/function_modifiers/adapters.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
ParametrizedDependency,
1818
UpstreamDependency,
1919
)
20+
from hamilton.htypes import custom_subclass_check
2021
from hamilton.io.data_adapters import AdapterCommon, DataLoader, DataSaver
2122
from hamilton.node import DependencyType
2223
from hamilton.registry import LOADER_REGISTRY, SAVER_REGISTRY
@@ -748,10 +749,17 @@ def validate(self, fn: Callable):
748749
)
749750
# check that the second is a dict
750751
second_arg = typing_inspect.get_args(return_annotation)[1]
751-
if not (second_arg == dict or second_arg == Dict):
752+
if not (custom_subclass_check(second_arg, dict)):
752753
raise InvalidDecoratorException(
753754
f"Function: {fn.__qualname__} must return a tuple of type (SOME_TYPE, dict)."
754755
)
756+
second_arg_params = typing_inspect.get_args(second_arg)
757+
if (
758+
len(second_arg_params) > 0 and not second_arg_params[0] == str
759+
): # metadata must have string keys
760+
raise InvalidDecoratorException(
761+
f"Function: {fn.__qualname__} must return a tuple of type (SOME_TYPE, dict[str, ...]). Instead got (SOME_TYPE, dict[{second_arg_params[0]}, ...]"
762+
)
755763

756764
def generate_nodes(self, fn: Callable, config) -> List[node.Node]:
757765
"""Generates two nodes. We have to add tags appropriately.

tests/function_modifiers/test_adapters.py

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -697,19 +697,23 @@ def fn(data1: dict, data2: dict) -> dict:
697697
import sys
698698

699699
if sys.version_info >= (3, 9):
700-
dl_type = tuple[int, dict]
701-
ds_type = dict
700+
dict_ = dict
701+
tuple_ = tuple
702702
else:
703-
dl_type = Tuple[int, Dict]
704-
ds_type = Dict
703+
dict_ = Dict
704+
tuple_ = Tuple
705705

706706

707707
# Mock functions for dataloader & datasaver testing
708-
def correct_dl_function(foo: int) -> dl_type:
708+
def correct_dl_function(foo: int) -> tuple_[int, dict_]:
709709
return 1, {}
710710

711711

712-
def correct_ds_function(data: float) -> ds_type:
712+
def correct_dl_function_with_subscripts(foo: int) -> tuple_[Dict[str, int], Dict[str, str]]:
713+
return {"a": 1}, {"b": "c"}
714+
715+
716+
def correct_ds_function(data: float) -> dict_:
713717
return {}
714718

715719

@@ -721,19 +725,24 @@ def non_tuple_return_function() -> int:
721725
return 1
722726

723727

724-
def incorrect_tuple_length_function() -> Tuple[int]:
728+
def incorrect_tuple_length_function() -> tuple_[int]:
725729
return (1,)
726730

727731

728-
def incorrect_second_element_function() -> Tuple[int, list]:
732+
def incorrect_second_element_function() -> tuple_[int, list]:
729733
return 1, []
730734

731735

736+
def incorrect_dict_subscript() -> tuple_[int, Dict[int, str]]:
737+
return 1, {1: "a"}
738+
739+
732740
incorrect_funcs = [
733741
no_return_annotation_function,
734742
non_tuple_return_function,
735743
incorrect_tuple_length_function,
736744
incorrect_second_element_function,
745+
incorrect_dict_subscript,
737746
]
738747

739748

@@ -744,6 +753,10 @@ def test_dl_validate_incorrect_functions(func):
744753
dl.validate(func)
745754

746755

756+
@pytest.mark.skipif(
757+
sys.version_info < (3, 9, 0),
758+
reason="dataloader not guarenteed to work with subscripted tuples on 3.8",
759+
)
747760
def test_dl_validate_with_correct_function():
748761
dl = dataloader()
749762
try:
@@ -753,6 +766,15 @@ def test_dl_validate_with_correct_function():
753766
pytest.fail("validate() raised InvalidDecoratorException unexpectedly!")
754767

755768

769+
def test_dl_validate_with_subscripts():
770+
dl = dataloader()
771+
try:
772+
dl.validate(correct_dl_function_with_subscripts)
773+
except InvalidDecoratorException:
774+
# i.e. fail the test if there's an error
775+
pytest.fail("validate() raised InvalidDecoratorException unexpectedly!")
776+
777+
756778
def test_ds_validate_with_correct_function():
757779
dl = datasaver()
758780
try:

tests/resources/nodes_with_future_annotation.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,16 @@
11
from __future__ import annotations
22

3+
import sys
4+
from typing import List, Tuple
5+
36
from hamilton.function_modifiers import dataloader
47
from hamilton.htypes import Collect, Parallelizable
58

69
"""Tests future annotations with common node types"""
710

11+
tuple_ = Tuple if sys.version_info < (3, 9, 0) else tuple
12+
list_ = List if sys.version_info < (3, 9, 0) else list
13+
814

915
def parallelized() -> Parallelizable[int]:
1016
yield 1
@@ -21,6 +27,6 @@ def collected(standard: Collect[int]) -> int:
2127

2228

2329
@dataloader()
24-
def sample_dataloader() -> tuple[list[str], dict]:
30+
def sample_dataloader() -> tuple_[list_[str], dict]:
2531
"""Grouping here as the rest test annotations"""
2632
return ["a", "b", "c"], {}

0 commit comments

Comments
 (0)