|
| 1 | + |
| 2 | +@testset "MNIST" begin |
| 3 | + dd = load_dataset("mnist") |
| 4 | + |
| 5 | + @testset "load_dataset" begin |
| 6 | + @test dd isa DatasetDict |
| 7 | + @test length(dd) == 2 |
| 8 | + end |
| 9 | + |
| 10 | + @testset "indexing with no transform" begin |
| 11 | + tr = dd.transform |
| 12 | + set_transform!(dd, identity) |
| 13 | + |
| 14 | + @test_throws MethodError dd[1] |
| 15 | + @test dd["test"] isa Dataset |
| 16 | + d = dd["test"] |
| 17 | + @test pyisinstance(d[1], pytype(pydict())) |
| 18 | + @test d[1]["image"] isa Py |
| 19 | + @test d[1]["label"] isa Py |
| 20 | + @test pyisinstance(d[1]["label"], pytype(pyint())) |
| 21 | + @test py2jl(d[1]["label"]) == 7 |
| 22 | + @test py2jl(d[2]["label"]) == 2 |
| 23 | + |
| 24 | + @test d[1:2] isa Py |
| 25 | + @test d[1:2]["image"] isa Py |
| 26 | + @test pyisinstance(d[1:2]["image"], pytype(pylist())) |
| 27 | + @test d[1:2]["label"] isa Py |
| 28 | + @test pyisinstance(d[1:2]["label"], pytype(pylist())) |
| 29 | + |
| 30 | + set_transform!(dd, tr) |
| 31 | + end |
| 32 | + |
| 33 | + @testset "indexing - py2jl" begin |
| 34 | + @test dd.transform === py2jl |
| 35 | + d = dd["test"] |
| 36 | + sample = d[1] |
| 37 | + @test sample isa Dict |
| 38 | + @test sample["label"] isa Int |
| 39 | + @test sample["label"] == 7 |
| 40 | + @test sample["image"] isa Matrix{UInt8} |
| 41 | + @test size(sample["image"]) == (28, 28) |
| 42 | + |
| 43 | + sample = d[1:2] |
| 44 | + @test sample isa Dict |
| 45 | + @test sample["image"] isa Vector{Matrix{UInt8}} |
| 46 | + @test size(sample["image"]) == (2,) |
| 47 | + @test sample["label"] isa Vector{Int} |
| 48 | + @test size(sample["label"]) == (2,) |
| 49 | + end |
| 50 | + |
| 51 | + @testset "python transforms" begin |
| 52 | + @pyexec """ |
| 53 | + def pytr(x): |
| 54 | + return {"label": [-l for l in x["label"]]} |
| 55 | + """ => pytr |
| 56 | + dd.set_transform(pytr) |
| 57 | + @test dd["test"][1]["label"] == -7 |
| 58 | + end |
| 59 | + |
| 60 | + @testset "getproperty returns julia types" begin |
| 61 | + @test dd.num_rows isa Dict{String, Int} |
| 62 | + @test dd.num_rows == Dict("test" => 10000, "train" => 60000) |
| 63 | + end |
| 64 | +end |
| 65 | + |
0 commit comments