Skip to content

Commit

Permalink
Patches for Tensor.__deepcopy__ for current fastai
Browse files Browse the repository at this point in the history
  • Loading branch information
warner-benjamin committed Mar 1, 2023
1 parent 9528029 commit 44d1cb6
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 23 deletions.
58 changes: 46 additions & 12 deletions fastxtend/callback/lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,50 @@
# fastai - Apache License 2.0 - Copyright (c) 2023 fast.ai

# %% ../../nbs/callback.lr_finder.ipynb 4
from fastcore.xtras import is_listy
from fastcore.foundation import patch, docs, Path
from copy import deepcopy
import tempfile
from packaging.version import parse

from fastcore.foundation import Path
from fastcore.basics import tuplify

import fastai
from fastai.callback.schedule import ParamScheduler, SchedExp, SuggestionMethod
from fastai.torch_core import tensor, get_random_states, set_random_states
from fastai.learner import Learner, CancelFitException, CancelValidException
from functools import partial
from copy import deepcopy
import torch
import collections, tempfile

from ..imports import *

# %% ../../nbs/callback.lr_finder.ipynb 6
@docs
if parse(fastai.__version__) < parse('2.7.12'):
_torch_version = parse(torch.__version__)
_torch_113 = parse('1.13')
_torch_112 = parse('1.12')

@patch
def clone(self:TensorBase, *, memory_format=None):
cls = type(self)
return self.as_subclass(Tensor).clone(memory_format=memory_format).as_subclass(cls)

@patch
def new_empty(self:TensorBase, size, *, dtype=None, layout=None, device=None, pin_memory=False, requires_grad=False):
cls = type(self)
if _torch_version < _torch_113 and layout is None:
layout = torch.strided
if _torch_version < _torch_112:
return super(TensorBase, self).new_empty(size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory, requires_grad=requires_grad)
return self.as_subclass(Tensor).new_empty(size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory, requires_grad=requires_grad).as_subclass(cls)

@patch
def new_empty(self:TensorBase, *size, dtype=None, layout=None, device=None, pin_memory=False, requires_grad=False):
cls = type(self)
if _torch_version < _torch_113 and layout is None:
layout = torch.strided
if _torch_version < _torch_112:
return super(TensorBase, self).new_empty(*size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory, requires_grad=requires_grad)
return self.as_subclass(Tensor).new_empty(*size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory, requires_grad=requires_grad).as_subclass(cls)

# %% ../../nbs/callback.lr_finder.ipynb 8
class LRFinder(ParamScheduler):
"Training with exponentially growing learning rate"
def __init__(self, start_lr=1e-7, end_lr=10, num_it=100, stop_div=True, restore_state=True):
Expand Down Expand Up @@ -49,9 +80,12 @@ def before_batch(self):
def after_batch(self):
"Record hyper-parameters of this batch and potentially stop training"
super().after_batch()
if self.smooth_loss < self.best_loss: self.best_loss = self.smooth_loss
if self.smooth_loss > 4*self.best_loss and self.stop_div: raise CancelFitException()
if self.train_iter >= self.num_it: raise CancelFitException()
if self.smooth_loss < self.best_loss:
self.best_loss = self.smooth_loss
if self.smooth_loss > 4*self.best_loss and self.stop_div:
raise CancelFitException()
if self.train_iter >= self.num_it:
raise CancelFitException()

def before_validate(self):
"Skip the validation part of training"
Expand All @@ -68,7 +102,7 @@ def after_fit(self):
self.learn.dls = self.old_dls
set_random_states(**self.states)

# %% ../../nbs/callback.lr_finder.ipynb 15
# %% ../../nbs/callback.lr_finder.ipynb 17
@patch
def lr_find(self:Learner, start_lr=1e-7, end_lr=10, num_it=100, stop_div=True, show_plot=True, suggest_funcs=(SuggestionMethod.Valley), restore_state=True):
"""
Expand All @@ -91,7 +125,7 @@ def lr_find(self:Learner, start_lr=1e-7, end_lr=10, num_it=100, stop_div=True, s
nms.append(func.__name__ if not isinstance(func, partial) else func.func.__name__) # deal with partials
_suggestions.append(func(lrs, losses, num_it))

SuggestedLRs = collections.namedtuple('SuggestedLRs', nms)
SuggestedLRs = namedtuple('SuggestedLRs', nms)
lrs, pnts = [], []
for lr, pnt in _suggestions:
lrs.append(lr)
Expand Down
76 changes: 65 additions & 11 deletions nbs/callback.lr_finder.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,19 @@
"outputs": [],
"source": [
"#|export\n",
"from fastcore.xtras import is_listy\n",
"from fastcore.foundation import patch, docs, Path\n",
"from copy import deepcopy\n",
"import tempfile\n",
"from packaging.version import parse\n",
"\n",
"from fastcore.foundation import Path\n",
"from fastcore.basics import tuplify\n",
"\n",
"import fastai\n",
"from fastai.callback.schedule import ParamScheduler, SchedExp, SuggestionMethod\n",
"from fastai.torch_core import tensor, get_random_states, set_random_states\n",
"from fastai.learner import Learner, CancelFitException, CancelValidException\n",
"from functools import partial\n",
"from copy import deepcopy\n",
"import torch\n",
"import collections, tempfile"
"\n",
"from fastxtend.imports import *"
]
},
{
Expand All @@ -73,7 +76,55 @@
"outputs": [],
"source": [
"#|export\n",
"@docs\n",
"if parse(fastai.__version__) < parse('2.7.12'):\n",
" _torch_version = parse(torch.__version__)\n",
" _torch_113 = parse('1.13')\n",
" _torch_112 = parse('1.12')\n",
"\n",
" @patch\n",
" def clone(self:TensorBase, *, memory_format=None):\n",
" cls = type(self)\n",
" return self.as_subclass(Tensor).clone(memory_format=memory_format).as_subclass(cls)\n",
"\n",
" @patch\n",
" def new_empty(self:TensorBase, size, *, dtype=None, layout=None, device=None, pin_memory=False, requires_grad=False):\n",
" cls = type(self)\n",
" if _torch_version < _torch_113 and layout is None:\n",
" layout = torch.strided\n",
" if _torch_version < _torch_112:\n",
" return super(TensorBase, self).new_empty(size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory, requires_grad=requires_grad)\n",
" return self.as_subclass(Tensor).new_empty(size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory, requires_grad=requires_grad).as_subclass(cls)\n",
"\n",
" @patch\n",
" def new_empty(self:TensorBase, *size, dtype=None, layout=None, device=None, pin_memory=False, requires_grad=False):\n",
" cls = type(self)\n",
" if _torch_version < _torch_113 and layout is None:\n",
" layout = torch.strided\n",
" if _torch_version < _torch_112:\n",
" return super(TensorBase, self).new_empty(*size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory, requires_grad=requires_grad)\n",
" return self.as_subclass(Tensor).new_empty(*size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory, requires_grad=requires_grad).as_subclass(cls)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#|hide\n",
"x = TensorBase(torch.rand(4,3,16,16))\n",
"x.test = 'test metadata'\n",
"y = deepcopy(x)\n",
"assert hasattr(y, 'test') and y.test == x.test"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#|export\n",
"class LRFinder(ParamScheduler):\n",
" \"Training with exponentially growing learning rate\"\n",
" def __init__(self, start_lr=1e-7, end_lr=10, num_it=100, stop_div=True, restore_state=True):\n",
Expand Down Expand Up @@ -102,9 +153,12 @@
" def after_batch(self):\n",
" \"Record hyper-parameters of this batch and potentially stop training\"\n",
" super().after_batch()\n",
" if self.smooth_loss < self.best_loss: self.best_loss = self.smooth_loss\n",
" if self.smooth_loss > 4*self.best_loss and self.stop_div: raise CancelFitException()\n",
" if self.train_iter >= self.num_it: raise CancelFitException()\n",
" if self.smooth_loss < self.best_loss:\n",
" self.best_loss = self.smooth_loss\n",
" if self.smooth_loss > 4*self.best_loss and self.stop_div:\n",
" raise CancelFitException()\n",
" if self.train_iter >= self.num_it:\n",
" raise CancelFitException()\n",
"\n",
" def before_validate(self):\n",
" \"Skip the validation part of training\"\n",
Expand Down Expand Up @@ -340,7 +394,7 @@
" nms.append(func.__name__ if not isinstance(func, partial) else func.func.__name__) # deal with partials\n",
" _suggestions.append(func(lrs, losses, num_it))\n",
"\n",
" SuggestedLRs = collections.namedtuple('SuggestedLRs', nms)\n",
" SuggestedLRs = namedtuple('SuggestedLRs', nms)\n",
" lrs, pnts = [], []\n",
" for lr, pnt in _suggestions:\n",
" lrs.append(lr)\n",
Expand Down

0 comments on commit 44d1cb6

Please sign in to comment.