diff --git a/fastxtend/callback/lr_finder.py b/fastxtend/callback/lr_finder.py index a6c0e6f..60bd4ea 100644 --- a/fastxtend/callback/lr_finder.py +++ b/fastxtend/callback/lr_finder.py @@ -8,19 +8,50 @@ # fastai - Apache License 2.0 - Copyright (c) 2023 fast.ai # %% ../../nbs/callback.lr_finder.ipynb 4 -from fastcore.xtras import is_listy -from fastcore.foundation import patch, docs, Path +from copy import deepcopy +import tempfile +from packaging.version import parse + +from fastcore.foundation import Path from fastcore.basics import tuplify + +import fastai from fastai.callback.schedule import ParamScheduler, SchedExp, SuggestionMethod from fastai.torch_core import tensor, get_random_states, set_random_states from fastai.learner import Learner, CancelFitException, CancelValidException -from functools import partial -from copy import deepcopy -import torch -import collections, tempfile + +from ..imports import * # %% ../../nbs/callback.lr_finder.ipynb 6 -@docs +if parse(fastai.__version__) < parse('2.7.12'): + _torch_version = parse(torch.__version__) + _torch_113 = parse('1.13') + _torch_112 = parse('1.12') + + @patch + def clone(self:TensorBase, *, memory_format=None): + cls = type(self) + return self.as_subclass(Tensor).clone(memory_format=memory_format).as_subclass(cls) + + @patch + def new_empty(self:TensorBase, size, *, dtype=None, layout=None, device=None, pin_memory=False, requires_grad=False): + cls = type(self) + if _torch_version < _torch_113 and layout is None: + layout = torch.strided + if _torch_version < _torch_112: + return super(TensorBase, self).new_empty(size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory, requires_grad=requires_grad) + return self.as_subclass(Tensor).new_empty(size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory, requires_grad=requires_grad).as_subclass(cls) + + @patch + def new_empty(self:TensorBase, *size, dtype=None, layout=None, device=None, pin_memory=False, requires_grad=False): + cls = type(self) + if _torch_version < _torch_113 and layout is None: + layout = torch.strided + if _torch_version < _torch_112: + return super(TensorBase, self).new_empty(*size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory, requires_grad=requires_grad) + return self.as_subclass(Tensor).new_empty(*size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory, requires_grad=requires_grad).as_subclass(cls) + +# %% ../../nbs/callback.lr_finder.ipynb 8 class LRFinder(ParamScheduler): "Training with exponentially growing learning rate" def __init__(self, start_lr=1e-7, end_lr=10, num_it=100, stop_div=True, restore_state=True): @@ -49,9 +80,12 @@ def before_batch(self): def after_batch(self): "Record hyper-parameters of this batch and potentially stop training" super().after_batch() - if self.smooth_loss < self.best_loss: self.best_loss = self.smooth_loss - if self.smooth_loss > 4*self.best_loss and self.stop_div: raise CancelFitException() - if self.train_iter >= self.num_it: raise CancelFitException() + if self.smooth_loss < self.best_loss: + self.best_loss = self.smooth_loss + if self.smooth_loss > 4*self.best_loss and self.stop_div: + raise CancelFitException() + if self.train_iter >= self.num_it: + raise CancelFitException() def before_validate(self): "Skip the validation part of training" @@ -68,7 +102,7 @@ def after_fit(self): self.learn.dls = self.old_dls set_random_states(**self.states) -# %% ../../nbs/callback.lr_finder.ipynb 15 +# %% ../../nbs/callback.lr_finder.ipynb 17 @patch def lr_find(self:Learner, start_lr=1e-7, end_lr=10, num_it=100, stop_div=True, show_plot=True, suggest_funcs=(SuggestionMethod.Valley), restore_state=True): """ @@ -91,7 +125,7 @@ def lr_find(self:Learner, start_lr=1e-7, end_lr=10, num_it=100, stop_div=True, s nms.append(func.__name__ if not isinstance(func, partial) else func.func.__name__) # deal with partials _suggestions.append(func(lrs, losses, num_it)) - SuggestedLRs = collections.namedtuple('SuggestedLRs', nms) + SuggestedLRs = namedtuple('SuggestedLRs', nms) lrs, pnts = [], [] for lr, pnt in _suggestions: lrs.append(lr) diff --git a/nbs/callback.lr_finder.ipynb b/nbs/callback.lr_finder.ipynb index 5fe58fc..b33d1f4 100644 --- a/nbs/callback.lr_finder.ipynb +++ b/nbs/callback.lr_finder.ipynb @@ -44,16 +44,19 @@ "outputs": [], "source": [ "#|export\n", - "from fastcore.xtras import is_listy\n", - "from fastcore.foundation import patch, docs, Path\n", + "from copy import deepcopy\n", + "import tempfile\n", + "from packaging.version import parse\n", + "\n", + "from fastcore.foundation import Path\n", "from fastcore.basics import tuplify\n", + "\n", + "import fastai\n", "from fastai.callback.schedule import ParamScheduler, SchedExp, SuggestionMethod\n", "from fastai.torch_core import tensor, get_random_states, set_random_states\n", "from fastai.learner import Learner, CancelFitException, CancelValidException\n", - "from functools import partial\n", - "from copy import deepcopy\n", - "import torch\n", - "import collections, tempfile" + "\n", + "from fastxtend.imports import *" ] }, { @@ -73,7 +76,55 @@ "outputs": [], "source": [ "#|export\n", - "@docs\n", + "if parse(fastai.__version__) < parse('2.7.12'):\n", + " _torch_version = parse(torch.__version__)\n", + " _torch_113 = parse('1.13')\n", + " _torch_112 = parse('1.12')\n", + "\n", + " @patch\n", + " def clone(self:TensorBase, *, memory_format=None):\n", + " cls = type(self)\n", + " return self.as_subclass(Tensor).clone(memory_format=memory_format).as_subclass(cls)\n", + "\n", + " @patch\n", + " def new_empty(self:TensorBase, size, *, dtype=None, layout=None, device=None, pin_memory=False, requires_grad=False):\n", + " cls = type(self)\n", + " if _torch_version < _torch_113 and layout is None:\n", + " layout = torch.strided\n", + " if _torch_version < _torch_112:\n", + " return super(TensorBase, self).new_empty(size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory, requires_grad=requires_grad)\n", + " return self.as_subclass(Tensor).new_empty(size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory, requires_grad=requires_grad).as_subclass(cls)\n", + "\n", + " @patch\n", + " def new_empty(self:TensorBase, *size, dtype=None, layout=None, device=None, pin_memory=False, requires_grad=False):\n", + " cls = type(self)\n", + " if _torch_version < _torch_113 and layout is None:\n", + " layout = torch.strided\n", + " if _torch_version < _torch_112:\n", + " return super(TensorBase, self).new_empty(*size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory, requires_grad=requires_grad)\n", + " return self.as_subclass(Tensor).new_empty(*size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory, requires_grad=requires_grad).as_subclass(cls)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#|hide\n", + "x = TensorBase(torch.rand(4,3,16,16))\n", + "x.test = 'test metadata'\n", + "y = deepcopy(x)\n", + "assert hasattr(y, 'test') and y.test == x.test" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#|export\n", "class LRFinder(ParamScheduler):\n", " \"Training with exponentially growing learning rate\"\n", " def __init__(self, start_lr=1e-7, end_lr=10, num_it=100, stop_div=True, restore_state=True):\n", @@ -102,9 +153,12 @@ " def after_batch(self):\n", " \"Record hyper-parameters of this batch and potentially stop training\"\n", " super().after_batch()\n", - " if self.smooth_loss < self.best_loss: self.best_loss = self.smooth_loss\n", - " if self.smooth_loss > 4*self.best_loss and self.stop_div: raise CancelFitException()\n", - " if self.train_iter >= self.num_it: raise CancelFitException()\n", + " if self.smooth_loss < self.best_loss:\n", + " self.best_loss = self.smooth_loss\n", + " if self.smooth_loss > 4*self.best_loss and self.stop_div:\n", + " raise CancelFitException()\n", + " if self.train_iter >= self.num_it:\n", + " raise CancelFitException()\n", "\n", " def before_validate(self):\n", " \"Skip the validation part of training\"\n", @@ -340,7 +394,7 @@ " nms.append(func.__name__ if not isinstance(func, partial) else func.func.__name__) # deal with partials\n", " _suggestions.append(func(lrs, losses, num_it))\n", "\n", - " SuggestedLRs = collections.namedtuple('SuggestedLRs', nms)\n", + " SuggestedLRs = namedtuple('SuggestedLRs', nms)\n", " lrs, pnts = [], []\n", " for lr, pnt in _suggestions:\n", " lrs.append(lr)\n",