-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
API for Forwards, Backwards, Central Finite Difference #127
Comments
I found the winning formula for the simple 1st order FD backwards scheme: u = FiniteDifferences.from_grid(u, domain)
u.accuracy = 2
u_rhs = -c * gradient(u, stagger=[1]) which generates the coefficients: from jaxdf.conv import fd_coefficients_fornberg
grid_points = [1, 0]
x0 = 0.0
order = 1
stencil, nodes = fd_coefficients_fornberg(order, grid_points, x0) which produces: # stencil, nodes
(array([-1., 1.]), array([0, 1])) which is equivalent to: where What I got wrong was the stagger. When I read stagger, my intuition was like a staggered grid, not a staggered stencil. So, I used What I found helped me understand was to expose all of the pieces when generating the FD kernel. For example: # generate nodes based on order, accuracy, method and stagger
nodes = get_fd_nodes(
derivative: int = 1,
accuracy: int = 1,
method: str = "central",
stagger: int = 0
)
# get coefficients
coeffs = get_fd_coeffs(nodes, derivative: int = 1)
# generate FD kernel (Optional)
kernel = get_fd_kernel(coeffs, domain) |
Hi and thanks for this! I will try to answer to the main points, but let me know if there's something I'm missing
I think you already figured this out, but that clearly means that there's a documentation page missing :) In general, one can do this u = FiniteDifferences(jnp.zeros((128,)), domain) # Declare field
u.accuracy = 4 # Choose derivative order
params = jops.gradient.default_params(u, stagger=[1]) # Choose grid staggering and get stencil That returns
In general, every opeartor has the def gradient_with_modified_kernel(u, new_value):
params = jops.gradient.default_params(u, stagger=[1])
new_params = [params[0].at[4].set(new_value)] # <-- modify using jax methods, and keep the same PyTree structure
return gradient(u, params=new_params) # <-- apply the operator with the modified parameters
The idea of the For static arguments, like the def gradient(
u: FiniteDifferences,
*,
accuracy: int = 4, # or, really, this should probably be `order`
method: 'central',
stagger = [0]
):
...
That should indeed be a staggered grid, but I always get myself confused with how kernels are applied in convolutions, correlations etc :) Probably makes sense to write a quick test to check the stencils returned against the ones of FiniteDiffX? Where do we place the
|
I would like to be able to control the finite difference scheme used, i.e. forward, backward or central. Depending upon the PDE, we normally use a custom scheme, e.g. advection --> backwards, diffusion ---> central.
Working Demo
I have a working colab notebook to get a feeling for what I mean. See it here.
Proposed Solution
I don't have a solution but somewhere in the
param
PyTree I think it is important to specify this (just like the accuracy, order, stepsize, etc).Another possible solution: one could use the
FiniteDiffX
package backend for generating the coefficients and kernel if one doesn't specify it. There I recently contributed to be able to specify the FD scheme.Last solution: Just create a custom operator that does exactly all that I've said before. There is an example in the "custom equation of motion" section which does exactly what I want.
The text was updated successfully, but these errors were encountered: