diff --git a/README.md b/README.md
index b6a4adb..1e604c8 100644
--- a/README.md
+++ b/README.md
@@ -33,9 +33,9 @@ With `xturing` you can,
-## 🌟 INT4 fine-tuning with LLaMA LoRA
+## 🌟 INT4 fine-tuning and generation with LLaMA LoRA
-We are excited to announce the latest enhancement to our `xTuring` library: INT4 fine-tuning demo. With this update, you can fine-tune LLMs like LLaMA with LoRA architecture in INT4 precision with less than `6 GB` of VRAM. This breakthrough significantly reduces memory requirements and accelerates the fine-tuning process, allowing you to achieve state-of-the-art performance with less computational resources.
+We are excited to announce the latest enhancement to our `xTuring` library: INT4 fine-tuning and generation integration. With this update, you can fine-tune LLMs like LLaMA with LoRA architecture in INT4 precision with less than `6 GB` of VRAM. This breakthrough significantly reduces memory requirements and accelerates the fine-tuning process, allowing you to achieve state-of-the-art performance with less computational resources.
More information about INT4 fine-tuning and benchmarks can be found in the [INT4 README](examples/int4_finetuning/README.md).
@@ -146,6 +146,7 @@ model = BaseModel.load("x/distilgpt2_lora_finetuned_alpaca")
- [x] OpenAI, Cohere and AI21 Studio model APIs for dataset generation
- [x] Added fine-tuned checkpoints for some models to the hub
- [x] INT4 LLaMA LoRA fine-tuning demo
+- [x] INT4 LLaMA LoRA fine-tuning with INT4 generation
- [ ] Evaluation of LLM models
- [ ] Support for Stable Diffusion
diff --git a/examples/int4_finetuning/LLaMA_lora_int4.ipynb b/examples/int4_finetuning/LLaMA_lora_int4.ipynb
index d503a39..8859907 100644
--- a/examples/int4_finetuning/LLaMA_lora_int4.ipynb
+++ b/examples/int4_finetuning/LLaMA_lora_int4.ipynb
@@ -31,119 +31,80 @@
},
"outputs": [],
"source": [
- "!pip install xturing --upgrade"
+ "!pip install xturing --upgrade\n",
+ "!pip install xturing[int4] --upgrade"
]
},
{
- "cell_type": "code",
- "execution_count": 7,
+ "cell_type": "markdown",
"metadata": {
"collapsed": false
},
- "outputs": [],
"source": [
- "import subprocess\n",
- "from pathlib import Path\n",
- "\n",
- "def pull_docker_image(image):\n",
- " cmd = [\"docker\", \"pull\", image]\n",
- " subprocess.run(cmd, check=True)\n",
- "\n",
- "\n",
- "def run_docker_container(image, port_mapping, env_vars=None, gpus=None, volumes=None):\n",
- " cmd = [\"docker\", \"container\", \"run\"]\n",
- "\n",
- " if env_vars is None:\n",
- " env_vars = {}\n",
- "\n",
- " if volumes is None:\n",
- " volumes = {}\n",
- "\n",
- " if gpus is not None:\n",
- " cmd.extend([\"--gpus\", gpus])\n",
- "\n",
- " for key, value in env_vars.items():\n",
- " cmd.extend([\"-e\", f\"{key}={value}\"])\n",
- "\n",
- " for local_path, container_path in volumes.items():\n",
- " cmd.extend([\"-v\", f\"{str(Path(local_path).resolve())}:{container_path}\"])\n",
- "\n",
- " cmd.extend([\"-p\", port_mapping, image])\n",
- "\n",
- " subprocess.run(cmd)"
+ "## 2. Load model and dataset"
]
},
{
- "cell_type": "markdown",
+ "cell_type": "code",
+ "execution_count": null,
"metadata": {
- "collapsed": false
+ "collapsed": false,
+ "pycharm": {
+ "is_executing": true
+ }
},
+ "outputs": [],
"source": [
- "## 2. Load and run docker image"
+ "from xturing.datasets.instruction_dataset import InstructionDataset\n",
+ "from xturing.models import BaseModel\n",
+ "\n",
+ "instruction_dataset = InstructionDataset(\"../llama/alpaca_data\")\n",
+ "# Initializes the model\n",
+ "model = BaseModel.create(\"llama_lora_int4\")"
]
},
{
"cell_type": "markdown",
+ "source": [
+ "## 3. Start the finetuning"
+ ],
"metadata": {
"collapsed": false
- },
- "source": [
- "1. Install Docker on your machine if you haven't already. You can follow the [official Docker documentation](https://docs.docker.com/engine/install/) for installation instructions.\n",
- "2. Install NVIDIA Container Toolkit\n",
- " ```bash\n",
- " sudo apt-get install -y nvidia-docker2\n",
- " ```\n",
- "3. Run the Docker daemon\n",
- " ```bash\n",
- " sudo systemctl start docker\n",
- " ```\n"
- ]
+ }
},
{
"cell_type": "code",
"execution_count": null,
- "metadata": {
- "collapsed": false,
- "pycharm": {
- "is_executing": true
- }
- },
"outputs": [],
"source": [
- "image = \"public.ecr.aws/t8g5g2q5/xturing:int4_finetuning\"\n",
- "port_mapping = \"5000:5000\"\n",
- "env_vars = {\n",
- " \"WANDB_MODE\": \"dryrun\",\n",
- " \"MICRO_BATCH_SIZE\": \"1\", # change this to increase your micro batch size\n",
- "}\n",
- "# if you want to log results to wandb, set the following env var\n",
- "# env_vars = {\n",
- "# \"WANDB_API_KEY\": \"\",\n",
- "# \"WANDB_PROJECT\": \"your_project_name\",\n",
- "# \"WANDB_ENTITY\": \"your_entity_name\",\n",
- "# # Add more environment variables as needed\n",
- "# }\n",
- "volumes = {\n",
- " # \"\": \"/model\",\n",
- " \"../llama/alpaca_data\": \"/data\", # change this to your data path if you want\n",
- "}\n",
- "gpus = \"all\"\n",
- "\n",
- "pull_docker_image(image)\n",
- "\n",
- "run_docker_container(image, port_mapping, env_vars, gpus, volumes)"
- ]
+ "# Finetuned the model\n",
+ "model.finetune(dataset=instruction_dataset)"
+ ],
+ "metadata": {
+ "collapsed": false
+ }
},
{
"cell_type": "markdown",
- "metadata": {},
"source": [
- "## Alternately, you can run the example using CLI command:\n",
- "\n",
- "```bash\n",
- "docker run -p 5000:5000 --gpus all -e WANDB_MODE=dryrun -e MICRO_BATCH_SIZE=1 -v /absolute/path/to/alpaca/data:/data public.ecr.aws/t8g5g2q5/xturing:int4_finetuning\n",
- "```"
- ]
+ "## 4. Generate an output text with the fine-tuned model"
+ ],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "outputs": [],
+ "source": [
+ "# Once the model has been finetuned, you can start doing inferences\n",
+ "output = model.generate(texts=[\"Why LLM models are becoming so important?\"])\n",
+ "print(\"Generated output by the model: {}\".format(output))"
+ ],
+ "metadata": {
+ "collapsed": false
+ }
}
],
"metadata": {
diff --git a/pyproject.toml b/pyproject.toml
index acc8eef..f870e6b 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[project]
name = "xturing"
-version = "0.0.10"
+version = "0.1.0"
description = "Fine-tuning, evaluation and data generation for LLMs"
authors = [
@@ -64,6 +64,10 @@ dependencies = [
[project.scripts]
xturing = "xturing.cli:xturing"
+[project.optional-dependencies]
+int4 = [
+ "torch >= 2.0"
+]
[project.urls]
homepage = "https://xturing.stochastic.ai/"
diff --git a/src/xturing/__about__.py b/src/xturing/__about__.py
index 9b36b86..3dc1f76 100644
--- a/src/xturing/__about__.py
+++ b/src/xturing/__about__.py
@@ -1 +1 @@
-__version__ = "0.0.10"
+__version__ = "0.1.0"
diff --git a/src/xturing/config/finetuning_config.yaml b/src/xturing/config/finetuning_config.yaml
index 3b1cde3..9a8d260 100644
--- a/src/xturing/config/finetuning_config.yaml
+++ b/src/xturing/config/finetuning_config.yaml
@@ -33,6 +33,13 @@ llama_lora_int8:
batch_size: 8
max_length: 256
+llama_lora_int4:
+ learning_rate: 1e-4
+ weight_decay: 0.01
+ num_train_epochs: 3
+ batch_size: 8
+ max_length: 256
+
gptj:
learning_rate: 5e-5
weight_decay: 0.01
diff --git a/src/xturing/config/generation_config.yaml b/src/xturing/config/generation_config.yaml
index 12ca88d..7279f6a 100644
--- a/src/xturing/config/generation_config.yaml
+++ b/src/xturing/config/generation_config.yaml
@@ -27,6 +27,13 @@ llama_lora_int8:
max_new_tokens: 256
do_sample: false
+# Contrastive search
+llama_lora_int4:
+ penalty_alpha: 0.6
+ top_k: 4
+ max_new_tokens: 256
+ do_sample: false
+
# Contrastive search
gptj:
penalty_alpha: 0.6
diff --git a/src/xturing/engines/__init__.py b/src/xturing/engines/__init__.py
index 5ce012d..8291147 100644
--- a/src/xturing/engines/__init__.py
+++ b/src/xturing/engines/__init__.py
@@ -25,6 +25,7 @@
LLamaInt8Engine,
LlamaLoraEngine,
LlamaLoraInt8Engine,
+ LlamaLoraInt4Engine,
)
from .opt_engine import OPTEngine, OPTInt8Engine, OPTLoraEngine, OPTLoraInt8Engine
@@ -42,6 +43,7 @@
BaseEngine.add_to_registry(LlamaLoraEngine.config_name, LlamaLoraEngine)
BaseEngine.add_to_registry(LLamaInt8Engine.config_name, LLamaInt8Engine)
BaseEngine.add_to_registry(LlamaLoraInt8Engine.config_name, LlamaLoraInt8Engine)
+BaseEngine.add_to_registry(LlamaLoraInt4Engine.config_name, LlamaLoraInt4Engine)
BaseEngine.add_to_registry(GalacticaEngine.config_name, GalacticaEngine)
BaseEngine.add_to_registry(GalacticaInt8Engine.config_name, GalacticaInt8Engine)
BaseEngine.add_to_registry(GalacticaLoraEngine.config_name, GalacticaLoraEngine)
diff --git a/src/xturing/engines/causal.py b/src/xturing/engines/causal.py
index 3598f44..86dd1cb 100644
--- a/src/xturing/engines/causal.py
+++ b/src/xturing/engines/causal.py
@@ -146,7 +146,7 @@ def __init__(
lora_config = LoraConfig(
r=8,
- lora_alpha=32,
+ lora_alpha=16,
target_modules=target_modules,
lora_dropout=0.05,
bias="none",
diff --git a/src/xturing/engines/llama_engine.py b/src/xturing/engines/llama_engine.py
index 46d9441..af0dd6a 100644
--- a/src/xturing/engines/llama_engine.py
+++ b/src/xturing/engines/llama_engine.py
@@ -1,13 +1,16 @@
import os
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
+import transformers
import torch
+from torch import nn
from xturing.engines.causal import CausalEngine, CausalLoraEngine
from xturing.engines.llama_utils import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
from xturing.engines.lora_engine import prepare_model_for_int8_training
-
+from xturing.engines.quant_utils import make_quant, autotune_warmup
+from xturing.utils.hub import ModelHub
class LLamaEngine(CausalEngine):
config_name: str = "llama_engine"
@@ -98,3 +101,84 @@ def __init__(self, weights_path: Optional[Union[str, Path]] = None):
load_8bit=True,
target_modules=["q_proj", "v_proj"],
)
+
+def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''):
+ if type(module) in layers:
+ return {name: module}
+ res = {}
+ for name1, child in module.named_children():
+ res.update(find_layers(
+ child, layers=layers, name=name + '.' + name1 if name != '' else name1
+ ))
+ return res
+
+class LlamaLoraInt4Engine(CausalLoraEngine):
+ config_name: str = "llama_lora_int4_engine"
+
+ def __init__(self, weights_path: Optional[Union[str, Path]] = None):
+ model_name = "decapoda-research/llama-7b-hf"
+
+ if weights_path is None:
+ weights_path = ModelHub().load("x/llama_lora_int4")
+
+ config = LlamaConfig.from_pretrained(model_name)
+
+ saved_kaiming_uniform_ = torch.nn.init.kaiming_uniform_
+ saved_uniform_ = torch.nn.init.uniform_
+ saved_normal_ = torch.nn.init.normal_
+
+ def noop(*args, **kwargs):
+ pass
+
+ torch.nn.init.kaiming_uniform_ = noop
+ torch.nn.init.uniform_ = noop
+ torch.nn.init.normal_ = noop
+
+ torch.set_default_dtype(torch.half)
+ transformers.modeling_utils._init_weights = False
+ torch.set_default_dtype(torch.half)
+ model = LlamaForCausalLM(config)
+ torch.set_default_dtype(torch.float)
+ model = model.eval()
+
+ layers = find_layers(model)
+
+ for name in ['lm_head']:
+ if name in layers:
+ del layers[name]
+
+ wbits = 4
+ groupsize = 128
+ warmup_autotune=True
+
+ make_quant(model, layers, wbits, groupsize)
+
+
+ model.load_state_dict(torch.load(weights_path / Path("pytorch_model.bin")), strict=False)
+
+ if warmup_autotune:
+ autotune_warmup(model)
+
+ model.seqlen = 2048
+
+ model.gptq = True
+
+ model.gradient_checkpointing_enable()
+ model.enable_input_require_grads()
+
+ tokenizer = LlamaTokenizer.from_pretrained(model_name, add_bos_token=False)
+ tokenizer.pad_token = tokenizer.eos_token
+ tokenizer.pad_token_id = tokenizer.eos_token_id
+
+ super().__init__(
+ model=model,
+ tokenizer=tokenizer,
+ target_modules=[
+ "q_proj",
+ "v_proj",
+ ]
+ )
+
+ torch.nn.init.kaiming_uniform_ = saved_kaiming_uniform_
+ torch.nn.init.uniform_ = saved_uniform_
+ torch.nn.init.normal_ = saved_normal_
diff --git a/src/xturing/engines/lora_engine/lora.py b/src/xturing/engines/lora_engine/lora.py
index fa8539c..8e81faf 100644
--- a/src/xturing/engines/lora_engine/lora.py
+++ b/src/xturing/engines/lora_engine/lora.py
@@ -21,6 +21,7 @@
from dataclasses import asdict, dataclass, field
from enum import Enum
from typing import List, Optional, Union
+import enum
import torch
import torch.nn as nn
@@ -44,6 +45,17 @@ def is_bnb_available():
def transpose(weight, fan_in_fan_out):
return weight.T if fan_in_fan_out else weight
+def is_gptq_available():
+ return importlib.util.find_spec("xturing.engines.quant_utils") is not None
+
+if is_gptq_available():
+ from ..quant_utils import QuantLinear
+
+class PeftType(str, enum.Enum):
+ PROMPT_TUNING = "PROMPT_TUNING"
+ P_TUNING = "P_TUNING"
+ PREFIX_TUNING = "PREFIX_TUNING"
+ LORA = "LORA"
WEIGHTS_NAME = "adapter_model.bin"
CONFIG_NAME = "adapter_config.json"
@@ -52,19 +64,18 @@ def transpose(weight, fan_in_fan_out):
@dataclass
class LoraConfig:
"""
- This is the configuration class to store the configuration of a [`LoraModel`].
-
+ This is the configuration class to store the configuration of a [`~peft.Lora`].
Args:
- r (`int`): Lora attention dimension.
+ r (`int`): Lora attention dimension
target_modules (`Union[List[str],str]`): The names of the modules to apply Lora to.
lora_alpha (`float`): The alpha parameter for Lora scaling.
lora_dropout (`float`): The dropout probability for Lora layers.
merge_weights (`bool`):
Whether to merge the weights of the Lora layers with the base transformer model in `eval` mode.
- fan_in_fan_out (`bool`): Set this to True if the layer to replace stores weight like (`fan_in`, `fan_out`).
- enable_lora ( `List[bool]`): Used with [`lora.MergedLinear`].
- bias (`str`): Bias type for Lora. Can be `none`, `all` or `lora_only`.
- modules_to_save (`List[str]`): List of modules apart from Lora layers to be set as trainable
+ fan_in_fan_out (`bool`): Set this to True if the layer to replace stores weight like (fan_in, fan_out)
+ enable_lora ( `List[bool]`): Used with `lora.MergedLinear`.
+ bias (`str`): Bias type for Lora. Can be 'none', 'all' or 'lora_only'
+ modules_to_save (`List[str]`):List of modules apart from LoRA layers to be set as trainable
and saved in the final checkpoint.
"""
@@ -79,22 +90,14 @@ class LoraConfig:
lora_alpha: int = field(default=None, metadata={"help": "Lora alpha"})
lora_dropout: float = field(default=None, metadata={"help": "Lora dropout"})
merge_weights: bool = field(
- default=False,
- metadata={"help": "Merge weights of the original model and the Lora model"},
+ default=False, metadata={"help": "Merge weights of the original model and the Lora model"}
)
fan_in_fan_out: bool = field(
default=False,
- metadata={
- "help": "Set this to True if the layer to replace stores weight like (fan_in, fan_out)"
- },
- )
- enable_lora: Optional[List[bool]] = field(
- default=None, metadata={"help": "Used with `lora.MergedLinear`."}
- )
- bias: str = field(
- default="none",
- metadata={"help": "Bias type for Lora. Can be 'none', 'all' or 'lora_only'"},
+ metadata={"help": "Set this to True if the layer to replace stores weight like (fan_in, fan_out)"},
)
+ enable_lora: Optional[List[bool]] = field(default=None, metadata={"help": "Used with `lora.MergedLinear`."})
+ bias: str = field(default="none", metadata={"help": "Bias type for Lora. Can be 'none', 'all' or 'lora_only'"})
modules_to_save: Optional[List[str]] = field(
default=None,
metadata={
@@ -110,6 +113,7 @@ class LoraConfig:
inference_mode: bool = field(
default=False, metadata={"help": "Whether to use inference mode"}
)
+ peft_type: PeftType = PeftType.LORA
base_model_name_or_path: str = field(
default=None, metadata={"help": "The name of the base model to use."}
@@ -201,35 +205,19 @@ def from_json_file(cls, path_json_file, **kwargs):
class LoraModel(torch.nn.Module):
"""
Creates Low Rank Adapter (Lora) model from a pretrained transformers model.
-
Args:
- model ([`~transformers.PreTrainedModel`]): The model to be adapted.
+ model ([`transformers.PreTrainedModel`]): The model to be adapted.
config ([`LoraConfig`]): The configuration of the Lora model.
-
Returns:
`torch.nn.Module`: The Lora model.
-
- Example:
-
- ```py
- >>> from transformers import AutoModelForSeq2SeqLM, LoraConfig
- >>> from peft import LoraModel, LoraConfig
-
- >>> config = LoraConfig(
- ... peft_type="LORA",
- ... task_type="SEQ_2_SEQ_LM",
- ... r=8,
- ... lora_alpha=32,
- ... target_modules=["q", "v"],
- ... lora_dropout=0.01,
- ... )
-
- >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
- >>> lora_model = LoraModel(config, model)
- ```
-
+ Example::
+ >>> from transformers import AutoModelForSeq2SeqLM, LoraConfig >>> from peft import LoraModel, LoraConfig >>>
+ config = LoraConfig(
+ peft_type="LORA", task_type="SEQ_2_SEQ_LM", r=8, lora_alpha=32, target_modules=["q", "v"],
+ lora_dropout=0.01, )
+ >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") >>> lora_model = LoraModel(config, model)
**Attributes**:
- - **model** ([`~transformers.PreTrainedModel`]) -- The model to be adapted.
+ - **model** ([`transformers.PreTrainedModel`]) -- The model to be adapted.
- **peft_config** ([`LoraConfig`]): The configuration of the Lora model.
"""
@@ -243,6 +231,8 @@ def __init__(self, config, model):
def _find_and_replace(self):
loaded_in_8bit = getattr(self.model, "is_loaded_in_8bit", False)
+ is_gtq_quantized = getattr(self.model, "gptq", False) # Step 1: Check if the model is GTQ quantized
+
if loaded_in_8bit and not is_bnb_available():
raise ImportError(
"To use Lora with 8-bit quantization, please install the `bitsandbytes` package. "
@@ -255,21 +245,15 @@ def _find_and_replace(self):
"lora_alpha": self.peft_config.lora_alpha,
"lora_dropout": self.peft_config.lora_dropout,
"fan_in_fan_out": self.peft_config.fan_in_fan_out,
- "merge_weights": (
- self.peft_config.merge_weights or self.peft_config.inference_mode
- )
+ "merge_weights": (self.peft_config.merge_weights or self.peft_config.inference_mode)
and not is_hf_device_map_available,
- "init_lora_weights": self.peft_config.init_lora_weights,
}
key_list = [key for key, _ in self.model.named_modules()]
for key in key_list:
if isinstance(self.peft_config.target_modules, str):
target_module_found = re.fullmatch(self.peft_config.target_modules, key)
else:
- target_module_found = any(
- key.endswith(target_key)
- for target_key in self.peft_config.target_modules
- )
+ target_module_found = any(key.endswith(target_key) for target_key in self.peft_config.target_modules)
if target_module_found:
if not is_target_modules_in_base_model:
is_target_modules_in_base_model = True
@@ -285,45 +269,49 @@ def _find_and_replace(self):
}
)
if self.peft_config.enable_lora is None:
- new_module = Linear8bitLt(
- target.in_features, target.out_features, bias=bias, **kwargs
- )
+ new_module = Linear8bitLt(target.in_features, target.out_features, bias=bias, **kwargs)
else:
kwargs.update({"enable_lora": self.peft_config.enable_lora})
- new_module = MergedLinear8bitLt(
- target.in_features, target.out_features, bias=bias, **kwargs
- )
- elif (
- isinstance(target, torch.nn.Linear)
- and self.peft_config.enable_lora is None
- ):
- new_module = Linear(
- target.in_features, target.out_features, bias=bias, **kwargs
+ new_module = MergedLinear8bitLt(target.in_features, target.out_features, bias=bias, **kwargs)
+ elif is_gptq_available() and isinstance(target, QuantLinear):
+ kwargs.update(
+ {
+ "bits": target.bits,
+ "groupsize": target.groupsize,
+ }
)
+ if self.peft_config.enable_lora is None:
+ new_module = LinearqbitLt(target.infeatures, target.outfeatures, bias=bias, **kwargs)
+ new_module.scales = target.scales
+ new_module.qzeros = target.qzeros
+ new_module.g_idx = target.g_idx
+ if target.bias:
+ new_module.bias = target.bias
+ else:
+ kwargs.update({"enable_lora": self.peft_config.enable_lora})
+ new_module = MergedLinearqbitLt(target.infeatures, target.outfeatures, bias=bias, **kwargs)
+ new_module.scales = target.scales
+ new_module.qzeros = target.qzeros
+ new_module.g_idx = target.g_idx
+ if target.bias:
+ new_module.bias = target.bias
+ elif isinstance(target, torch.nn.Linear) and self.peft_config.enable_lora is None:
+ new_module = Linear(target.in_features, target.out_features, bias=bias, **kwargs)
elif self.peft_config.enable_lora is not None:
kwargs.update({"enable_lora": self.peft_config.enable_lora})
if isinstance(target, Conv1D):
in_features, out_features = (
- target.weight.ds_shape
- if hasattr(target.weight, "ds_shape")
- else target.weight.shape
+ target.weight.ds_shape if hasattr(target.weight, "ds_shape") else target.weight.shape
)
else:
- in_features, out_features = (
- target.in_features,
- target.out_features,
- )
+ in_features, out_features = target.in_features, target.out_features
if kwargs["fan_in_fan_out"]:
warnings.warn(
"fan_in_fan_out is set to True but the target module is not a Conv1D. "
"Setting fan_in_fan_out to False."
)
- kwargs[
- "fan_in_fan_out"
- ] = self.peft_config.fan_in_fan_out = False
- new_module = MergedLinear(
- in_features, out_features, bias=bias, **kwargs
- )
+ kwargs["fan_in_fan_out"] = self.peft_config.fan_in_fan_out = False
+ new_module = MergedLinear(in_features, out_features, bias=bias, **kwargs)
self._replace_module(parent, target_name, new_module, target)
if not is_target_modules_in_base_model:
raise ValueError(
@@ -339,17 +327,30 @@ def _get_submodules(self, key):
def _replace_module(self, parent_module, child_name, new_module, old_module):
setattr(parent_module, child_name, new_module)
- new_module.weight = old_module.weight
- if old_module.bias is not None:
- new_module.bias = old_module.bias
- if getattr(old_module, "state", None) is not None:
- new_module.state = old_module.state
- new_module.to(old_module.weight.device)
-
- # dispatch to correct device
- for name, module in new_module.named_modules():
- if "lora_" in name:
- module.to(old_module.weight.device)
+ if is_gptq_available() and isinstance(old_module, QuantLinear):
+ new_module.qweight = old_module.qweight
+ if old_module.bias is not None:
+ new_module.bias = old_module.bias
+ if getattr(old_module, "state", None) is not None:
+ new_module.state = old_module.state
+ new_module.to(old_module.qweight.device)
+
+ # dispatch to correct device
+ for name, module in new_module.named_modules():
+ if "lora_" in name:
+ module.to(old_module.qweight.device)
+ else:
+ new_module.weight = old_module.weight
+ if old_module.bias is not None:
+ new_module.bias = old_module.bias
+ if getattr(old_module, "state", None) is not None:
+ new_module.state = old_module.state
+ new_module.to(old_module.weight.device)
+
+ # dispatch to correct device
+ for name, module in new_module.named_modules():
+ if "lora_" in name:
+ module.to(old_module.weight.device)
def __getattr__(self, name: str):
"""Forward missing attributes to the wrapped module."""
@@ -363,10 +364,7 @@ def modules_to_save(self):
return None
def get_peft_config_as_dict(self, inference: bool = False):
- config = {
- k: v.value if isinstance(v, Enum) else v
- for k, v in asdict(self.peft_config).items()
- }
+ config = {k: v.value if isinstance(v, Enum) else v for k, v in asdict(self.peft_config).items()}
if inference:
config["inference_mode"] = True
return config
@@ -382,44 +380,6 @@ def enable_adapter_layers(self):
def disable_adapter_layers(self):
self._set_adapter_layers(enabled=False)
- def merge_and_unload(self):
- r"""
- This method merges the LoRa layers into the base model. This is needed if someone wants to use the base model
- as a standalone model.
- """
- if self.config.model_type == "gpt2":
- raise ValueError("GPT2 models are not supported for merging LORA layers")
-
- if getattr(self.model, "is_loaded_in_8bit", False):
- raise ValueError(
- "Cannot merge LORA layers when the model is loaded in 8-bit mode"
- )
-
- key_list = [key for key, _ in self.model.named_modules() if "lora" not in key]
- for key in key_list:
- parent, target, target_name = self._get_submodules(key)
- if isinstance(target, LoraLayer):
- bias = target.bias is not None
- new_module = torch.nn.Linear(
- target.in_features, target.out_features, bias=bias
- )
-
- # manually merge if not merged
- if not target.merged:
- # merge weights per: https://arxiv.org/pdf/2106.09685.pdf / page 4
- if target.r > 0:
- target.weight.data += (
- transpose(
- target.lora_B.weight @ target.lora_A.weight,
- target.fan_in_fan_out,
- )
- * target.scaling
- ).to(target.weight.dtype)
- target.merged = True
-
- self._replace_module(parent, target_name, new_module, target)
- return self.model
-
def print_trainable_parameters(self):
"""
Prints the number of trainable parameters in the model.
@@ -551,16 +511,8 @@ def __init__(
merge_weights: bool = True,
**kwargs,
):
- init_lora_weights = kwargs.pop("init_lora_weights", True)
-
nn.Linear.__init__(self, in_features, out_features, **kwargs)
- LoraLayer.__init__(
- self,
- r=r,
- lora_alpha=lora_alpha,
- lora_dropout=lora_dropout,
- merge_weights=merge_weights,
- )
+ LoraLayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights)
self.fan_in_fan_out = fan_in_fan_out
# Actual trainable parameters
@@ -570,8 +522,7 @@ def __init__(
self.scaling = self.lora_alpha / self.r
# Freezing the pre-trained weight matrix
self.weight.requires_grad = False
- if init_lora_weights:
- self.reset_parameters()
+ self.reset_parameters()
if fan_in_fan_out:
self.weight.data = self.weight.data.T
@@ -590,20 +541,14 @@ def train(self, mode: bool = True):
# Merge the weights and mark it
if self.r > 0:
self.weight.data += (
- transpose(
- self.lora_B.weight @ self.lora_A.weight, self.fan_in_fan_out
- )
- * self.scaling
+ transpose(self.lora_B.weight @ self.lora_A.weight, self.fan_in_fan_out) * self.scaling
)
self.merged = True
elif self.merge_weights and self.merged:
# Make sure that the weights are not merged
if self.r > 0:
self.weight.data -= (
- transpose(
- self.lora_B.weight @ self.lora_A.weight, self.fan_in_fan_out
- )
- * self.scaling
+ transpose(self.lora_B.weight @ self.lora_A.weight, self.fan_in_fan_out) * self.scaling
)
self.merged = False
@@ -616,27 +561,19 @@ def forward(self, x: torch.Tensor):
if self.disable_adapters:
if self.r > 0 and self.merged:
self.weight.data -= (
- transpose(
- self.lora_B.weight @ self.lora_A.weight, self.fan_in_fan_out
- )
- * self.scaling
+ transpose(self.lora_B.weight @ self.lora_A.weight, self.fan_in_fan_out) * self.scaling
)
self.merged = False
- return F.linear(
- x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias
- )
+ return F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
elif self.r > 0 and not self.merged:
- result = F.linear(
- x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias
- )
+ result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
if self.r > 0:
- result += self.lora_B(self.lora_A(self.lora_dropout(x))) * self.scaling
+ loraoutput = self.lora_B(self.lora_A(self.lora_dropout(x))) * self.scaling
+ result = result + loraoutput
return result
else:
- return F.linear(
- x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias
- )
+ return F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
class MergedLinear(nn.Linear, LoraLayer):
@@ -653,16 +590,8 @@ def __init__(
merge_weights: bool = True,
**kwargs,
):
- init_lora_weights = kwargs.pop("init_lora_weights", True)
-
nn.Linear.__init__(self, in_features, out_features, **kwargs)
- LoraLayer.__init__(
- self,
- r=r,
- lora_alpha=lora_alpha,
- lora_dropout=lora_dropout,
- merge_weights=merge_weights,
- )
+ LoraLayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights)
if out_features % len(enable_lora) != 0:
raise ValueError("The length of enable_lora must divide out_features")
self.enable_lora = enable_lora
@@ -681,14 +610,10 @@ def __init__(
# Freezing the pre-trained weight matrix
self.weight.requires_grad = False
# Compute the indices
- self.lora_ind = self.weight.new_zeros(
- (out_features,), dtype=torch.bool
- ).view(len(enable_lora), -1)
+ self.lora_ind = self.weight.new_zeros((out_features,), dtype=torch.bool).view(len(enable_lora), -1)
self.lora_ind[enable_lora, :] = True
self.lora_ind = self.lora_ind.view(-1)
-
- if init_lora_weights:
- self.reset_parameters()
+ self.reset_parameters()
if fan_in_fan_out:
self.weight.data = self.weight.data.T
@@ -702,9 +627,7 @@ def reset_parameters(self):
def zero_pad(self, x):
result = x.new_zeros((*x.shape[:-1], self.out_features))
result = result.view(-1, self.out_features)
- result[:, self.lora_ind] = x.reshape(
- -1, self.out_features // len(self.enable_lora) * sum(self.enable_lora)
- )
+ result[:, self.lora_ind] = x.reshape(-1, self.out_features // len(self.enable_lora) * sum(self.enable_lora))
return result.view((*x.shape[:-1], self.out_features))
def train(self, mode: bool = True):
@@ -723,9 +646,7 @@ def train(self, mode: bool = True):
.squeeze(0)
.transpose(-2, -1)
)
- self.weight.data += transpose(
- self.zero_pad(delta_w * self.scaling), not self.fan_in_fan_out
- )
+ self.weight.data += transpose(self.zero_pad(delta_w * self.scaling), not self.fan_in_fan_out)
self.merged = True
elif self.merge_weights and self.merged:
# Make sure that the weights are not merged
@@ -739,9 +660,7 @@ def train(self, mode: bool = True):
.squeeze(0)
.transpose(-2, -1)
)
- self.weight.data -= transpose(
- self.zero_pad(delta_w * self.scaling), not self.fan_in_fan_out
- )
+ self.weight.data -= transpose(self.zero_pad(delta_w * self.scaling), not self.fan_in_fan_out)
self.merged = False
def eval(self):
@@ -761,21 +680,13 @@ def forward(self, x: torch.Tensor):
.squeeze(0)
.transpose(-2, -1)
)
- self.weight.data -= transpose(
- self.zero_pad(delta_w * self.scaling), not self.fan_in_fan_out
- )
+ self.weight.data -= transpose(self.zero_pad(delta_w * self.scaling), not self.fan_in_fan_out)
self.merged = False
- return F.linear(
- x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias
- )
+ return F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
elif self.merged:
- return F.linear(
- x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias
- )
+ return F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
else:
- result = F.linear(
- x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias
- )
+ result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
if self.r > 0:
after_A = self.lora_A(self.lora_dropout(x))
after_B = self.lora_B(after_A.transpose(-2, -1)).transpose(-2, -1)
@@ -802,19 +713,11 @@ def __init__(
out_features,
bias=kwargs.get("bias", True),
has_fp16_weights=kwargs.get("has_fp16_weights", True),
- memory_efficient_backward=kwargs.get(
- "memory_efficient_backward", False
- ),
+ memory_efficient_backward=kwargs.get("memory_efficient_backward", False),
threshold=kwargs.get("threshold", 0.0),
index=kwargs.get("index", None),
)
- LoraLayer.__init__(
- self,
- r=r,
- lora_alpha=lora_alpha,
- lora_dropout=lora_dropout,
- merge_weights=False,
- )
+ LoraLayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=False)
# Actual trainable parameters
if r > 0:
self.lora_A = nn.Linear(in_features, r, bias=False)
@@ -841,17 +744,10 @@ def forward(self, x: torch.Tensor):
if x.dtype != torch.float32:
x = x.float()
- output = (
- self.lora_B(self.lora_A(self.lora_dropout(x))).to(
- expected_dtype
- )
- * self.scaling
- )
+ output = self.lora_B(self.lora_A(self.lora_dropout(x))).to(expected_dtype) * self.scaling
result += output
else:
- output = (
- self.lora_B(self.lora_A(self.lora_dropout(x))) * self.scaling
- )
+ output = self.lora_B(self.lora_A(self.lora_dropout(x))) * self.scaling
result += output
return result
@@ -873,19 +769,11 @@ def __init__(
out_features,
bias=kwargs.get("bias", True),
has_fp16_weights=kwargs.get("has_fp16_weights", True),
- memory_efficient_backward=kwargs.get(
- "memory_efficient_backward", False
- ),
+ memory_efficient_backward=kwargs.get("memory_efficient_backward", False),
threshold=kwargs.get("threshold", 0.0),
index=kwargs.get("index", None),
)
- LoraLayer.__init__(
- self,
- r=r,
- lora_alpha=lora_alpha,
- lora_dropout=lora_dropout,
- merge_weights=False,
- )
+ LoraLayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=False)
if out_features % len(enable_lora) != 0:
raise ValueError("The length of enable_lora must divide out_features")
self.enable_lora = enable_lora
@@ -903,9 +791,7 @@ def __init__(
# Freezing the pre-trained weight matrix
self.weight.requires_grad = False
# Compute the indices
- self.lora_ind = self.weight.new_zeros(
- (out_features,), dtype=torch.bool
- ).view(len(enable_lora), -1)
+ self.lora_ind = self.weight.new_zeros((out_features,), dtype=torch.bool).view(len(enable_lora), -1)
self.lora_ind[enable_lora, :] = True
self.lora_ind = self.lora_ind.view(-1)
self.reset_parameters()
@@ -944,6 +830,133 @@ def forward(self, x: torch.Tensor):
result += output
return result
+if is_gptq_available():
+ class LinearqbitLt(QuantLinear, LoraLayer):
+ # Lora implemented in a dense layer
+ def __init__(
+ self,
+ in_features,
+ out_features,
+ r: int = 0,
+ lora_alpha: int = 1,
+ lora_dropout: float = 0.0,
+ **kwargs,
+ ):
+
+ QuantLinear.__init__(
+ self,
+ kwargs.get('bits', 4),
+ kwargs.get('groupsize', 128),
+ in_features,
+ out_features,
+ kwargs.get('bias', False),
+ )
+
+ LoraLayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=False)
+ # Actual trainable parameters
+ if r > 0:
+ self.lora_A = nn.Linear(in_features, r, bias=False)
+ self.lora_B = nn.Linear(r, out_features, bias=False)
+ self.scaling = self.lora_alpha / self.r
+ # Freezing the pre-trained weight matrix
+ self.qweight.requires_grad = False
+ self.scales.requires_grad = False
+ self.qzeros.requires_grad = False
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ if hasattr(self, "lora_A"):
+ # initialize A the same way as the default for nn.Linear and B to zero
+ # nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5))
+ self.lora_A.weight = torch.nn.Parameter(torch.nn.init.kaiming_uniform(self.lora_A.weight, a=math.sqrt(5)))
+ nn.init.zeros_(self.lora_B.weight)
+
+ def forward(self, x: torch.Tensor):
+ # x = x.detach()
+ custom_layer_output = super().forward(x)
+
+ dtype = custom_layer_output.dtype
+ x = x.float()
+ lora_output = self.lora_B(self.lora_A(self.lora_dropout(x))).to(dtype) * self.scaling
+ result = custom_layer_output + lora_output
+ return result
+
+ class MergedLinearqbitLt(QuantLinear, LoraLayer):
+ # Lora implemented in a dense layer
+ def __init__(
+ self,
+ in_features: int,
+ out_features: int,
+ r: int = 0,
+ lora_alpha: int = 1,
+ lora_dropout: float = 0.0,
+ enable_lora: List[bool] = [False],
+ **kwargs,
+ ):
+ QuantLinear.__init__(
+ self,
+ kwargs.get('bits', 4),
+ kwargs.get('groupsize', 128),
+ in_features,
+ out_features,
+ )
+ LoraLayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=False)
+ if out_features % len(enable_lora) != 0:
+ raise ValueError("The length of enable_lora must divide out_features")
+ self.enable_lora = enable_lora
+ # Actual trainable parameters
+ if r > 0 and any(enable_lora):
+ self.lora_A = nn.Linear(in_features, r * sum(enable_lora), bias=False)
+ self.lora_B = nn.Conv1d(
+ r * sum(enable_lora),
+ out_features // len(enable_lora) * sum(enable_lora),
+ kernel_size=1,
+ groups=2,
+ bias=False,
+ )
+ self.scaling = self.lora_alpha / self.r
+ # Freezing the pre-trained weight matrix
+ self.qweight.requires_grad = False
+ # Compute the indices
+ self.lora_ind = self.weight.new_zeros((out_features,), dtype=torch.bool).view(len(enable_lora), -1)
+ self.lora_ind[enable_lora, :] = True
+ self.lora_ind = self.lora_ind.view(-1)
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ if hasattr(self, "lora_A"):
+ # initialize A the same way as the default for nn.Linear and B to zero
+ nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5))
+ nn.init.zeros_(self.lora_B.weight)
+
+ def zero_pad(self, x):
+ result = x.new_zeros((*x.shape[:-1], self.out_features))
+ result = result.view(-1, self.out_features)
+ result[:, self.lora_ind] = x.reshape(
+ -1, self.out_features // len(self.enable_lora) * sum(self.enable_lora)
+ )
+ return result.view((*x.shape[:-1], self.out_features))
+
+ def forward(self, x: torch.Tensor):
+ result = super().forward(x)#.detach()
+ if self.disable_adapters:
+ return result
+ elif self.r > 0:
+ if not torch.is_autocast_enabled():
+ expected_dtype = result.dtype
+ if x.dtype != torch.float32:
+ x = x.float()
+ after_A = self.lora_A(self.lora_dropout(x))
+ after_B = self.lora_B(after_A.transpose(-2, -1)).transpose(-2, -1)
+ output = self.zero_pad(after_B).to(expected_dtype) * self.scaling
+ result += output
+ else:
+ after_A = self.lora_A(self.lora_dropout(x))
+ after_B = self.lora_B(after_A.transpose(-2, -1)).transpose(-2, -1)
+ output = self.zero_pad(after_B) * self.scaling
+ result += output
+ return result
+
def prepare_model_for_int8_training(
model,
@@ -955,7 +968,6 @@ def prepare_model_for_int8_training(
This method wrapps the entire protocol for preparing a model before running a training. This includes:
1- Cast the layernorm in fp32 2- making output embedding layer require grads 3- Add the upcasting of the lm
head to fp32
-
Args:
model, (`transformers.PreTrainedModel`):
The loaded model from `transformers`
@@ -995,7 +1007,6 @@ class CastOutputToFloat(torch.nn.Sequential):
r"""
Manually cast to the expected dtype of the lm_head as sometimes there is a final layer norm that is casted
in fp32
-
"""
def forward(self, x):
diff --git a/src/xturing/engines/quant_utils/__init__.py b/src/xturing/engines/quant_utils/__init__.py
new file mode 100644
index 0000000..11d0eb5
--- /dev/null
+++ b/src/xturing/engines/quant_utils/__init__.py
@@ -0,0 +1 @@
+from .quant import make_quant, autotune_warmup, QuantLinear
\ No newline at end of file
diff --git a/src/xturing/engines/quant_utils/custom_autotune.py b/src/xturing/engines/quant_utils/custom_autotune.py
new file mode 100644
index 0000000..dfc8c6f
--- /dev/null
+++ b/src/xturing/engines/quant_utils/custom_autotune.py
@@ -0,0 +1,168 @@
+#https://github.com/fpgaminer/GPTQ-triton
+"""
+Mostly the same as the autotuner in Triton, but with a few changes like using 40 runs instead of 100.
+"""
+
+import builtins
+import math
+import time
+from typing import Dict
+
+import triton
+
+
+class Autotuner(triton.KernelInterface):
+ def __init__(self, fn, arg_names, configs, key, reset_to_zero, prune_configs_by: Dict = None, nearest_power_of_two: bool = False):
+ '''
+ :param prune_configs_by: a dict of functions that are used to prune configs, fields:
+ 'perf_model': performance model used to predicate running time with different configs, returns running time
+ 'top_k': number of configs to bench
+ 'prune_num_stages_by'(optional): a function used to prune num_stages. It take configs:List[Config] as its input, and returns pruned configs.
+ 'nearest_power_of_two'(optional): whether to round key arguments to the nearest power of two when caching tuning results
+ '''
+ if not configs:
+ self.configs = [triton.Config({}, num_warps=4, num_stages=2)]
+ else:
+ self.configs = configs
+ self.key_idx = [arg_names.index(k) for k in key]
+ self.nearest_power_of_two = nearest_power_of_two
+ self.cache = {}
+ # hook to reset all required tensor to zeros before relaunching a kernel
+ self.hook = lambda args: 0
+ if reset_to_zero is not None:
+ self.reset_idx = [arg_names.index(k) for k in reset_to_zero]
+
+ def _hook(args):
+ for i in self.reset_idx:
+ args[i].zero_()
+ self.hook = _hook
+ self.arg_names = arg_names
+ # prune configs
+ if prune_configs_by:
+ perf_model, top_k = prune_configs_by['perf_model'], prune_configs_by['top_k']
+ if 'early_config_prune' in prune_configs_by:
+ early_config_prune = prune_configs_by['early_config_prune']
+ else:
+ perf_model, top_k, early_config_prune = None, None, None
+ self.perf_model, self.configs_top_k = perf_model, top_k
+ self.early_config_prune = early_config_prune
+ self.fn = fn
+
+ def _bench(self, *args, config, **meta):
+ # check for conflicts, i.e. meta-parameters both provided
+ # as kwargs and by the autotuner
+ conflicts = meta.keys() & config.kwargs.keys()
+ if conflicts:
+ raise ValueError(
+ f"Conflicting meta-parameters: {', '.join(conflicts)}."
+ " Make sure that you don't re-define auto-tuned symbols."
+ )
+ # augment meta-parameters with tunable ones
+ current = dict(meta, **config.kwargs)
+
+ def kernel_call():
+ if config.pre_hook:
+ config.pre_hook(self.nargs)
+ self.hook(args)
+ self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current)
+ try:
+ # In testings using only 40 reps seems to be close enough and it appears to be what PyTorch uses
+ # PyTorch also sets fast_flush to True, but I didn't see any speedup so I'll leave the default
+ return triton.testing.do_bench(kernel_call, rep=40)
+ except triton.compiler.OutOfResources:
+ return float('inf')
+
+ def run(self, *args, **kwargs):
+ self.nargs = dict(zip(self.arg_names, args))
+ if len(self.configs) > 1:
+ key = tuple(args[i] for i in self.key_idx)
+
+ # This reduces the amount of autotuning by rounding the keys to the nearest power of two
+ # In my testing this gives decent results, and greatly reduces the amount of tuning required
+ if self.nearest_power_of_two:
+ key = tuple([2 ** int(math.log2(x) + 0.5) for x in key])
+
+ if key not in self.cache:
+ # prune configs
+ pruned_configs = self.prune_configs(kwargs)
+ bench_start = time.time()
+ timings = {config: self._bench(*args, config=config, **kwargs)
+ for config in pruned_configs}
+ timings = {k:v for k,v in timings.items() if v != float('inf')}
+ bench_end = time.time()
+ self.bench_time = bench_end - bench_start
+ self.cache[key] = builtins.min(timings, key=timings.get)
+ self.hook(args)
+ self.configs_timings = timings
+ config = self.cache[key]
+ else:
+ config = self.configs[0]
+ self.best_config = config
+ if config.pre_hook is not None:
+ config.pre_hook(self.nargs)
+ return self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs)
+
+ def prune_configs(self, kwargs):
+ pruned_configs = self.configs
+ if self.early_config_prune:
+ pruned_configs = self.early_config_prune(self.configs, self.nargs)
+ if self.perf_model:
+ top_k = self.configs_top_k
+ if isinstance(top_k, float) and top_k <= 1.0:
+ top_k = int(len(self.configs) * top_k)
+ if len(pruned_configs) > top_k:
+ est_timing = {
+ config: self.perf_model(**self.nargs, **kwargs, **config.kwargs, num_stages=config.num_stages,
+ num_warps=config.num_warps)
+ for config in pruned_configs
+ }
+ pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k]
+ return pruned_configs
+
+ def warmup(self, *args, **kwargs):
+ self.nargs = dict(zip(self.arg_names, args))
+ for config in self.prune_configs(kwargs):
+ self.fn.warmup(
+ *args,
+ num_warps=config.num_warps,
+ num_stages=config.num_stages,
+ **kwargs,
+ **config.kwargs,
+ )
+ self.nargs = None
+
+
+def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, nearest_power_of_two=False):
+ """
+ Decorator for auto-tuning a :code:`triton.jit`'d function.
+ .. highlight:: python
+ .. code-block:: python
+ @triton.autotune(configs=[
+ triton.Config(meta={'BLOCK_SIZE': 128}, num_warps=4),
+ triton.Config(meta={'BLOCK_SIZE': 1024}, num_warps=8),
+ ],
+ key=['x_size'] # the two above configs will be evaluated anytime
+ # the value of x_size changes
+ )
+ @triton.jit
+ def kernel(x_ptr, x_size, **META):
+ BLOCK_SIZE = META['BLOCK_SIZE']
+ :note: When all the configurations are evaluated, the kernel will run multiple time.
+ This means that whatever value the kernel updates will be updated multiple times.
+ To avoid this undesired behavior, you can use the `reset_to_zero` argument, which
+ reset the value of the provided tensor to `zero` before running any configuration.
+ :param configs: a list of :code:`triton.Config` objects
+ :type configs: list[triton.Config]
+ :param key: a list of argument names whose change in value will trigger the evaluation of all provided configs.
+ :type key: list[str]
+ :param prune_configs_by: a dict of functions that are used to prune configs, fields:
+ 'perf_model': performance model used to predicate running time with different configs, returns running time
+ 'top_k': number of configs to bench
+ 'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It take configs:List[Config] as its input, and returns pruned configs.
+ :param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs.
+ :type reset_to_zero: list[str]
+ """
+ def decorator(fn):
+ return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero, prune_configs_by, nearest_power_of_two)
+
+ return decorator
diff --git a/src/xturing/engines/quant_utils/quant.py b/src/xturing/engines/quant_utils/quant.py
new file mode 100644
index 0000000..448e344
--- /dev/null
+++ b/src/xturing/engines/quant_utils/quant.py
@@ -0,0 +1,496 @@
+import numpy as np
+import torch
+import torch.nn as nn
+from torch.cuda.amp import custom_bwd, custom_fwd
+import math
+
+def quantize(x, scale, zero, maxq):
+ if maxq < 0:
+ return (x > scale / 2).float() * scale + (x < zero / 2).float() * zero
+ q = torch.clamp(torch.round(x / scale) + zero, 0, maxq)
+ return scale * (q - zero)
+
+class Quantizer(nn.Module):
+
+ def __init__(self, shape=1):
+ super(Quantizer, self).__init__()
+ self.register_buffer('maxq', torch.tensor(0))
+ self.register_buffer('scale', torch.zeros(shape))
+ self.register_buffer('zero', torch.zeros(shape))
+
+ def configure(
+ self,
+ bits, perchannel=False, sym=True,
+ mse=False, norm=2.4, grid=100, maxshrink=.8,
+ trits=False
+ ):
+
+ self.maxq = torch.tensor(2 ** bits - 1)
+ self.perchannel = perchannel
+ self.sym = sym
+ self.mse = mse
+ self.norm = norm
+ self.grid = grid
+ self.maxshrink = maxshrink
+ if trits:
+ self.maxq = torch.tensor(-1)
+
+ def find_params(self, x, weight=False):
+ dev = x.device
+ self.maxq = self.maxq.to(dev)
+
+ shape = x.shape
+ if self.perchannel:
+ if weight:
+ x = x.flatten(1)
+ else:
+ if len(shape) == 4:
+ x = x.permute([1, 0, 2, 3])
+ x = x.flatten(1)
+ if len(shape) == 3:
+ x = x.reshape((-1, shape[-1])).t()
+ if len(shape) == 2:
+ x = x.t()
+ else:
+ x = x.flatten().unsqueeze(0)
+
+ tmp = torch.zeros(x.shape[0], device=dev)
+ xmin = torch.minimum(x.min(1)[0], tmp)
+ xmax = torch.maximum(x.max(1)[0], tmp)
+
+ if self.sym:
+ xmax = torch.maximum(torch.abs(xmin), xmax)
+ tmp = xmin < 0
+ if torch.any(tmp):
+ xmin[tmp] = -xmax[tmp]
+ tmp = (xmin == 0) & (xmax == 0)
+ xmin[tmp] = -1
+ xmax[tmp] = +1
+
+ if self.maxq < 0:
+ self.scale = xmax
+ self.zero = xmin
+ else:
+ self.scale = (xmax - xmin) / self.maxq
+ if self.sym:
+ self.zero = torch.full_like(self.scale, (self.maxq + 1) / 2)
+ else:
+ self.zero = torch.round(-xmin / self.scale)
+
+ if self.mse:
+ best = torch.full([x.shape[0]], float('inf'), device=dev)
+ for i in range(int(self.maxshrink * self.grid)):
+ p = 1 - i / self.grid
+ xmin1 = p * xmin
+ xmax1 = p * xmax
+ scale1 = (xmax1 - xmin1) / self.maxq
+ zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero
+ q = quantize(x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq)
+ q -= x
+ q.abs_()
+ q.pow_(self.norm)
+ err = torch.sum(q, 1)
+ tmp = err < best
+ if torch.any(tmp):
+ best[tmp] = err[tmp]
+ self.scale[tmp] = scale1[tmp]
+ self.zero[tmp] = zero1[tmp]
+ if not self.perchannel:
+ if weight:
+ tmp = shape[0]
+ else:
+ tmp = shape[1] if len(shape) != 3 else shape[2]
+ self.scale = self.scale.repeat(tmp)
+ self.zero = self.zero.repeat(tmp)
+
+ if weight:
+ shape = [-1] + [1] * (len(shape) - 1)
+ self.scale = self.scale.reshape(shape)
+ self.zero = self.zero.reshape(shape)
+ return
+ if len(shape) == 4:
+ self.scale = self.scale.reshape((1, -1, 1, 1))
+ self.zero = self.zero.reshape((1, -1, 1, 1))
+ if len(shape) == 3:
+ self.scale = self.scale.reshape((1, 1, -1))
+ self.zero = self.zero.reshape((1, 1, -1))
+ if len(shape) == 2:
+ self.scale = self.scale.unsqueeze(0)
+ self.zero = self.zero.unsqueeze(0)
+
+ def quantize(self, x):
+ if self.ready():
+ return quantize(x, self.scale, self.zero, self.maxq)
+ return x
+
+ def enabled(self):
+ return self.maxq > 0
+
+ def ready(self):
+ return torch.all(self.scale != 0)
+
+try:
+ import triton
+ import triton.language as tl
+ from . import custom_autotune
+
+ # code based https://github.com/fpgaminer/GPTQ-triton
+ @custom_autotune.autotune(
+ configs=[
+ triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
+ # These provided a benefit on a 3090
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
+ triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
+ triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
+ triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
+ ],
+ key=['M', 'N'],
+ nearest_power_of_two=True,
+ )
+
+ @triton.jit
+ def matmul_248_kernel(a_ptr, b_ptr, c_ptr,
+ scales_ptr, zeros_ptr, g_ptr,
+ M, N, K, bits, maxq,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
+ stride_scales, stride_zeros,
+ BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
+ GROUP_SIZE_M: tl.constexpr):
+ """
+ Compute the matrix multiplication C = A x B.
+ A is of shape (M, K) float16
+ B is of shape (K//8, N) int32
+ C is of shape (M, N) float16
+ scales is of shape (G, N) float16
+ zeros is of shape (G, N) float16
+ g_ptr is of shape (K) int32
+ """
+ infearure_per_bits = 32 // bits
+
+ pid = tl.program_id(axis=0)
+ num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
+ num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
+ num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
+ group_id = pid // num_pid_in_group
+ first_pid_m = group_id * GROUP_SIZE_M
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
+ pid_m = first_pid_m + (pid % group_size_m)
+ pid_n = (pid % num_pid_in_group) // group_size_m
+
+ offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
+ offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
+ a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
+ a_mask = (offs_am[:, None] < M)
+ # b_ptrs is set up such that it repeats elements along the K axis 8 times
+ b_ptrs = b_ptr + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N)
+ g_ptrs = g_ptr + offs_k
+ # shifter is used to extract the N bits of each element in the 32-bit word from B
+ scales_ptrs = scales_ptr + offs_bn[None, :]
+ zeros_ptrs = zeros_ptr + (offs_bn[None, :]// infearure_per_bits)
+
+ shifter = (offs_k % infearure_per_bits) * bits
+ zeros_shifter = (offs_bn % infearure_per_bits) * bits
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
+
+ for k in range(0, num_pid_k):
+ g_idx = tl.load(g_ptrs)
+
+ # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop
+ scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
+ zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
+
+ zeros = (zeros >> zeros_shifter[None, :]) & maxq
+ zeros = (zeros + 1)
+
+ a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
+ b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
+
+ # Now we need to unpack b (which is N-bit values) into 32-bit values
+ b = (b >> shifter[:, None]) & maxq # Extract the N-bit values
+ b = (b - zeros) * scales # Scale and shift
+
+ accumulator += tl.dot(a, b)
+ a_ptrs += BLOCK_SIZE_K
+ b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk
+ g_ptrs += BLOCK_SIZE_K
+
+ c = accumulator.to(tl.float16)
+ c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :]
+ c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N)
+ tl.store(c_ptrs, accumulator, mask=c_mask)
+
+ # code based https://github.com/fpgaminer/GPTQ-triton
+ @custom_autotune.autotune(
+ configs=[
+ triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_N': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 256, 'BLOCK_SIZE_N': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_N': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_N': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_N': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_K': 32, 'BLOCK_SIZE_N': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
+ # These provided a benefit on a 3090
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_N': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
+ triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_N': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 32, 'BLOCK_SIZE_N': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
+ triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_N': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_N': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 32, 'BLOCK_SIZE_N': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
+ triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_N': 128, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
+ ],
+ key=['M', 'K'],
+ nearest_power_of_two=True,
+ )
+
+ @triton.jit
+ def trans_matmul_248_kernel(a_ptr, b_ptr, c_ptr,
+ scales_ptr, zeros_ptr, g_ptr,
+ M, N, K, bits, maxq,
+ stride_am, stride_ak,
+ stride_bk, stride_bn,
+ stride_cm, stride_cn,
+ stride_scales, stride_zeros,
+ BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
+ GROUP_SIZE_M: tl.constexpr):
+ """
+ Compute the matrix multiplication C = A x B.
+ A is of shape (M, N) float16
+ B is of shape (K//8, N) int32
+ C is of shape (M, K) float16
+ scales is of shape (G, N) float16
+ zeros is of shape (G, N) float16
+ g_ptr is of shape (K) int32
+ """
+ infearure_per_bits = 32 // bits
+
+ pid = tl.program_id(axis=0)
+ num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
+ num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)
+ num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
+ num_pid_in_group = GROUP_SIZE_M * num_pid_k
+ group_id = pid // num_pid_in_group
+ first_pid_m = group_id * GROUP_SIZE_M
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
+ pid_m = first_pid_m + (pid % group_size_m)
+ pid_k = (pid % num_pid_in_group) // group_size_m
+
+ offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
+ offs_bk = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
+ offs_n = tl.arange(0, BLOCK_SIZE_N)
+ a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_n[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_N)
+ a_mask = (offs_am[:, None] < M)
+ # b_ptrs is set up such that it repeats elements along the K axis 8 times
+ b_ptrs = b_ptr + ((offs_bk[:, None] // infearure_per_bits) * stride_bk + offs_n[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N)
+ g_ptrs = g_ptr + offs_bk
+ g_idx = tl.load(g_ptrs)
+
+ # shifter is used to extract the N bits of each element in the 32-bit word from B
+ scales_ptrs = scales_ptr + offs_n[None, :] + g_idx[:, None] * stride_scales
+ zeros_ptrs = zeros_ptr + (offs_n[None, :]// infearure_per_bits) + g_idx[:, None] * stride_zeros
+
+ shifter = (offs_bk % infearure_per_bits) * bits
+ zeros_shifter = (offs_n % infearure_per_bits) * bits
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32)
+
+ for k in range(0, num_pid_n):
+ # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop
+ scales = tl.load(scales_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
+ zeros = tl.load(zeros_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
+
+ zeros = (zeros >> zeros_shifter[None, :]) & maxq
+ zeros = (zeros + 1)
+
+ a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_N)
+ b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
+
+ # Now we need to unpack b (which is N-bit values) into 32-bit values
+ b = (b >> shifter[:, None]) & maxq # Extract the N-bit values
+ b = (b - zeros) * scales # Scale and shift
+ b = tl.trans(b)
+
+ accumulator += tl.dot(a, b)
+ a_ptrs += BLOCK_SIZE_N
+ b_ptrs += BLOCK_SIZE_N
+ scales_ptrs += BLOCK_SIZE_N
+ zeros_ptrs += (BLOCK_SIZE_N // infearure_per_bits)
+
+ c = accumulator.to(tl.float16)
+ c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bk[None, :]
+ c_mask = (offs_am[:, None] < M) & (offs_bk[None, :] < K)
+ tl.store(c_ptrs, accumulator, mask=c_mask)
+except:
+ print('trioton not installed.')
+
+def matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq):
+ output = torch.empty((input.shape[0], qweight.shape[1]), device='cuda', dtype=torch.float16)
+ grid = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(qweight.shape[1], META['BLOCK_SIZE_N']),)
+ matmul_248_kernel[grid](input, qweight, output,
+ scales, qzeros, g_idx,
+ input.shape[0], qweight.shape[1], input.shape[1], bits, maxq,
+ input.stride(0), input.stride(1),
+ qweight.stride(0), qweight.stride(1),
+ output.stride(0), output.stride(1),
+ scales.stride(0), qzeros.stride(0))
+ return output
+
+def transpose_matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq):
+ output_dim = (qweight.shape[0] * 32) // bits
+ output = torch.empty((input.shape[0], output_dim), device='cuda', dtype=torch.float16)
+ grid = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(output_dim, META['BLOCK_SIZE_K']),)
+ trans_matmul_248_kernel[grid](input, qweight, output,
+ scales, qzeros, g_idx,
+ input.shape[0], qweight.shape[1], output_dim, bits, maxq,
+ input.stride(0), input.stride(1),
+ qweight.stride(0), qweight.stride(1),
+ output.stride(0), output.stride(1),
+ scales.stride(0), qzeros.stride(0))
+ return output
+
+class QuantLinearFunction(torch.autograd.Function):
+ @staticmethod
+ @custom_fwd(cast_inputs=torch.float16)
+ def forward(ctx, input, qweight, scales, qzeros, g_idx, bits, maxq):
+ output = matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq)
+ ctx.save_for_backward(qweight, scales, qzeros, g_idx)
+ ctx.bits,ctx.maxq = bits, maxq
+ return output
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, grad_output):
+ qweight, scales, qzeros, g_idx = ctx.saved_tensors
+ bits, maxq = ctx.bits, ctx.maxq
+ grad_input = None
+
+ if ctx.needs_input_grad[0]:
+ grad_input = transpose_matmul248(grad_output, qweight, scales, qzeros, g_idx, bits, maxq)
+ return grad_input, None, None, None, None, None, None
+
+class QuantLinear(nn.Module):
+ def __init__(self, bits, groupsize, infeatures, outfeatures, bias):
+ super().__init__()
+ if bits not in [2,4,8]:
+ raise NotImplementedError("Only 2,4,8 bits are supported.")
+ self.infeatures = infeatures
+ self.outfeatures = outfeatures
+ self.bits = bits
+ self.maxq = 2 ** self.bits - 1
+ self.groupsize = groupsize if groupsize != -1 else infeatures
+
+ self.register_buffer('qweight', torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32))
+ self.register_buffer('qzeros', torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures // 32 * self.bits), dtype=torch.int32))
+ self.register_buffer('scales', torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16))
+ self.register_buffer('g_idx', torch.tensor([i // self.groupsize for i in range(infeatures)], dtype = torch.int32))
+ if bias:
+ self.register_buffer('bias', torch.zeros((outfeatures),dtype=torch.float16))
+ else:
+ self.bias = None
+
+ def pack(self, linear, scales, zeros, g_idx = None):
+ self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx
+
+ scales = scales.t().contiguous()
+ zeros = zeros.t().contiguous()
+ scale_zeros = zeros * scales
+ self.scales = scales.clone().half()
+ if linear.bias is not None:
+ self.bias = linear.bias.clone().half()
+
+ intweight = []
+ for idx in range(self.infeatures):
+ intweight.append(torch.round((linear.weight.data[:,idx] + scale_zeros[self.g_idx[idx]]) / self.scales[self.g_idx[idx]]).to(torch.int)[:,None])
+ intweight = torch.cat(intweight,dim=1)
+ intweight = intweight.t().contiguous()
+ intweight = intweight.numpy().astype(np.uint32)
+ qweight = np.zeros((intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32)
+ i = 0
+ row = 0
+ while row < qweight.shape[0]:
+ if self.bits in [2,4,8]:
+ for j in range(i, i + (32//self.bits)):
+ qweight[row] |= intweight[j] << (self.bits * (j - i))
+ i += 32//self.bits
+ row += 1
+ else:
+ raise NotImplementedError("Only 2,4,8 bits are supported.")
+
+ qweight = qweight.astype(np.int32)
+ self.qweight = torch.from_numpy(qweight)
+
+ zeros -= 1;
+ zeros = zeros.numpy().astype(np.uint32)
+ qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32)
+ i = 0
+ col = 0
+ while col < qzeros.shape[1]:
+ if self.bits in [2,4,8]:
+ for j in range(i, i + (32//self.bits)):
+ qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i))
+ i += 32//self.bits
+ col += 1
+ else:
+ raise NotImplementedError("Only 2,4,8 bits are supported.")
+
+ qzeros = qzeros.astype(np.int32)
+ self.qzeros = torch.from_numpy(qzeros)
+
+ def forward(self, x):
+ out_shape = x.shape[:-1] + (self.outfeatures, )
+ out = QuantLinearFunction.apply(x.reshape(-1,x.shape[-1]), self.qweight, self.scales,
+ self.qzeros, self.g_idx, self.bits, self.maxq)
+ out = out + self.bias if self.bias is not None else out
+ return out.reshape(out_shape)
+
+def autotune_warmup(model, transpose = False):
+ """
+ Pre-tunes the quantized kernel
+ """
+ from tqdm import tqdm
+
+ n_values = {}
+
+ for _, m in model.named_modules():
+ if not isinstance(m, QuantLinear):
+ continue
+
+ k = m.infeatures
+ n = m.outfeatures
+
+ if n not in n_values:
+ n_values[n] = (k, m.qweight.cuda(), m.scales.cuda(), m.qzeros.cuda(), m.g_idx.cuda(), m.bits, m.maxq)
+
+ print(f'Found {len(n_values)} unique N values.')
+
+ print('Warming up autotune cache ...')
+ for m in tqdm(range(0, 12)):
+ m = 2 ** m # [1, 2048]
+ for n, (k, qweight, scales, qzeros, g_idx, bits, maxq) in n_values.items():
+ a = torch.randn(m, k, dtype=torch.float16, device='cuda')
+ matmul248(a, qweight, scales, qzeros, g_idx, bits, maxq)
+ if transpose:
+ a = torch.randn(m, n, dtype=torch.float16, device='cuda')
+ transpose_matmul248(a, qweight, scales, qzeros, g_idx, bits, maxq)
+ del n_values
+
+def make_quant(module, names, bits, groupsize, name=''):
+ if isinstance(module, QuantLinear):
+ return
+ for attr in dir(module):
+ tmp = getattr(module, attr)
+ name1 = name + '.' + attr if name != '' else attr
+ if name1 in names:
+ delattr(module, attr)
+ setattr(module, attr, QuantLinear(bits, groupsize, tmp.in_features, tmp.out_features, tmp.bias is not None))
+ for name1, child in module.named_children():
+ make_quant(child, names, bits, groupsize, name + '.' + name1 if name != '' else name1)
diff --git a/src/xturing/models/__init__.py b/src/xturing/models/__init__.py
index 7c7f3e1..512cd7b 100644
--- a/src/xturing/models/__init__.py
+++ b/src/xturing/models/__init__.py
@@ -5,7 +5,7 @@
from .galactica import Galactica, GalacticaInt8, GalacticaLora, GalacticaLoraInt8
from .gpt2 import GPT2, GPT2Int8, GPT2Lora, GPT2LoraInt8
from .gptj import GPTJ, GPTJInt8, GPTJLora, GPTJLoraInt8
-from .llama import Llama, LlamaInt8, LlamaLora, LlamaLoraInt8
+from .llama import Llama, LlamaInt8, LlamaLora, LlamaLoraInt8, LlamaLoraInt4
from .opt import OPT, OPTInt8, OPTLora, OPTLoraInt8
from .stable_diffusion import StableDiffusion
@@ -23,6 +23,7 @@
BaseModel.add_to_registry(LlamaLora.config_name, LlamaLora)
BaseModel.add_to_registry(LlamaInt8.config_name, LlamaInt8)
BaseModel.add_to_registry(LlamaLoraInt8.config_name, LlamaLoraInt8)
+BaseModel.add_to_registry(LlamaLoraInt4.config_name, LlamaLoraInt4)
BaseModel.add_to_registry(Galactica.config_name, Galactica)
BaseModel.add_to_registry(GalacticaLora.config_name, GalacticaLora)
BaseModel.add_to_registry(GalacticaInt8.config_name, GalacticaInt8)
diff --git a/src/xturing/models/llama.py b/src/xturing/models/llama.py
index 20f98cb..36e6a62 100644
--- a/src/xturing/models/llama.py
+++ b/src/xturing/models/llama.py
@@ -5,6 +5,7 @@
LLamaInt8Engine,
LlamaLoraEngine,
LlamaLoraInt8Engine,
+ LlamaLoraInt4Engine,
)
from xturing.models.causal import (
CausalInt8Model,
@@ -12,6 +13,10 @@
CausalLoraModel,
CausalModel,
)
+from xturing.trainers.base import BaseTrainer
+from xturing.datasets.instruction_dataset import InstructionDataset
+from xturing.datasets.text_dataset import TextDataset
+from xturing.trainers.lightning_trainer import LightningTrainer
class Llama(CausalModel):
@@ -40,3 +45,25 @@ class LlamaLoraInt8(CausalLoraInt8Model):
def __init__(self, weights_path: Optional[str] = None):
super().__init__(LlamaLoraInt8Engine.config_name, weights_path)
+
+
+class LlamaLoraInt4(CausalLoraInt8Model):
+ config_name: str = "llama_lora_int4"
+
+ def _make_trainer(self, dataset: Union[TextDataset, InstructionDataset]):
+ return BaseTrainer.create(
+ LightningTrainer.config_name,
+ self.engine,
+ dataset,
+ self._make_collate_fn(dataset),
+ int(self.finetuning_args.num_train_epochs),
+ int(self.finetuning_args.batch_size),
+ float(self.finetuning_args.learning_rate),
+ self.finetuning_args.optimizer_name,
+ True,
+ True,
+ lora_type=32,
+ )
+
+ def __init__(self, weights_path: Optional[str] = None):
+ super().__init__(LlamaLoraInt4Engine.config_name, weights_path)
diff --git a/src/xturing/trainers/lightning_trainer.py b/src/xturing/trainers/lightning_trainer.py
index 4035580..7950753 100644
--- a/src/xturing/trainers/lightning_trainer.py
+++ b/src/xturing/trainers/lightning_trainer.py
@@ -100,6 +100,7 @@ def __init__(
use_lora: bool = False,
use_deepspeed: bool = False,
max_training_time_in_secs: Optional[int] = None,
+ lora_type: int = 16,
):
self.lightning_model = TuringLightningModule(
model_engine=model_engine,
@@ -129,6 +130,8 @@ def __init__(
duration=datetime.timedelta(seconds=max_training_time_in_secs)
)
)
+ model_engine.model.train()
+ model_engine.model.print_trainable_parameters()
if DEFAULT_DEVICE.type == "cpu":
self.trainer = Trainer(
@@ -167,7 +170,7 @@ def __init__(
num_nodes=1,
accelerator="gpu",
strategy=strategy,
- precision=16,
+ precision=lora_type,
max_epochs=max_epochs,
callbacks=training_callbacks,
enable_checkpointing=True,
diff --git a/src/xturing/utils/hub.py b/src/xturing/utils/hub.py
index 600987b..116a76c 100644
--- a/src/xturing/utils/hub.py
+++ b/src/xturing/utils/hub.py
@@ -88,6 +88,7 @@ class ModelHub(Hub):
"distilgpt2_lora_finetuned_alpaca"
),
"llama_lora_finetuned_alpaca": make_model_url("llama_lora_finetuned_alpaca"),
+ "llama_lora_int4": make_model_url("llama_lora_int4"),
}
def __init__(self):
diff --git a/src/xturing/utils/text_splitter.py b/src/xturing/utils/text_splitter.py
index 653aaf3..9a8bba8 100644
--- a/src/xturing/utils/text_splitter.py
+++ b/src/xturing/utils/text_splitter.py
@@ -12,7 +12,6 @@
Collection,
Iterable,
List,
- Literal,
Optional,
Union,
)
@@ -118,8 +117,8 @@ def _huggingface_tokenizer_length(text: str) -> int:
def from_tiktoken_encoder(
cls,
encoding_name: str = "gpt2",
- allowed_special: Union[Literal["all"], AbstractSet[str]] = set(),
- disallowed_special: Union[Literal["all"], Collection[str]] = "all",
+ allowed_special = set(),
+ disallowed_special = set(),
**kwargs: Any,
) -> TextSplitter:
"""Text splitter that uses tiktoken encoder to count length."""