Skip to content

Commit

Permalink
Unwrap Transformers model by default after fit
Browse files Browse the repository at this point in the history
  • Loading branch information
warner-benjamin committed Oct 14, 2023
1 parent eb5696a commit 56f5e79
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 53 deletions.
2 changes: 2 additions & 0 deletions fastxtend/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1449,6 +1449,8 @@
'fastxtend/text/huggingface.py'),
'fastxtend.text.huggingface.HuggingFaceCallback.after_create': ( 'text.huggingface.html#huggingfacecallback.after_create',
'fastxtend/text/huggingface.py'),
'fastxtend.text.huggingface.HuggingFaceCallback.after_fit': ( 'text.huggingface.html#huggingfacecallback.after_fit',
'fastxtend/text/huggingface.py'),
'fastxtend.text.huggingface.HuggingFaceCallback.after_loss': ( 'text.huggingface.html#huggingfacecallback.after_loss',
'fastxtend/text/huggingface.py'),
'fastxtend.text.huggingface.HuggingFaceCallback.after_pred': ( 'text.huggingface.html#huggingfacecallback.after_pred',
Expand Down
13 changes: 11 additions & 2 deletions fastxtend/text/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,11 @@ class HuggingFaceCallback(Callback):
def __init__(self,
labels:str|None='labels', # Input batch labels key. Set to None if dataset doesn't contain labels
loss:str='loss', # Model output loss key
logits:str='logits', # Model output logits key
logits:str='logits', # Model output logits key,
unwrap:bool=True, # After training completes, unwrap the Transformers model
):
self._label_key, self._loss_key, self._logit_key = labels, loss, logits
self._label_key, self._loss_key = labels, loss
self._logit_key, self.unwrap = logits, unwrap

def after_create(self):
self._model_loss = isinstance(self.learn.loss_func, HuggingFaceLoss)
Expand Down Expand Up @@ -93,6 +95,13 @@ def after_loss(self):
else:
self.xb[0][self._label_key] = self.learn.yb[0]

def after_fit(self):
if self.unwrap:
if isinstance(self.learn.model, dynamo.OptimizedModule) and hasattr(self.learn, 'compiler'):
self.learn.compiler._reset_compiled()
if isinstance(self.model, HuggingFaceWrapper):
self.learn.model = self.learn.model.hf_model

# %% ../../nbs/text.huggingface.ipynb 11
class HuggingFaceLoader(_DataLoader):
"A minimal compatibility wrapper between a Hugging Face Dataloader and `Learner`"
Expand Down
168 changes: 117 additions & 51 deletions nbs/text.huggingface.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"# Hugging Face\n",
"> Basic compatability between fastai and Hugging Face Transformers models"
"# Hugging Face Transformers Compatibility\n",
"> Train Hugging Face Transformers models using fastai"
]
},
{
Expand All @@ -61,7 +61,13 @@
"source": [
"fastxtend provides basic compatibility for training Hugging Face [Transformers](https://huggingface.co/docs/transformers/index) models using the `fastai.learner.Learner`.\n",
"\n",
"For a complete Hugging Face integration with fastai, you should check out [blurr](https://ohmeow.github.io/blurr).\n",
":::{.callout-tip collapse=\"true\"}\n",
"#### Tip: Use blurr For Integrated Transformers with fastai\n",
"\n",
"[blurr](https://ohmeow.github.io/blurr) provides a complete Hugging Face Transformers integration with fastai, including working fastai datablocks, dataloaders, and other fastai methods.\n",
"\n",
"In contrast, fastxtend only provides basic `Learner` compatibility.\n",
":::\n",
"\n",
"To use fastxend's compatibility, setup the Hugging Face dataset, dataloader, and model per the [Transformers documentation](https://huggingface.co/docs/transformers/index), exchanging the PyTorch `Dataloader` for the `HuggingFaceLoader`. Then wrap the dataloaders in `fastai.data.core.DataLoaders` and create a `Learner` with the Hugging Face model, `HuggingFaceLoss`, and `HuggingFaceCallback`. This will automatically setup the compatibility and use the Hugging Face model's built in loss.\n",
"\n",
Expand All @@ -75,6 +81,7 @@
" drop_last=True, num_workers=num_cpus()\n",
")\n",
"\n",
"# defining the valid_dataloader cut for brevity\n",
"dls = DataLoaders(train_dataloader, valid_dataloader)\n",
"\n",
"hf_model = GPTForCausalLM(...)\n",
Expand Down Expand Up @@ -154,9 +161,11 @@
" def __init__(self,\n",
" labels:str|None='labels', # Input batch labels key. Set to None if dataset doesn't contain labels\n",
" loss:str='loss', # Model output loss key\n",
" logits:str='logits', # Model output logits key\n",
" logits:str='logits', # Model output logits key,\n",
" unwrap:bool=True, # After training completes, unwrap the Transformers model\n",
" ):\n",
" self._label_key, self._loss_key, self._logit_key = labels, loss, logits\n",
" self._label_key, self._loss_key = labels, loss\n",
" self._logit_key, self.unwrap = logits, unwrap\n",
"\n",
" def after_create(self):\n",
" self._model_loss = isinstance(self.learn.loss_func, HuggingFaceLoss)\n",
Expand Down Expand Up @@ -184,18 +193,29 @@
" self.learn.loss_grad = self._loss\n",
" self.learn.loss = self.learn.loss_grad.clone()\n",
" else:\n",
" self.xb[0][self._label_key] = self.learn.yb[0]"
" self.xb[0][self._label_key] = self.learn.yb[0]\n",
"\n",
" def after_fit(self):\n",
" if self.unwrap:\n",
" if isinstance(self.learn.model, dynamo.OptimizedModule) and hasattr(self.learn, 'compiler'):\n",
" self.learn.compiler._reset_compiled()\n",
" if isinstance(self.model, HuggingFaceWrapper):\n",
" self.learn.model = self.learn.model.hf_model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If `HuggingFaceLoss` is passed to `fastai.learner.Learner`, then <code>HuggingFaceCallback</code> will use the Hugging Face model's built in loss.\n",
"<code>HuggingFaceCallback</code> automatically wraps a Transformer model with the `HuggingFaceWrapper` for compatability with `fastai.learner.Learner`.\n",
"\n",
"If `HuggingFaceLoss` is passed to `Learner`, then <code>HuggingFaceCallback</code> will use the Hugging Face model's built in loss.\n",
"\n",
"If any other loss function is passed to `Learner`, <code>HuggingFaceCallback</code> will prevent the built-in loss from being calculated and will use the `Learner` loss function instead.\n",
"\n",
"If `labels=None`, then <code>HuggingFaceCallback</code> will not attempt to assign a fastai label from the Hugging Face input batch. The default fastai and fastxtend metrics will not work without targets."
"If `labels=None`, then <code>HuggingFaceCallback</code> will not attempt to assign a fastai label from the Hugging Face input batch. The default fastai and fastxtend metrics will not work without targets.\n",
"\n",
"After training, the <code>HuggingFaceCallback</code> will automatically unwrap model. Set `unwrap=False` to keep the model wrapped in <code>HuggingFaceWrapper</code>."
]
},
{
Expand Down Expand Up @@ -265,6 +285,7 @@
"outputs": [],
"source": [
"#|hide\n",
"#|cuda\n",
"import os\n",
"\n",
"from datasets import concatenate_datasets, load_dataset\n",
Expand Down Expand Up @@ -318,7 +339,6 @@
"outputs": [],
"source": [
"#|cuda\n",
"#|example\n",
"tokenizer = AutoTokenizer.from_pretrained('distilroberta-base')\n",
"model = AutoModelForSequenceClassification.from_pretrained('distilroberta-base', num_labels=2)"
]
Expand All @@ -337,7 +357,6 @@
"outputs": [],
"source": [
"#|cuda\n",
"#|example\n",
"imdb = load_dataset('imdb')\n",
"with less_random():\n",
" imdb['train'] = imdb['train'].shuffle().select(range(5000))\n",
Expand All @@ -348,7 +367,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Next we'll tokenize the data using Dataset's `map` method."
"Next, we'll tokenize the data using Dataset's `map` method."
]
},
{
Expand All @@ -358,7 +377,6 @@
"outputs": [],
"source": [
"#|cuda\n",
"#|example\n",
"def tokenize_data(batch, tokenizer):\n",
" return tokenizer(batch['text'], truncation=True)\n",
"\n",
Expand Down Expand Up @@ -391,16 +409,15 @@
"outputs": [],
"source": [
"#|cuda\n",
"#|example\n",
"with less_random():\n",
" train_dataloader = HuggingFaceLoader(\n",
" imdb['train'].with_format('torch'), batch_size=32,\n",
" imdb['train'].with_format('torch'), batch_size=16,\n",
" collate_fn=DataCollatorWithPadding(tokenizer), shuffle=True,\n",
" drop_last=True, num_workers=num_cpus()\n",
" )\n",
"\n",
" valid_dataloader = HuggingFaceLoader(\n",
" imdb['test'].with_format('torch'), batch_size=32,\n",
" imdb['test'].with_format('torch'), batch_size=16,\n",
" collate_fn=DataCollatorWithPadding(tokenizer), shuffle=False,\n",
" drop_last=False, num_workers=num_cpus()\n",
" )\n",
Expand Down Expand Up @@ -477,24 +494,24 @@
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>0.690667</td>\n",
" <td>0.685078</td>\n",
" <td>0.505000</td>\n",
" <td>00:34</td>\n",
" <td>0.691708</td>\n",
" <td>0.690203</td>\n",
" <td>0.492000</td>\n",
" <td>00:38</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>0.517118</td>\n",
" <td>0.380596</td>\n",
" <td>0.860000</td>\n",
" <td>00:34</td>\n",
" <td>0.510412</td>\n",
" <td>0.409681</td>\n",
" <td>0.854000</td>\n",
" <td>00:37</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>0.305611</td>\n",
" <td>0.297281</td>\n",
" <td>0.879000</td>\n",
" <td>00:34</td>\n",
" <td>0.282954</td>\n",
" <td>0.300484</td>\n",
" <td>0.873000</td>\n",
" <td>00:38</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
Expand All @@ -509,12 +526,12 @@
],
"source": [
"#|cuda\n",
"#|example\n",
"with less_random():\n",
" learn = Learner(dls, model, loss_func=HuggingFaceLoss(), metrics=Accuracy(),\n",
" opt_func=stableadam(foreach=True), cbs=HuggingFaceCallback()).to_bf16()\n",
" learn = Learner(dls, model, loss_func=HuggingFaceLoss(),\n",
" opt_func=stableadam(foreach=True),\n",
" metrics=Accuracy(), cbs=HuggingFaceCallback).to_bf16()\n",
"\n",
" learn.fit_flat_warmup(3, lr=1e-3, wd=1e-2)"
" learn.fit_flat_warmup(3, lr=8e-4, wd=1e-2)"
]
},
{
Expand All @@ -525,18 +542,23 @@
"source": [
"#|hide\n",
"#|cuda\n",
"#|example\n",
"model = None\n",
"free_gpu_memory(learn)"
"free_gpu_memory(learn, dls)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If we wanted to define our own loss, such as`nn.CrossEntropyLoss` with label smoothing, we could pass in any PyTorch compatible loss function to `Learner` and <code>HuggingFaceCallback</code> will automatically use it instead of DistilRoBERTa's internal loss function.\n",
"If we wanted to define our own loss, such as `nn.CrossEntropyLoss` with label smoothing, we could pass in any PyTorch compatible loss function to `Learner` and <code>HuggingFaceCallback</code> will automatically use it instead of DistilRoBERTa's internal loss function.\n",
"\n",
"In this example, we use fastxtend's `CompilerCallback` callback via the `Learner.compile` convenience method to accelerate training throughput using `torch.compile`. After compiling the model in the first epoch, this significantly speeds up training and reduces memory usage. An overall loss in this small example, but we'd want to use it if training on the entirety of IMDb."
"In this example, we use fastxtend's `CompilerCallback` via the `Learner.compile` convenience method to accelerate training throughput using `torch.compile`. After compiling the model in the first epoch, training speed is increased, and memory usage is reduced. In this small example it's an overall loss, but we'd want to compile DistilRoBERTa if training on the entirety of IMDb.\n",
"\n",
":::{.callout-warning collapse=\"true\"}\n",
"#### Warning: Dynamic Requires PyTorch 2.1\n",
"\n",
"Compiling the model with `compile(dynamic=True)` requires Pytorch 2.1. Dynamic shapes does not work in PyTorch 2.0.\n",
":::"
]
},
{
Expand Down Expand Up @@ -587,24 +609,24 @@
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>0.689795</td>\n",
" <td>0.685920</td>\n",
" <td>0.507000</td>\n",
" <td>01:23</td>\n",
" <td>0.686346</td>\n",
" <td>0.677865</td>\n",
" <td>0.658000</td>\n",
" <td>01:25</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>0.556394</td>\n",
" <td>0.444239</td>\n",
" <td>0.865000</td>\n",
" <td>00:23</td>\n",
" <td>0.423131</td>\n",
" <td>0.383354</td>\n",
" <td>0.886000</td>\n",
" <td>00:27</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>0.388424</td>\n",
" <td>0.383318</td>\n",
" <td>0.885000</td>\n",
" <td>00:23</td>\n",
" <td>0.355547</td>\n",
" <td>0.374400</td>\n",
" <td>0.887000</td>\n",
" <td>00:27</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
Expand All @@ -619,15 +641,59 @@
],
"source": [
"#|cuda\n",
"#|example\n",
"model = AutoModelForSequenceClassification.from_pretrained('distilroberta-base', num_labels=2)\n",
"\n",
"with less_random():\n",
" learn = Learner(dls, model, loss_func=nn.CrossEntropyLoss(label_smoothing=0.1),\n",
" metrics=Accuracy(), opt_func=stableadam(foreach=True),\n",
" cbs=HuggingFaceCallback()).to_bf16().compile(dynamic=True)\n",
" opt_func=stableadam(foreach=True), metrics=Accuracy(),\n",
" cbs=HuggingFaceCallback).to_bf16().compile(dynamic=True)\n",
"\n",
" learn.fit_flat_warmup(3, lr=1e-3, wd=1e-2)"
" learn.fit_flat_warmup(3, lr=8e-4, wd=1e-2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Saving the Model\n",
"\n",
"After training, the <code>HuggingFaceCallback</code> will automatically unwrap our model, leaving `Learner.model` as the original Transformers model. \n",
"\n",
"We use any Transformers method to save the model, such as `save_pretrained`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#|hide\n",
"#|cuda\n",
"import tempfile\n",
"temp_path = tempfile.TemporaryDirectory(dir=learn.path)\n",
"model_path = temp_path.name"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#|cuda\n",
"learn.model.save_pretrained(model_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#|hide\n",
"#|cuda\n",
"temp_path.cleanup()"
]
}
],
Expand Down

0 comments on commit 56f5e79

Please sign in to comment.