Skip to content

Commit

Permalink
Fix module visitor and mapper trait definition in the book (#2609)
Browse files Browse the repository at this point in the history
  • Loading branch information
laggui authored Dec 13, 2024
1 parent 9d355ef commit 834ff44
Showing 1 changed file with 24 additions and 13 deletions.
37 changes: 24 additions & 13 deletions burn-book/src/building-blocks/module.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,18 +89,29 @@ You can implement your own mapper or visitor by implementing these simple traits
```rust, ignore
/// Module visitor trait.
pub trait ModuleVisitor<B: Backend> {
/// Visit a tensor in the module.
fn visit<const D: usize>(&mut self, id: ParamId, tensor: &Tensor<B, D>);
/// Visit a float tensor in the module.
fn visit_float<const D: usize>(&mut self, id: ParamId, tensor: &Tensor<B, D>);
/// Visit an int tensor in the module.
fn visit_int<const D: usize>(&mut self, id: ParamId, tensor: &Tensor<B, D, Int>);
/// Visit a bool tensor in the module.
fn visit_bool<const D: usize>(&mut self, id: ParamId, tensor: &Tensor<B, D, Bool>);
}
/// Module mapper trait.
pub trait ModuleMapper<B: Backend> {
/// Map a tensor in the module.
fn map<const D: usize>(&mut self, id: ParamId, tensor: Tensor<B, D>) ->
Tensor<B, D>;
/// Map a float tensor in the module.
fn map_float<const D: usize>(&mut self, id: ParamId, tensor: Tensor<B, D>) -> Tensor<B, D>;
/// Map an int tensor in the module.
fn map_int<const D: usize>(&mut self, id: ParamId, tensor: Tensor<B, D, Int>) -> Tensor<B, D, Int>;
/// Map a bool tensor in the module.
fn map_bool<const D: usize>(&mut self, id: ParamId, tensor: Tensor<B, D, Bool>) -> Tensor<B, D, Bool>;
}
```

Note that the trait doesn't require all methods to be implemented as they are already defined to
perform no operation. If you're only interested in float tensors (like the majority of use cases),
then you can simply implement `map_float` or `visit_float`.

## Module Display

Burn provides a simple way to display the structure of a module and its configuration at a glance.
Expand Down Expand Up @@ -182,14 +193,14 @@ Burn comes with built-in modules that you can use to build your own modules.

### Convolutions

| Burn API | PyTorch Equivalent |
| ----------------- | -------------------- |
| `Conv1d` | `nn.Conv1d` |
| `Conv2d` | `nn.Conv2d` |
| `Conv3d` | `nn.Conv3d` |
| `ConvTranspose1d` | `nn.ConvTranspose1d` |
| `ConvTranspose2d` | `nn.ConvTranspose2d` |
| `ConvTranspose3d` | `nn.ConvTranspose3d` |
| Burn API | PyTorch Equivalent |
| ----------------- | ------------------------------ |
| `Conv1d` | `nn.Conv1d` |
| `Conv2d` | `nn.Conv2d` |
| `Conv3d` | `nn.Conv3d` |
| `ConvTranspose1d` | `nn.ConvTranspose1d` |
| `ConvTranspose2d` | `nn.ConvTranspose2d` |
| `ConvTranspose3d` | `nn.ConvTranspose3d` |
| `DeformConv2d` | `torchvision.ops.DeformConv2d` |

### Pooling
Expand Down

0 comments on commit 834ff44

Please sign in to comment.