Skip to content

Commit

Permalink
Merge pull request #156 from stochasticai/dev
Browse files Browse the repository at this point in the history
Release 0.1.0
  • Loading branch information
sarthaklangde authored Apr 18, 2023
2 parents 3531155 + 5d97147 commit 25065ca
Show file tree
Hide file tree
Showing 18 changed files with 1,100 additions and 327 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ With `xturing` you can,

<br>

## 🌟 INT4 fine-tuning with LLaMA LoRA
## 🌟 INT4 fine-tuning and generation with LLaMA LoRA

We are excited to announce the latest enhancement to our `xTuring` library: INT4 fine-tuning demo. With this update, you can fine-tune LLMs like LLaMA with LoRA architecture in INT4 precision with less than `6 GB` of VRAM. This breakthrough significantly reduces memory requirements and accelerates the fine-tuning process, allowing you to achieve state-of-the-art performance with less computational resources.
We are excited to announce the latest enhancement to our `xTuring` library: INT4 fine-tuning and generation integration. With this update, you can fine-tune LLMs like LLaMA with LoRA architecture in INT4 precision with less than `6 GB` of VRAM. This breakthrough significantly reduces memory requirements and accelerates the fine-tuning process, allowing you to achieve state-of-the-art performance with less computational resources.

More information about INT4 fine-tuning and benchmarks can be found in the [INT4 README](examples/int4_finetuning/README.md).

Expand Down Expand Up @@ -146,6 +146,7 @@ model = BaseModel.load("x/distilgpt2_lora_finetuned_alpaca")
- [x] OpenAI, Cohere and AI21 Studio model APIs for dataset generation
- [x] Added fine-tuned checkpoints for some models to the hub
- [x] INT4 LLaMA LoRA fine-tuning demo
- [x] INT4 LLaMA LoRA fine-tuning with INT4 generation
- [ ] Evaluation of LLM models
- [ ] Support for Stable Diffusion

Expand Down
129 changes: 45 additions & 84 deletions examples/int4_finetuning/LLaMA_lora_int4.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -31,119 +31,80 @@
},
"outputs": [],
"source": [
"!pip install xturing --upgrade"
"!pip install xturing --upgrade\n",
"!pip install xturing[int4] --upgrade"
]
},
{
"cell_type": "code",
"execution_count": 7,
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"import subprocess\n",
"from pathlib import Path\n",
"\n",
"def pull_docker_image(image):\n",
" cmd = [\"docker\", \"pull\", image]\n",
" subprocess.run(cmd, check=True)\n",
"\n",
"\n",
"def run_docker_container(image, port_mapping, env_vars=None, gpus=None, volumes=None):\n",
" cmd = [\"docker\", \"container\", \"run\"]\n",
"\n",
" if env_vars is None:\n",
" env_vars = {}\n",
"\n",
" if volumes is None:\n",
" volumes = {}\n",
"\n",
" if gpus is not None:\n",
" cmd.extend([\"--gpus\", gpus])\n",
"\n",
" for key, value in env_vars.items():\n",
" cmd.extend([\"-e\", f\"{key}={value}\"])\n",
"\n",
" for local_path, container_path in volumes.items():\n",
" cmd.extend([\"-v\", f\"{str(Path(local_path).resolve())}:{container_path}\"])\n",
"\n",
" cmd.extend([\"-p\", port_mapping, image])\n",
"\n",
" subprocess.run(cmd)"
"## 2. Load model and dataset"
]
},
{
"cell_type": "markdown",
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
"collapsed": false,
"pycharm": {
"is_executing": true
}
},
"outputs": [],
"source": [
"## 2. Load and run docker image"
"from xturing.datasets.instruction_dataset import InstructionDataset\n",
"from xturing.models import BaseModel\n",
"\n",
"instruction_dataset = InstructionDataset(\"../llama/alpaca_data\")\n",
"# Initializes the model\n",
"model = BaseModel.create(\"llama_lora_int4\")"
]
},
{
"cell_type": "markdown",
"source": [
"## 3. Start the finetuning"
],
"metadata": {
"collapsed": false
},
"source": [
"1. Install Docker on your machine if you haven't already. You can follow the [official Docker documentation](https://docs.docker.com/engine/install/) for installation instructions.\n",
"2. Install NVIDIA Container Toolkit\n",
" ```bash\n",
" sudo apt-get install -y nvidia-docker2\n",
" ```\n",
"3. Run the Docker daemon\n",
" ```bash\n",
" sudo systemctl start docker\n",
" ```\n"
]
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false,
"pycharm": {
"is_executing": true
}
},
"outputs": [],
"source": [
"image = \"public.ecr.aws/t8g5g2q5/xturing:int4_finetuning\"\n",
"port_mapping = \"5000:5000\"\n",
"env_vars = {\n",
" \"WANDB_MODE\": \"dryrun\",\n",
" \"MICRO_BATCH_SIZE\": \"1\", # change this to increase your micro batch size\n",
"}\n",
"# if you want to log results to wandb, set the following env var\n",
"# env_vars = {\n",
"# \"WANDB_API_KEY\": \"<your_wandb_api_key>\",\n",
"# \"WANDB_PROJECT\": \"your_project_name\",\n",
"# \"WANDB_ENTITY\": \"your_entity_name\",\n",
"# # Add more environment variables as needed\n",
"# }\n",
"volumes = {\n",
" # \"<where to save model>\": \"/model\",\n",
" \"../llama/alpaca_data\": \"/data\", # change this to your data path if you want\n",
"}\n",
"gpus = \"all\"\n",
"\n",
"pull_docker_image(image)\n",
"\n",
"run_docker_container(image, port_mapping, env_vars, gpus, volumes)"
]
"# Finetuned the model\n",
"model.finetune(dataset=instruction_dataset)"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Alternately, you can run the example using CLI command:\n",
"\n",
"```bash\n",
"docker run -p 5000:5000 --gpus all -e WANDB_MODE=dryrun -e MICRO_BATCH_SIZE=1 -v /absolute/path/to/alpaca/data:/data public.ecr.aws/t8g5g2q5/xturing:int4_finetuning\n",
"```"
]
"## 4. Generate an output text with the fine-tuned model"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"# Once the model has been finetuned, you can start doing inferences\n",
"output = model.generate(texts=[\"Why LLM models are becoming so important?\"])\n",
"print(\"Generated output by the model: {}\".format(output))"
],
"metadata": {
"collapsed": false
}
}
],
"metadata": {
Expand Down
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "xturing"
version = "0.0.10"
version = "0.1.0"
description = "Fine-tuning, evaluation and data generation for LLMs"

authors = [
Expand Down Expand Up @@ -64,6 +64,10 @@ dependencies = [
[project.scripts]
xturing = "xturing.cli:xturing"

[project.optional-dependencies]
int4 = [
"torch >= 2.0"
]

[project.urls]
homepage = "https://xturing.stochastic.ai/"
Expand Down
2 changes: 1 addition & 1 deletion src/xturing/__about__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.0.10"
__version__ = "0.1.0"
7 changes: 7 additions & 0 deletions src/xturing/config/finetuning_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,13 @@ llama_lora_int8:
batch_size: 8
max_length: 256

llama_lora_int4:
learning_rate: 1e-4
weight_decay: 0.01
num_train_epochs: 3
batch_size: 8
max_length: 256

gptj:
learning_rate: 5e-5
weight_decay: 0.01
Expand Down
7 changes: 7 additions & 0 deletions src/xturing/config/generation_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,13 @@ llama_lora_int8:
max_new_tokens: 256
do_sample: false

# Contrastive search
llama_lora_int4:
penalty_alpha: 0.6
top_k: 4
max_new_tokens: 256
do_sample: false

# Contrastive search
gptj:
penalty_alpha: 0.6
Expand Down
2 changes: 2 additions & 0 deletions src/xturing/engines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
LLamaInt8Engine,
LlamaLoraEngine,
LlamaLoraInt8Engine,
LlamaLoraInt4Engine,
)
from .opt_engine import OPTEngine, OPTInt8Engine, OPTLoraEngine, OPTLoraInt8Engine

Expand All @@ -42,6 +43,7 @@
BaseEngine.add_to_registry(LlamaLoraEngine.config_name, LlamaLoraEngine)
BaseEngine.add_to_registry(LLamaInt8Engine.config_name, LLamaInt8Engine)
BaseEngine.add_to_registry(LlamaLoraInt8Engine.config_name, LlamaLoraInt8Engine)
BaseEngine.add_to_registry(LlamaLoraInt4Engine.config_name, LlamaLoraInt4Engine)
BaseEngine.add_to_registry(GalacticaEngine.config_name, GalacticaEngine)
BaseEngine.add_to_registry(GalacticaInt8Engine.config_name, GalacticaInt8Engine)
BaseEngine.add_to_registry(GalacticaLoraEngine.config_name, GalacticaLoraEngine)
Expand Down
2 changes: 1 addition & 1 deletion src/xturing/engines/causal.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def __init__(

lora_config = LoraConfig(
r=8,
lora_alpha=32,
lora_alpha=16,
target_modules=target_modules,
lora_dropout=0.05,
bias="none",
Expand Down
86 changes: 85 additions & 1 deletion src/xturing/engines/llama_engine.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import os
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
import transformers

import torch
from torch import nn

from xturing.engines.causal import CausalEngine, CausalLoraEngine
from xturing.engines.llama_utils import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
from xturing.engines.lora_engine import prepare_model_for_int8_training

from xturing.engines.quant_utils import make_quant, autotune_warmup
from xturing.utils.hub import ModelHub

class LLamaEngine(CausalEngine):
config_name: str = "llama_engine"
Expand Down Expand Up @@ -98,3 +101,84 @@ def __init__(self, weights_path: Optional[Union[str, Path]] = None):
load_8bit=True,
target_modules=["q_proj", "v_proj"],
)

def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''):
if type(module) in layers:
return {name: module}
res = {}
for name1, child in module.named_children():
res.update(find_layers(
child, layers=layers, name=name + '.' + name1 if name != '' else name1
))
return res

class LlamaLoraInt4Engine(CausalLoraEngine):
config_name: str = "llama_lora_int4_engine"

def __init__(self, weights_path: Optional[Union[str, Path]] = None):
model_name = "decapoda-research/llama-7b-hf"

if weights_path is None:
weights_path = ModelHub().load("x/llama_lora_int4")

config = LlamaConfig.from_pretrained(model_name)

saved_kaiming_uniform_ = torch.nn.init.kaiming_uniform_
saved_uniform_ = torch.nn.init.uniform_
saved_normal_ = torch.nn.init.normal_

def noop(*args, **kwargs):
pass

torch.nn.init.kaiming_uniform_ = noop
torch.nn.init.uniform_ = noop
torch.nn.init.normal_ = noop

torch.set_default_dtype(torch.half)
transformers.modeling_utils._init_weights = False
torch.set_default_dtype(torch.half)
model = LlamaForCausalLM(config)
torch.set_default_dtype(torch.float)
model = model.eval()

layers = find_layers(model)

for name in ['lm_head']:
if name in layers:
del layers[name]

wbits = 4
groupsize = 128
warmup_autotune=True

make_quant(model, layers, wbits, groupsize)


model.load_state_dict(torch.load(weights_path / Path("pytorch_model.bin")), strict=False)

if warmup_autotune:
autotune_warmup(model)

model.seqlen = 2048

model.gptq = True

model.gradient_checkpointing_enable()
model.enable_input_require_grads()

tokenizer = LlamaTokenizer.from_pretrained(model_name, add_bos_token=False)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id

super().__init__(
model=model,
tokenizer=tokenizer,
target_modules=[
"q_proj",
"v_proj",
]
)

torch.nn.init.kaiming_uniform_ = saved_kaiming_uniform_
torch.nn.init.uniform_ = saved_uniform_
torch.nn.init.normal_ = saved_normal_
Loading

0 comments on commit 25065ca

Please sign in to comment.