Skip to content

Commit f74d35a

Browse files
add DatasetDict (#7)
* add DatasetDict * remove python bounds * remove huggingface channel * pin pyarrow to 6.0.0 * use == instead of = in CondaPkg * relax numpy * relax pillow
1 parent c59197a commit f74d35a

8 files changed

Lines changed: 146 additions & 15 deletions

File tree

CondaPkg.toml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
channels = ["conda-forge", "huggingface"]
1+
channels = ["conda-forge"]
22

33
[deps]
44
datasets = ">=2.7, <3"
5-
numpy = ">=1.23, <2"
6-
pillow = ">=9.2, <10"
7-
python = ">=3.6, <4"
5+
numpy = ">=1.20, <2"
6+
pillow = ">=9.1, <10"
7+
pyarrow = "==6.0.0"

src/HuggingFaceDatasets.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@ include("observation.jl")
1111
include("dataset.jl")
1212
export Dataset, set_transform!
1313

14+
include("datasetdict.jl")
15+
export DatasetDict
16+
1417
include("transforms.jl")
1518
export py2jl
1619

@@ -25,6 +28,8 @@ function load_dataset(args...; kws...)
2528
d = datasets.load_dataset(args...; kws...)
2629
if pyisinstance(d, datasets.Dataset)
2730
return Dataset(d)
31+
elseif pyisinstance(d, datasets.DatasetDict)
32+
return DatasetDict(d)
2833
else
2934
return d
3035
end

src/dataset.jl

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,41 @@
11
"""
2-
Dataset(dataset, transform = py2jl)
2+
Dataset(pydataset; transform = py2jl)
33
4-
A Julia wrapper around the python `datasets.Dataset` type.
5-
It is the return type of [`load_dataset`](@ref).
4+
A Julia wrapper around the objects of the python `datasets.Dataset` class.
65
76
The `transform` is applied after datasets' one.
8-
The [`py2jl`](@def) default converts python types to julia types.
7+
The [`py2jl`](@ref) default converts python types to julia types.
98
109
Provides:
1110
- 1-based indexing.
1211
- [`set_transform!`](@ref) julia method.
1312
- All python class' methods from `datasets.Dataset`.
13+
14+
See also [`load_dataset`](@ref) and [`DatasetDict`](@ref).
1415
"""
1516
mutable struct Dataset
1617
pyd::Py
1718
transform
18-
end
1919

20-
Dataset(pydataset::Py; transform = py2jl) = Dataset(pydataset, transform)
20+
function Dataset(pydataset::Py; transform = py2jl)
21+
@assert pyisinstance(pydataset, datasets.Dataset)
22+
return new(pydataset, transform)
23+
end
24+
end
2125

2226
function Base.getproperty(d::Dataset, s::Symbol)
2327
if s in fieldnames(Dataset)
2428
return getfield(d, s)
2529
else
2630
res = getproperty(getfield(d, :pyd), s)
2731
if pyisinstance(res, datasets.Dataset)
28-
return Dataset(res, d.transform)
32+
return Dataset(res; d.transform)
2933
else
30-
return res
34+
return res |> py2jl
3135
end
3236
end
3337
end
3438

35-
3639
Base.length(d::Dataset) = length(d.pyd)
3740

3841
Base.getindex(d::Dataset, ::Colon) = d[1:length(d)]

src/datasetdict.jl

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
"""
2+
DatasetDict(pydatasetdict::Py; transform = py2jl)
3+
4+
A `DatasetDict` is a dictionary of `Dataset`s. It is a wrapper around a `datasets.DatasetDict` object.
5+
6+
The `transform` is applied to each [`Dataset`](@ref).
7+
The [`py2jl`](@ref) default converts python types to julia types.
8+
9+
See also [`load_dataset`](@ref) and [`Dataset`](@ref).
10+
"""
11+
mutable struct DatasetDict
12+
pyd::Py
13+
transform
14+
15+
function DatasetDict(pydatasetdict::Py; transform = py2jl)
16+
@assert pyisinstance(pydatasetdict, datasets.DatasetDict)
17+
return new(pydatasetdict, transform)
18+
end
19+
end
20+
21+
function Base.getproperty(d::DatasetDict, s::Symbol)
22+
if s in fieldnames(DatasetDict)
23+
return getfield(d, s)
24+
else
25+
res = getproperty(getfield(d, :pyd), s)
26+
if pyisinstance(res, datasets.Dataset)
27+
return Dataset(res; d.transform)
28+
elseif pyisinstance(res, datasets.DatasetDict)
29+
return DatasetDict(res; d.transform)
30+
else
31+
return res |> py2jl
32+
end
33+
end
34+
end
35+
36+
Base.length(d::DatasetDict) = length(d.pyd)
37+
38+
function Base.getindex(d::DatasetDict, i::AbstractString)
39+
x = d.pyd[i]
40+
return Dataset(x; d.transform)
41+
end
42+
43+
function set_transform!(d::DatasetDict, transform)
44+
if transform === nothing
45+
d.transform = identity
46+
else
47+
d.transform = transform
48+
end
49+
end
50+

src/transforms.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,4 @@ function tojulia(x::Py)
1818
end
1919

2020
tojulia(x::PyList) = [py2jl(x) for x in x]
21-
tojulia(x::PyDict) = Dict(py2jl(k) => py2jl(v) for (k, v) in pairs(x))
21+
tojulia(x::PyDict) = Dict(py2jl(k) => py2jl(v) for (k, v) in pairs(x))

test/datasets.jl renamed to test/dataset.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,10 @@
5454
d.set_transform(pytr)
5555
@test d[1]["label"] == -7
5656
end
57+
58+
@testset "getproperty returns julia types" begin
59+
@test d.num_rows isa Int
60+
end
5761
end
5862

5963

test/datasetdict.jl

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
2+
@testset "MNIST" begin
3+
dd = load_dataset("mnist")
4+
5+
@testset "load_dataset" begin
6+
@test dd isa DatasetDict
7+
@test length(dd) == 2
8+
end
9+
10+
@testset "indexing with no transform" begin
11+
tr = dd.transform
12+
set_transform!(dd, identity)
13+
14+
@test_throws MethodError dd[1]
15+
@test dd["test"] isa Dataset
16+
d = dd["test"]
17+
@test pyisinstance(d[1], pytype(pydict()))
18+
@test d[1]["image"] isa Py
19+
@test d[1]["label"] isa Py
20+
@test pyisinstance(d[1]["label"], pytype(pyint()))
21+
@test py2jl(d[1]["label"]) == 7
22+
@test py2jl(d[2]["label"]) == 2
23+
24+
@test d[1:2] isa Py
25+
@test d[1:2]["image"] isa Py
26+
@test pyisinstance(d[1:2]["image"], pytype(pylist()))
27+
@test d[1:2]["label"] isa Py
28+
@test pyisinstance(d[1:2]["label"], pytype(pylist()))
29+
30+
set_transform!(dd, tr)
31+
end
32+
33+
@testset "indexing - py2jl" begin
34+
@test dd.transform === py2jl
35+
d = dd["test"]
36+
sample = d[1]
37+
@test sample isa Dict
38+
@test sample["label"] isa Int
39+
@test sample["label"] == 7
40+
@test sample["image"] isa Matrix{UInt8}
41+
@test size(sample["image"]) == (28, 28)
42+
43+
sample = d[1:2]
44+
@test sample isa Dict
45+
@test sample["image"] isa Vector{Matrix{UInt8}}
46+
@test size(sample["image"]) == (2,)
47+
@test sample["label"] isa Vector{Int}
48+
@test size(sample["label"]) == (2,)
49+
end
50+
51+
@testset "python transforms" begin
52+
@pyexec """
53+
def pytr(x):
54+
return {"label": [-l for l in x["label"]]}
55+
""" => pytr
56+
dd.set_transform(pytr)
57+
@test dd["test"][1]["label"] == -7
58+
end
59+
60+
@testset "getproperty returns julia types" begin
61+
@test dd.num_rows isa Dict{String, Int}
62+
@test dd.num_rows == Dict("test" => 10000, "train" => 60000)
63+
end
64+
end
65+

test/runtests.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,9 @@ using HuggingFaceDatasets, PythonCall, MLUtils
44
# using ImageShow, ImageInTerminal
55

66
@testset "dataset" begin
7-
include("datasets.jl")
7+
include("dataset.jl")
8+
end
9+
10+
@testset "datasetdict" begin
11+
include("datasetdict.jl")
812
end

0 commit comments

Comments
 (0)