Skip to content

Commit

Permalink
add DatasetDict (#7)
Browse files Browse the repository at this point in the history
* add DatasetDict

* remove python bounds

* remove huggingface channel

* pin pyarrow to 6.0.0

* use == instead of = in CondaPkg

* relax numpy

* relax pillow
  • Loading branch information
CarloLucibello authored Dec 23, 2022
1 parent c59197a commit f74d35a
Show file tree
Hide file tree
Showing 8 changed files with 146 additions and 15 deletions.
8 changes: 4 additions & 4 deletions CondaPkg.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
channels = ["conda-forge", "huggingface"]
channels = ["conda-forge"]

[deps]
datasets = ">=2.7, <3"
numpy = ">=1.23, <2"
pillow = ">=9.2, <10"
python = ">=3.6, <4"
numpy = ">=1.20, <2"
pillow = ">=9.1, <10"
pyarrow = "==6.0.0"
5 changes: 5 additions & 0 deletions src/HuggingFaceDatasets.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ include("observation.jl")
include("dataset.jl")
export Dataset, set_transform!

include("datasetdict.jl")
export DatasetDict

include("transforms.jl")
export py2jl

Expand All @@ -25,6 +28,8 @@ function load_dataset(args...; kws...)
d = datasets.load_dataset(args...; kws...)
if pyisinstance(d, datasets.Dataset)
return Dataset(d)
elseif pyisinstance(d, datasets.DatasetDict)
return DatasetDict(d)
else
return d
end
Expand Down
21 changes: 12 additions & 9 deletions src/dataset.jl
Original file line number Diff line number Diff line change
@@ -1,38 +1,41 @@
"""
Dataset(dataset, transform = py2jl)
Dataset(pydataset; transform = py2jl)
A Julia wrapper around the python `datasets.Dataset` type.
It is the return type of [`load_dataset`](@ref).
A Julia wrapper around the objects of the python `datasets.Dataset` class.
The `transform` is applied after datasets' one.
The [`py2jl`](@def) default converts python types to julia types.
The [`py2jl`](@ref) default converts python types to julia types.
Provides:
- 1-based indexing.
- [`set_transform!`](@ref) julia method.
- All python class' methods from `datasets.Dataset`.
See also [`load_dataset`](@ref) and [`DatasetDict`](@ref).
"""
mutable struct Dataset
pyd::Py
transform
end

Dataset(pydataset::Py; transform = py2jl) = Dataset(pydataset, transform)
function Dataset(pydataset::Py; transform = py2jl)
@assert pyisinstance(pydataset, datasets.Dataset)
return new(pydataset, transform)
end
end

function Base.getproperty(d::Dataset, s::Symbol)
if s in fieldnames(Dataset)
return getfield(d, s)
else
res = getproperty(getfield(d, :pyd), s)
if pyisinstance(res, datasets.Dataset)
return Dataset(res, d.transform)
return Dataset(res; d.transform)
else
return res
return res |> py2jl
end
end
end


Base.length(d::Dataset) = length(d.pyd)

Base.getindex(d::Dataset, ::Colon) = d[1:length(d)]
Expand Down
50 changes: 50 additions & 0 deletions src/datasetdict.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
"""
DatasetDict(pydatasetdict::Py; transform = py2jl)
A `DatasetDict` is a dictionary of `Dataset`s. It is a wrapper around a `datasets.DatasetDict` object.
The `transform` is applied to each [`Dataset`](@ref).
The [`py2jl`](@ref) default converts python types to julia types.
See also [`load_dataset`](@ref) and [`Dataset`](@ref).
"""
mutable struct DatasetDict
pyd::Py
transform

function DatasetDict(pydatasetdict::Py; transform = py2jl)
@assert pyisinstance(pydatasetdict, datasets.DatasetDict)
return new(pydatasetdict, transform)
end
end

function Base.getproperty(d::DatasetDict, s::Symbol)
if s in fieldnames(DatasetDict)
return getfield(d, s)
else
res = getproperty(getfield(d, :pyd), s)
if pyisinstance(res, datasets.Dataset)
return Dataset(res; d.transform)
elseif pyisinstance(res, datasets.DatasetDict)
return DatasetDict(res; d.transform)
else
return res |> py2jl
end
end
end

Base.length(d::DatasetDict) = length(d.pyd)

function Base.getindex(d::DatasetDict, i::AbstractString)
x = d.pyd[i]
return Dataset(x; d.transform)
end

function set_transform!(d::DatasetDict, transform)
if transform === nothing
d.transform = identity
else
d.transform = transform
end
end

2 changes: 1 addition & 1 deletion src/transforms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@ function tojulia(x::Py)
end

tojulia(x::PyList) = [py2jl(x) for x in x]
tojulia(x::PyDict) = Dict(py2jl(k) => py2jl(v) for (k, v) in pairs(x))
tojulia(x::PyDict) = Dict(py2jl(k) => py2jl(v) for (k, v) in pairs(x))
4 changes: 4 additions & 0 deletions test/datasets.jl → test/dataset.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@
d.set_transform(pytr)
@test d[1]["label"] == -7
end

@testset "getproperty returns julia types" begin
@test d.num_rows isa Int
end
end


Expand Down
65 changes: 65 additions & 0 deletions test/datasetdict.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@

@testset "MNIST" begin
dd = load_dataset("mnist")

@testset "load_dataset" begin
@test dd isa DatasetDict
@test length(dd) == 2
end

@testset "indexing with no transform" begin
tr = dd.transform
set_transform!(dd, identity)

@test_throws MethodError dd[1]
@test dd["test"] isa Dataset
d = dd["test"]
@test pyisinstance(d[1], pytype(pydict()))
@test d[1]["image"] isa Py
@test d[1]["label"] isa Py
@test pyisinstance(d[1]["label"], pytype(pyint()))
@test py2jl(d[1]["label"]) == 7
@test py2jl(d[2]["label"]) == 2

@test d[1:2] isa Py
@test d[1:2]["image"] isa Py
@test pyisinstance(d[1:2]["image"], pytype(pylist()))
@test d[1:2]["label"] isa Py
@test pyisinstance(d[1:2]["label"], pytype(pylist()))

set_transform!(dd, tr)
end

@testset "indexing - py2jl" begin
@test dd.transform === py2jl
d = dd["test"]
sample = d[1]
@test sample isa Dict
@test sample["label"] isa Int
@test sample["label"] == 7
@test sample["image"] isa Matrix{UInt8}
@test size(sample["image"]) == (28, 28)

sample = d[1:2]
@test sample isa Dict
@test sample["image"] isa Vector{Matrix{UInt8}}
@test size(sample["image"]) == (2,)
@test sample["label"] isa Vector{Int}
@test size(sample["label"]) == (2,)
end

@testset "python transforms" begin
@pyexec """
def pytr(x):
return {"label": [-l for l in x["label"]]}
""" => pytr
dd.set_transform(pytr)
@test dd["test"][1]["label"] == -7
end

@testset "getproperty returns julia types" begin
@test dd.num_rows isa Dict{String, Int}
@test dd.num_rows == Dict("test" => 10000, "train" => 60000)
end
end

6 changes: 5 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,9 @@ using HuggingFaceDatasets, PythonCall, MLUtils
# using ImageShow, ImageInTerminal

@testset "dataset" begin
include("datasets.jl")
include("dataset.jl")
end

@testset "datasetdict" begin
include("datasetdict.jl")
end

2 comments on commit f74d35a

@CarloLucibello
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/74572

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.1.0 -m "<description of version>" f74d35ad01d80726e3c0f63866c5a87d1b275892
git push origin v0.1.0

Please sign in to comment.