diff --git a/burn-book/src/building-blocks/module.md b/burn-book/src/building-blocks/module.md index 8397bc56ff..128a99a7b6 100644 --- a/burn-book/src/building-blocks/module.md +++ b/burn-book/src/building-blocks/module.md @@ -89,18 +89,29 @@ You can implement your own mapper or visitor by implementing these simple traits ```rust, ignore /// Module visitor trait. pub trait ModuleVisitor { - /// Visit a tensor in the module. - fn visit(&mut self, id: ParamId, tensor: &Tensor); + /// Visit a float tensor in the module. + fn visit_float(&mut self, id: ParamId, tensor: &Tensor); + /// Visit an int tensor in the module. + fn visit_int(&mut self, id: ParamId, tensor: &Tensor); + /// Visit a bool tensor in the module. + fn visit_bool(&mut self, id: ParamId, tensor: &Tensor); } /// Module mapper trait. pub trait ModuleMapper { - /// Map a tensor in the module. - fn map(&mut self, id: ParamId, tensor: Tensor) -> - Tensor; + /// Map a float tensor in the module. + fn map_float(&mut self, id: ParamId, tensor: Tensor) -> Tensor; + /// Map an int tensor in the module. + fn map_int(&mut self, id: ParamId, tensor: Tensor) -> Tensor; + /// Map a bool tensor in the module. + fn map_bool(&mut self, id: ParamId, tensor: Tensor) -> Tensor; } ``` +Note that the trait doesn't require all methods to be implemented as they are already defined to +perform no operation. If you're only interested in float tensors (like the majority of use cases), +then you can simply implement `map_float` or `visit_float`. + ## Module Display Burn provides a simple way to display the structure of a module and its configuration at a glance. @@ -182,14 +193,14 @@ Burn comes with built-in modules that you can use to build your own modules. ### Convolutions -| Burn API | PyTorch Equivalent | -| ----------------- | -------------------- | -| `Conv1d` | `nn.Conv1d` | -| `Conv2d` | `nn.Conv2d` | -| `Conv3d` | `nn.Conv3d` | -| `ConvTranspose1d` | `nn.ConvTranspose1d` | -| `ConvTranspose2d` | `nn.ConvTranspose2d` | -| `ConvTranspose3d` | `nn.ConvTranspose3d` | +| Burn API | PyTorch Equivalent | +| ----------------- | ------------------------------ | +| `Conv1d` | `nn.Conv1d` | +| `Conv2d` | `nn.Conv2d` | +| `Conv3d` | `nn.Conv3d` | +| `ConvTranspose1d` | `nn.ConvTranspose1d` | +| `ConvTranspose2d` | `nn.ConvTranspose2d` | +| `ConvTranspose3d` | `nn.ConvTranspose3d` | | `DeformConv2d` | `torchvision.ops.DeformConv2d` | ### Pooling