diff --git a/README.md b/README.md index b6a4adb..1e604c8 100644 --- a/README.md +++ b/README.md @@ -33,9 +33,9 @@ With `xturing` you can,
-## 🌟 INT4 fine-tuning with LLaMA LoRA +## 🌟 INT4 fine-tuning and generation with LLaMA LoRA -We are excited to announce the latest enhancement to our `xTuring` library: INT4 fine-tuning demo. With this update, you can fine-tune LLMs like LLaMA with LoRA architecture in INT4 precision with less than `6 GB` of VRAM. This breakthrough significantly reduces memory requirements and accelerates the fine-tuning process, allowing you to achieve state-of-the-art performance with less computational resources. +We are excited to announce the latest enhancement to our `xTuring` library: INT4 fine-tuning and generation integration. With this update, you can fine-tune LLMs like LLaMA with LoRA architecture in INT4 precision with less than `6 GB` of VRAM. This breakthrough significantly reduces memory requirements and accelerates the fine-tuning process, allowing you to achieve state-of-the-art performance with less computational resources. More information about INT4 fine-tuning and benchmarks can be found in the [INT4 README](examples/int4_finetuning/README.md). @@ -146,6 +146,7 @@ model = BaseModel.load("x/distilgpt2_lora_finetuned_alpaca") - [x] OpenAI, Cohere and AI21 Studio model APIs for dataset generation - [x] Added fine-tuned checkpoints for some models to the hub - [x] INT4 LLaMA LoRA fine-tuning demo +- [x] INT4 LLaMA LoRA fine-tuning with INT4 generation - [ ] Evaluation of LLM models - [ ] Support for Stable Diffusion diff --git a/examples/int4_finetuning/LLaMA_lora_int4.ipynb b/examples/int4_finetuning/LLaMA_lora_int4.ipynb index d503a39..8859907 100644 --- a/examples/int4_finetuning/LLaMA_lora_int4.ipynb +++ b/examples/int4_finetuning/LLaMA_lora_int4.ipynb @@ -31,119 +31,80 @@ }, "outputs": [], "source": [ - "!pip install xturing --upgrade" + "!pip install xturing --upgrade\n", + "!pip install xturing[int4] --upgrade" ] }, { - "cell_type": "code", - "execution_count": 7, + "cell_type": "markdown", "metadata": { "collapsed": false }, - "outputs": [], "source": [ - "import subprocess\n", - "from pathlib import Path\n", - "\n", - "def pull_docker_image(image):\n", - " cmd = [\"docker\", \"pull\", image]\n", - " subprocess.run(cmd, check=True)\n", - "\n", - "\n", - "def run_docker_container(image, port_mapping, env_vars=None, gpus=None, volumes=None):\n", - " cmd = [\"docker\", \"container\", \"run\"]\n", - "\n", - " if env_vars is None:\n", - " env_vars = {}\n", - "\n", - " if volumes is None:\n", - " volumes = {}\n", - "\n", - " if gpus is not None:\n", - " cmd.extend([\"--gpus\", gpus])\n", - "\n", - " for key, value in env_vars.items():\n", - " cmd.extend([\"-e\", f\"{key}={value}\"])\n", - "\n", - " for local_path, container_path in volumes.items():\n", - " cmd.extend([\"-v\", f\"{str(Path(local_path).resolve())}:{container_path}\"])\n", - "\n", - " cmd.extend([\"-p\", port_mapping, image])\n", - "\n", - " subprocess.run(cmd)" + "## 2. Load model and dataset" ] }, { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": null, "metadata": { - "collapsed": false + "collapsed": false, + "pycharm": { + "is_executing": true + } }, + "outputs": [], "source": [ - "## 2. Load and run docker image" + "from xturing.datasets.instruction_dataset import InstructionDataset\n", + "from xturing.models import BaseModel\n", + "\n", + "instruction_dataset = InstructionDataset(\"../llama/alpaca_data\")\n", + "# Initializes the model\n", + "model = BaseModel.create(\"llama_lora_int4\")" ] }, { "cell_type": "markdown", + "source": [ + "## 3. Start the finetuning" + ], "metadata": { "collapsed": false - }, - "source": [ - "1. Install Docker on your machine if you haven't already. You can follow the [official Docker documentation](https://docs.docker.com/engine/install/) for installation instructions.\n", - "2. Install NVIDIA Container Toolkit\n", - " ```bash\n", - " sudo apt-get install -y nvidia-docker2\n", - " ```\n", - "3. Run the Docker daemon\n", - " ```bash\n", - " sudo systemctl start docker\n", - " ```\n" - ] + } }, { "cell_type": "code", "execution_count": null, - "metadata": { - "collapsed": false, - "pycharm": { - "is_executing": true - } - }, "outputs": [], "source": [ - "image = \"public.ecr.aws/t8g5g2q5/xturing:int4_finetuning\"\n", - "port_mapping = \"5000:5000\"\n", - "env_vars = {\n", - " \"WANDB_MODE\": \"dryrun\",\n", - " \"MICRO_BATCH_SIZE\": \"1\", # change this to increase your micro batch size\n", - "}\n", - "# if you want to log results to wandb, set the following env var\n", - "# env_vars = {\n", - "# \"WANDB_API_KEY\": \"\",\n", - "# \"WANDB_PROJECT\": \"your_project_name\",\n", - "# \"WANDB_ENTITY\": \"your_entity_name\",\n", - "# # Add more environment variables as needed\n", - "# }\n", - "volumes = {\n", - " # \"\": \"/model\",\n", - " \"../llama/alpaca_data\": \"/data\", # change this to your data path if you want\n", - "}\n", - "gpus = \"all\"\n", - "\n", - "pull_docker_image(image)\n", - "\n", - "run_docker_container(image, port_mapping, env_vars, gpus, volumes)" - ] + "# Finetuned the model\n", + "model.finetune(dataset=instruction_dataset)" + ], + "metadata": { + "collapsed": false + } }, { "cell_type": "markdown", - "metadata": {}, "source": [ - "## Alternately, you can run the example using CLI command:\n", - "\n", - "```bash\n", - "docker run -p 5000:5000 --gpus all -e WANDB_MODE=dryrun -e MICRO_BATCH_SIZE=1 -v /absolute/path/to/alpaca/data:/data public.ecr.aws/t8g5g2q5/xturing:int4_finetuning\n", - "```" - ] + "## 4. Generate an output text with the fine-tuned model" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "# Once the model has been finetuned, you can start doing inferences\n", + "output = model.generate(texts=[\"Why LLM models are becoming so important?\"])\n", + "print(\"Generated output by the model: {}\".format(output))" + ], + "metadata": { + "collapsed": false + } } ], "metadata": { diff --git a/pyproject.toml b/pyproject.toml index acc8eef..f870e6b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "xturing" -version = "0.0.10" +version = "0.1.0" description = "Fine-tuning, evaluation and data generation for LLMs" authors = [ @@ -64,6 +64,10 @@ dependencies = [ [project.scripts] xturing = "xturing.cli:xturing" +[project.optional-dependencies] +int4 = [ + "torch >= 2.0" +] [project.urls] homepage = "https://xturing.stochastic.ai/" diff --git a/src/xturing/__about__.py b/src/xturing/__about__.py index 9b36b86..3dc1f76 100644 --- a/src/xturing/__about__.py +++ b/src/xturing/__about__.py @@ -1 +1 @@ -__version__ = "0.0.10" +__version__ = "0.1.0" diff --git a/src/xturing/config/finetuning_config.yaml b/src/xturing/config/finetuning_config.yaml index 3b1cde3..9a8d260 100644 --- a/src/xturing/config/finetuning_config.yaml +++ b/src/xturing/config/finetuning_config.yaml @@ -33,6 +33,13 @@ llama_lora_int8: batch_size: 8 max_length: 256 +llama_lora_int4: + learning_rate: 1e-4 + weight_decay: 0.01 + num_train_epochs: 3 + batch_size: 8 + max_length: 256 + gptj: learning_rate: 5e-5 weight_decay: 0.01 diff --git a/src/xturing/config/generation_config.yaml b/src/xturing/config/generation_config.yaml index 12ca88d..7279f6a 100644 --- a/src/xturing/config/generation_config.yaml +++ b/src/xturing/config/generation_config.yaml @@ -27,6 +27,13 @@ llama_lora_int8: max_new_tokens: 256 do_sample: false +# Contrastive search +llama_lora_int4: + penalty_alpha: 0.6 + top_k: 4 + max_new_tokens: 256 + do_sample: false + # Contrastive search gptj: penalty_alpha: 0.6 diff --git a/src/xturing/engines/__init__.py b/src/xturing/engines/__init__.py index 5ce012d..8291147 100644 --- a/src/xturing/engines/__init__.py +++ b/src/xturing/engines/__init__.py @@ -25,6 +25,7 @@ LLamaInt8Engine, LlamaLoraEngine, LlamaLoraInt8Engine, + LlamaLoraInt4Engine, ) from .opt_engine import OPTEngine, OPTInt8Engine, OPTLoraEngine, OPTLoraInt8Engine @@ -42,6 +43,7 @@ BaseEngine.add_to_registry(LlamaLoraEngine.config_name, LlamaLoraEngine) BaseEngine.add_to_registry(LLamaInt8Engine.config_name, LLamaInt8Engine) BaseEngine.add_to_registry(LlamaLoraInt8Engine.config_name, LlamaLoraInt8Engine) +BaseEngine.add_to_registry(LlamaLoraInt4Engine.config_name, LlamaLoraInt4Engine) BaseEngine.add_to_registry(GalacticaEngine.config_name, GalacticaEngine) BaseEngine.add_to_registry(GalacticaInt8Engine.config_name, GalacticaInt8Engine) BaseEngine.add_to_registry(GalacticaLoraEngine.config_name, GalacticaLoraEngine) diff --git a/src/xturing/engines/causal.py b/src/xturing/engines/causal.py index 3598f44..86dd1cb 100644 --- a/src/xturing/engines/causal.py +++ b/src/xturing/engines/causal.py @@ -146,7 +146,7 @@ def __init__( lora_config = LoraConfig( r=8, - lora_alpha=32, + lora_alpha=16, target_modules=target_modules, lora_dropout=0.05, bias="none", diff --git a/src/xturing/engines/llama_engine.py b/src/xturing/engines/llama_engine.py index 46d9441..af0dd6a 100644 --- a/src/xturing/engines/llama_engine.py +++ b/src/xturing/engines/llama_engine.py @@ -1,13 +1,16 @@ import os from pathlib import Path from typing import Any, Dict, List, Optional, Tuple, Union +import transformers import torch +from torch import nn from xturing.engines.causal import CausalEngine, CausalLoraEngine from xturing.engines.llama_utils import LlamaConfig, LlamaForCausalLM, LlamaTokenizer from xturing.engines.lora_engine import prepare_model_for_int8_training - +from xturing.engines.quant_utils import make_quant, autotune_warmup +from xturing.utils.hub import ModelHub class LLamaEngine(CausalEngine): config_name: str = "llama_engine" @@ -98,3 +101,84 @@ def __init__(self, weights_path: Optional[Union[str, Path]] = None): load_8bit=True, target_modules=["q_proj", "v_proj"], ) + +def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''): + if type(module) in layers: + return {name: module} + res = {} + for name1, child in module.named_children(): + res.update(find_layers( + child, layers=layers, name=name + '.' + name1 if name != '' else name1 + )) + return res + +class LlamaLoraInt4Engine(CausalLoraEngine): + config_name: str = "llama_lora_int4_engine" + + def __init__(self, weights_path: Optional[Union[str, Path]] = None): + model_name = "decapoda-research/llama-7b-hf" + + if weights_path is None: + weights_path = ModelHub().load("x/llama_lora_int4") + + config = LlamaConfig.from_pretrained(model_name) + + saved_kaiming_uniform_ = torch.nn.init.kaiming_uniform_ + saved_uniform_ = torch.nn.init.uniform_ + saved_normal_ = torch.nn.init.normal_ + + def noop(*args, **kwargs): + pass + + torch.nn.init.kaiming_uniform_ = noop + torch.nn.init.uniform_ = noop + torch.nn.init.normal_ = noop + + torch.set_default_dtype(torch.half) + transformers.modeling_utils._init_weights = False + torch.set_default_dtype(torch.half) + model = LlamaForCausalLM(config) + torch.set_default_dtype(torch.float) + model = model.eval() + + layers = find_layers(model) + + for name in ['lm_head']: + if name in layers: + del layers[name] + + wbits = 4 + groupsize = 128 + warmup_autotune=True + + make_quant(model, layers, wbits, groupsize) + + + model.load_state_dict(torch.load(weights_path / Path("pytorch_model.bin")), strict=False) + + if warmup_autotune: + autotune_warmup(model) + + model.seqlen = 2048 + + model.gptq = True + + model.gradient_checkpointing_enable() + model.enable_input_require_grads() + + tokenizer = LlamaTokenizer.from_pretrained(model_name, add_bos_token=False) + tokenizer.pad_token = tokenizer.eos_token + tokenizer.pad_token_id = tokenizer.eos_token_id + + super().__init__( + model=model, + tokenizer=tokenizer, + target_modules=[ + "q_proj", + "v_proj", + ] + ) + + torch.nn.init.kaiming_uniform_ = saved_kaiming_uniform_ + torch.nn.init.uniform_ = saved_uniform_ + torch.nn.init.normal_ = saved_normal_ diff --git a/src/xturing/engines/lora_engine/lora.py b/src/xturing/engines/lora_engine/lora.py index fa8539c..8e81faf 100644 --- a/src/xturing/engines/lora_engine/lora.py +++ b/src/xturing/engines/lora_engine/lora.py @@ -21,6 +21,7 @@ from dataclasses import asdict, dataclass, field from enum import Enum from typing import List, Optional, Union +import enum import torch import torch.nn as nn @@ -44,6 +45,17 @@ def is_bnb_available(): def transpose(weight, fan_in_fan_out): return weight.T if fan_in_fan_out else weight +def is_gptq_available(): + return importlib.util.find_spec("xturing.engines.quant_utils") is not None + +if is_gptq_available(): + from ..quant_utils import QuantLinear + +class PeftType(str, enum.Enum): + PROMPT_TUNING = "PROMPT_TUNING" + P_TUNING = "P_TUNING" + PREFIX_TUNING = "PREFIX_TUNING" + LORA = "LORA" WEIGHTS_NAME = "adapter_model.bin" CONFIG_NAME = "adapter_config.json" @@ -52,19 +64,18 @@ def transpose(weight, fan_in_fan_out): @dataclass class LoraConfig: """ - This is the configuration class to store the configuration of a [`LoraModel`]. - + This is the configuration class to store the configuration of a [`~peft.Lora`]. Args: - r (`int`): Lora attention dimension. + r (`int`): Lora attention dimension target_modules (`Union[List[str],str]`): The names of the modules to apply Lora to. lora_alpha (`float`): The alpha parameter for Lora scaling. lora_dropout (`float`): The dropout probability for Lora layers. merge_weights (`bool`): Whether to merge the weights of the Lora layers with the base transformer model in `eval` mode. - fan_in_fan_out (`bool`): Set this to True if the layer to replace stores weight like (`fan_in`, `fan_out`). - enable_lora ( `List[bool]`): Used with [`lora.MergedLinear`]. - bias (`str`): Bias type for Lora. Can be `none`, `all` or `lora_only`. - modules_to_save (`List[str]`): List of modules apart from Lora layers to be set as trainable + fan_in_fan_out (`bool`): Set this to True if the layer to replace stores weight like (fan_in, fan_out) + enable_lora ( `List[bool]`): Used with `lora.MergedLinear`. + bias (`str`): Bias type for Lora. Can be 'none', 'all' or 'lora_only' + modules_to_save (`List[str]`):List of modules apart from LoRA layers to be set as trainable and saved in the final checkpoint. """ @@ -79,22 +90,14 @@ class LoraConfig: lora_alpha: int = field(default=None, metadata={"help": "Lora alpha"}) lora_dropout: float = field(default=None, metadata={"help": "Lora dropout"}) merge_weights: bool = field( - default=False, - metadata={"help": "Merge weights of the original model and the Lora model"}, + default=False, metadata={"help": "Merge weights of the original model and the Lora model"} ) fan_in_fan_out: bool = field( default=False, - metadata={ - "help": "Set this to True if the layer to replace stores weight like (fan_in, fan_out)" - }, - ) - enable_lora: Optional[List[bool]] = field( - default=None, metadata={"help": "Used with `lora.MergedLinear`."} - ) - bias: str = field( - default="none", - metadata={"help": "Bias type for Lora. Can be 'none', 'all' or 'lora_only'"}, + metadata={"help": "Set this to True if the layer to replace stores weight like (fan_in, fan_out)"}, ) + enable_lora: Optional[List[bool]] = field(default=None, metadata={"help": "Used with `lora.MergedLinear`."}) + bias: str = field(default="none", metadata={"help": "Bias type for Lora. Can be 'none', 'all' or 'lora_only'"}) modules_to_save: Optional[List[str]] = field( default=None, metadata={ @@ -110,6 +113,7 @@ class LoraConfig: inference_mode: bool = field( default=False, metadata={"help": "Whether to use inference mode"} ) + peft_type: PeftType = PeftType.LORA base_model_name_or_path: str = field( default=None, metadata={"help": "The name of the base model to use."} @@ -201,35 +205,19 @@ def from_json_file(cls, path_json_file, **kwargs): class LoraModel(torch.nn.Module): """ Creates Low Rank Adapter (Lora) model from a pretrained transformers model. - Args: - model ([`~transformers.PreTrainedModel`]): The model to be adapted. + model ([`transformers.PreTrainedModel`]): The model to be adapted. config ([`LoraConfig`]): The configuration of the Lora model. - Returns: `torch.nn.Module`: The Lora model. - - Example: - - ```py - >>> from transformers import AutoModelForSeq2SeqLM, LoraConfig - >>> from peft import LoraModel, LoraConfig - - >>> config = LoraConfig( - ... peft_type="LORA", - ... task_type="SEQ_2_SEQ_LM", - ... r=8, - ... lora_alpha=32, - ... target_modules=["q", "v"], - ... lora_dropout=0.01, - ... ) - - >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") - >>> lora_model = LoraModel(config, model) - ``` - + Example:: + >>> from transformers import AutoModelForSeq2SeqLM, LoraConfig >>> from peft import LoraModel, LoraConfig >>> + config = LoraConfig( + peft_type="LORA", task_type="SEQ_2_SEQ_LM", r=8, lora_alpha=32, target_modules=["q", "v"], + lora_dropout=0.01, ) + >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") >>> lora_model = LoraModel(config, model) **Attributes**: - - **model** ([`~transformers.PreTrainedModel`]) -- The model to be adapted. + - **model** ([`transformers.PreTrainedModel`]) -- The model to be adapted. - **peft_config** ([`LoraConfig`]): The configuration of the Lora model. """ @@ -243,6 +231,8 @@ def __init__(self, config, model): def _find_and_replace(self): loaded_in_8bit = getattr(self.model, "is_loaded_in_8bit", False) + is_gtq_quantized = getattr(self.model, "gptq", False) # Step 1: Check if the model is GTQ quantized + if loaded_in_8bit and not is_bnb_available(): raise ImportError( "To use Lora with 8-bit quantization, please install the `bitsandbytes` package. " @@ -255,21 +245,15 @@ def _find_and_replace(self): "lora_alpha": self.peft_config.lora_alpha, "lora_dropout": self.peft_config.lora_dropout, "fan_in_fan_out": self.peft_config.fan_in_fan_out, - "merge_weights": ( - self.peft_config.merge_weights or self.peft_config.inference_mode - ) + "merge_weights": (self.peft_config.merge_weights or self.peft_config.inference_mode) and not is_hf_device_map_available, - "init_lora_weights": self.peft_config.init_lora_weights, } key_list = [key for key, _ in self.model.named_modules()] for key in key_list: if isinstance(self.peft_config.target_modules, str): target_module_found = re.fullmatch(self.peft_config.target_modules, key) else: - target_module_found = any( - key.endswith(target_key) - for target_key in self.peft_config.target_modules - ) + target_module_found = any(key.endswith(target_key) for target_key in self.peft_config.target_modules) if target_module_found: if not is_target_modules_in_base_model: is_target_modules_in_base_model = True @@ -285,45 +269,49 @@ def _find_and_replace(self): } ) if self.peft_config.enable_lora is None: - new_module = Linear8bitLt( - target.in_features, target.out_features, bias=bias, **kwargs - ) + new_module = Linear8bitLt(target.in_features, target.out_features, bias=bias, **kwargs) else: kwargs.update({"enable_lora": self.peft_config.enable_lora}) - new_module = MergedLinear8bitLt( - target.in_features, target.out_features, bias=bias, **kwargs - ) - elif ( - isinstance(target, torch.nn.Linear) - and self.peft_config.enable_lora is None - ): - new_module = Linear( - target.in_features, target.out_features, bias=bias, **kwargs + new_module = MergedLinear8bitLt(target.in_features, target.out_features, bias=bias, **kwargs) + elif is_gptq_available() and isinstance(target, QuantLinear): + kwargs.update( + { + "bits": target.bits, + "groupsize": target.groupsize, + } ) + if self.peft_config.enable_lora is None: + new_module = LinearqbitLt(target.infeatures, target.outfeatures, bias=bias, **kwargs) + new_module.scales = target.scales + new_module.qzeros = target.qzeros + new_module.g_idx = target.g_idx + if target.bias: + new_module.bias = target.bias + else: + kwargs.update({"enable_lora": self.peft_config.enable_lora}) + new_module = MergedLinearqbitLt(target.infeatures, target.outfeatures, bias=bias, **kwargs) + new_module.scales = target.scales + new_module.qzeros = target.qzeros + new_module.g_idx = target.g_idx + if target.bias: + new_module.bias = target.bias + elif isinstance(target, torch.nn.Linear) and self.peft_config.enable_lora is None: + new_module = Linear(target.in_features, target.out_features, bias=bias, **kwargs) elif self.peft_config.enable_lora is not None: kwargs.update({"enable_lora": self.peft_config.enable_lora}) if isinstance(target, Conv1D): in_features, out_features = ( - target.weight.ds_shape - if hasattr(target.weight, "ds_shape") - else target.weight.shape + target.weight.ds_shape if hasattr(target.weight, "ds_shape") else target.weight.shape ) else: - in_features, out_features = ( - target.in_features, - target.out_features, - ) + in_features, out_features = target.in_features, target.out_features if kwargs["fan_in_fan_out"]: warnings.warn( "fan_in_fan_out is set to True but the target module is not a Conv1D. " "Setting fan_in_fan_out to False." ) - kwargs[ - "fan_in_fan_out" - ] = self.peft_config.fan_in_fan_out = False - new_module = MergedLinear( - in_features, out_features, bias=bias, **kwargs - ) + kwargs["fan_in_fan_out"] = self.peft_config.fan_in_fan_out = False + new_module = MergedLinear(in_features, out_features, bias=bias, **kwargs) self._replace_module(parent, target_name, new_module, target) if not is_target_modules_in_base_model: raise ValueError( @@ -339,17 +327,30 @@ def _get_submodules(self, key): def _replace_module(self, parent_module, child_name, new_module, old_module): setattr(parent_module, child_name, new_module) - new_module.weight = old_module.weight - if old_module.bias is not None: - new_module.bias = old_module.bias - if getattr(old_module, "state", None) is not None: - new_module.state = old_module.state - new_module.to(old_module.weight.device) - - # dispatch to correct device - for name, module in new_module.named_modules(): - if "lora_" in name: - module.to(old_module.weight.device) + if is_gptq_available() and isinstance(old_module, QuantLinear): + new_module.qweight = old_module.qweight + if old_module.bias is not None: + new_module.bias = old_module.bias + if getattr(old_module, "state", None) is not None: + new_module.state = old_module.state + new_module.to(old_module.qweight.device) + + # dispatch to correct device + for name, module in new_module.named_modules(): + if "lora_" in name: + module.to(old_module.qweight.device) + else: + new_module.weight = old_module.weight + if old_module.bias is not None: + new_module.bias = old_module.bias + if getattr(old_module, "state", None) is not None: + new_module.state = old_module.state + new_module.to(old_module.weight.device) + + # dispatch to correct device + for name, module in new_module.named_modules(): + if "lora_" in name: + module.to(old_module.weight.device) def __getattr__(self, name: str): """Forward missing attributes to the wrapped module.""" @@ -363,10 +364,7 @@ def modules_to_save(self): return None def get_peft_config_as_dict(self, inference: bool = False): - config = { - k: v.value if isinstance(v, Enum) else v - for k, v in asdict(self.peft_config).items() - } + config = {k: v.value if isinstance(v, Enum) else v for k, v in asdict(self.peft_config).items()} if inference: config["inference_mode"] = True return config @@ -382,44 +380,6 @@ def enable_adapter_layers(self): def disable_adapter_layers(self): self._set_adapter_layers(enabled=False) - def merge_and_unload(self): - r""" - This method merges the LoRa layers into the base model. This is needed if someone wants to use the base model - as a standalone model. - """ - if self.config.model_type == "gpt2": - raise ValueError("GPT2 models are not supported for merging LORA layers") - - if getattr(self.model, "is_loaded_in_8bit", False): - raise ValueError( - "Cannot merge LORA layers when the model is loaded in 8-bit mode" - ) - - key_list = [key for key, _ in self.model.named_modules() if "lora" not in key] - for key in key_list: - parent, target, target_name = self._get_submodules(key) - if isinstance(target, LoraLayer): - bias = target.bias is not None - new_module = torch.nn.Linear( - target.in_features, target.out_features, bias=bias - ) - - # manually merge if not merged - if not target.merged: - # merge weights per: https://arxiv.org/pdf/2106.09685.pdf / page 4 - if target.r > 0: - target.weight.data += ( - transpose( - target.lora_B.weight @ target.lora_A.weight, - target.fan_in_fan_out, - ) - * target.scaling - ).to(target.weight.dtype) - target.merged = True - - self._replace_module(parent, target_name, new_module, target) - return self.model - def print_trainable_parameters(self): """ Prints the number of trainable parameters in the model. @@ -551,16 +511,8 @@ def __init__( merge_weights: bool = True, **kwargs, ): - init_lora_weights = kwargs.pop("init_lora_weights", True) - nn.Linear.__init__(self, in_features, out_features, **kwargs) - LoraLayer.__init__( - self, - r=r, - lora_alpha=lora_alpha, - lora_dropout=lora_dropout, - merge_weights=merge_weights, - ) + LoraLayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights) self.fan_in_fan_out = fan_in_fan_out # Actual trainable parameters @@ -570,8 +522,7 @@ def __init__( self.scaling = self.lora_alpha / self.r # Freezing the pre-trained weight matrix self.weight.requires_grad = False - if init_lora_weights: - self.reset_parameters() + self.reset_parameters() if fan_in_fan_out: self.weight.data = self.weight.data.T @@ -590,20 +541,14 @@ def train(self, mode: bool = True): # Merge the weights and mark it if self.r > 0: self.weight.data += ( - transpose( - self.lora_B.weight @ self.lora_A.weight, self.fan_in_fan_out - ) - * self.scaling + transpose(self.lora_B.weight @ self.lora_A.weight, self.fan_in_fan_out) * self.scaling ) self.merged = True elif self.merge_weights and self.merged: # Make sure that the weights are not merged if self.r > 0: self.weight.data -= ( - transpose( - self.lora_B.weight @ self.lora_A.weight, self.fan_in_fan_out - ) - * self.scaling + transpose(self.lora_B.weight @ self.lora_A.weight, self.fan_in_fan_out) * self.scaling ) self.merged = False @@ -616,27 +561,19 @@ def forward(self, x: torch.Tensor): if self.disable_adapters: if self.r > 0 and self.merged: self.weight.data -= ( - transpose( - self.lora_B.weight @ self.lora_A.weight, self.fan_in_fan_out - ) - * self.scaling + transpose(self.lora_B.weight @ self.lora_A.weight, self.fan_in_fan_out) * self.scaling ) self.merged = False - return F.linear( - x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias - ) + return F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias) elif self.r > 0 and not self.merged: - result = F.linear( - x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias - ) + result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias) if self.r > 0: - result += self.lora_B(self.lora_A(self.lora_dropout(x))) * self.scaling + loraoutput = self.lora_B(self.lora_A(self.lora_dropout(x))) * self.scaling + result = result + loraoutput return result else: - return F.linear( - x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias - ) + return F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias) class MergedLinear(nn.Linear, LoraLayer): @@ -653,16 +590,8 @@ def __init__( merge_weights: bool = True, **kwargs, ): - init_lora_weights = kwargs.pop("init_lora_weights", True) - nn.Linear.__init__(self, in_features, out_features, **kwargs) - LoraLayer.__init__( - self, - r=r, - lora_alpha=lora_alpha, - lora_dropout=lora_dropout, - merge_weights=merge_weights, - ) + LoraLayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights) if out_features % len(enable_lora) != 0: raise ValueError("The length of enable_lora must divide out_features") self.enable_lora = enable_lora @@ -681,14 +610,10 @@ def __init__( # Freezing the pre-trained weight matrix self.weight.requires_grad = False # Compute the indices - self.lora_ind = self.weight.new_zeros( - (out_features,), dtype=torch.bool - ).view(len(enable_lora), -1) + self.lora_ind = self.weight.new_zeros((out_features,), dtype=torch.bool).view(len(enable_lora), -1) self.lora_ind[enable_lora, :] = True self.lora_ind = self.lora_ind.view(-1) - - if init_lora_weights: - self.reset_parameters() + self.reset_parameters() if fan_in_fan_out: self.weight.data = self.weight.data.T @@ -702,9 +627,7 @@ def reset_parameters(self): def zero_pad(self, x): result = x.new_zeros((*x.shape[:-1], self.out_features)) result = result.view(-1, self.out_features) - result[:, self.lora_ind] = x.reshape( - -1, self.out_features // len(self.enable_lora) * sum(self.enable_lora) - ) + result[:, self.lora_ind] = x.reshape(-1, self.out_features // len(self.enable_lora) * sum(self.enable_lora)) return result.view((*x.shape[:-1], self.out_features)) def train(self, mode: bool = True): @@ -723,9 +646,7 @@ def train(self, mode: bool = True): .squeeze(0) .transpose(-2, -1) ) - self.weight.data += transpose( - self.zero_pad(delta_w * self.scaling), not self.fan_in_fan_out - ) + self.weight.data += transpose(self.zero_pad(delta_w * self.scaling), not self.fan_in_fan_out) self.merged = True elif self.merge_weights and self.merged: # Make sure that the weights are not merged @@ -739,9 +660,7 @@ def train(self, mode: bool = True): .squeeze(0) .transpose(-2, -1) ) - self.weight.data -= transpose( - self.zero_pad(delta_w * self.scaling), not self.fan_in_fan_out - ) + self.weight.data -= transpose(self.zero_pad(delta_w * self.scaling), not self.fan_in_fan_out) self.merged = False def eval(self): @@ -761,21 +680,13 @@ def forward(self, x: torch.Tensor): .squeeze(0) .transpose(-2, -1) ) - self.weight.data -= transpose( - self.zero_pad(delta_w * self.scaling), not self.fan_in_fan_out - ) + self.weight.data -= transpose(self.zero_pad(delta_w * self.scaling), not self.fan_in_fan_out) self.merged = False - return F.linear( - x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias - ) + return F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias) elif self.merged: - return F.linear( - x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias - ) + return F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias) else: - result = F.linear( - x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias - ) + result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias) if self.r > 0: after_A = self.lora_A(self.lora_dropout(x)) after_B = self.lora_B(after_A.transpose(-2, -1)).transpose(-2, -1) @@ -802,19 +713,11 @@ def __init__( out_features, bias=kwargs.get("bias", True), has_fp16_weights=kwargs.get("has_fp16_weights", True), - memory_efficient_backward=kwargs.get( - "memory_efficient_backward", False - ), + memory_efficient_backward=kwargs.get("memory_efficient_backward", False), threshold=kwargs.get("threshold", 0.0), index=kwargs.get("index", None), ) - LoraLayer.__init__( - self, - r=r, - lora_alpha=lora_alpha, - lora_dropout=lora_dropout, - merge_weights=False, - ) + LoraLayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=False) # Actual trainable parameters if r > 0: self.lora_A = nn.Linear(in_features, r, bias=False) @@ -841,17 +744,10 @@ def forward(self, x: torch.Tensor): if x.dtype != torch.float32: x = x.float() - output = ( - self.lora_B(self.lora_A(self.lora_dropout(x))).to( - expected_dtype - ) - * self.scaling - ) + output = self.lora_B(self.lora_A(self.lora_dropout(x))).to(expected_dtype) * self.scaling result += output else: - output = ( - self.lora_B(self.lora_A(self.lora_dropout(x))) * self.scaling - ) + output = self.lora_B(self.lora_A(self.lora_dropout(x))) * self.scaling result += output return result @@ -873,19 +769,11 @@ def __init__( out_features, bias=kwargs.get("bias", True), has_fp16_weights=kwargs.get("has_fp16_weights", True), - memory_efficient_backward=kwargs.get( - "memory_efficient_backward", False - ), + memory_efficient_backward=kwargs.get("memory_efficient_backward", False), threshold=kwargs.get("threshold", 0.0), index=kwargs.get("index", None), ) - LoraLayer.__init__( - self, - r=r, - lora_alpha=lora_alpha, - lora_dropout=lora_dropout, - merge_weights=False, - ) + LoraLayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=False) if out_features % len(enable_lora) != 0: raise ValueError("The length of enable_lora must divide out_features") self.enable_lora = enable_lora @@ -903,9 +791,7 @@ def __init__( # Freezing the pre-trained weight matrix self.weight.requires_grad = False # Compute the indices - self.lora_ind = self.weight.new_zeros( - (out_features,), dtype=torch.bool - ).view(len(enable_lora), -1) + self.lora_ind = self.weight.new_zeros((out_features,), dtype=torch.bool).view(len(enable_lora), -1) self.lora_ind[enable_lora, :] = True self.lora_ind = self.lora_ind.view(-1) self.reset_parameters() @@ -944,6 +830,133 @@ def forward(self, x: torch.Tensor): result += output return result +if is_gptq_available(): + class LinearqbitLt(QuantLinear, LoraLayer): + # Lora implemented in a dense layer + def __init__( + self, + in_features, + out_features, + r: int = 0, + lora_alpha: int = 1, + lora_dropout: float = 0.0, + **kwargs, + ): + + QuantLinear.__init__( + self, + kwargs.get('bits', 4), + kwargs.get('groupsize', 128), + in_features, + out_features, + kwargs.get('bias', False), + ) + + LoraLayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=False) + # Actual trainable parameters + if r > 0: + self.lora_A = nn.Linear(in_features, r, bias=False) + self.lora_B = nn.Linear(r, out_features, bias=False) + self.scaling = self.lora_alpha / self.r + # Freezing the pre-trained weight matrix + self.qweight.requires_grad = False + self.scales.requires_grad = False + self.qzeros.requires_grad = False + self.reset_parameters() + + def reset_parameters(self): + if hasattr(self, "lora_A"): + # initialize A the same way as the default for nn.Linear and B to zero + # nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5)) + self.lora_A.weight = torch.nn.Parameter(torch.nn.init.kaiming_uniform(self.lora_A.weight, a=math.sqrt(5))) + nn.init.zeros_(self.lora_B.weight) + + def forward(self, x: torch.Tensor): + # x = x.detach() + custom_layer_output = super().forward(x) + + dtype = custom_layer_output.dtype + x = x.float() + lora_output = self.lora_B(self.lora_A(self.lora_dropout(x))).to(dtype) * self.scaling + result = custom_layer_output + lora_output + return result + + class MergedLinearqbitLt(QuantLinear, LoraLayer): + # Lora implemented in a dense layer + def __init__( + self, + in_features: int, + out_features: int, + r: int = 0, + lora_alpha: int = 1, + lora_dropout: float = 0.0, + enable_lora: List[bool] = [False], + **kwargs, + ): + QuantLinear.__init__( + self, + kwargs.get('bits', 4), + kwargs.get('groupsize', 128), + in_features, + out_features, + ) + LoraLayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=False) + if out_features % len(enable_lora) != 0: + raise ValueError("The length of enable_lora must divide out_features") + self.enable_lora = enable_lora + # Actual trainable parameters + if r > 0 and any(enable_lora): + self.lora_A = nn.Linear(in_features, r * sum(enable_lora), bias=False) + self.lora_B = nn.Conv1d( + r * sum(enable_lora), + out_features // len(enable_lora) * sum(enable_lora), + kernel_size=1, + groups=2, + bias=False, + ) + self.scaling = self.lora_alpha / self.r + # Freezing the pre-trained weight matrix + self.qweight.requires_grad = False + # Compute the indices + self.lora_ind = self.weight.new_zeros((out_features,), dtype=torch.bool).view(len(enable_lora), -1) + self.lora_ind[enable_lora, :] = True + self.lora_ind = self.lora_ind.view(-1) + self.reset_parameters() + + def reset_parameters(self): + if hasattr(self, "lora_A"): + # initialize A the same way as the default for nn.Linear and B to zero + nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5)) + nn.init.zeros_(self.lora_B.weight) + + def zero_pad(self, x): + result = x.new_zeros((*x.shape[:-1], self.out_features)) + result = result.view(-1, self.out_features) + result[:, self.lora_ind] = x.reshape( + -1, self.out_features // len(self.enable_lora) * sum(self.enable_lora) + ) + return result.view((*x.shape[:-1], self.out_features)) + + def forward(self, x: torch.Tensor): + result = super().forward(x)#.detach() + if self.disable_adapters: + return result + elif self.r > 0: + if not torch.is_autocast_enabled(): + expected_dtype = result.dtype + if x.dtype != torch.float32: + x = x.float() + after_A = self.lora_A(self.lora_dropout(x)) + after_B = self.lora_B(after_A.transpose(-2, -1)).transpose(-2, -1) + output = self.zero_pad(after_B).to(expected_dtype) * self.scaling + result += output + else: + after_A = self.lora_A(self.lora_dropout(x)) + after_B = self.lora_B(after_A.transpose(-2, -1)).transpose(-2, -1) + output = self.zero_pad(after_B) * self.scaling + result += output + return result + def prepare_model_for_int8_training( model, @@ -955,7 +968,6 @@ def prepare_model_for_int8_training( This method wrapps the entire protocol for preparing a model before running a training. This includes: 1- Cast the layernorm in fp32 2- making output embedding layer require grads 3- Add the upcasting of the lm head to fp32 - Args: model, (`transformers.PreTrainedModel`): The loaded model from `transformers` @@ -995,7 +1007,6 @@ class CastOutputToFloat(torch.nn.Sequential): r""" Manually cast to the expected dtype of the lm_head as sometimes there is a final layer norm that is casted in fp32 - """ def forward(self, x): diff --git a/src/xturing/engines/quant_utils/__init__.py b/src/xturing/engines/quant_utils/__init__.py new file mode 100644 index 0000000..11d0eb5 --- /dev/null +++ b/src/xturing/engines/quant_utils/__init__.py @@ -0,0 +1 @@ +from .quant import make_quant, autotune_warmup, QuantLinear \ No newline at end of file diff --git a/src/xturing/engines/quant_utils/custom_autotune.py b/src/xturing/engines/quant_utils/custom_autotune.py new file mode 100644 index 0000000..dfc8c6f --- /dev/null +++ b/src/xturing/engines/quant_utils/custom_autotune.py @@ -0,0 +1,168 @@ +#https://github.com/fpgaminer/GPTQ-triton +""" +Mostly the same as the autotuner in Triton, but with a few changes like using 40 runs instead of 100. +""" + +import builtins +import math +import time +from typing import Dict + +import triton + + +class Autotuner(triton.KernelInterface): + def __init__(self, fn, arg_names, configs, key, reset_to_zero, prune_configs_by: Dict = None, nearest_power_of_two: bool = False): + ''' + :param prune_configs_by: a dict of functions that are used to prune configs, fields: + 'perf_model': performance model used to predicate running time with different configs, returns running time + 'top_k': number of configs to bench + 'prune_num_stages_by'(optional): a function used to prune num_stages. It take configs:List[Config] as its input, and returns pruned configs. + 'nearest_power_of_two'(optional): whether to round key arguments to the nearest power of two when caching tuning results + ''' + if not configs: + self.configs = [triton.Config({}, num_warps=4, num_stages=2)] + else: + self.configs = configs + self.key_idx = [arg_names.index(k) for k in key] + self.nearest_power_of_two = nearest_power_of_two + self.cache = {} + # hook to reset all required tensor to zeros before relaunching a kernel + self.hook = lambda args: 0 + if reset_to_zero is not None: + self.reset_idx = [arg_names.index(k) for k in reset_to_zero] + + def _hook(args): + for i in self.reset_idx: + args[i].zero_() + self.hook = _hook + self.arg_names = arg_names + # prune configs + if prune_configs_by: + perf_model, top_k = prune_configs_by['perf_model'], prune_configs_by['top_k'] + if 'early_config_prune' in prune_configs_by: + early_config_prune = prune_configs_by['early_config_prune'] + else: + perf_model, top_k, early_config_prune = None, None, None + self.perf_model, self.configs_top_k = perf_model, top_k + self.early_config_prune = early_config_prune + self.fn = fn + + def _bench(self, *args, config, **meta): + # check for conflicts, i.e. meta-parameters both provided + # as kwargs and by the autotuner + conflicts = meta.keys() & config.kwargs.keys() + if conflicts: + raise ValueError( + f"Conflicting meta-parameters: {', '.join(conflicts)}." + " Make sure that you don't re-define auto-tuned symbols." + ) + # augment meta-parameters with tunable ones + current = dict(meta, **config.kwargs) + + def kernel_call(): + if config.pre_hook: + config.pre_hook(self.nargs) + self.hook(args) + self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current) + try: + # In testings using only 40 reps seems to be close enough and it appears to be what PyTorch uses + # PyTorch also sets fast_flush to True, but I didn't see any speedup so I'll leave the default + return triton.testing.do_bench(kernel_call, rep=40) + except triton.compiler.OutOfResources: + return float('inf') + + def run(self, *args, **kwargs): + self.nargs = dict(zip(self.arg_names, args)) + if len(self.configs) > 1: + key = tuple(args[i] for i in self.key_idx) + + # This reduces the amount of autotuning by rounding the keys to the nearest power of two + # In my testing this gives decent results, and greatly reduces the amount of tuning required + if self.nearest_power_of_two: + key = tuple([2 ** int(math.log2(x) + 0.5) for x in key]) + + if key not in self.cache: + # prune configs + pruned_configs = self.prune_configs(kwargs) + bench_start = time.time() + timings = {config: self._bench(*args, config=config, **kwargs) + for config in pruned_configs} + timings = {k:v for k,v in timings.items() if v != float('inf')} + bench_end = time.time() + self.bench_time = bench_end - bench_start + self.cache[key] = builtins.min(timings, key=timings.get) + self.hook(args) + self.configs_timings = timings + config = self.cache[key] + else: + config = self.configs[0] + self.best_config = config + if config.pre_hook is not None: + config.pre_hook(self.nargs) + return self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs) + + def prune_configs(self, kwargs): + pruned_configs = self.configs + if self.early_config_prune: + pruned_configs = self.early_config_prune(self.configs, self.nargs) + if self.perf_model: + top_k = self.configs_top_k + if isinstance(top_k, float) and top_k <= 1.0: + top_k = int(len(self.configs) * top_k) + if len(pruned_configs) > top_k: + est_timing = { + config: self.perf_model(**self.nargs, **kwargs, **config.kwargs, num_stages=config.num_stages, + num_warps=config.num_warps) + for config in pruned_configs + } + pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k] + return pruned_configs + + def warmup(self, *args, **kwargs): + self.nargs = dict(zip(self.arg_names, args)) + for config in self.prune_configs(kwargs): + self.fn.warmup( + *args, + num_warps=config.num_warps, + num_stages=config.num_stages, + **kwargs, + **config.kwargs, + ) + self.nargs = None + + +def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, nearest_power_of_two=False): + """ + Decorator for auto-tuning a :code:`triton.jit`'d function. + .. highlight:: python + .. code-block:: python + @triton.autotune(configs=[ + triton.Config(meta={'BLOCK_SIZE': 128}, num_warps=4), + triton.Config(meta={'BLOCK_SIZE': 1024}, num_warps=8), + ], + key=['x_size'] # the two above configs will be evaluated anytime + # the value of x_size changes + ) + @triton.jit + def kernel(x_ptr, x_size, **META): + BLOCK_SIZE = META['BLOCK_SIZE'] + :note: When all the configurations are evaluated, the kernel will run multiple time. + This means that whatever value the kernel updates will be updated multiple times. + To avoid this undesired behavior, you can use the `reset_to_zero` argument, which + reset the value of the provided tensor to `zero` before running any configuration. + :param configs: a list of :code:`triton.Config` objects + :type configs: list[triton.Config] + :param key: a list of argument names whose change in value will trigger the evaluation of all provided configs. + :type key: list[str] + :param prune_configs_by: a dict of functions that are used to prune configs, fields: + 'perf_model': performance model used to predicate running time with different configs, returns running time + 'top_k': number of configs to bench + 'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It take configs:List[Config] as its input, and returns pruned configs. + :param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs. + :type reset_to_zero: list[str] + """ + def decorator(fn): + return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero, prune_configs_by, nearest_power_of_two) + + return decorator diff --git a/src/xturing/engines/quant_utils/quant.py b/src/xturing/engines/quant_utils/quant.py new file mode 100644 index 0000000..448e344 --- /dev/null +++ b/src/xturing/engines/quant_utils/quant.py @@ -0,0 +1,496 @@ +import numpy as np +import torch +import torch.nn as nn +from torch.cuda.amp import custom_bwd, custom_fwd +import math + +def quantize(x, scale, zero, maxq): + if maxq < 0: + return (x > scale / 2).float() * scale + (x < zero / 2).float() * zero + q = torch.clamp(torch.round(x / scale) + zero, 0, maxq) + return scale * (q - zero) + +class Quantizer(nn.Module): + + def __init__(self, shape=1): + super(Quantizer, self).__init__() + self.register_buffer('maxq', torch.tensor(0)) + self.register_buffer('scale', torch.zeros(shape)) + self.register_buffer('zero', torch.zeros(shape)) + + def configure( + self, + bits, perchannel=False, sym=True, + mse=False, norm=2.4, grid=100, maxshrink=.8, + trits=False + ): + + self.maxq = torch.tensor(2 ** bits - 1) + self.perchannel = perchannel + self.sym = sym + self.mse = mse + self.norm = norm + self.grid = grid + self.maxshrink = maxshrink + if trits: + self.maxq = torch.tensor(-1) + + def find_params(self, x, weight=False): + dev = x.device + self.maxq = self.maxq.to(dev) + + shape = x.shape + if self.perchannel: + if weight: + x = x.flatten(1) + else: + if len(shape) == 4: + x = x.permute([1, 0, 2, 3]) + x = x.flatten(1) + if len(shape) == 3: + x = x.reshape((-1, shape[-1])).t() + if len(shape) == 2: + x = x.t() + else: + x = x.flatten().unsqueeze(0) + + tmp = torch.zeros(x.shape[0], device=dev) + xmin = torch.minimum(x.min(1)[0], tmp) + xmax = torch.maximum(x.max(1)[0], tmp) + + if self.sym: + xmax = torch.maximum(torch.abs(xmin), xmax) + tmp = xmin < 0 + if torch.any(tmp): + xmin[tmp] = -xmax[tmp] + tmp = (xmin == 0) & (xmax == 0) + xmin[tmp] = -1 + xmax[tmp] = +1 + + if self.maxq < 0: + self.scale = xmax + self.zero = xmin + else: + self.scale = (xmax - xmin) / self.maxq + if self.sym: + self.zero = torch.full_like(self.scale, (self.maxq + 1) / 2) + else: + self.zero = torch.round(-xmin / self.scale) + + if self.mse: + best = torch.full([x.shape[0]], float('inf'), device=dev) + for i in range(int(self.maxshrink * self.grid)): + p = 1 - i / self.grid + xmin1 = p * xmin + xmax1 = p * xmax + scale1 = (xmax1 - xmin1) / self.maxq + zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero + q = quantize(x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq) + q -= x + q.abs_() + q.pow_(self.norm) + err = torch.sum(q, 1) + tmp = err < best + if torch.any(tmp): + best[tmp] = err[tmp] + self.scale[tmp] = scale1[tmp] + self.zero[tmp] = zero1[tmp] + if not self.perchannel: + if weight: + tmp = shape[0] + else: + tmp = shape[1] if len(shape) != 3 else shape[2] + self.scale = self.scale.repeat(tmp) + self.zero = self.zero.repeat(tmp) + + if weight: + shape = [-1] + [1] * (len(shape) - 1) + self.scale = self.scale.reshape(shape) + self.zero = self.zero.reshape(shape) + return + if len(shape) == 4: + self.scale = self.scale.reshape((1, -1, 1, 1)) + self.zero = self.zero.reshape((1, -1, 1, 1)) + if len(shape) == 3: + self.scale = self.scale.reshape((1, 1, -1)) + self.zero = self.zero.reshape((1, 1, -1)) + if len(shape) == 2: + self.scale = self.scale.unsqueeze(0) + self.zero = self.zero.unsqueeze(0) + + def quantize(self, x): + if self.ready(): + return quantize(x, self.scale, self.zero, self.maxq) + return x + + def enabled(self): + return self.maxq > 0 + + def ready(self): + return torch.all(self.scale != 0) + +try: + import triton + import triton.language as tl + from . import custom_autotune + + # code based https://github.com/fpgaminer/GPTQ-triton + @custom_autotune.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + # These provided a benefit on a 3090 + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + ], + key=['M', 'N'], + nearest_power_of_two=True, + ) + + @triton.jit + def matmul_248_kernel(a_ptr, b_ptr, c_ptr, + scales_ptr, zeros_ptr, g_ptr, + M, N, K, bits, maxq, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + stride_scales, stride_zeros, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr): + """ + Compute the matrix multiplication C = A x B. + A is of shape (M, K) float16 + B is of shape (K//8, N) int32 + C is of shape (M, N) float16 + scales is of shape (G, N) float16 + zeros is of shape (G, N) float16 + g_ptr is of shape (K) int32 + """ + infearure_per_bits = 32 // bits + + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_k = tl.cdiv(K, BLOCK_SIZE_K) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K) + a_mask = (offs_am[:, None] < M) + # b_ptrs is set up such that it repeats elements along the K axis 8 times + b_ptrs = b_ptr + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N) + g_ptrs = g_ptr + offs_k + # shifter is used to extract the N bits of each element in the 32-bit word from B + scales_ptrs = scales_ptr + offs_bn[None, :] + zeros_ptrs = zeros_ptr + (offs_bn[None, :]// infearure_per_bits) + + shifter = (offs_k % infearure_per_bits) * bits + zeros_shifter = (offs_bn % infearure_per_bits) * bits + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in range(0, num_pid_k): + g_idx = tl.load(g_ptrs) + + # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop + scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + + zeros = (zeros >> zeros_shifter[None, :]) & maxq + zeros = (zeros + 1) + + a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K) + b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated + + # Now we need to unpack b (which is N-bit values) into 32-bit values + b = (b >> shifter[:, None]) & maxq # Extract the N-bit values + b = (b - zeros) * scales # Scale and shift + + accumulator += tl.dot(a, b) + a_ptrs += BLOCK_SIZE_K + b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk + g_ptrs += BLOCK_SIZE_K + + c = accumulator.to(tl.float16) + c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :] + c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N) + tl.store(c_ptrs, accumulator, mask=c_mask) + + # code based https://github.com/fpgaminer/GPTQ-triton + @custom_autotune.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_N': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 256, 'BLOCK_SIZE_N': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_N': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_N': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_N': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_K': 32, 'BLOCK_SIZE_N': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + # These provided a benefit on a 3090 + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_N': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_N': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 32, 'BLOCK_SIZE_N': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_N': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_N': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 32, 'BLOCK_SIZE_N': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_N': 128, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + ], + key=['M', 'K'], + nearest_power_of_two=True, + ) + + @triton.jit + def trans_matmul_248_kernel(a_ptr, b_ptr, c_ptr, + scales_ptr, zeros_ptr, g_ptr, + M, N, K, bits, maxq, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + stride_scales, stride_zeros, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr): + """ + Compute the matrix multiplication C = A x B. + A is of shape (M, N) float16 + B is of shape (K//8, N) int32 + C is of shape (M, K) float16 + scales is of shape (G, N) float16 + zeros is of shape (G, N) float16 + g_ptr is of shape (K) int32 + """ + infearure_per_bits = 32 // bits + + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_k = tl.cdiv(K, BLOCK_SIZE_K) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_k + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_k = (pid % num_pid_in_group) // group_size_m + + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bk = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + offs_n = tl.arange(0, BLOCK_SIZE_N) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_n[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_N) + a_mask = (offs_am[:, None] < M) + # b_ptrs is set up such that it repeats elements along the K axis 8 times + b_ptrs = b_ptr + ((offs_bk[:, None] // infearure_per_bits) * stride_bk + offs_n[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N) + g_ptrs = g_ptr + offs_bk + g_idx = tl.load(g_ptrs) + + # shifter is used to extract the N bits of each element in the 32-bit word from B + scales_ptrs = scales_ptr + offs_n[None, :] + g_idx[:, None] * stride_scales + zeros_ptrs = zeros_ptr + (offs_n[None, :]// infearure_per_bits) + g_idx[:, None] * stride_zeros + + shifter = (offs_bk % infearure_per_bits) * bits + zeros_shifter = (offs_n % infearure_per_bits) * bits + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32) + + for k in range(0, num_pid_n): + # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop + scales = tl.load(scales_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + zeros = tl.load(zeros_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + + zeros = (zeros >> zeros_shifter[None, :]) & maxq + zeros = (zeros + 1) + + a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_N) + b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated + + # Now we need to unpack b (which is N-bit values) into 32-bit values + b = (b >> shifter[:, None]) & maxq # Extract the N-bit values + b = (b - zeros) * scales # Scale and shift + b = tl.trans(b) + + accumulator += tl.dot(a, b) + a_ptrs += BLOCK_SIZE_N + b_ptrs += BLOCK_SIZE_N + scales_ptrs += BLOCK_SIZE_N + zeros_ptrs += (BLOCK_SIZE_N // infearure_per_bits) + + c = accumulator.to(tl.float16) + c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bk[None, :] + c_mask = (offs_am[:, None] < M) & (offs_bk[None, :] < K) + tl.store(c_ptrs, accumulator, mask=c_mask) +except: + print('trioton not installed.') + +def matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq): + output = torch.empty((input.shape[0], qweight.shape[1]), device='cuda', dtype=torch.float16) + grid = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(qweight.shape[1], META['BLOCK_SIZE_N']),) + matmul_248_kernel[grid](input, qweight, output, + scales, qzeros, g_idx, + input.shape[0], qweight.shape[1], input.shape[1], bits, maxq, + input.stride(0), input.stride(1), + qweight.stride(0), qweight.stride(1), + output.stride(0), output.stride(1), + scales.stride(0), qzeros.stride(0)) + return output + +def transpose_matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq): + output_dim = (qweight.shape[0] * 32) // bits + output = torch.empty((input.shape[0], output_dim), device='cuda', dtype=torch.float16) + grid = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(output_dim, META['BLOCK_SIZE_K']),) + trans_matmul_248_kernel[grid](input, qweight, output, + scales, qzeros, g_idx, + input.shape[0], qweight.shape[1], output_dim, bits, maxq, + input.stride(0), input.stride(1), + qweight.stride(0), qweight.stride(1), + output.stride(0), output.stride(1), + scales.stride(0), qzeros.stride(0)) + return output + +class QuantLinearFunction(torch.autograd.Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float16) + def forward(ctx, input, qweight, scales, qzeros, g_idx, bits, maxq): + output = matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq) + ctx.save_for_backward(qweight, scales, qzeros, g_idx) + ctx.bits,ctx.maxq = bits, maxq + return output + + @staticmethod + @custom_bwd + def backward(ctx, grad_output): + qweight, scales, qzeros, g_idx = ctx.saved_tensors + bits, maxq = ctx.bits, ctx.maxq + grad_input = None + + if ctx.needs_input_grad[0]: + grad_input = transpose_matmul248(grad_output, qweight, scales, qzeros, g_idx, bits, maxq) + return grad_input, None, None, None, None, None, None + +class QuantLinear(nn.Module): + def __init__(self, bits, groupsize, infeatures, outfeatures, bias): + super().__init__() + if bits not in [2,4,8]: + raise NotImplementedError("Only 2,4,8 bits are supported.") + self.infeatures = infeatures + self.outfeatures = outfeatures + self.bits = bits + self.maxq = 2 ** self.bits - 1 + self.groupsize = groupsize if groupsize != -1 else infeatures + + self.register_buffer('qweight', torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32)) + self.register_buffer('qzeros', torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures // 32 * self.bits), dtype=torch.int32)) + self.register_buffer('scales', torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16)) + self.register_buffer('g_idx', torch.tensor([i // self.groupsize for i in range(infeatures)], dtype = torch.int32)) + if bias: + self.register_buffer('bias', torch.zeros((outfeatures),dtype=torch.float16)) + else: + self.bias = None + + def pack(self, linear, scales, zeros, g_idx = None): + self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx + + scales = scales.t().contiguous() + zeros = zeros.t().contiguous() + scale_zeros = zeros * scales + self.scales = scales.clone().half() + if linear.bias is not None: + self.bias = linear.bias.clone().half() + + intweight = [] + for idx in range(self.infeatures): + intweight.append(torch.round((linear.weight.data[:,idx] + scale_zeros[self.g_idx[idx]]) / self.scales[self.g_idx[idx]]).to(torch.int)[:,None]) + intweight = torch.cat(intweight,dim=1) + intweight = intweight.t().contiguous() + intweight = intweight.numpy().astype(np.uint32) + qweight = np.zeros((intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32) + i = 0 + row = 0 + while row < qweight.shape[0]: + if self.bits in [2,4,8]: + for j in range(i, i + (32//self.bits)): + qweight[row] |= intweight[j] << (self.bits * (j - i)) + i += 32//self.bits + row += 1 + else: + raise NotImplementedError("Only 2,4,8 bits are supported.") + + qweight = qweight.astype(np.int32) + self.qweight = torch.from_numpy(qweight) + + zeros -= 1; + zeros = zeros.numpy().astype(np.uint32) + qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32) + i = 0 + col = 0 + while col < qzeros.shape[1]: + if self.bits in [2,4,8]: + for j in range(i, i + (32//self.bits)): + qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i)) + i += 32//self.bits + col += 1 + else: + raise NotImplementedError("Only 2,4,8 bits are supported.") + + qzeros = qzeros.astype(np.int32) + self.qzeros = torch.from_numpy(qzeros) + + def forward(self, x): + out_shape = x.shape[:-1] + (self.outfeatures, ) + out = QuantLinearFunction.apply(x.reshape(-1,x.shape[-1]), self.qweight, self.scales, + self.qzeros, self.g_idx, self.bits, self.maxq) + out = out + self.bias if self.bias is not None else out + return out.reshape(out_shape) + +def autotune_warmup(model, transpose = False): + """ + Pre-tunes the quantized kernel + """ + from tqdm import tqdm + + n_values = {} + + for _, m in model.named_modules(): + if not isinstance(m, QuantLinear): + continue + + k = m.infeatures + n = m.outfeatures + + if n not in n_values: + n_values[n] = (k, m.qweight.cuda(), m.scales.cuda(), m.qzeros.cuda(), m.g_idx.cuda(), m.bits, m.maxq) + + print(f'Found {len(n_values)} unique N values.') + + print('Warming up autotune cache ...') + for m in tqdm(range(0, 12)): + m = 2 ** m # [1, 2048] + for n, (k, qweight, scales, qzeros, g_idx, bits, maxq) in n_values.items(): + a = torch.randn(m, k, dtype=torch.float16, device='cuda') + matmul248(a, qweight, scales, qzeros, g_idx, bits, maxq) + if transpose: + a = torch.randn(m, n, dtype=torch.float16, device='cuda') + transpose_matmul248(a, qweight, scales, qzeros, g_idx, bits, maxq) + del n_values + +def make_quant(module, names, bits, groupsize, name=''): + if isinstance(module, QuantLinear): + return + for attr in dir(module): + tmp = getattr(module, attr) + name1 = name + '.' + attr if name != '' else attr + if name1 in names: + delattr(module, attr) + setattr(module, attr, QuantLinear(bits, groupsize, tmp.in_features, tmp.out_features, tmp.bias is not None)) + for name1, child in module.named_children(): + make_quant(child, names, bits, groupsize, name + '.' + name1 if name != '' else name1) diff --git a/src/xturing/models/__init__.py b/src/xturing/models/__init__.py index 7c7f3e1..512cd7b 100644 --- a/src/xturing/models/__init__.py +++ b/src/xturing/models/__init__.py @@ -5,7 +5,7 @@ from .galactica import Galactica, GalacticaInt8, GalacticaLora, GalacticaLoraInt8 from .gpt2 import GPT2, GPT2Int8, GPT2Lora, GPT2LoraInt8 from .gptj import GPTJ, GPTJInt8, GPTJLora, GPTJLoraInt8 -from .llama import Llama, LlamaInt8, LlamaLora, LlamaLoraInt8 +from .llama import Llama, LlamaInt8, LlamaLora, LlamaLoraInt8, LlamaLoraInt4 from .opt import OPT, OPTInt8, OPTLora, OPTLoraInt8 from .stable_diffusion import StableDiffusion @@ -23,6 +23,7 @@ BaseModel.add_to_registry(LlamaLora.config_name, LlamaLora) BaseModel.add_to_registry(LlamaInt8.config_name, LlamaInt8) BaseModel.add_to_registry(LlamaLoraInt8.config_name, LlamaLoraInt8) +BaseModel.add_to_registry(LlamaLoraInt4.config_name, LlamaLoraInt4) BaseModel.add_to_registry(Galactica.config_name, Galactica) BaseModel.add_to_registry(GalacticaLora.config_name, GalacticaLora) BaseModel.add_to_registry(GalacticaInt8.config_name, GalacticaInt8) diff --git a/src/xturing/models/llama.py b/src/xturing/models/llama.py index 20f98cb..36e6a62 100644 --- a/src/xturing/models/llama.py +++ b/src/xturing/models/llama.py @@ -5,6 +5,7 @@ LLamaInt8Engine, LlamaLoraEngine, LlamaLoraInt8Engine, + LlamaLoraInt4Engine, ) from xturing.models.causal import ( CausalInt8Model, @@ -12,6 +13,10 @@ CausalLoraModel, CausalModel, ) +from xturing.trainers.base import BaseTrainer +from xturing.datasets.instruction_dataset import InstructionDataset +from xturing.datasets.text_dataset import TextDataset +from xturing.trainers.lightning_trainer import LightningTrainer class Llama(CausalModel): @@ -40,3 +45,25 @@ class LlamaLoraInt8(CausalLoraInt8Model): def __init__(self, weights_path: Optional[str] = None): super().__init__(LlamaLoraInt8Engine.config_name, weights_path) + + +class LlamaLoraInt4(CausalLoraInt8Model): + config_name: str = "llama_lora_int4" + + def _make_trainer(self, dataset: Union[TextDataset, InstructionDataset]): + return BaseTrainer.create( + LightningTrainer.config_name, + self.engine, + dataset, + self._make_collate_fn(dataset), + int(self.finetuning_args.num_train_epochs), + int(self.finetuning_args.batch_size), + float(self.finetuning_args.learning_rate), + self.finetuning_args.optimizer_name, + True, + True, + lora_type=32, + ) + + def __init__(self, weights_path: Optional[str] = None): + super().__init__(LlamaLoraInt4Engine.config_name, weights_path) diff --git a/src/xturing/trainers/lightning_trainer.py b/src/xturing/trainers/lightning_trainer.py index 4035580..7950753 100644 --- a/src/xturing/trainers/lightning_trainer.py +++ b/src/xturing/trainers/lightning_trainer.py @@ -100,6 +100,7 @@ def __init__( use_lora: bool = False, use_deepspeed: bool = False, max_training_time_in_secs: Optional[int] = None, + lora_type: int = 16, ): self.lightning_model = TuringLightningModule( model_engine=model_engine, @@ -129,6 +130,8 @@ def __init__( duration=datetime.timedelta(seconds=max_training_time_in_secs) ) ) + model_engine.model.train() + model_engine.model.print_trainable_parameters() if DEFAULT_DEVICE.type == "cpu": self.trainer = Trainer( @@ -167,7 +170,7 @@ def __init__( num_nodes=1, accelerator="gpu", strategy=strategy, - precision=16, + precision=lora_type, max_epochs=max_epochs, callbacks=training_callbacks, enable_checkpointing=True, diff --git a/src/xturing/utils/hub.py b/src/xturing/utils/hub.py index 600987b..116a76c 100644 --- a/src/xturing/utils/hub.py +++ b/src/xturing/utils/hub.py @@ -88,6 +88,7 @@ class ModelHub(Hub): "distilgpt2_lora_finetuned_alpaca" ), "llama_lora_finetuned_alpaca": make_model_url("llama_lora_finetuned_alpaca"), + "llama_lora_int4": make_model_url("llama_lora_int4"), } def __init__(self): diff --git a/src/xturing/utils/text_splitter.py b/src/xturing/utils/text_splitter.py index 653aaf3..9a8bba8 100644 --- a/src/xturing/utils/text_splitter.py +++ b/src/xturing/utils/text_splitter.py @@ -12,7 +12,6 @@ Collection, Iterable, List, - Literal, Optional, Union, ) @@ -118,8 +117,8 @@ def _huggingface_tokenizer_length(text: str) -> int: def from_tiktoken_encoder( cls, encoding_name: str = "gpt2", - allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), - disallowed_special: Union[Literal["all"], Collection[str]] = "all", + allowed_special = set(), + disallowed_special = set(), **kwargs: Any, ) -> TextSplitter: """Text splitter that uses tiktoken encoder to count length."""