Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fast array extraction #7227

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open

Conversation

alex-hh
Copy link
Contributor

@alex-hh alex-hh commented Oct 14, 2024

Implements #7210 using method suggested in #7207 (comment)

import numpy as np
from datasets import Dataset, Features, Array3D
features=Features(**{"array0": Array3D((None, 10, 10), dtype="float32"), "array1": Array3D((None,10,10), dtype="float32")})
dataset = Dataset.from_dict({f"array{i}": [np.zeros((x,10,10), dtype=np.float32) for x in [2000,1000]*25] for i in range(2)}, features=features)

~0.02 s vs 0.9s on main

ds = dataset.to_iterable_dataset()
t0 = time.time()
for ex in ds:
    pass
t1 = time.time()

< 0.01 s vs 1.3 s on main

@lhoestq I can see this breaks a bunch of array-related tests but can update the test cases if you would support making this change?

I also added an Array1D feature which will always be decoded into a numpy array and likewise improves extraction performance:

from datasets import Dataset, Features, Array1D, Sequence, Value
array_features=Features(**{"array0": Array1D((None,), dtype="float32"), "array1": Array1D((None,), dtype="float32")})
sequence_features=Features(**{"array0": Sequence(feature=Value("float32"), length=-1), "array1": Sequence(feature=Value("float32"), length=-1)})
array_dataset = Dataset.from_dict({f"array{i}": [np.zeros((x,), dtype=np.float32) for x in [20000,10000]*25] for i in range(2)}, features=array_features)
sequence_dataset = Dataset.from_dict({f"array{i}": [np.zeros((x,), dtype=np.float32) for x in [20000,10000]*25] for i in range(2)}, features=sequence_features)


```python
t0 = time.time()
for ex in array_dataset.to_iterable_dataset():
    pass
t1 = time.time()

< 0.01 s

t0 = time.time()
for ex in sequence_dataset.to_iterable_dataset():
    pass
t1 = time.time()

~1.1s

And also added support for extracting structs of arrays as dicts of numpy arrays:

import numpy as np
from datasets import Dataset, Features, Array3D, Sequence
features=Features(struct={"array0": Array3D((None,10,10), dtype="float32"), "array1": Array3D((None,10,10), dtype="float32")}, _list=Sequence(feature=Array3D((None,10,10), dtype="float32")))
dataset = Dataset.from_dict({"struct": [{f"array{i}": np.zeros((x,10,10), dtype=np.float32) for i in range(2)} for x in [2000,1000]*25], "_list": [[np.zeros((x,10,10), dtype=np.float32) for i in range(2)] for x in [2000,1000]*25]}, features=features)
t0 = time.time()
for ex in dataset.to_iterable_dataset():
    pass
t1 = time.time()
assert isinstance(ex["struct"]["array0"], np.ndarray) and ex["struct"]["array0"].ndim == 3

~0.02 s and no exception vs ~7s with an exception on main

Copy link
Member

@lhoestq lhoestq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice ! sure feel free to update the tests

Comment on lines +196 to +197
if pa.types.is_struct(pa_array.field(field.name).type):
batch[field.name] = extract_struct_array(pa_array.field(field.name))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can also check if it's a list or large_list type

Copy link
Contributor Author

@alex-hh alex-hh Oct 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked that lists of ArrayExtensionType features will call ArrayExtensionArray.to_pylist(), which didn't seem to be the case for struct, and is the main performance issue there

Not sure about large list?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cool ! maybe also check list of struct of ArrayExtensionType but no big deal, we can fix that rare case later (large list is also rare)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the list of struct case might require an ArrayExtensionScalar or something with an as_py method that returns a numpy object.

Seems like it could be useful but have no idea whether this is possible or how best to do it if so?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unless you know how to do this could we leave as issue?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe just add a TODO comment about it for now ?

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@alex-hh alex-hh force-pushed the fast-array-extraction branch from 550d2f0 to a89ef52 Compare October 15, 2024 16:19
@alex-hh
Copy link
Contributor Author

alex-hh commented Oct 15, 2024

I've updated the most straightforward failing test cases - lmk if you agree with those.

Might need some help / pointers on the remaining new failing tests, which seem a little bit more subtle.

@alex-hh
Copy link
Contributor Author

alex-hh commented Oct 18, 2024

@lhoestq I've had a go at fixing a few more test cases but getting quite uncertain about the remaining ones (as well as about some of the array writing ones that I tried to fix in my last commit). There are still 27 failures vs 21 on main. I'm not completely sure in some cases what intended behaviour is and my understanding of the flow for typed writing is a bit vague.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants