Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow StateIndex to be passed dynamically #843

Merged
merged 1 commit into from
Sep 27, 2024

Conversation

NeilGirdhar
Copy link
Contributor

Fixes #842

@NeilGirdhar
Copy link
Contributor Author

NeilGirdhar commented Sep 11, 2024

If this is acceptable, I'm happy to add some methods if you want to make the code more polished:

class StateIndex:
  def initial_value(self):
    return self.init[0]

  def initial_value_deleted(self):
     return self.init == ()

Copy link
Owner

@patrick-kidger patrick-kidger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As you suggest in #842, I think it'd be better to replace () with a sentinel value that is a empty pytree, just to help emphasise what's going on.

I don't think you need to wrap the initial value into a length-1 tuple? Just the original value should work, I think.

@NeilGirdhar NeilGirdhar changed the title Make StateIndex a PyTree Allow StateIndex to be passed dynamically Sep 11, 2024
@NeilGirdhar NeilGirdhar force-pushed the stateindex branch 4 times, most recently from 7db8934 to 572b72f Compare September 12, 2024 16:19
@NeilGirdhar
Copy link
Contributor Author

Errors fixed, but I have some questions about the code. I don't understand why StateIndex.marker's sentinel can't be changed to None. Is it because of replacement with eqx.combine, etc.? Also, in State.set, why can't we assert that item.marker is an integer? I guess I don't understand what the code is doing.

@patrick-kidger
Copy link
Owner

Indeed, I think .marker being an object() might be unnecessary. I think there is a potential footgun here if such a 'raw' StateIndex is passed to State though -- with object() then we'd at least get a unique dictionary key, but with None then they'd all overwrite each other. Probably the appropriate solution is indeedto explicitly raise an error if such a 'raw' StateIndex is passed to State -- i.e. your suggestion of asserting that it is an integer.

@NeilGirdhar
Copy link
Contributor Author

then we'd at least get a unique dictionary key, but with None then they'd all overwrite each other.

Interesting, okay. You may want to consider adding a comment if that's behavior that you're counting on in for some use cases. (Sorry I've been struggling with COVID all week, and my brain's a bit slower than usual.)

i.e. your suggestion of asserting that it is an integer.

I tried that, but couldn't get it to pass the tests.

Anyway, I do love how the interface of State protects users from setting twice on the same state, or trying to use an expired state. Nice user-facing design.

@patrick-kidger
Copy link
Owner

patrick-kidger commented Sep 12, 2024

You may want to consider adding a comment if that's behavior that you're counting on in for some use cases.

Honestly, I'm not sure that'd be good behaviour to rely on... ! It's certainly not the usual path. I'd be happy to change that without considering it a compatibility break.

Lmk where you land on all of this + when you want a review of this PR. (Once it's passing tests.)

@NeilGirdhar
Copy link
Contributor Author

Honestly, I'm not sure that'd be good behaviour to rely on... !

Great! If I have time, I'll take a look at this again.

Lmk where you land on all of this + when you want a review of this PR. (Once it's passing tests.)

Hmmm, it passes the tests on my machine on Python 3.11 (the failing test), and 3.12. I'm not sure how to debug this. Do you have any insight into this by any chance?

@patrick-kidger
Copy link
Owner

It does seem a bit weird!
I do note that JAX recently did a new release, which has apparently since been yanked. Possibly something to do with that new release?

@NeilGirdhar
Copy link
Contributor Author

NeilGirdhar commented Sep 13, 2024

I do note that JAX recently did a new release, which has apparently since been yanked. Possibly something to do with that new release?

I'm not sure, but I tested with the previous release Jax 0.4.31, whereas the test ran with the new release. I'll re-run the job.

But the error

Closure-converted function called with different dynamic arguments to the example arguments provided.

is related to Equinox, right? I'm not sure how closure conversion works since I haven't used it yet.

@NeilGirdhar
Copy link
Contributor Author

(Looks like this passes now.)

@NeilGirdhar
Copy link
Contributor Author

@patrick-kidger Do you have time to take a look at this?

@patrick-kidger
Copy link
Owner

Yup, I do! Have been otherwise engaged this past week. I expect to have a look at this in the next couple of days :)

@NeilGirdhar
Copy link
Contributor Author

No worries, take your time :)

@patrick-kidger patrick-kidger merged commit f687b9f into patrick-kidger:main Sep 27, 2024
2 checks passed
@patrick-kidger
Copy link
Owner

Okay, LGTM -- merged! Thank you for the contribution :)

@NeilGirdhar NeilGirdhar deleted the stateindex branch September 27, 2024 15:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

StateIndex is a Module, but not a PyTree
2 participants