Skip to content

Commit

Permalink
Merge pull request #242 from FluxML/a2/docs-rehaul-1
Browse files Browse the repository at this point in the history
Docs rehaul for v0.8
  • Loading branch information
theabhirath authored Jun 6, 2023
2 parents 010d4bc + defba66 commit ec27452
Show file tree
Hide file tree
Showing 25 changed files with 215 additions and 114 deletions.
32 changes: 0 additions & 32 deletions Artifacts.toml
Original file line number Diff line number Diff line change
@@ -1,35 +1,3 @@
[densenet121]
git-tree-sha1 = "e929cd4de9255f65fe85e9365a0223e5876e7fb0"
lazy = true

[[densenet121.download]]
sha256 = "0e9259a45097ded4d79a7b30d41ded7d799c7b08134ca65c84635b92100ce440"
url = "https://huggingface.co/FluxML/densenet121/resolve/442cd2e85d63b3aa7340deb786a02c2ccffa1590/densenet121.tar.gz"

[densenet161]
git-tree-sha1 = "d780a8eb8667bb2a840ed0685d937d311490aa1e"
lazy = true

[[densenet161.download]]
sha256 = "5eb8a7c69c353a7e5756dfc041850e59b34b97821789fea8f73f1eff2d7c10d8"
url = "https://huggingface.co/FluxML/densenet161/resolve/dd02e5e0562620d187b2fd9b3a0994b0e2d85054/densenet161.tar.gz"

[densenet169]
git-tree-sha1 = "b8cc68d9554f8274d0c3e0385e04d2ee41090831"
lazy = true

[[densenet169.download]]
sha256 = "8b9b182670ea6053f51db1542c133f183b4c0aa45f56e635598994a2f77fa39a"
url = "https://huggingface.co/FluxML/densenet169/resolve/9868bf1fb24592b6d7ee9aca70844eda859a3d3a/densenet169.tar.gz"

[densenet201]
git-tree-sha1 = "8d581b667da3e64a557a251876ebde6ee115a504"
lazy = true

[[densenet201.download]]
sha256 = "4d6c6bae58ff586a82c8a4941cfba1709d1b1aec013fa3b188ce60df3e23e2aa"
url = "https://huggingface.co/FluxML/densenet201/resolve/4a39454fa8039b8d0546d8e60e2c98aa58e7a421/densenet201.tar.gz"

[resnet101]
git-tree-sha1 = "68d563526ab34d3e5aa66b7d96278d2acde212f9"
lazy = true
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ To contribute new models, see our [contributing docs](https://fluxml.ai/Metalhea
| [AlexNet](https://papers.nips.cc/paper/2012/file/c399862d3b9d6b76c8436e924a68c45b-Paper.pdf) | [`AlexNet`](https://fluxml.ai/Metalhead.jl/dev/api/other/#Metalhead.AlexNet) | N |
| [ConvMixer](https://arxiv.org/abs/2201.09792) | [`ConvMixer`](https://fluxml.ai/Metalhead.jl/dev/api/hybrid/#Metalhead.ConvMixer) | N |
| [ConvNeXt](https://arxiv.org/abs/2201.03545) | [`ConvNeXt`](https://fluxml.ai/Metalhead.jl/dev/api/hybrid/#Metalhead.ConvNeXt) | N |
| [DenseNet](https://arxiv.org/abs/1608.06993) | [`DenseNet`](https://fluxml.ai/Metalhead.jl/dev/api/densenet/#Metalhead.DenseNet) | Y |
| [DenseNet](https://arxiv.org/abs/1608.06993) | [`DenseNet`](https://fluxml.ai/Metalhead.jl/dev/api/densenet/#Metalhead.DenseNet) | N |
| [EfficientNet](https://arxiv.org/abs/1905.11946) | [`EfficientNet`](https://fluxml.ai/Metalhead.jl/dev/api/efficientnet/#Metalhead.EfficientNet) | N |
| [EfficientNetv2](https://arxiv.org/abs/2104.00298) | [`EfficientNetv2`](https://fluxml.ai/Metalhead.jl/dev/api/efficientnet/#Metalhead.EfficientNetv2) | N |
| [gMLP](https://arxiv.org/abs/2105.08050) | [`gMLP`](https://fluxml.ai/Metalhead.jl/dev/api/mixers/#Metalhead.gMLP) | N |
Expand Down
3 changes: 2 additions & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ makedocs(; modules = [Metalhead, Artifacts, LazyArtifacts, Images, DataAugmentat
"api/resnet.md",
"api/densenet.md",
"api/efficientnet.md",
"api/mobilenet.md",
"api/inception.md",
"api/hybrid.md",
"api/others.md",
Expand All @@ -35,7 +36,7 @@ makedocs(; modules = [Metalhead, Artifacts, LazyArtifacts, Images, DataAugmentat
"api/vit.md",
],
"Layers" => "api/layers.md",
"Utilities" => "api/utilities.md",
"Model Utilities" => "api/utilities.md",
],
],
format = Documenter.HTML(; canonical = "https://fluxml.ai/Metalhead.jl/stable/",
Expand Down
10 changes: 4 additions & 6 deletions docs/src/api/efficientnet.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
# Efficient Networks
# EfficientNet family of models

This is the API reference for the EfficientNet family of models supported by Metalhead.jl.

```@docs
EfficientNet
EfficientNetv2
MobileNetv1
MobileNetv2
MobileNetv3
MNASNet
```
```
2 changes: 1 addition & 1 deletion docs/src/api/hybrid.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@ ConvNeXt
```@docs
Metalhead.convmixer
Metalhead.convnext
```
```
2 changes: 1 addition & 1 deletion docs/src/api/inception.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Inception models
# Inception family of models

This is the API reference for the Inception family of models supported by Metalhead.jl.

Expand Down
92 changes: 87 additions & 5 deletions docs/src/api/layers.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Layers

Metalhead also defines a module called `Layers` which contains some more modern layers that are not available in Flux. To use the functions defined in the `Layers` module, you need to import it.
Metalhead also defines a module called `Layers` which contains some custom layers that are used to configure the models in Metalhead. These layers are not available in Flux at present. To use the functions defined in the `Layers` module, you need to import it.

```julia
using Metalhead: Layers
Expand All @@ -10,9 +10,91 @@ This page contains the API reference for the `Layers` module.

!!! warning

The `Layers` module is still a work in progress. While we will endeavour to keep the API stable, we cannot guarantee that it will not change in the future. If you find any of the functions in this
module do not work as expected, please open an issue on GitHub.
The `Layers` module is still a work in progress. While we will endeavour to keep the API stable, we cannot guarantee that it will not change in the future. If you find any of the functions in this module do not work as expected, please open an issue on GitHub.

```@autodocs
Modules = [Metalhead.Layers]
## Convolution + BatchNorm layers

```@docs
Metalhead.Layers.conv_norm
Metalhead.Layers.basic_conv_bn
```

## Convolution-related custom blocks

These blocks are designed to be used in convolutional neural networks. Most of these are used in the MobileNet and EfficientNet family of models, but they also feature in "fancier" versions of well known-models like ResNet (SE-ResNet).

```@docs
Metalhead.Layers.dwsep_conv_norm
Metalhead.Layers.mbconv
Metalhead.Layers.fused_mbconv
Metalhead.Layers.squeeze_excite
Metalhead.Layers.effective_squeeze_excite
```

## Normalisation, Dropout and Pooling layers

Metalhead provides various custom layers for normalisation, dropout and pooling which have been used to additionally customise various models.

### Normalisation layers

```@docs
Metalhead.Layers.ChannelLayerNorm
Metalhead.Layers.LayerNormV2
Metalhead.Layers.LayerScale
```

### Dropout layers

```@docs
Metalhead.Layers.DropBlock
Metalhead.Layers.dropblock
Metalhead.Layers.StochasticDepth
```

### Pooling layers

```@docs
Metalhead.Layers.AdaptiveMeanMaxPool
```

## Classifier creation

Metalhead provides a function to create a classifier for neural network models that is quite flexible, and is used by the library extensively to create the classifier "head" for networks.

```@docs
Metalhead.Layers.create_classifier
```

## Vision transformer-related layers

The `Layers` module contains specific layers that are used to build vision transformer (ViT)-inspired models:

```@docs
Metalhead.Layers.MultiHeadSelfAttention
Metalhead.Layers.ClassTokens
Metalhead.Layers.ViPosEmbedding
Metalhead.Layers.PatchEmbedding
```

## MLPMixer-related blocks

Apart from this, the `Layers` module also contains certain blocks used in MLPMixer-style models:

```@docs
Metalhead.Layers.gated_mlp_block
Metalhead.Layers.mlp_block
```

## Utilities for layers

These are some miscellaneous utilities present in the `Layers` module, and are used with other custom/inbuilt layers to make certain common operations in neural networks easier.

```@docs
Metalhead.Layers.inputscale
Metalhead.Layers.actadd
Metalhead.Layers.addact
Metalhead.Layers.cat_channels
Metalhead.Layers.flatten_chains
Metalhead.Layers.linear_scheduler
Metalhead.Layers.swapdims
```
2 changes: 1 addition & 1 deletion docs/src/api/mixers.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,4 @@ Metalhead.mixerblock
Metalhead.resmixerblock
Metalhead.SpatialGatingUnit
Metalhead.spatialgatingblock
```
```
10 changes: 10 additions & 0 deletions docs/src/api/mobilenet.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# MobileNet family of models

This is the API reference for the MobileNet family of models supported by Metalhead.jl.

```@docs
MobileNetv1
MobileNetv2
MobileNetv3
MNASNet
```
6 changes: 2 additions & 4 deletions docs/src/api/utilities.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
# Utilities
# Model utilities

Metalhead provides some utility functions for making it easier to work with the models inside the library or to build new ones. The API reference for these is documented below.

## `backbone` and `classifier`

```@docs
backbone
classifier
```
```
38 changes: 38 additions & 0 deletions docs/src/howto/resnet.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,41 @@ model = ResNet(50; pretrain=true)
```

To check out more about using pretrained models, check out the [pretrained models guide](@ref pretrained).

## The mid-level function

Metalhead also provides a function for users looking to customise the ResNet family of models further. This function is named [`Metalhead.resnet`](@ref) and has a detailed docstring that describes all the various customisation options. You may want to open the above link in another tab, because we're going to be referring to it extensively to build a ResNet model of our liking.

First, let's take a peek at how we would write the vanilla ResNet-18 model using this function. At its core, a residual network is a convolutional network split into stages, where each stage contains a "residual" block repeated several times. The Metalhead.jl design reflects this. While there are many keyword arguments that we can configure, there are two required positional arguments--the block type and the number of times a block is repeated in each stage. For all other options, the default values work well. The original ResNet paper suggest using a "basic block" type and a block repetition of two. So we can write the ResNet-18 model as follows:

```julia
resnet18 = Metalhead.resnet(Metalhead.basicblock, [2, 2, 2, 2])
```

What if we want to customise the number of output classes? That's easy; the model has several keyword arguments, one of which allows this. The docstring tells us that it is `nclasses`, and so we can write:

```julia
resnet18 = Metalhead.resnet(Metalhead.basicblock, [2, 2, 2, 2]; nclasses = 10)
```

Let's try customising this further. Say I want to make a ResNet-50-like model, but with [`StochasticDepth`](https://arxiv.org/abs/1603.09382) to provide even more regularisation, and also a custom pooling layer such as `AdaptiveMeanMaxPool`. Both of these options are provided by Metalhead out of the box, and so we can write:

```julia
using Metalhead: Layers # AdaptiveMeanMaxPool is in the Layers module in Metalhead

custom_resnet = Metalhead.resnet(Metalhead.bottleneck, [3, 4, 6, 3];
pool_layer = Layers.AdaptiveMeanMaxPool((1, 1)),
stochastic_depth_prob = 0.2)
```

To make this a ResNeXt-like model, all we need to do is configure the cardinality and the
base width:

```julia
custom_resnet = Metalhead.resnet(Metalhead.bottleneck, [3, 4, 6, 3];
cardinality = 32, base_width = 4,
pool_layer = Layers.AdaptiveMeanMaxPool((1, 1)),
stochastic_depth_prob = 0.2)
```

And we have a custom model, built with minimal effort! The documentation for `Metalhead.resnet` has been written with extensive care and in as much detail as possible to facilitate ease of use. Still, if you find anything difficult to understand, feel free to open an issue and we will be happy to help you out, and to improve the documentation where necessary.
22 changes: 16 additions & 6 deletions docs/src/tutorials/quickstart.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,17 @@

Metalhead.jl is a library written in Flux.jl that is a collection of image models, layers and utilities for deep learning in computer vision.

## Pre-trained models
## Model architectures and pre-trained models

In Metalhead.jl, camel-cased functions mimicking the naming style followed in the paper such as [`ResNet`](@ref) or [`ResNeXt`](@ref) are considered the "higher" level API for models. These are the functions that end-users who do not want to experiment much with model architectures should use. These models also support the option for loading pre-trained weights from ImageNet.
In Metalhead.jl, camel-cased functions mimicking the naming style followed in the paper such as [`ResNet`](@ref) or [`MobileNetv3`](@ref) are considered the "higher" level API for models. These are the functions that end-users who do not want to experiment much with model architectures should use. To use these models, simply call the function of the model:

```julia
using Metalhead

model = ResNet(18);
```

The API reference contains the documentation and options for each model function. These models also support the option for loading pre-trained weights from ImageNet.

!!! note

Expand All @@ -18,12 +26,14 @@ using Metalhead
model = ResNet(18; pretrain = true);
```

Refer to the pretraining guide for more details on how to use pre-trained models.
Refer to the [pretraining guide](@pretrained) for more details on how to use pre-trained models.

## More model configuration options

For users who want to use more options for model configuration, Metalhead provides a "mid-level" API for models. The model functions that are in lowercase such as [`resnet`](@ref) or [`mobilenetv3`](@ref) are the "lower" level API for models. These are the functions that end-users who want to experiment with model architectures should use. These models do not support the option for loading pre-trained weights from ImageNet out of the box.
For users who want to use more options for model configuration, Metalhead provides a "mid-level" API for models. These are the model functions that are in lowercase such as [`resnet`](@ref) or [`mobilenetv3`](@ref). End-users who want to experiment with model architectures should use these functions. These models do not support the option for loading pre-trained weights from ImageNet out of the box, although one can always load weights explicitly using the `loadmodel!` function from Flux.

To use any of these models, check out the docstrings for the model functions (these are documented in the API reference). Note that these functions typically require more configuration options to be passed in, but offer a lot more flexibility in terms of model architecture. Metalhead defines as many default options as possible so as to make it easier for the user to pick and choose specific options to customise.

To use any of these models, check out the docstrings for the model functions. Note that these functions typically require more configuration options to be passed in, but offer a lot more flexibility in terms of model architecture.
## Builders for the advanced user

##
For users who want the ability to customise their models as much as possible, Metalhead offers a powerful low-level interface. These are known as [**builders**](@ref builders) and allow the user to hack into the core of models and build them up as per their liking. Most users will not need to use builders since a large number of configuration options are exposed at the mid-level API. However, for package developers and users who want to build customised versions of their own models, the low-level API provides the customisability required while still reducing user code.
22 changes: 7 additions & 15 deletions src/convnets/densenet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -124,13 +124,12 @@ Create a DenseNet model with specified configuration. Currently supported values
([reference](https://arxiv.org/abs/1608.06993)).
# Arguments
- `config`: the configuration of the model
- `pretrain`: whether to load the model with pre-trained weights for ImageNet.
- `growth_rate`: the output feature map growth probability of dense blocks (i.e. `k` in the ref)
- `reduction`: the factor by which the number of feature maps is scaled across each transition
- `inchannels`: the number of input channels
- `nclasses`: the number of output classes
- `config`: the configuration of the model
- `pretrain`: whether to load the model with pre-trained weights for ImageNet.
- `growth_rate`: the output feature map growth probability of dense blocks (i.e. `k` in the ref)
- `reduction`: the factor by which the number of feature maps is scaled across each transition
- `inchannels`: the number of input channels
- `nclasses`: the number of output classes
!!! warning
Expand All @@ -150,8 +149,7 @@ function DenseNet(config::Int; pretrain::Bool = false, growth_rate::Int = 32,
nclasses)
model = DenseNet(layers)
if pretrain
artifact_name = string("densenet", config)
loadpretrain!(model, artifact_name) # see also HACK below
loadpretrain!(model, string("densenet", config))
end
return model
end
Expand All @@ -160,9 +158,3 @@ end

backbone(m::DenseNet) = m.layers[1]
classifier(m::DenseNet) = m.layers[2]

## HACK TO LOAD OLD WEIGHTS, remove when we have a new artifact
function Flux.loadmodel!(m::DenseNet, src)
Flux.loadmodel!(m.layers[1], src.layers[1])
Flux.loadmodel!(m.layers[2], src.layers[2])
end
4 changes: 4 additions & 0 deletions src/convnets/hybrid/convmixer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ Creates a ConvMixer model.
- `inchannels`: number of input channels
- `nclasses`: number of classes in the output
!!! warning
`ConvMixer` does not currently support pretrained weights.
See also [`Metalhead.convmixer`](@ref).
"""
struct ConvMixer
Expand Down
2 changes: 1 addition & 1 deletion src/convnets/hybrid/convnext.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ function convnextblock(planes::Integer, stochastic_depth_prob = 0.0,
layerscale_init = 1.0f-6)
return SkipConnection(Chain(DepthwiseConv((7, 7), planes => planes; pad = 3),
swapdims((3, 1, 2, 4)),
LayerNorm(planes; ϵ = 1.0f-6),
LayerNorm(planes; eps = 1.0f-6),
mlp_block(planes, 4 * planes),
LayerScale(planes, layerscale_init),
swapdims((2, 3, 1, 4)),
Expand Down
1 change: 1 addition & 0 deletions src/convnets/resnets/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,7 @@ Wide ResNet, ResNeXt and Res2Net. For an _even_ more generic model API, see [`Me
- `reduction_factor`: The reduction factor used in the model.
- `connection`: This is a function that determines the residual connection in the model. For
`resnets`, either of [`Metalhead.addact`](@ref) or [`Metalhead.actadd`](@ref) is recommended.
These decide whether the residual connection is added before or after the activation function.
- `norm_layer`: The normalisation layer to be used in the model.
- `revnorm`: set to `true` to place the normalisation layers before the convolutions
- `attn_fn`: A callback that is used to determine the attention function to be used in the model.
Expand Down
6 changes: 3 additions & 3 deletions src/convnets/resnets/resnet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ function ResNet(depth::Integer; pretrain::Bool = false, inchannels::Integer = 3,
model = ResNet(layers)
if pretrain
artifact_name = "resnet$(depth)"
if depth [18, 34]
if depth in [18, 34]
artifact_name *= "-IMAGENET1K_V1"
elseif depth [50, 101, 152]
elseif depth in [50, 101, 152]
artifact_name *= "-IMAGENET1K_V2"
end
loadpretrain!(model, artifact_name)
Expand Down Expand Up @@ -69,7 +69,7 @@ function WideResNet(depth::Integer; pretrain::Bool = false, inchannels::Integer
model = WideResNet(layers)
if pretrain
artifact_name = "wideresnet$(depth)"
if depth [50, 101]
if depth in [50, 101]
artifact_name *= "-IMAGENET1K_V2"
end
loadpretrain!(model, artifact_name)
Expand Down
Loading

2 comments on commit ec27452

@theabhirath
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/84999

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.8.0 -m "<description of version>" ec274521764667b7ef167af6b1c7495a29147129
git push origin v0.8.0

Please sign in to comment.