Skip to content

Commit 5226695

Browse files
atorerocopybara-github
authored andcommitted
Support output_class in to_py.
Please note that it changes the logic of traversal if `output_class` is provided: now during previsit we do not create objects instances but rather store their classes and then create them when visiting ojects/lists. PiperOrigin-RevId: 891851157 Change-Id: Idc03bef7a1a2473ae382e146494c827f92f8bc8d
1 parent a1885c0 commit 5226695

11 files changed

Lines changed: 1392 additions & 181 deletions

File tree

py/koladata/base/py_conversions/dataclasses_util.cc

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -183,33 +183,6 @@ DataClassesUtil::CreateClassInstanceKwargs(
183183
return py_instance;
184184
}
185185

186-
absl::StatusOr<arolla::python::PyObjectPtr>
187-
DataClassesUtil::CreateClassInstanceArgs(
188-
PyObjectPtr absl_nonnull py_class,
189-
absl::Span<const arolla::python::PyObjectPtr absl_nonnull> args) {
190-
PyObjectPtr py_tuple = PyObjectPtr::Own(PyTuple_New(args.size()));
191-
if (py_tuple == nullptr) {
192-
return arolla::python::StatusCausedByPyErr(
193-
absl::StatusCode::kInternal,
194-
absl::StrFormat("could not create a new tuple of size %d",
195-
args.size()));
196-
}
197-
for (size_t i = 0; i < args.size(); ++i) {
198-
PyTuple_SET_ITEM(py_tuple.get(), i, Py_NewRef(args[i].get()));
199-
}
200-
201-
PyObjectPtr py_instance =
202-
PyObjectPtr::Own(PyObject_CallOneArg(py_class.get(), py_tuple.get()));
203-
204-
if (py_instance == nullptr) {
205-
return arolla::python::StatusWithRawPyErr(
206-
absl::StatusCode::kInvalidArgument,
207-
"could not create a new instance of the class");
208-
}
209-
210-
return py_instance;
211-
}
212-
213186
absl::StatusOr<arolla::python::PyObjectPtr>
214187
DataClassesUtil::GetSimpleNamespaceClass() {
215188
RETURN_IF_ERROR(InitFns());

py/koladata/base/py_conversions/dataclasses_util.h

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -92,12 +92,6 @@ class DataClassesUtil {
9292
absl::Span<const std::string> attr_names,
9393
absl::Span<const arolla::python::PyObjectPtr absl_nonnull> attr_values);
9494

95-
// Creates a class instance with the given args.
96-
// Returned value is owned by DataClassesUtil.
97-
absl::StatusOr<arolla::python::PyObjectPtr> CreateClassInstanceArgs(
98-
arolla::python::PyObjectPtr absl_nonnull py_class,
99-
absl::Span<const arolla::python::PyObjectPtr absl_nonnull> args);
100-
10195
// Returns a new reference to the SimpleNamespace class.
10296
absl::StatusOr<arolla::python::PyObjectPtr> GetSimpleNamespaceClass();
10397

py/koladata/base/py_conversions/dataclasses_util.py

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -136,13 +136,25 @@ def get_class_field_type(
136136
raise ValueError(f"field '{attr_name}' has unsupported type: {t}")
137137

138138
if origin_type := typing.get_origin(py_obj):
139-
# `list[Obj]`
139+
# `list[Obj] or `tuple[Obj, ...]`
140140
if (
141141
attr_name == '__items__'
142142
and isinstance(origin_type, type)
143143
and issubclass(origin_type, typing.Sequence)
144144
):
145-
return typing.get_args(py_obj)[0]
145+
args = typing.get_args(py_obj)
146+
if len(args) == 0:
147+
raise ValueError(
148+
'expected list/tuple/Sequence; got instead:'
149+
f' {typing.get_origin(py_obj)}'
150+
)
151+
if origin_type is tuple:
152+
if len(args) != 2 or args[1] is not Ellipsis:
153+
raise ValueError(
154+
'only tuple[T, ...] is supported; got instead:'
155+
f' {typing.get_origin(py_obj)}'
156+
)
157+
return args[0]
146158

147159
if isinstance(origin_type, type) and issubclass(
148160
origin_type, typing.Mapping
@@ -155,9 +167,13 @@ def get_class_field_type(
155167
args = typing.get_args(py_obj)
156168
if len(args) != 2:
157169
raise ValueError(
158-
f'Expected dict; got instead: {typing.get_origin(py_obj)}'
170+
f'expected dict; got instead: {typing.get_origin(py_obj)}'
159171
)
160172
return args[1]
173+
if attr_name == '__keys__' or attr_name == '__values__':
174+
raise ValueError(f'expected dict class; got instead: {py_obj}')
175+
if attr_name == '__items__':
176+
raise ValueError(f'expected list class; got instead: {py_obj}')
161177
raise ValueError(
162178
f'unsupported GenericAlias {py_obj} for attribute {attr_name}'
163179
)
@@ -173,11 +189,30 @@ def has_optional_field(
173189
attr_name: str,
174190
type_hints_cache: dict[_Type, dict[str, _Type]],
175191
) -> bool:
176-
"""Returns whether the given attribute exists and is optional."""
192+
"""Returns whether the given attribute is present and is optional.
193+
194+
If the attribute is not present, returns False.
195+
If the attribute is present but is not optional, raises ValueError.
196+
197+
Args:
198+
py_class: The class to inspect.
199+
attr_name: The name of the attribute to inspect.
200+
type_hints_cache: A cache of type hints for dataclasses.
201+
202+
Raises:
203+
ValueError: If the attribute is present but is not optional.
204+
205+
Returns:
206+
True if the given attribute exists and is optional, False otherwise.
207+
"""
177208
if not dataclasses.is_dataclass(py_class):
178209
return False
179210
field = _get_field_type_annotation(py_class, attr_name, type_hints_cache)
180-
return _get_underlying_optional_type(field) is not None
211+
if field is None:
212+
return False
213+
if _get_underlying_optional_type(field) is None:
214+
raise ValueError(f'field cannot have missing values: {attr_name}')
215+
return True
181216

182217

183218
_simple_namespace_class = types.SimpleNamespace

py/koladata/base/py_conversions/dataclasses_util_test.py

Lines changed: 47 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ class Obj2:
109109

110110
self.assertEqual(util.get_class_field_type(Obj2, 'a', False), Obj1)
111111
self.assertIsNone(util.get_class_field_type(Obj2, 'b', False))
112+
self.assertIsNone(util.get_class_field_type(Obj2, 'b', True))
112113

113114
def test_get_class_field_type_optional(self):
114115
util = testing_clib.DataClassesUtil()
@@ -121,18 +122,25 @@ class Obj2:
121122
self.assertEqual(util.get_class_field_type(Obj2, 'a', False), Obj1)
122123
self.assertEqual(util.get_class_field_type(Obj2, 'b', False), Obj1)
123124

124-
def test_get_class_field_list(self):
125+
def test_get_class_field_list_tuple(self):
125126
util = testing_clib.DataClassesUtil()
126127

127128
@dataclasses.dataclass
128129
class Obj2:
129130
a: list[Obj1]
131+
b: tuple[Obj1, ...]
130132

131133
list_type = util.get_class_field_type(Obj2, 'a', False)
134+
tuple_type = util.get_class_field_type(Obj2, 'b', False)
132135
self.assertEqual(list_type, list[Obj1])
136+
self.assertEqual(tuple_type, tuple[Obj1, ...])
133137
self.assertEqual(
134138
util.get_class_field_type(list_type, '__items__', False), Obj1
135139
)
140+
with self.assertRaisesRegex(ValueError, 'only tuple.T, .... is supported'):
141+
_ = util.get_class_field_type(tuple[Obj1, Obj1], '__items__', False)
142+
with self.assertRaisesRegex(ValueError, 'only tuple.T, .... is supported'):
143+
_ = util.get_class_field_type(tuple[Obj1], '__items__', False)
136144

137145
def test_get_class_field_dict(self):
138146
util = testing_clib.DataClassesUtil()
@@ -192,6 +200,21 @@ def test_get_class_field_type_errors(self):
192200
):
193201
_ = util.get_class_field_type(int, 'a', False)
194202

203+
with self.assertRaisesRegex(
204+
ValueError, 'expected dict class; got instead: list.int.'
205+
):
206+
_ = util.get_class_field_type(list[int], '__keys__', False)
207+
208+
with self.assertRaisesRegex(
209+
ValueError, 'expected dict class; got instead: list.int.'
210+
):
211+
_ = util.get_class_field_type(list[int], '__values__', False)
212+
213+
with self.assertRaisesRegex(
214+
ValueError, 'expected list class; got instead: dict.int. int.'
215+
):
216+
_ = util.get_class_field_type(dict[int, int], '__items__', False)
217+
195218
def test_has_optional_field(self):
196219
util = testing_clib.DataClassesUtil()
197220

@@ -205,17 +228,36 @@ class Obj2:
205228
bad_0: int | Any
206229
bad_1: int | float | None
207230
bad_2: None | int
231+
g: int = 1
208232

209233
self.assertTrue(util.has_optional_field(Obj2, 'a'))
210234
self.assertFalse(util.has_optional_field(Obj2, 'b'))
211235
self.assertTrue(util.has_optional_field(Obj2, 'c'))
212-
self.assertFalse(util.has_optional_field(Obj2, 'd'))
213-
self.assertFalse(util.has_optional_field(Obj2, 'e'))
214-
self.assertFalse(util.has_optional_field(Obj2, 'f'))
236+
with self.assertRaisesRegex(
237+
ValueError,
238+
'field cannot have missing values: d',
239+
):
240+
_ = util.has_optional_field(Obj2, 'd')
241+
with self.assertRaisesRegex(
242+
ValueError,
243+
'field cannot have missing values: e',
244+
):
245+
_ = util.has_optional_field(Obj2, 'e')
246+
with self.assertRaisesRegex(
247+
ValueError,
248+
'field cannot have missing values: f',
249+
):
250+
_ = util.has_optional_field(Obj2, 'f')
251+
with self.assertRaisesRegex(
252+
ValueError,
253+
'field cannot have missing values: g',
254+
):
255+
_ = util.has_optional_field(Obj2, 'g')
215256
self.assertFalse(util.has_optional_field(int, 'non_existent_field'))
216257
with self.assertRaisesRegex(
217258
ValueError,
218-
'only unions `SomeType | None` are supported ; got instead: int | Any',
259+
'only unions `SomeType | None` are supported ; got instead: int |'
260+
' typing.Any',
219261
):
220262
_ = util.has_optional_field(Obj2, 'bad_0')
221263

@@ -237,13 +279,6 @@ def test_create_class_instance_kwargs(self):
237279
obj = util.create_class_instance_kwargs(Obj1, ['a', 'b'], [123, 'abc'])
238280
self.assertEqual(obj, Obj1(a=123, b='abc'))
239281

240-
def test_create_class_instance_args(self):
241-
util = testing_clib.DataClassesUtil()
242-
obj = util.create_class_instance_args(list, [1, 2, 3])
243-
self.assertEqual(obj, list([1, 2, 3]))
244-
obj2 = util.create_class_instance_args(tuple, [1, 2, 3])
245-
self.assertEqual(obj2, tuple([1, 2, 3]))
246-
247282
def test_get_simple_namespace_class(self):
248283
util = testing_clib.DataClassesUtil()
249284
simple_namespace_class = util.get_simple_namespace_class()

py/koladata/base/py_conversions/testing_clib.cc

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -85,21 +85,6 @@ PYBIND11_MODULE(testing_clib, m) {
8585
attr_names_vec, attr_values_ptrs));
8686
return py::reinterpret_steal<py::object>(res.release());
8787
})
88-
.def("create_class_instance_args",
89-
[](DataClassesUtil& self, py::handle py_class,
90-
absl::Span<const py::handle> attr_values) -> py::object {
91-
std::vector<arolla::python::PyObjectPtr> attr_values_ptrs;
92-
for (const auto& attr_value : attr_values) {
93-
attr_values_ptrs.push_back(
94-
arolla::python::PyObjectPtr::NewRef(attr_value.ptr()));
95-
}
96-
97-
auto res = arolla::python::pybind11_unstatus_or(
98-
self.CreateClassInstanceArgs(
99-
arolla::python::PyObjectPtr::NewRef(py_class.ptr()),
100-
attr_values_ptrs));
101-
return py::reinterpret_steal<py::object>(res.release());
102-
})
10388
.def("get_simple_namespace_class",
10489
[](DataClassesUtil& self) -> py::object {
10590
arolla::python::PyObjectPtr py_obj =

0 commit comments

Comments
 (0)