Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Standalone hash encoding network? #3

Open
pourion opened this issue Aug 16, 2023 · 2 comments
Open

Standalone hash encoding network? #3

pourion opened this issue Aug 16, 2023 · 2 comments

Comments

@pourion
Copy link

pourion commented Aug 16, 2023

Hello,

Thank you for this amazing work! I am looking for a lof of the elements that you have used in your nerf application, for example the jax-tcnn is a really useful library on its own. I am not familiar with "nix" and am using Docker for my development environment (https://github.com/JAX-DIPS/JAX-DIPS), is it possible to provide me with instructions of how I can only install this library in my project please? I appreciate any help you could offer me.

Thank you,
Pouria

@blurgyy
Copy link
Owner

blurgyy commented Aug 16, 2023

Hi Pouria,

Thanks for the kind words, JAX-DIPS looks like a very useful library that I may use in the future.

I do not have much time to try out another environment recently, but may be able to come up with a solution around next week.

Another note, if what you need is only a usable standalone hashgrid encoder, you can directly copy the code of the class HashGridEncoder from models/encoders:

jaxngp/models/encoders.py

Lines 16 to 256 in 1ef676a

cell_vert_offsets = {
2: jnp.asarray([
[0., 0.],
[0., 1.],
[1., 0.],
[1., 1.],
]),
3: jnp.asarray([
[0., 0., 0.],
[0., 0., 1.],
[0., 1., 0.],
[0., 1., 1.],
[1., 0., 0.],
[1., 0., 1.],
[1., 1., 0.],
[1., 1., 1.],
]),
}
adjacent_offsets = {
2: jnp.asarray([
[0., 1.],
[1., 0.],
[0., -1.],
[-1., 0.],
]),
3: jnp.asarray([
[0., 0., 1.],
[0., 1., 0.],
[1., 0., 0.],
[0., 0., -1.],
[0., -1., 0.],
[-1., 0., 0.],
]),
}
class Encoder(nn.Module): ...
# TODO: enforce types used in arrays
@empty_impl
class HashGridEncoder(Encoder):
# Let's use the same notations as in the paper
# Number of levels (16).
L: int
# Maximum entries per level (hash table size) (2**14 to 2**24).
T: int
# Number of feature dimensions per entry (2).
F: int
# Coarsest resolution (16).
N_min: int
# Finest resolution (512 to 524288).
N_max: int
tv_scale: float
param_dtype: Dtype = jnp.float32
@property
def b(self) -> float:
# Equation(3)
# Essentially, it is $(n_max / n_min) ** (1/(L - 1))$
return math.exp((math.log(self.N_max) - math.log(self.N_min)) / (self.L - 1))
@nn.compact
def __call__(self, pos: jax.Array, bound: float) -> jax.Array:
dim = pos.shape[-1]
# CAVEAT: hashgrid encoder is defined only in the unit cube [0, 1)^3
pos = (pos + bound) / (2 * bound)
scales, resolutions, first_hash_level, offsets = [], [], 0, [0]
for i in range(self.L):
scale = self.N_min * (self.b**i) - 1
scales.append(scale)
res = math.ceil(scale) + 1
resolutions.append(res)
n_entries = res ** dim
if n_entries <= self.T:
first_hash_level += 1
else:
n_entries = self.T
offsets.append(offsets[-1] + n_entries)
latents = self.param(
"latent codes stored on grid vertices",
# paper:
# We initialize the hash table entries using the uniform distribution U(−10^{−4}, 10^{−4})
# to provide a small amount of randomness while encouraging initial predictions close
# to zero.
lambda key, shape, dtype: jran.uniform(key, shape, dtype, -1e-4, 1e-4),
(offsets[-1], self.F),
self.param_dtype,
)
@jax.vmap
@jax.vmap
def make_vert_pos(pos_scaled: jax.Array):
# [dim]
pos_floored = jnp.floor(pos_scaled)
# [2**dim, dim]
vert_pos = pos_floored[None, :] + cell_vert_offsets[dim]
return vert_pos.astype(jnp.uint32)
@jax.vmap
@jax.vmap
def make_adjacent_pos(pos_scaled: jax.Array):
# [dim]
pos_floored = jnp.floor(pos_scaled)
# [dim * 2, dim]
adjacent_pos = pos_floored[None, :] + adjacent_offsets[dim]
return adjacent_pos.astype(jnp.uint32)
@vmap_jaxfn_with(in_axes=(0, 0))
@vmap_jaxfn_with(in_axes=(None, 0))
def make_tiled_indices(res, vert_pos):
"""(first 2 axes `[L, n_points]` are vmapped away)
Inputs:
res `uint32` `[L]`: each hierarchy's resolution
vert_pos `uint32` `[L, n_points, B, dim]`: integer positions of grid cell's
vertices, of each level
Returns:
indices `uint32` `[L, n_points, B]`: grid cell indices of the vertices
"""
# [dim]
if dim == 2:
strides = jnp.stack([jnp.ones_like(res), res]).T
elif dim == 3:
strides = jnp.stack([jnp.ones_like(res), res, res ** 2]).T
else:
raise NotImplementedError("{} is only implemented for 2D and 3D data".format(__class__.__name__))
# [2**dim]
indices = jnp.sum(strides[None, :] * vert_pos, axis=-1)
return indices
@jax.vmap
@jax.vmap
def make_hash_indices(vert_pos):
"""(first 2 axes `[L, n_points]` are vmapped away)
Inputs:
vert_pos `uint32` `[L, n_points, B, dim]`: integer positions of grid cell's
vertices, of each level
Returns:
indices `uint32` `[L, n_points, B]`: grid cell indices of the vertices
"""
# use primes as reported in the paper
primes = jnp.asarray([1, 2_654_435_761, 805_459_861], dtype=jnp.uint32)
# [2**dim]
if dim == 2:
indices = vert_pos[:, 0] ^ (vert_pos[:, 1] * primes[1])
elif dim == 3:
indices = vert_pos[:, 0] ^ (vert_pos[:, 1] * primes[1]) ^ (vert_pos[:, 2] * primes[2])
else:
raise NotImplementedError("{} is only implemented for 2D and 3D data".format(__class__.__name__))
return indices
def make_indices(vert_pos, resolutions, first_hash_level):
if first_hash_level > 0:
resolutions = jnp.asarray(resolutions, dtype=jnp.uint32)
indices = make_tiled_indices(resolutions[:first_hash_level], vert_pos[:first_hash_level, ...])
else:
indices = jnp.empty(0, dtype=jnp.uint32)
if first_hash_level < self.L:
indices = jnp.concatenate([indices, make_hash_indices(vert_pos[first_hash_level:, ...])], axis=0)
indices = jnp.mod(indices, self.T)
indices += jnp.asarray(offsets[:-1], dtype=jnp.uint32)[:, None, None]
return indices
@jax.vmap
@jax.vmap
def lerp_weights(pos_scaled: jax.Array):
"""(first 2 axes `[L, n_points]` are vmapped away)
Inputs:
pos_scaled `float` `[L, n_points, dim]`: coordinates of query points, scaled to the
hierarchy in question
Returns:
weights `float` `[L, n_points, 2**dim]`: linear interpolation weights for each cell
vertex
"""
# [dim]
pos_offset, _ = jnp.modf(pos_scaled)
# [2**dim, dim]
widths = jnp.clip(
# cell_vert_offsets: [2**dim, dim]
(1 - cell_vert_offsets[dim]) + (2 * cell_vert_offsets[dim] - 1) * pos_offset[None, :],
0,
1,
)
# [2**dim]
return jnp.prod(widths, axis=-1)
# [L]
scales = jnp.asarray(scales, dtype=jnp.float32)
# [L, n_points, dim]
pos_scaled = pos[None, :, :] * scales[:, None, None] + 0.5
# [L, n_points, 2**dim, dim]
vert_pos = make_vert_pos(pos_scaled)
# [L, n_points, 2**dim]
indices = make_indices(vert_pos, resolutions, first_hash_level)
# [L, n_points, 2**dim, F]
vert_latents = latents[indices]
# [L, n_points, 2**dim]
vert_weights = lerp_weights(pos_scaled)
# [L, n_points, F]
encodings = (vert_latents * vert_weights[..., None]).sum(axis=-2)
# [n_points, L*F]
encodings = encodings.transpose(1, 0, 2).reshape(-1, self.L * self.F)
## Total variation
if self.tv_scale > 0:
# [L, n_points, dim * 2, dim]
adjacent_pos = make_adjacent_pos(pos_scaled)
# [L, n_points, dim * 2]
adjacent_indices = make_indices(adjacent_pos, resolutions, first_hash_level)
# [L, n_points, dim * 2, F]
adjacent_latents = latents[adjacent_indices]
# [L, n_points, dim * 2, F]
tv = self.tv_scale * jnp.square(adjacent_latents - vert_latents[:, :, :1, :])
# [L, n_points]
tv = tv.sum(axis=(-2, -1))
tv = tv.mean()
else:
tv = 0
return encodings, tv

It's a JAX/Flax implementation of the hashgrid encoder, and it has identical access pattern as the hashgrid encoder from the tiny-cuda-nn library. The core part is its __call__ method, it only depends on JAX so you can adapt it if you are not using Flax. All the NeRF experiments reported in jaxngp's README are using this implementation during training, and using tiny-cuda-nn's hashgrid during inference. I only created the jax-tcnn library as an attempt to speed up rendering during inference (turned out the speed difference between this implementation and tiny-cuda-nn's hashgrid is pretty small).

Please let me know if that works for you.

Regards,
blurgyy

@pourion
Copy link
Author

pourion commented Aug 16, 2023

Hi Blurgyy,

Thank you so much for your quick response, this indeed is really helpful. I will follow your instruction and use the code you mentioned, thank you so much for your help 🙏

I have been looking across github for some time and honestly there are so many pieces that stand out in your project! I will keep following your development and learn from you.

Sincerely,
Pouria

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants