From 642728bc34402edb51a2f05c2b2fc45f642e33f4 Mon Sep 17 00:00:00 2001 From: Tushar Date: Thu, 15 Jun 2023 14:12:00 +0000 Subject: [PATCH 01/18] feat: added hooks to remove unused imports and abosulify imports added the autoflake hook for unused imports and absoulify-imports to make all the import absolute --- .pre-commit-config.yaml | 8 ++++++++ requirements-dev.txt | 2 ++ 2 files changed, 10 insertions(+) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 2404b16..f1c9c39 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -21,3 +21,11 @@ repos: rev: v0.19.1 hooks: - id: gitlint + - repo: https://github.com/PyCQA/autoflake + rev: v2.1.1 + hooks: + - id: autoflake + - repo: https://github.com/MarcoGorelli/absolufy-imports + rev: v0.3.1 + hooks: + - id: absolufy-imports diff --git a/requirements-dev.txt b/requirements-dev.txt index 51f1982..96b437d 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,2 +1,4 @@ pre-commit pytest +autoflake +absoulify-imports From ad468e421606cc500c1cd59e4d182bc5d0a5d134 Mon Sep 17 00:00:00 2001 From: Tushar Date: Thu, 15 Jun 2023 14:29:37 +0000 Subject: [PATCH 02/18] fix: applied autoflake on all the files After applying, manually checked the soundness of the removal --- src/xturing/cli/chat.py | 1 - src/xturing/datasets/instruction_dataset.py | 1 - src/xturing/engines/generic_engine.py | 1 - src/xturing/engines/llama_engine.py | 2 +- src/xturing/engines/llama_utils/llama.py | 3 --- src/xturing/engines/opt_engine.py | 1 - src/xturing/models/causal.py | 3 +-- src/xturing/models/llama.py | 16 ++++++++++------ src/xturing/preprocessors/base.py | 1 - src/xturing/registry.py | 2 +- .../self_instruct/bootstrap_instructions.py | 1 - src/xturing/self_instruct/generate_instances.py | 1 - .../self_instruct/identify_if_classification.py | 1 - .../self_instruct/prepare_for_finetuning.py | 1 - src/xturing/self_instruct/prepare_seed_tasks.py | 2 -- src/xturing/trainers/lightning_trainer.py | 7 ++----- src/xturing/utils/text_splitter.py | 16 +++------------- 17 files changed, 18 insertions(+), 42 deletions(-) diff --git a/src/xturing/cli/chat.py b/src/xturing/cli/chat.py index 2588a5e..f3cfdea 100644 --- a/src/xturing/cli/chat.py +++ b/src/xturing/cli/chat.py @@ -1,4 +1,3 @@ -import time from pathlib import Path import click diff --git a/src/xturing/datasets/instruction_dataset.py b/src/xturing/datasets/instruction_dataset.py index 213a7fd..8c90f56 100644 --- a/src/xturing/datasets/instruction_dataset.py +++ b/src/xturing/datasets/instruction_dataset.py @@ -1,5 +1,4 @@ import json -import os from dataclasses import dataclass from pathlib import Path from typing import List, Optional, Union diff --git a/src/xturing/engines/generic_engine.py b/src/xturing/engines/generic_engine.py index 958725d..85c9767 100644 --- a/src/xturing/engines/generic_engine.py +++ b/src/xturing/engines/generic_engine.py @@ -1,4 +1,3 @@ -import os from pathlib import Path from typing import List, Optional, Union diff --git a/src/xturing/engines/llama_engine.py b/src/xturing/engines/llama_engine.py index 2456a51..48749b8 100644 --- a/src/xturing/engines/llama_engine.py +++ b/src/xturing/engines/llama_engine.py @@ -1,6 +1,6 @@ import os from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Optional, Union import torch import transformers diff --git a/src/xturing/engines/llama_utils/llama.py b/src/xturing/engines/llama_utils/llama.py index 57a89f9..6835116 100644 --- a/src/xturing/engines/llama_utils/llama.py +++ b/src/xturing/engines/llama_utils/llama.py @@ -1,6 +1,5 @@ import math import os -from pathlib import Path from shutil import copyfile from typing import Any, Dict, List, Optional, Tuple, Union @@ -18,8 +17,6 @@ from transformers.modeling_utils import PreTrainedModel from transformers.tokenization_utils import PreTrainedTokenizer -from xturing.engines.causal import CausalEngine, CausalLoraEngine - # Tokenizer taken from transformers library: https://github.com/huggingface/transformers """Tokenization classes for LLaMA.""" diff --git a/src/xturing/engines/opt_engine.py b/src/xturing/engines/opt_engine.py index 49e5167..b6de1ae 100644 --- a/src/xturing/engines/opt_engine.py +++ b/src/xturing/engines/opt_engine.py @@ -1,4 +1,3 @@ -import os from pathlib import Path from typing import Optional, Union diff --git a/src/xturing/models/causal.py b/src/xturing/models/causal.py index 24d0c75..88b4f5b 100644 --- a/src/xturing/models/causal.py +++ b/src/xturing/models/causal.py @@ -1,6 +1,6 @@ import json from pathlib import Path -from typing import Iterable, List, Optional, Type, Union +from typing import Iterable, List, Optional, Union import torch from pytorch_lightning.loggers import Logger @@ -13,7 +13,6 @@ from xturing.datasets.instruction_dataset import InstructionDataset from xturing.datasets.text_dataset import TextDataset from xturing.engines.base import BaseEngine -from xturing.engines.causal import CausalLoraEngine from xturing.models import BaseModel from xturing.preprocessors.base import BasePreprocessor from xturing.trainers.base import BaseTrainer diff --git a/src/xturing/models/llama.py b/src/xturing/models/llama.py index 85278b4..902fde9 100644 --- a/src/xturing/models/llama.py +++ b/src/xturing/models/llama.py @@ -1,12 +1,15 @@ -from typing import Iterable, List, Optional, Union +from typing import Iterable, Optional, Union + from pytorch_lightning.loggers import Logger +from xturing.datasets.instruction_dataset import InstructionDataset +from xturing.datasets.text_dataset import TextDataset from xturing.engines.llama_engine import ( LLamaEngine, LLamaInt8Engine, LlamaLoraEngine, - LlamaLoraInt8Engine, LlamaLoraInt4Engine, + LlamaLoraInt8Engine, ) from xturing.models.causal import ( CausalInt8Model, @@ -15,8 +18,6 @@ CausalModel, ) from xturing.trainers.base import BaseTrainer -from xturing.datasets.instruction_dataset import InstructionDataset -from xturing.datasets.text_dataset import TextDataset from xturing.trainers.lightning_trainer import LightningTrainer @@ -51,8 +52,11 @@ def __init__(self, weights_path: Optional[str] = None): class LlamaLoraInt4(CausalLoraInt8Model): config_name: str = "llama_lora_int4" - def _make_trainer(self, dataset: Union[TextDataset, InstructionDataset], - logger: Union[Logger, Iterable[Logger], bool] = True): + def _make_trainer( + self, + dataset: Union[TextDataset, InstructionDataset], + logger: Union[Logger, Iterable[Logger], bool] = True, + ): return BaseTrainer.create( LightningTrainer.config_name, self.engine, diff --git a/src/xturing/preprocessors/base.py b/src/xturing/preprocessors/base.py index da6ef33..ed16f0b 100644 --- a/src/xturing/preprocessors/base.py +++ b/src/xturing/preprocessors/base.py @@ -1,4 +1,3 @@ -from xturing.models.stable_diffusion import StableDiffusion from xturing.preprocessors.instruction_collator import InstructionDataCollator from xturing.preprocessors.text_collator import TextDataCollator from xturing.registry import BaseParent diff --git a/src/xturing/registry.py b/src/xturing/registry.py index 8d0c584..4137658 100644 --- a/src/xturing/registry.py +++ b/src/xturing/registry.py @@ -1,4 +1,4 @@ -from typing import Any, Dict +from typing import Any class BaseParent: diff --git a/src/xturing/self_instruct/bootstrap_instructions.py b/src/xturing/self_instruct/bootstrap_instructions.py index 099015d..53dc6d1 100644 --- a/src/xturing/self_instruct/bootstrap_instructions.py +++ b/src/xturing/self_instruct/bootstrap_instructions.py @@ -1,5 +1,4 @@ import json -import os import random import re import string diff --git a/src/xturing/self_instruct/generate_instances.py b/src/xturing/self_instruct/generate_instances.py index d478a44..b82d340 100644 --- a/src/xturing/self_instruct/generate_instances.py +++ b/src/xturing/self_instruct/generate_instances.py @@ -1,5 +1,4 @@ import json -import os import random from collections import OrderedDict from pathlib import Path diff --git a/src/xturing/self_instruct/identify_if_classification.py b/src/xturing/self_instruct/identify_if_classification.py index 1cf7974..f0b85b5 100644 --- a/src/xturing/self_instruct/identify_if_classification.py +++ b/src/xturing/self_instruct/identify_if_classification.py @@ -1,5 +1,4 @@ import json -import os import random from collections import OrderedDict from pathlib import Path diff --git a/src/xturing/self_instruct/prepare_for_finetuning.py b/src/xturing/self_instruct/prepare_for_finetuning.py index a249057..9e51e5a 100644 --- a/src/xturing/self_instruct/prepare_for_finetuning.py +++ b/src/xturing/self_instruct/prepare_for_finetuning.py @@ -1,5 +1,4 @@ import json -import os import random import re from pathlib import Path diff --git a/src/xturing/self_instruct/prepare_seed_tasks.py b/src/xturing/self_instruct/prepare_seed_tasks.py index 64d3a89..be646b2 100644 --- a/src/xturing/self_instruct/prepare_seed_tasks.py +++ b/src/xturing/self_instruct/prepare_seed_tasks.py @@ -1,7 +1,5 @@ -import ast import json import os -from typing import List from tqdm import tqdm diff --git a/src/xturing/trainers/lightning_trainer.py b/src/xturing/trainers/lightning_trainer.py index 71eedcd..af8c38f 100644 --- a/src/xturing/trainers/lightning_trainer.py +++ b/src/xturing/trainers/lightning_trainer.py @@ -1,16 +1,13 @@ import datetime -import os -import tempfile -import uuid from pathlib import Path -from typing import Iterable, Optional, Union, Type +from typing import Iterable, Optional, Union import pytorch_lightning as pl import torch from deepspeed.ops.adam import DeepSpeedCPUAdam from pytorch_lightning import callbacks -from pytorch_lightning.trainer.trainer import Trainer from pytorch_lightning.loggers import Logger +from pytorch_lightning.trainer.trainer import Trainer from xturing.config import DEFAULT_DEVICE, IS_INTERACTIVE from xturing.datasets.base import BaseDataset diff --git a/src/xturing/utils/text_splitter.py b/src/xturing/utils/text_splitter.py index 9a8bba8..247fd7e 100644 --- a/src/xturing/utils/text_splitter.py +++ b/src/xturing/utils/text_splitter.py @@ -2,19 +2,9 @@ """Functionality for splitting text.""" from __future__ import annotations -import copy import logging from abc import ABC, abstractmethod -from typing import ( - AbstractSet, - Any, - Callable, - Collection, - Iterable, - List, - Optional, - Union, -) +from typing import Any, Callable, Iterable, List, Optional logger = logging.getLogger() @@ -117,8 +107,8 @@ def _huggingface_tokenizer_length(text: str) -> int: def from_tiktoken_encoder( cls, encoding_name: str = "gpt2", - allowed_special = set(), - disallowed_special = set(), + allowed_special=set(), + disallowed_special=set(), **kwargs: Any, ) -> TextSplitter: """Text splitter that uses tiktoken encoder to count length.""" From cc723269c7fd00475b8604d6878e06a6ee3100f6 Mon Sep 17 00:00:00 2001 From: Tushar Date: Thu, 15 Jun 2023 14:33:34 +0000 Subject: [PATCH 03/18] feat: added args for handling errors without these args, there can be issues --- .pre-commit-config.yaml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f1c9c39..688dbc6 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -25,6 +25,10 @@ repos: rev: v2.1.1 hooks: - id: autoflake + args: + - "--in-place" + - "--remove-all-unused-imports" + - "--ignore-init-module-imports" - repo: https://github.com/MarcoGorelli/absolufy-imports rev: v0.3.1 hooks: From 4bb4d016db6688532fd01f98f2da773cde7e2746 Mon Sep 17 00:00:00 2001 From: Tushar Date: Thu, 15 Jun 2023 14:50:24 +0000 Subject: [PATCH 04/18] feat: testing the pre-commit hooks Changed the imports in models/falcon.py to relative to test --- .pre-commit-config.yaml | 6 ++---- src/xturing/engines/lora_engine/__init__.py | 6 +++++- src/xturing/engines/quant_utils/__init__.py | 2 +- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 688dbc6..64bee29 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -25,11 +25,9 @@ repos: rev: v2.1.1 hooks: - id: autoflake - args: - - "--in-place" - - "--remove-all-unused-imports" - - "--ignore-init-module-imports" + args: ["--in-place", "--remove-all-unused-imports", "--ignore-init-module-imports"] - repo: https://github.com/MarcoGorelli/absolufy-imports rev: v0.3.1 hooks: - id: absolufy-imports + args: ["--application-directories=src"] diff --git a/src/xturing/engines/lora_engine/__init__.py b/src/xturing/engines/lora_engine/__init__.py index f41b4c6..0266937 100644 --- a/src/xturing/engines/lora_engine/__init__.py +++ b/src/xturing/engines/lora_engine/__init__.py @@ -1 +1,5 @@ -from .lora import LoraConfig, LoraModel, prepare_model_for_int8_training +from xturing.engines.lora_engine.lora import ( + LoraConfig, + LoraModel, + prepare_model_for_int8_training, +) diff --git a/src/xturing/engines/quant_utils/__init__.py b/src/xturing/engines/quant_utils/__init__.py index 11d0eb5..d45d5f8 100644 --- a/src/xturing/engines/quant_utils/__init__.py +++ b/src/xturing/engines/quant_utils/__init__.py @@ -1 +1 @@ -from .quant import make_quant, autotune_warmup, QuantLinear \ No newline at end of file +from xturing.engines.quant_utils.quant import QuantLinear, autotune_warmup, make_quant From 7720c085c1166762edf72f29ee0372d832f0c4fb Mon Sep 17 00:00:00 2001 From: Tushar Date: Thu, 15 Jun 2023 15:03:18 +0000 Subject: [PATCH 05/18] fix: ran absolufy for all the files all the imports in the library are now absolute --- src/xturing/datasets/__init__.py | 11 +++--- src/xturing/engines/__init__.py | 37 +++++++++++++------ src/xturing/engines/llama_utils/__init__.py | 6 +++- src/xturing/model_apis/__init__.py | 16 ++++----- src/xturing/models/__init__.py | 40 ++++++++++++++------- src/xturing/models/distilgpt2.py | 3 +- src/xturing/models/gpt2.py | 8 +++-- src/xturing/preprocessors/__init__.py | 6 ++-- src/xturing/trainers/__init__.py | 4 +-- 9 files changed, 86 insertions(+), 45 deletions(-) diff --git a/src/xturing/datasets/__init__.py b/src/xturing/datasets/__init__.py index 73a80df..296795f 100644 --- a/src/xturing/datasets/__init__.py +++ b/src/xturing/datasets/__init__.py @@ -1,7 +1,10 @@ -from .base import BaseDataset -from .instruction_dataset import InstructionDataset, InstructionDatasetMeta -from .text2image_dataset import Text2ImageDataset -from .text_dataset import TextDataset, TextDatasetMeta +from xturing.datasets.base import BaseDataset +from xturing.datasets.instruction_dataset import ( + InstructionDataset, + InstructionDatasetMeta, +) +from xturing.datasets.text2image_dataset import Text2ImageDataset +from xturing.datasets.text_dataset import TextDataset, TextDatasetMeta BaseDataset.add_to_registry(TextDataset.config_name, TextDataset) BaseDataset.add_to_registry(InstructionDataset.config_name, InstructionDataset) diff --git a/src/xturing/engines/__init__.py b/src/xturing/engines/__init__.py index 56a9192..5691654 100644 --- a/src/xturing/engines/__init__.py +++ b/src/xturing/engines/__init__.py @@ -1,45 +1,60 @@ -from .base import BaseEngine -from .bloom_engine import ( +from xturing.engines.base import BaseEngine +from xturing.engines.bloom_engine import ( BloomEngine, BloomInt8Engine, BloomLoraEngine, BloomLoraInt8Engine, ) -from .cerebras_engine import ( +from xturing.engines.cerebras_engine import ( CerebrasEngine, CerebrasInt8Engine, CerebrasLoraEngine, CerebrasLoraInt8Engine, ) -from .distilgpt2_engine import DistilGPT2Engine, DistilGPT2LoraEngine -from .falcon_engine import ( +from xturing.engines.distilgpt2_engine import DistilGPT2Engine, DistilGPT2LoraEngine +from xturing.engines.falcon_engine import ( FalconEngine, FalconInt8Engine, FalconLoraEngine, FalconLoraInt8Engine, ) -from .galactica_engine import ( +from xturing.engines.galactica_engine import ( GalacticaEngine, GalacticaInt8Engine, GalacticaLoraEngine, GalacticaLoraInt8Engine, ) -from .generic_engine import ( +from xturing.engines.generic_engine import ( GenericEngine, GenericInt8Engine, GenericLoraEngine, GenericLoraInt8Engine, ) -from .gpt2_engine import GPT2Engine, GPT2Int8Engine, GPT2LoraEngine, GPT2LoraInt8Engine -from .gptj_engine import GPTJEngine, GPTJInt8Engine, GPTJLoraEngine, GPTJLoraInt8Engine -from .llama_engine import ( +from xturing.engines.gpt2_engine import ( + GPT2Engine, + GPT2Int8Engine, + GPT2LoraEngine, + GPT2LoraInt8Engine, +) +from xturing.engines.gptj_engine import ( + GPTJEngine, + GPTJInt8Engine, + GPTJLoraEngine, + GPTJLoraInt8Engine, +) +from xturing.engines.llama_engine import ( LLamaEngine, LLamaInt8Engine, LlamaLoraEngine, LlamaLoraInt4Engine, LlamaLoraInt8Engine, ) -from .opt_engine import OPTEngine, OPTInt8Engine, OPTLoraEngine, OPTLoraInt8Engine +from xturing.engines.opt_engine import ( + OPTEngine, + OPTInt8Engine, + OPTLoraEngine, + OPTLoraInt8Engine, +) BaseEngine.add_to_registry(DistilGPT2Engine.config_name, DistilGPT2Engine) BaseEngine.add_to_registry(DistilGPT2LoraEngine.config_name, DistilGPT2LoraEngine) diff --git a/src/xturing/engines/llama_utils/__init__.py b/src/xturing/engines/llama_utils/__init__.py index 4c9a60b..a6cf48d 100644 --- a/src/xturing/engines/llama_utils/__init__.py +++ b/src/xturing/engines/llama_utils/__init__.py @@ -1 +1,5 @@ -from .llama import LlamaConfig, LlamaForCausalLM, LlamaTokenizer +from xturing.engines.llama_utils.llama import ( + LlamaConfig, + LlamaForCausalLM, + LlamaTokenizer, +) diff --git a/src/xturing/model_apis/__init__.py b/src/xturing/model_apis/__init__.py index cee1ee9..eccca01 100644 --- a/src/xturing/model_apis/__init__.py +++ b/src/xturing/model_apis/__init__.py @@ -1,11 +1,11 @@ -from .ai21 import AI21TextGenerationAPI -from .ai21 import J2Grande as AI21J2Grande -from .base import BaseApi, TextGenerationAPI -from .cohere import CohereTextGenerationAPI -from .cohere import Medium as CohereMedium -from .openai import ChatGPT as OpenAIChatGPT -from .openai import Davinci as OpenAIDavinci -from .openai import OpenAITextGenerationAPI +from xturing.model_apis.ai21 import AI21TextGenerationAPI +from xturing.model_apis.ai21 import J2Grande as AI21J2Grande +from xturing.model_apis.base import BaseApi, TextGenerationAPI +from xturing.model_apis.cohere import CohereTextGenerationAPI +from xturing.model_apis.cohere import Medium as CohereMedium +from xturing.model_apis.openai import ChatGPT as OpenAIChatGPT +from xturing.model_apis.openai import Davinci as OpenAIDavinci +from xturing.model_apis.openai import OpenAITextGenerationAPI BaseApi.add_to_registry(OpenAITextGenerationAPI.config_name, OpenAITextGenerationAPI) BaseApi.add_to_registry(CohereTextGenerationAPI.config_name, CohereTextGenerationAPI) diff --git a/src/xturing/models/__init__.py b/src/xturing/models/__init__.py index 69299ac..7aa967f 100644 --- a/src/xturing/models/__init__.py +++ b/src/xturing/models/__init__.py @@ -1,20 +1,36 @@ -from .base import BaseModel -from .bloom import Bloom, BloomInt8, BloomLora, BloomLoraInt8 -from .cerebras import Cerebras, CerebrasInt8, CerebrasLora, CerebrasLoraInt8 -from .distilgpt2 import DistilGPT2, DistilGPT2Lora -from .falcon import Falcon, FalconInt8, FalconLora, FalconLoraInt8 -from .galactica import Galactica, GalacticaInt8, GalacticaLora, GalacticaLoraInt8 -from .generic import ( +from xturing.models.base import BaseModel +from xturing.models.bloom import Bloom, BloomInt8, BloomLora, BloomLoraInt8 +from xturing.models.cerebras import ( + Cerebras, + CerebrasInt8, + CerebrasLora, + CerebrasLoraInt8, +) +from xturing.models.distilgpt2 import DistilGPT2, DistilGPT2Lora +from xturing.models.falcon import Falcon, FalconInt8, FalconLora, FalconLoraInt8 +from xturing.models.galactica import ( + Galactica, + GalacticaInt8, + GalacticaLora, + GalacticaLoraInt8, +) +from xturing.models.generic import ( GenericInt8Model, GenericLoraInt8Model, GenericLoraModel, GenericModel, ) -from .gpt2 import GPT2, GPT2Int8, GPT2Lora, GPT2LoraInt8 -from .gptj import GPTJ, GPTJInt8, GPTJLora, GPTJLoraInt8 -from .llama import Llama, LlamaInt8, LlamaLora, LlamaLoraInt4, LlamaLoraInt8 -from .opt import OPT, OPTInt8, OPTLora, OPTLoraInt8 -from .stable_diffusion import StableDiffusion +from xturing.models.gpt2 import GPT2, GPT2Int8, GPT2Lora, GPT2LoraInt8 +from xturing.models.gptj import GPTJ, GPTJInt8, GPTJLora, GPTJLoraInt8 +from xturing.models.llama import ( + Llama, + LlamaInt8, + LlamaLora, + LlamaLoraInt4, + LlamaLoraInt8, +) +from xturing.models.opt import OPT, OPTInt8, OPTLora, OPTLoraInt8 +from xturing.models.stable_diffusion import StableDiffusion BaseModel.add_to_registry(DistilGPT2.config_name, DistilGPT2) BaseModel.add_to_registry(DistilGPT2Lora.config_name, DistilGPT2Lora) diff --git a/src/xturing/models/distilgpt2.py b/src/xturing/models/distilgpt2.py index 5a1a356..a0d8c38 100644 --- a/src/xturing/models/distilgpt2.py +++ b/src/xturing/models/distilgpt2.py @@ -1,8 +1,7 @@ from typing import Optional from xturing.engines.distilgpt2_engine import DistilGPT2Engine, DistilGPT2LoraEngine - -from .causal import CausalLoraModel, CausalModel +from xturing.models.causal import CausalLoraModel, CausalModel class DistilGPT2(CausalModel): diff --git a/src/xturing/models/gpt2.py b/src/xturing/models/gpt2.py index 4f7c753..2aa99b4 100644 --- a/src/xturing/models/gpt2.py +++ b/src/xturing/models/gpt2.py @@ -6,8 +6,12 @@ GPT2LoraEngine, GPT2LoraInt8Engine, ) - -from .causal import CausalInt8Model, CausalLoraInt8Model, CausalLoraModel, CausalModel +from xturing.models.causal import ( + CausalInt8Model, + CausalLoraInt8Model, + CausalLoraModel, + CausalModel, +) class GPT2(CausalModel): diff --git a/src/xturing/preprocessors/__init__.py b/src/xturing/preprocessors/__init__.py index 473b519..d978454 100644 --- a/src/xturing/preprocessors/__init__.py +++ b/src/xturing/preprocessors/__init__.py @@ -1,3 +1,3 @@ -from .base import BasePreprocessor -from .instruction_collator import InstructionDataCollator -from .text_collator import TextDataCollator +from xturing.preprocessors.base import BasePreprocessor +from xturing.preprocessors.instruction_collator import InstructionDataCollator +from xturing.preprocessors.text_collator import TextDataCollator diff --git a/src/xturing/trainers/__init__.py b/src/xturing/trainers/__init__.py index a36f345..9ccb5e9 100644 --- a/src/xturing/trainers/__init__.py +++ b/src/xturing/trainers/__init__.py @@ -1,2 +1,2 @@ -from .base import BaseTrainer -from .lightning_trainer import LightningTrainer +from xturing.trainers.base import BaseTrainer +from xturing.trainers.lightning_trainer import LightningTrainer From 6c3a0a540a3c3381737bfcbf58f967714dab8d4c Mon Sep 17 00:00:00 2001 From: Tushar Date: Mon, 10 Jul 2023 09:01:57 +0000 Subject: [PATCH 06/18] feat: evaluation of causal models Added the flow for evaluting causal models on any dataset. Testing in progress --- src/xturing/evalutaion/api.py | 267 +++++++++++++++ src/xturing/evalutaion/base.py | 150 +++++++++ src/xturing/evalutaion/data.py | 195 +++++++++++ src/xturing/evalutaion/eval.py | 151 +++++++++ src/xturing/evalutaion/evaluate.py | 197 ++++++++++++ src/xturing/evalutaion/match.py | 42 +++ src/xturing/evalutaion/metrics.py | 76 +++++ src/xturing/evalutaion/models.py | 268 +++++++++++++++ src/xturing/evalutaion/prompt.py | 119 +++++++ src/xturing/evalutaion/record.py | 501 +++++++++++++++++++++++++++++ src/xturing/evalutaion/registry.py | 179 +++++++++++ src/xturing/evalutaion/utils.py | 129 ++++++++ src/xturing/models/causal.py | 94 +++++- src/xturing/utils/metrics.py | 54 ++++ src/xturing/utils/prompt.py | 65 ++++ src/xturing/utils/utils.py | 16 + 16 files changed, 2497 insertions(+), 6 deletions(-) create mode 100644 src/xturing/evalutaion/api.py create mode 100644 src/xturing/evalutaion/base.py create mode 100644 src/xturing/evalutaion/data.py create mode 100644 src/xturing/evalutaion/eval.py create mode 100644 src/xturing/evalutaion/evaluate.py create mode 100644 src/xturing/evalutaion/match.py create mode 100644 src/xturing/evalutaion/metrics.py create mode 100644 src/xturing/evalutaion/models.py create mode 100644 src/xturing/evalutaion/prompt.py create mode 100644 src/xturing/evalutaion/record.py create mode 100644 src/xturing/evalutaion/registry.py create mode 100644 src/xturing/evalutaion/utils.py create mode 100644 src/xturing/utils/metrics.py create mode 100644 src/xturing/utils/prompt.py diff --git a/src/xturing/evalutaion/api.py b/src/xturing/evalutaion/api.py new file mode 100644 index 0000000..dd24c66 --- /dev/null +++ b/src/xturing/evalutaion/api.py @@ -0,0 +1,267 @@ +""" +This file provides common interfaces and utilities used by eval creators to +sample from models and process the results. +""" + +import logging +from typing import Callable, Dict, List, Optional, Tuple, Union + +from .base import ModelSpec +from .prompt import ( + ChatCompletionPrompt, + CompletionPrompt, + OpenAICreateChatPrompt, + OpenAICreatePrompt, + Prompt, +) +from .record import record_match, record_sampling +from .utils import ( + openai_chat_completion_create_retrying, + openai_completion_create_retrying, +) + +logger = logging.getLogger(__name__) + + +def completion_query( + model_spec: ModelSpec, + prompt: Union[OpenAICreatePrompt, OpenAICreateChatPrompt, Prompt], + **kwargs, +) -> Tuple[dict, Union[OpenAICreatePrompt, OpenAICreateChatPrompt], dict]: + """ + Query the API for a completion. + + ARGS + ==== + `model_spec`: `ModelSpec` containing model details to use in the query. + This should be the dict returned by `registry.get_model()`. + If `model_spec` is not provided, we use the default model that was + intialized at the beginning of the run. + `prompt`: Either a `Prompt` object or a raw prompt that will get wrapped in + the approriate `Prompt` class. + `kwargs`: Other arguments passed to the API. + + RETURNS + ======= + The result of the API call. + The prompt that was fed into the API call as a str. + A dict containing metadata about the query. + """ + if not isinstance(prompt, Prompt): + assert ( + isinstance(prompt, str) + or ( + isinstance(prompt, list) + and all(isinstance(token, int) for token in prompt) + ) + or ( + isinstance(prompt, list) + and all(isinstance(token, str) for token in prompt) + ) + or ( + isinstance(prompt, list) + and all(isinstance(msg, dict) for msg in prompt) + ) + ), f"Got type {type(prompt)}, with val {type(prompt[0])} for prompt, expected str or list[int] or list[str] or list[dict[str, str]]" + + if model_spec.is_chat: + prompt = ChatCompletionPrompt( + raw_prompt=prompt, + ) + else: + prompt = CompletionPrompt( + raw_prompt=prompt, + ) + + openai_create_prompt: Union[ + OpenAICreatePrompt, OpenAICreateChatPrompt + ] = prompt.to_openai_create_prompt() + + if model_spec.is_chat: + result = openai_chat_completion_create_retrying( + model=model_spec.model, + api_key=model_spec.api_key, + messages=openai_create_prompt, + **{**kwargs, **model_spec.extra_options}, + ) + else: + result = openai_completion_create_retrying( + model=model_spec.model, + api_key=model_spec.api_key, + prompt=openai_create_prompt, + **{**kwargs, **model_spec.extra_options}, + ) + + metadata = {} + if result: + metadata["completion_id"] = result.get("id", None) + metadata["model"] = result.get("model", None) + + if model_spec.is_chat: + for choice in result["choices"]: + choice["text"] = choice["message"]["content"] + + return result, openai_create_prompt, metadata + + +def check_sampled_text( + model_spec: ModelSpec, + prompt: Union[OpenAICreatePrompt, OpenAICreateChatPrompt, Prompt], + expected: Union[str, List[str], Tuple[str]], + *, + options: Optional[List[str]] = None, + separator: Callable[[str], bool] = None, +) -> Optional[str]: + """ + Generates a completion using the prompt, checks whether the completion is + one of the expected completions, and then records the result. + + ARGS + ==== + `model_spec`: See `completion_query`. + `prompt`: See `completion_query`. + `options`: The list of canonical options, defaults to `expected` if None. + The completion will be converted to one of these options. + `expected`: The desired completion or the list of desired completions. + `separator`: A callable which check the character sampled after the option + to see if it is a valid separator. + + RETURNS + ======= + The option that was picked, i.e., matched the completion, or None. + """ + if isinstance(expected, tuple): + expected = list(expected) + elif not isinstance(expected, list): + expected = [expected] + if options is None: + options = expected + + result, actual_prompt, metadata = completion_query( + prompt=prompt, + temperature=0.0, + model_spec=model_spec, + ) + choice = result["choices"][0] + + sampled = choice["text"].strip() if model_spec.strip_completion else choice["text"] + + picked = None + for option in options: + if not sampled.startswith(option): + continue + if ( + separator is not None + and len(sampled) > len(option) + and not separator(sampled[len(option)]) + ): + continue + picked = option + break + + result = { + "prompt": actual_prompt, + "sampled": sampled, + "options": options, + "picked": picked, + } + match = picked in expected + result["expected"] = expected + result["match"] = match + result["metadata"] = metadata + record_sampling(**result) + record_match(match, expected=expected, picked=picked, sampled=sampled) + return picked + + +def sample_freeform( + model_spec: ModelSpec, + prompt: Union[OpenAICreatePrompt, OpenAICreateChatPrompt, Prompt], + *, + temperature: float = 1.0, + top_p: float = 0.9, + max_tokens: int = 512, + stop: Optional[str] = None, + n_samples: int = None, + return_logprobs: bool = False, + **kwargs, +) -> Union[str, List[str], dict]: + """ + Samples a freeform response from the specified model, records the sampling, + and returns the sampled text. + + ARGS + ==== + `model_spec`: See `completion_query`. + `prompt`: See `completion_query`. + `temperature`: Passed to `openai.Completion.create`. + `top_p`: Passed to `openai.Completion.create`. + `max_tokens`: Passed to `openai.Completion.create`. + `stop`: Passed to `openai.Completion.create`. + `n_samples`: The number of samples to generate (1 if None). + `return_logprobs`: If True, returns the tokens and corresponding logprobs + in addition to the sampled text. + `kwargs`: See `completion_query`. + + RETURNS + ======= + If `return_logprobs` is True, returns a dict with the sampled text, tokens, + and corresponding logprobs. If `n_samples` is None, the outer list is + removed from all values. + Otherwise, returns the sampled text, or a list of sampled texts if + `n_samples` is not None. + """ + response, actual_prompt, metadata = completion_query( + prompt=prompt, + temperature=temperature, + top_p=top_p, + max_tokens=max_tokens, + stop=stop, + n=(1 if n_samples is None else n_samples), + model_spec=model_spec, + headers={}, + **kwargs, + ) + sampled = [choice["text"] for choice in response["choices"]] + if n_samples is None: + sampled = sampled[0] + record_sampling(prompt=actual_prompt, sampled=sampled, metadata=metadata) + + if return_logprobs: + assert not model_spec.is_chat, "logprobs only works for non-chat models" + assert not kwargs.get("logprobs") is None + + def _maybe_tokens(logprobs: Optional[dict]) -> Optional[List[str]]: + return logprobs["tokens"] if logprobs is not None else None + + def _maybe_logprobs(logprobs: Optional[dict]) -> Optional[List[float]]: + return logprobs["token_logprobs"] if logprobs is not None else None + + def _maybe_top_logprobs( + logprobs: Optional[dict], + ) -> Optional[List[Dict[str, float]]]: + return ( + [dict(x) for x in logprobs["top_logprobs"]] + if logprobs is not None + else None + ) + + tokens = [_maybe_tokens(choice["logprobs"]) for choice in response["choices"]] + logprobs = [ + _maybe_logprobs(choice["logprobs"]) for choice in response["choices"] + ] + top_logprobs = [ + _maybe_top_logprobs(choice["logprobs"]) for choice in response["choices"] + ] + if n_samples is None: + tokens = tokens[0] + logprobs = logprobs[0] + top_logprobs = top_logprobs[0] + return { + "text": sampled, + "tokens": tokens, + "logprobs": logprobs, + "top_logprobs": top_logprobs, + } + + return sampled diff --git a/src/xturing/evalutaion/base.py b/src/xturing/evalutaion/base.py new file mode 100644 index 0000000..97bb60e --- /dev/null +++ b/src/xturing/evalutaion/base.py @@ -0,0 +1,150 @@ +""" +This file defines the base specifications for models, evals, and runs. Running +evals and most development work should not require familiarity with this file. +""" +import base64 +import datetime +import os +from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional, Sequence + +if TYPE_CHECKING: + from dataclasses import dataclass +else: + from pydantic.dataclasses import dataclass + + +@dataclass +# class ModelSpec: +# """ +# Specification for a model. +# """ + +# name: str +# model: Optional[str] = None + +# is_chat: bool = False + +# encoding: Optional[str] = None +# organization: Optional[str] = None +# api_key: Optional[str] = None +# extra_options: Optional[Mapping[str, Any]] = None +# headers: Optional[Mapping[str, Any]] = None +# strip_completion: bool = True +# n_ctx: Optional[int] = None +# format: Optional[str] = None +# key: Optional[str] = None +# group: Optional[str] = None + +# def __post_init__(self): +# if self.extra_options is None: +# self.extra_options = {} +# if self.headers is None: +# self.headers = {} + +# if self.model is None: +# raise ValueError(f"Must specify a model") + + +@dataclass +# class BaseEvalSpec: +# """ +# Specification for a base eval. +# """ + +# id: Optional[str] = None +# metrics: Optional[Sequence[str]] = None +# description: Optional[str] = None +# disclaimer: Optional[str] = None + +# """ +# True if higher values are better, False if lower values are better. +# This should really be part of a metric, but it's easier to put it here. +# """ +# higher_is_better: bool = True + +# key: Optional[str] = None +# group: Optional[str] = None + + +@dataclass +class EvalSpec: + """ + Specification for an eval. + """ + + cls: str + args: Optional[Dict[str, Any]] = None + key: Optional[str] = None + group: Optional[str] = None + + +@dataclass +class EvalSetSpec: + """ + Specification for an eval set. + """ + + evals: Sequence[str] + key: Optional[str] = None + group: Optional[str] = None + + +@dataclass +# class ModelSpecs: +# completions_: Optional[Sequence[ModelSpec]] = None +# embedding_: Optional[ModelSpec] = None +# ranking_: Optional[ModelSpec] = None + +# @property +# def embedding(self) -> ModelSpec: +# if self.embedding_ is None: +# raise ValueError("Embedding model was not specified") +# return self.embedding_ + +# @property +# def ranking(self) -> ModelSpec: +# if self.ranking_ is None: +# raise ValueError("Ranking model was not specified") +# return self.ranking_ + +# @property +# def completion(self) -> ModelSpec: +# if self.completions_ is None: +# raise ValueError("Completion model was not specified") +# return self.completions_[0] + +# @property +# def completions(self) -> Sequence[ModelSpec]: +# if self.completions_ is None: +# raise ValueError("Completion model was not specified") +# return self.completions_ + +# @property +# def names(self) -> Dict[str, Sequence[str]]: +# dict = {} +# if self.completions_ is not None: +# dict["completions"] = [model.name for model in self.completions_] +# if self.embedding_ is not None: +# dict["embedding"] = [self.embedding_.name] +# if self.ranking_ is not None: +# dict["ranking"] = [self.ranking_.name] +# return dict + + +@dataclass +class RunSpec: + model_name: str + model_names: Dict[str, Sequence[str]] + eval_name: str + base_eval: str + split: str + run_config: Dict[str, Any] + created_by: str + run_id: str = None + created_at: str = None + + def __post_init__(self): + now = datetime.datetime.utcnow() + rand_suffix = base64.b32encode(os.urandom(5)).decode("ascii") + self.run_id = now.strftime("%y%m%d%H%M%S") + rand_suffix + self.created_at = str(now) diff --git a/src/xturing/evalutaion/data.py b/src/xturing/evalutaion/data.py new file mode 100644 index 0000000..5def56a --- /dev/null +++ b/src/xturing/evalutaion/data.py @@ -0,0 +1,195 @@ +""" +This file defines utilities for working with data and files of various types. +""" +import csv +import dataclasses +import gzip +import itertools +import json +import logging +import os +import urllib + +# from collections.abc import Iterator +from functools import partial +from typing import Any, Dict, Iterator, List, Sequence, Union + +import blobfile as bf +import lz4.frame +import pydantic +import pyzstd + +logger = logging.getLogger(__name__) + + +def gzip_open(filename: str, mode: str = "rb", openhook: Any = open) -> gzip.GzipFile: + """Wrap the given openhook in gzip.""" + if mode and "b" not in mode: + mode += "b" + + return gzip.GzipFile(fileobj=openhook(filename, mode), mode=mode) + + +def lz4_open( + filename: str, mode: str = "rb", openhook: Any = open +) -> lz4.frame.LZ4FrameFile: + if mode and "b" not in mode: + mode += "b" + + return lz4.frame.LZ4FrameFile(openhook(filename, mode), mode=mode) + + +def zstd_open(filename: str, mode: str = "rb", openhook: Any = open) -> pyzstd.ZstdFile: + if mode and "b" not in mode: + mode += "b" + + return pyzstd.ZstdFile(openhook(filename, mode), mode=mode) + + +def open_by_file_pattern(filename: str, mode: str = "r", **kwargs: Any) -> Any: + """Can read/write to files on gcs/local with or without gzipping. If file + is stored on gcs, streams with blobfile. Otherwise use vanilla python open. If + filename endswith gz, then zip/unzip contents on the fly (note that gcs paths and + gzip are compatible)""" + open_fn = partial(bf.BlobFile, **kwargs) + try: + if filename.endswith(".gz"): + return gzip_open(filename, openhook=open_fn, mode=mode) + elif filename.endswith(".lz4"): + return lz4_open(filename, openhook=open_fn, mode=mode) + elif filename.endswith(".zst"): + return zstd_open(filename, openhook=open_fn, mode=mode) + else: + scheme = urllib.parse.urlparse(filename).scheme + if scheme == "" or scheme == "file": + return open_fn( + os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "registry", + "data", + filename, + ), + mode=mode, + ) + else: + return open_fn(filename, mode=mode) + except Exception as e: + raise RuntimeError(f"Failed to open: {filename}") from e + + +def _get_jsonl_file(path): + logger.info(f"Fetching {path}") + with open_by_file_pattern(path, mode="r") as f: + return list(map(json.loads, f.readlines())) + + +def _get_json_file(path): + logger.info(f"Fetching {path}") + with open_by_file_pattern(path, mode="r") as f: + return json.loads(f.read()) + + +def _stream_jsonl_file(path) -> Iterator: + logger.info(f"Streaming {path}") + with bf.BlobFile(path, "r", streaming=True) as f: + for line in f: + yield json.loads(line) + + +def get_lines(path) -> List[dict]: + """ + Get a list of lines from a file. + """ + with open_by_file_pattern(path, mode="r") as f: + return f.readlines() + + +def get_jsonl(path: str) -> List[dict]: + """ + Extract json lines from the given path. + If the path is a directory, look in subpaths recursively. + + Return all lines from all jsonl files as a single list. + """ + if bf.isdir(path): + result = [] + for filename in bf.listdir(path): + if filename.endswith(".jsonl"): + result += get_jsonl(os.path.join(path, filename)) + return result + return _get_jsonl_file(path) + + +def get_jsonls(paths: Sequence[str], line_limit=None) -> List[dict]: + return list(iter_jsonls(paths, line_limit)) + + +def get_json(path) -> dict: + if bf.isdir(path): + raise ValueError("Path is a directory, only files are supported") + return _get_json_file(path) + + +def iter_jsonls(paths: Union[str, List[str]], line_limit=None) -> Iterator[dict]: + """ + For each path in the input, iterate over the jsonl files in that path. + Look in subdirectories recursively. + + Use an iterator to conserve memory. + """ + if type(paths) == str: + paths = [paths] + + def _iter(): + for path in paths: + if bf.isdir(path): + for filename in bf.listdir(path): + if filename.endswith(".jsonl"): + yield from iter_jsonls([os.path.join(path, filename)]) + else: + yield from _stream_jsonl_file(path) + + return itertools.islice(_iter(), line_limit) + + +def get_csv(path, fieldnames=None): + with bf.BlobFile(path, "r", cache_dir="/tmp/bf_cache", streaming=False) as f: + reader = csv.DictReader(f, fieldnames=fieldnames) + return [row for row in reader] + + +def _to_py_types(o: Any) -> Any: + if isinstance(o, dict): + return {k: _to_py_types(v) for k, v in o.items()} + if isinstance(o, list): + return [_to_py_types(v) for v in o] + + if dataclasses.is_dataclass(o): + return _to_py_types(dataclasses.asdict(o)) + + # pydantic data classes + if isinstance(o, pydantic.BaseModel): + return json.loads(o.json()) + + return o + + +class EnhancedJSONEncoder(json.JSONEncoder): + def default(self, o: Any) -> str: + return _to_py_types(o) + + +def jsondumps(o: Any, ensure_ascii: bool = False, **kwargs: Any) -> str: + return json.dumps(o, cls=EnhancedJSONEncoder, ensure_ascii=ensure_ascii, **kwargs) + + +def jsondump(o: Any, fp: Any, ensure_ascii: bool = False, **kwargs: Any) -> None: + json.dump(o, fp, cls=EnhancedJSONEncoder, ensure_ascii=ensure_ascii, **kwargs) + + +def jsonloads(s: str, **kwargs: Any) -> Any: + return json.loads(s, **kwargs) + + +def jsonload(fp: Any, **kwargs: Any) -> Any: + return json.load(fp, **kwargs) diff --git a/src/xturing/evalutaion/eval.py b/src/xturing/evalutaion/eval.py new file mode 100644 index 0000000..7c46e1b --- /dev/null +++ b/src/xturing/evalutaion/eval.py @@ -0,0 +1,151 @@ +""" +This file defines the base class for evals. +""" +import abc +import asyncio +import concurrent.futures +import logging +import os +import random +from multiprocessing.pool import ThreadPool +from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple + +from tqdm import tqdm + +from .base import ModelSpec, ModelSpecs +from .record import RecorderBase +from .registry import Registry + +logger = logging.getLogger(__name__) + + +SHUFFLE_SEED = 123 +_MAX_SAMPLES = None + + +def _index_samples(samples: List[Any]) -> List[Tuple[Any, int]]: + """Shuffle `samples` and pair each sample with its index.""" + indices = list(range(len(samples))) + random.Random(SHUFFLE_SEED).shuffle(indices) + if _MAX_SAMPLES is not None: + indices = indices[:_MAX_SAMPLES] + logger.info(f"Evaluating {len(indices)} samples") + work_items = [(samples[i], i) for i in indices] + return work_items + + +def set_max_samples(max_samples: int): + global _MAX_SAMPLES + _MAX_SAMPLES = max_samples + + +class Eval(abc.ABC): + """ + Evaluation classes generally should override two methods: + `eval_sample`: Takes in a test sample and a random number generator and + records the metrics of interest. + `run`: Takes in a recorder and runs the evaluation. Generally, most `run` + methods will follow this same pattern: loading the data, calling + `eval_all_samples`, and aggregating the recorded results. + """ + + def __init__( + self, + model_specs, + seed: int = 20220722, + name: str = "no_name_eval.default", + registry: Optional[Registry] = None, + ): + splits = name.split(".") + if len(splits) < 2: + raise ValueError( + f"Eval name must at least have .. Got name {name}" + ) + + self.model_specs = model_specs + self.seed = seed + self.name = name + self.registry = registry or Registry() + + def eval_sample(self, sample: Any, rng: random.Random): + raise NotImplementedError() + + @abc.abstractmethod + def run(self, recorder: RecorderBase) -> Dict[str, float]: + """Run the evaluation with the corresponding recorder.""" + raise NotImplementedError() + + async def async_eval_all_samples( + self, + eval_fn: Callable[[Tuple[Any, int]], Awaitable[Tuple[int, Any]]], + samples: List[Any], + concurrency: int = 32, + show_progress: bool = True, + ): + work_items = _index_samples(samples) + semaphore = asyncio.Semaphore(concurrency) + + async def eval_fn_with_semaphore(args): + async with semaphore: + return await eval_fn(args) + + futures = [ + asyncio.ensure_future(eval_fn_with_semaphore(args)) for args in work_items + ] + + for future in tqdm( + asyncio.as_completed(futures), total=len(samples), disable=not show_progress + ): + await future + + def eval_all_samples( + self, + recorder: RecorderBase, + samples, + show_progress=True, + ): + """ + Evaluate all provided samples in parallel. + """ + work_items = _index_samples(samples) + threads = int(os.environ.get("EVALS_THREADS", "10")) + show_progress = bool(os.environ.get("EVALS_SHOW_EVAL_PROGRESS", show_progress)) + timeout = float(os.environ.get("EVALS_THREAD_TIMEOUT", "40")) + + def eval_sample(args): + """ + Evaluate a single sample. + """ + sample, idx = args + base_name, split = self.name.split(".")[0:2] + sample_id = f"{base_name}.{split}.{idx}" + with recorder.as_default_recorder(sample_id): + recorder.record_raw(sample) + seed = f"{sample_id}:{self.seed}".encode("utf-8") + rng = random.Random(seed) + return idx, self.eval_sample(sample, rng) + + def worker_thread(args): + """ + Worker thread for evaluating a single sample. + """ + while True: + executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) + future = executor.submit(eval_sample, args=args) + try: + result = future.result(timeout=timeout) + return result + except concurrent.futures.TimeoutError as e: + executor.shutdown(wait=False) + + with ThreadPool(threads) as pool: + if os.environ.get("EVALS_SEQUENTIAL", "0") in {"1", "true", "yes"}: + logger.info(f"Running in sequential mode!") + iter = map(eval_sample, work_items) + else: + logger.info(f"Running in threaded mode with {threads} threads!") + iter = pool.imap_unordered(worker_thread, work_items) + idx_and_result = list( + tqdm(iter, total=len(work_items), disable=not show_progress) + ) + return [r for _, r in sorted(idx_and_result)] diff --git a/src/xturing/evalutaion/evaluate.py b/src/xturing/evalutaion/evaluate.py new file mode 100644 index 0000000..4b27253 --- /dev/null +++ b/src/xturing/evalutaion/evaluate.py @@ -0,0 +1,197 @@ +import argparse +import logging +import shlex +import sys +from functools import cached_property +from typing import Any, Mapping, Optional + +import openai + +from .base import EvalSpec, RunSpec +from .record import DummyRecorder, LocalRecorder, Recorder +from .registry import Registry, registry + +logger = logging.getLogger(__name__) + + +def _purple(str): + return f"\033[1;35m{str}\033[0m" + + +def run_evaluation(args): + if args.debug: + logging.getLogger().setLevel(logging.DEBUG) + + model = args.model + + run_config = { + "model": model, + "eval": args.eval_spec, + "seed": args.seed, + } + + model_name = model.config_name + eval_name = args.eval_spec.key + + run_spec = RunSpec( + model_name=model_name, + eval_name=eval_name, + base_eval=eval_name.split(".")[0], + split=eval_name.split(".")[1], + run_config=run_config, + created_by=args.user, + run_id="something", + ) + if args.record_path is None: + record_path = f"/tmp/evallogs/{run_spec.run_id}_{args.model}_{args.eval}.jsonl" + else: + record_path = args.record_path + + # Recording progress + if args.dry_run: + recorder = DummyRecorder(run_spec=run_spec, log=args.dry_run_logging) + elif args.local_run: + recorder = LocalRecorder(record_path, run_spec=run_spec) + else: + recorder = Recorder(record_path, run_spec=run_spec) + + api_extra_options = {} + if not args.cache: + api_extra_options["cache_level"] = 0 + + run_url = f"{run_spec.run_id}" + logger.info(_purple(f"Run started: {run_url}")) + + def parse_extra_eval_params(param_str: Optional[str]) -> Mapping[str, Any]: + """Parse a string of the form "key1=value1,key2=value2" into a dict.""" + if not param_str: + return {} + + def to_number(x): + try: + return int(x) + except: + pass + try: + return float(x) + except: + pass + return x + + str_dict = dict(kv.split("=") for kv in param_str.split(",")) + return {k: to_number(v) for k, v in str_dict.items()} + + extra_eval_params = parse_extra_eval_params(args.extra_eval_params) + + eval_class = registry.get_class(args.eval_spec) + eval = eval_class( + model_specs=model, + seed=args.seed, + name=eval_name, + registry=registry, + **extra_eval_params, + ) + result = eval.run(recorder) + recorder.record_final_report(result) + + if not (args.dry_run or args.local_run): + logger.info(_purple(f"Run completed: {run_url}")) + + logger.info("Final report:") + for key, value in result.items(): + logger.info(f"{key}: {value}") + return run_spec.run_id + + +def evaluate( + model: str, + eval: str, + embedding_model: str = "", + ranking_model: str = "", + extra_eval_params: str = "", + max_samples: Optional[int] = None, + cache: bool = True, + visible: Optional[bool] = None, + seed: int = 20220722, + user: str = "", + record_path: Optional[str] = None, + log_to_file: Optional[str] = None, + debug: bool = False, + local_run: bool = True, + dry_run: bool = False, + dry_run_logging: bool = True, +) -> Any: + parser = argparse.ArgumentParser(description="Run evals through the API") + parser.add_argument("model", type=str, help="Name of a completion model.") + parser.add_argument("eval", type=str, help="Name of an eval. See registry.") + parser.add_argument("--embedding_model", type=str, default="") + parser.add_argument("--ranking_model", type=str, default="") + parser.add_argument("--extra_eval_params", type=str, default="") + parser.add_argument("--max_samples", type=int, default=None) + parser.add_argument("--cache", action=argparse.BooleanOptionalAction, default=True) + parser.add_argument( + "--visible", action=argparse.BooleanOptionalAction, default=None + ) + parser.add_argument("--seed", type=int, default=20220722) + parser.add_argument("--user", type=str, default="") + parser.add_argument("--record_path", type=str, default=None) + parser.add_argument( + "--log_to_file", type=str, default=None, help="Log to a file instead of stdout" + ) + parser.add_argument("--debug", action=argparse.BooleanOptionalAction, default=False) + parser.add_argument( + "--local-run", action=argparse.BooleanOptionalAction, default=True + ) + parser.add_argument( + "--dry-run", action=argparse.BooleanOptionalAction, default=False + ) + parser.add_argument( + "--dry-run-logging", action=argparse.BooleanOptionalAction, default=True + ) + + args = argparse.Namespace( + model=model, + eval=eval, + embedding_model=embedding_model, + ranking_model=ranking_model, + extra_eval_params=extra_eval_params, + max_samples=max_samples, + cache=cache, + visible=visible, + seed=seed, + user=user, + record_path=record_path, + log_to_file=log_to_file, + debug=debug, + local_run=local_run, + dry_run=dry_run, + dry_run_logging=dry_run_logging, + ) + + # args_parsed = parser.parse_args() + + # Running evaluation code + logging.basicConfig( + format="[%(asctime)s] [%(filename)s:%(lineno)d] %(message)s", + level=logging.INFO, + filename=args.log_to_file if args.log_to_file else None, + ) + + logging.getLogger("openai").setLevel(logging.WARN) + if hasattr(openai.error, "set_display_cause"): + openai.error.set_display_cause() + + run_evaluation(args) + + +#################################### +# EXAMPLE USAGE: + +# evaluate( +# model_name="davinci", +# eval="test", +# embedding_model="", +# ranking_model="", +# extra_eval_params="", +# max_samples=None, +# ) diff --git a/src/xturing/evalutaion/match.py b/src/xturing/evalutaion/match.py new file mode 100644 index 0000000..7dda0a1 --- /dev/null +++ b/src/xturing/evalutaion/match.py @@ -0,0 +1,42 @@ +from typing import Any + +from .data import get_jsonl +from .eval import Eval +from .metrics import get_accuracy +from .models import check_sampled_text +from .prompt import is_chat_prompt + + +class Match(Eval): + def __init__( + self, + model_specs, + samples_jsonl: str, + *args, + max_tokens: int = 500, + num_few_shot: int = 0, + few_shot_jsonl: str = None, + **kwargs, + ): + super().__init__(model_specs, *args, **kwargs) + self.max_tokens = max_tokens + self.samples_jsonl = samples_jsonl + + def eval_sample(self, sample: Any, *_): + prompt = sample["input"] + if self.num_few_shot > 0: + assert is_chat_prompt(sample["input"]), "few shot requires chat prompt" + prompt = sample["input"][:-1] + for s in self.few_shot[: self.num_few_shot]: + prompt += s["sample"] + prompt += sample["input"][-1:] + + return check_sampled_text(self.model_spec, prompt, expected=sample["ideal"]) + + def run(self, recorder): + samples = get_jsonl(self.samples_jsonl) + self.eval_all_samples(recorder, samples) + events = recorder.get_events("match") + return { + "accuracy": get_accuracy(events), + } diff --git a/src/xturing/evalutaion/metrics.py b/src/xturing/evalutaion/metrics.py new file mode 100644 index 0000000..6f46144 --- /dev/null +++ b/src/xturing/evalutaion/metrics.py @@ -0,0 +1,76 @@ +""" +This file defines various common metrics of interest. +""" +import random +from typing import Optional, Sequence, Set + +import numpy as np + +from .record import Event + + +def get_accuracy(events: Sequence[Event]) -> float: + num_correct = 0 + num_total = 0 + for event in events: + num_total += 1 + num_correct += int(event.data["correct"]) + if num_total == 0: + return float("nan") + else: + return num_correct / num_total + + +def get_bootstrap_accuracy_std(events: Sequence[Event], num_samples: int = 1000): + vals = [m.data["correct"] for m in events] + return np.std([np.mean(random.sample(vals, len(vals) // 2)) for _ in range(1000)]) + + +def get_confusion_matrix( + matches: Sequence[Event], class_labels: Optional[Set] = None +) -> np.ndarray: + labels = set() + for match in matches: + labels.add(match.data["expected"]) + if class_labels is None: + labels = {label: i for i, label in enumerate(sorted(labels))} + else: + assert labels.issubset(class_labels) + labels = {label: i for i, label in enumerate(class_labels)} + result = np.zeros((len(labels), len(labels) + 1), dtype=int) + for match in matches: + i = labels[match.data["expected"]] + j = labels.get(match.data["picked"], len(labels)) + result[i, j] += 1 + return result + + +def compute_matthew_corr(confusion_matrix): + assert confusion_matrix.shape == (2, 3), f"Got shape: {confusion_matrix.shape}" + r = confusion_matrix[:, :2] + r[:, 0] += confusion_matrix[:, 2] + return (r[1, 1] * r[0, 0] - r[1, 0] * r[0, 1]) / np.sqrt( + r[1, :].sum() * r[0, :].sum() * r[:, 0].sum() * r[:, 1].sum() + ) + + +def compute_precision(confusion_matrix, idx=0): + return confusion_matrix[idx, idx] / confusion_matrix[:, idx].sum() + + +def compute_recall(confusion_matrix, idx=0): + return confusion_matrix[idx, idx] / confusion_matrix[idx, :].sum() + + +def compute_f_score(confusion_matrix, idx=0, beta=1.0): + precision = compute_precision(confusion_matrix, idx=idx) + recall = compute_recall(confusion_matrix, idx=idx) + return (1 + beta**2) * (precision * recall) / (beta**2 * precision + recall) + + +def compute_averaged_f_score(confusion_matrix, beta=1.0, average="macro"): + assert average in ["macro"] + f_scores = [] + for i in range(confusion_matrix.shape[0]): + f_scores.append(compute_f_score(confusion_matrix, idx=i, beta=beta)) + return np.array(f_scores).mean() diff --git a/src/xturing/evalutaion/models.py b/src/xturing/evalutaion/models.py new file mode 100644 index 0000000..d17f75d --- /dev/null +++ b/src/xturing/evalutaion/models.py @@ -0,0 +1,268 @@ +# import os + +# from dotenv import load_dotenv, find_dotenv + +""" +This file provides common interfaces and utilities used by eval creators to +sample from models and process the results. +""" + +import logging +import os +from typing import Callable, List, Optional, Tuple, Union + +from transformers import AutoModelForCausalLM, AutoTokenizer + +from xturing.datasets.instruction_dataset import InstructionDataset +from xturing.models.base import BaseModel + +from .base import ModelSpec +from .prompt import ( + ChatCompletionPrompt, + CompletionPrompt, + OpenAICreateChatPrompt, + OpenAICreatePrompt, + Prompt, +) +from .record import record_match, record_sampling + +logger = logging.getLogger(__name__) + +# # load openai key +# load_dotenv(find_dotenv()) +# OPENAI_KEY = os.environ["OPENAI_KEY"] + +# HELPER FUNCTIONS + + +def chat_prompt_to_text(prompt): + if type(prompt) == str: + return prompt + else: + return " ".join([message["content"] for message in prompt]) + + +def load_model(model_name): + if not os.path.exists(f"./{model_name}"): + print(f"LOADING MODEL: {model_name}") + model = BaseModel.create(model_name) + model.save(f"./{model_name}") + + return BaseModel.load(f"./{model_name}") + + +def completion_query( + model_spec: ModelSpec, + prompt: Union[OpenAICreatePrompt, OpenAICreateChatPrompt, Prompt], + **kwargs, +) -> Tuple[dict, Union[OpenAICreatePrompt, OpenAICreateChatPrompt], dict]: + """ + Query the API for a completion. + + ARGS + ==== + `model_spec`: `ModelSpec` containing model details to use in the query. + This should be the dict returned by `registry.get_model()`. + If `model_spec` is not provided, we use the default model that was + intialized at the beginning of the run. + `prompt`: Either a `Prompt` object or a raw prompt that will get wrapped in + the approriate `Prompt` class. + `kwargs`: Other arguments passed to the API. + + RETURNS + ======= + The result of the API call. + The prompt that was fed into the API call as a str. + A dict containing metadata about the query. + """ + + # parse prompt + + # Initialize model + # TODO: pass kwargs to model! + + # model = AutoModelForCausalLM.from_pretrained(model_spec.name) + + # huggingface_models = ["gpt2"] + + # if model_spec.name in huggingface_models: + # model = AutoModelForCausalLM.from_pretrained(model_spec.name) + # else: + # model = BaseModel.load(model_spec.name) + # tokenizer = AutoTokenizer.from_pretrained(model_spec.name, return_tensors="pt") + + # TODO: is concatenating the contents a good solution to transform chat-style inputs to one string? + + # inputs = tokenizer(actual_prompt, return_tensors="pt").input_ids + + # Run completion + # outputs = model.generate( + # input_ids=inputs, return_dict_in_generate=True, output_scores=True, **kwargs + # ) + + actual_prompt = chat_prompt_to_text(prompt) + + # TODO add config + + model = load_model(model_spec.name) + + text_out = model.generate(texts=[actual_prompt]) + + # parse results + result = { + "text": text_out, + "tokens": None, + "logprobs": None, + } + # TODO: change metadata based on model + metadata = {"model": model_spec.name} + + return result, actual_prompt, metadata + + +def check_sampled_text( + model_spec: ModelSpec, + prompt: Union[OpenAICreatePrompt, OpenAICreateChatPrompt, Prompt], + expected: Union[str, List[str], Tuple[str]], + *, + options: Optional[List[str]] = None, + separator: Callable[[str], bool] = None, +) -> Optional[str]: + """ + Generates a completion using the prompt, checks whether the completion is + one of the expected completions, and then records the result. + + ARGS + ==== + `model_spec`: See `completion_query`. + `prompt`: See `completion_query`. + `options`: The list of canonical options, defaults to `expected` if None. + The completion will be converted to one of these options. + `expected`: The desired completion or the list of desired completions. + `separator`: A callable which check the character sampled after the option + to see if it is a valid separator. + + RETURNS + ======= + The option that was picked, i.e., matched the completion, or None. + """ + if isinstance(expected, tuple): + expected = list(expected) + elif not isinstance(expected, list): + expected = [expected] + if options is None: + options = expected + + result, actual_prompt, metadata = completion_query( + prompt=prompt, + model_spec=model_spec, + ) + + choice = result["text"][0] + + # TODO: check what result is supposed to look like [from OPENAI API] + sampled = choice.strip() if model_spec.strip_completion else choice + + picked = None + for option in options: + if not sampled.startswith(option): + continue + if ( + separator is not None + and len(sampled) > len(option) + and not separator(sampled[len(option)]) + ): + continue + picked = option + break + + result = { + "prompt": actual_prompt, + "sampled": sampled, + "options": options, + "picked": picked, + } + match = picked in expected + result["expected"] = expected + result["match"] = match + result["metadata"] = metadata + print("result", result) + record_sampling(**result) + record_match(match, expected=expected, picked=picked, sampled=sampled) + return picked + + +def sample_freeform( + model_spec: ModelSpec, + prompt: Union[OpenAICreatePrompt, OpenAICreateChatPrompt, Prompt], + *, + temperature: float = 1.0, + top_p: float = 0.9, + max_tokens: int = 512, + stop: Optional[str] = None, + n_samples: int = None, + return_logprobs: bool = False, + **kwargs, +) -> Union[str, List[str], dict]: + """ + Samples a freeform response from the specified model, records the sampling, + and returns the sampled text. + + ARGS + ==== + `model_spec`: See `completion_query`. + `prompt`: See `completion_query`. + `temperature`: Passed to `openai.Completion.create`. + `top_p`: Passed to `openai.Completion.create`. + `max_tokens`: Passed to `openai.Completion.create`. + `stop`: Passed to `openai.Completion.create`. + `n_samples`: The number of samples to generate (1 if None). + `return_logprobs`: If True, returns the tokens and corresponding logprobs + in addition to the sampled text. + `kwargs`: See `completion_query`. + + RETURNS + ======= + If `return_logprobs` is True, returns a dict with the sampled text, tokens, + and corresponding logprobs. If `n_samples` is None, the outer list is + removed from all values. + Otherwise, returns the sampled text, or a list of sampled texts if + `n_samples` is not None. + """ + + # TODO: add kwargs to completion query (see api.py for reference) + result, actual_prompt, metadata = completion_query( + prompt=prompt, + model_spec=model_spec, + do_sample=True, + num_return_sequences=n_samples if n_samples else 1, + max_new_tokens=max_tokens, + top_p=top_p, + ) + + if n_samples is None: + sampled = result["text"][0] + else: + sampled = result["text"] + + record_sampling(prompt=actual_prompt, sampled=sampled, metadata=metadata) + + if return_logprobs: + # assert not model_spec.is_chat, "logprobs only works for non-chat models" + # assert not kwargs.get("logprobs") is None + + tokens = result["tokens"] + logprobs = result["logprobs"] + top_logprobs = logprobs # TODO: check how to get top logprobs, for now I return all logprobs + if n_samples is None: + tokens = tokens[0] + logprobs = logprobs[0] + top_logprobs = top_logprobs[0] + return { + "text": sampled, + "tokens": tokens, + "logprobs": logprobs, + "top_logprobs": top_logprobs, + } + + return sampled diff --git a/src/xturing/evalutaion/prompt.py b/src/xturing/evalutaion/prompt.py new file mode 100644 index 0000000..438ebdb --- /dev/null +++ b/src/xturing/evalutaion/prompt.py @@ -0,0 +1,119 @@ +""" +This file defines the classes for how to manage prompts for different types of +models, i.e., "chat models" vs. "non chat models". +""" +import logging +import threading +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Dict, List, Union + +logger = logging.getLogger(__name__) +ENCODER_LOCK = threading.Lock() + +# This is an approximation to the type accepted as the `prompt` field to `openai.Completion.create` calls +OpenAICreatePrompt = Union[str, List[str], List[int], List[List[int]]] + +# This is the type accepted as the `prompt` field to `openai.ChatCompletion.create` calls +OpenAIChatMessage = Dict[ + str, str +] # A message is a dictionary with "role" and "content" keys +OpenAICreateChatPrompt = List[OpenAIChatMessage] # A chat log is a list of messages + + +def chat_prompt_to_text_prompt(prompt: OpenAICreateChatPrompt) -> str: + """ + Render a chat prompt as a text prompt. User and assistant messages are separated by newlines + and prefixed with "User: " and "Assistant: ", respectively, unless there is only one message. + System messages have no prefix. + """ + assert is_chat_prompt(prompt), f"Expected a chat prompt, got {prompt}" + chat_to_prefixes = { + # roles + "system": "", + # names + "example_user": "User: ", + "example_assistant": "Assistant: ", + } + + # For a single message, be it system, user, or assistant, just return the message + if len(prompt) == 1: + return prompt[0]["content"] + + text = "" + for msg in prompt: + role = msg["name"] if "name" in msg else msg["role"] + prefix = chat_to_prefixes.get(role, role.capitalize() + ": ") + content = msg["content"] + text += f"{prefix}{content}\n" + text += "Assistant: " + return text.lstrip() + + +def text_prompt_to_chat_prompt(prompt: str) -> OpenAICreateChatPrompt: + assert isinstance(prompt, str), f"Expected a text prompt, got {prompt}" + return [ + {"role": "system", "content": prompt}, + ] + + +@dataclass +class Prompt(ABC): + """ + A `Prompt` encapsulates everything required to present the `raw_prompt` in different formats, + e.g., a normal unadorned format vs. a chat format. + """ + + @abstractmethod + def to_openai_create_prompt(self): + """ + Return the actual data to be passed as the `prompt` field to either `openai.ChatCompletion.create`, + if the model is a chat model, or `openai.Completion.create` otherwise. + See the above types to see what each API call is able to handle. + """ + + +def is_chat_prompt(prompt: Prompt) -> bool: + return isinstance(prompt, list) and all(isinstance(msg, dict) for msg in prompt) + + +@dataclass +class CompletionPrompt(Prompt): + """ + A `Prompt` object that wraps prompts to be compatible with non chat models, which use `openai.Completion.create`. + """ + + raw_prompt: Union[OpenAICreatePrompt, OpenAICreateChatPrompt] + + def _render_chat_prompt_as_text( + self, prompt: OpenAICreateChatPrompt + ) -> OpenAICreatePrompt: + return chat_prompt_to_text_prompt(prompt) + + def to_openai_create_prompt(self) -> OpenAICreatePrompt: + if is_chat_prompt(self.raw_prompt): + return self._render_chat_prompt_as_text(self.raw_prompt) + return self.raw_prompt + + +@dataclass +class ChatCompletionPrompt(Prompt): + """ + A `Prompt` object that wraps prompts to be compatible with chat models, which use `openai.ChatCompletion.create`. + + The format expected by chat models is a list of messages, where each message is a dict with "role" and "content" keys. + """ + + raw_prompt: Union[OpenAICreatePrompt, OpenAICreateChatPrompt] + + def _render_text_as_chat_prompt(self, prompt: str) -> OpenAICreateChatPrompt: + """ + Render a text string as a chat prompt. The default option we adopt here is to simply take the full prompt + and treat it as a system message. + """ + return text_prompt_to_chat_prompt(prompt) + + def to_openai_create_prompt(self) -> OpenAICreateChatPrompt: + if is_chat_prompt(self.raw_prompt): + return self.raw_prompt + return self._render_text_as_chat_prompt(self.raw_prompt) diff --git a/src/xturing/evalutaion/record.py b/src/xturing/evalutaion/record.py new file mode 100644 index 0000000..8a07e74 --- /dev/null +++ b/src/xturing/evalutaion/record.py @@ -0,0 +1,501 @@ +""" +This file defines the recorder classes which log eval results in different ways, +such as to a local JSON file or to a remote Snowflake database. + +If you would like to implement a custom recorder, you can see how the +`LocalRecorder` and `Recorder` classes inherit from the `RecorderBase` class and +override certain methods. +""" +import atexit +import contextlib +import dataclasses +import logging +import threading +import time +from contextvars import ContextVar +from datetime import datetime, timezone +from typing import Any, List, Optional, Sequence + +import blobfile as bf +import evals +from evals.base import RunSpec +from evals.data import jsondumps +from evals.utils.misc import t +from evals.utils.snowflake import SnowflakeConnection + +logger = logging.getLogger(__name__) + +MIN_FLUSH_EVENTS = 100 +MAX_SNOWFLAKE_BYTES = 16 * 10**6 +MIN_FLUSH_SECONDS = 10 + +_default_recorder: ContextVar[Optional["RecorderBase"]] = ContextVar( + "default_recorder", default=None +) + + +def default_recorder() -> Optional["RecorderBase"]: + return _default_recorder.get() + + +@dataclasses.dataclass +class Event: + run_id: str + event_id: int + sample_id: Optional[str] + type: str + data: dict + created_by: str + created_at: str + + +class RecorderBase: + """ + The standard events for which recording methods are provided are: + - `match`: A match or non match, as specified by the `correct` bool, between + the `expected` and `picked` results. + - `embedding`: An embedding of the `prompt` of type `embedding_type`. + - `sampling`: What was `sampled` from the model given the input `prompt`. + - `cond_logp`: The conditional log probability, as `logp`, of the + `completion` from the model given the input `prompt`. + - `pick_option`: The option `picked` by the model out of the valid `options` + given the input `prompt`. + - `raw`: A raw sample specified by the `data`. + - `metrics`: A set of metrics specified by the `kwargs`. + - `error`: An `error` along with an accompanying `msg`. + - `extra`: Any extra `data` of interest to be recorded. + For these events, helper methods are defined at the bottom of this file. + More generally, you can record any event by calling `record_event` with the + event `type` and `data`. + Finally, you can also record a final report using `record_final_report`. + """ + + def __init__( + self, + run_spec: evals.base.RunSpec, + ) -> None: + self._sample_id: ContextVar[Optional[int]] = ContextVar( + "_sample_id", default=None + ) + self.run_spec = run_spec + self._events: List[Event] = [] + self._last_flush_time = time.time() + self._flushes_done = 0 + self._written_events = 0 + self._flushes_started = 0 + self._event_lock = threading.Lock() + atexit.register(self.flush_events) + + @contextlib.contextmanager + def as_default_recorder(self, sample_id: str): + sample_id_token = self._sample_id.set(sample_id) + default_recorder_token = _default_recorder.set(self) + yield + _default_recorder.reset(default_recorder_token) + self._sample_id.reset(sample_id_token) + + def current_sample_id(self) -> Optional[str]: + return self._sample_id.get() + + def get_events(self, type: str) -> Sequence[Event]: + with self._event_lock: + return [event for event in self._events if event.type == type] + + def get_metrics(self): + return list(map(lambda x: x.data, self.get_events("metrics"))) + + def get_scores(self, key: str): + return list(map(lambda e: e.data[key], self.get_events("metrics"))) + + def _create_event(self, type, data=None, sample_id=None): + if sample_id is None: + sample_id = self.current_sample_id() + if sample_id is None: + raise ValueError( + "No sample_id set! Either pass it in or use as_default_recorder!" + ) + + return Event( + run_id=self.run_spec.run_id, + event_id=len(self._events), + type=type, + sample_id=sample_id, + data=data, + created_by=self.run_spec.created_by, + created_at=str(datetime.now(timezone.utc)), + ) + + def _flush_events_internal(self, events_to_write: Sequence[Event]): + pass + + def flush_events(self): + with self._event_lock: + if len(self._events) == self._written_events: + return + events_to_write = self._events[self._written_events :] + self._written_events = len(self._events) + self._flushes_started += 1 + self._flush_events_internal(events_to_write) + + def record_event(self, type, data=None, sample_id=None): + if sample_id is None: + sample_id = self.current_sample_id() + if sample_id is None: + raise ValueError( + "No sample_id set! Either pass it in or use as_default_recorder!" + ) + + with self._event_lock: + event = Event( + run_id=self.run_spec.run_id, + event_id=len(self._events), + type=type, + sample_id=sample_id, + data=data, + created_by=self.run_spec.created_by, + created_at=str(datetime.now(timezone.utc)), + ) + self._events.append(event) + if ( + self._flushes_done < self._flushes_started + or len(self._events) < self._written_events + MIN_FLUSH_EVENTS + or time.time() < self._last_flush_time + MIN_FLUSH_SECONDS + ): + return + events_to_write = self._events[self._written_events :] + self._written_events = len(self._events) + self._flushes_started += 1 + self._flush_events_internal(events_to_write) + + def record_match( + self, correct: bool, *, expected=None, picked=None, sample_id=None, **extra + ): + assert isinstance( + correct, bool + ), f"correct must be a bool, but was a {type(correct)}: {correct}" + + if isinstance(expected, list) and len(expected) == 1: + expected = expected[0] + data = { + "correct": bool(correct), + "expected": expected, + "picked": picked, + **extra, + } + self.record_event("match", data, sample_id=sample_id) + + def record_embedding(self, prompt, embedding_type, sample_id=None, **extra): + data = { + "prompt": prompt, + "embedding_type": embedding_type, + **extra, + } + self.record_event("embedding", data, sample_id=sample_id) + + def record_sampling(self, prompt, sampled, sample_id=None, **extra): + data = { + "prompt": prompt, + "sampled": sampled, + **extra, + } + self.record_event("sampling", data, sample_id=sample_id) + + def record_cond_logp(self, prompt, completion, logp, sample_id=None, **extra): + data = { + "prompt": prompt, + "completion": completion, + "logp": logp, + **extra, + } + self.record_event("cond_logp", data, sample_id=sample_id) + + def record_pick_option(self, prompt, options, picked, sample_id=None, **extra): + data = { + "prompt": prompt, + "options": options, + "picked": picked, + **extra, + } + self.record_event("pick_option", data, sample_id=sample_id) + + def record_raw(self, data): + self.record_event("raw_sample", data) + + def record_metrics(self, **kwargs): + self.record_event("metrics", kwargs) + + def record_error(self, msg: str, error: Exception, **kwargs): + data = { + "type": type(error).__name__, + "message": str(error), + } + data.update(kwargs) + self.record_event("error", data) + + def record_extra(self, data, sample_id=None): + self.record_event("extra", data, sample_id=sample_id) + + def record_final_report(self, final_report: Any): + logging.info(f"Final report: {final_report}. Not writing anywhere.") + + +def _green(str): + return f"\033[1;32m{str}\033[0m" + + +def _red(str): + return f"\033[1;31m{str}\033[0m" + + +class DummyRecorder(RecorderBase): + """ + A "recorder" which only logs certain events to the console. + Can be used by passing `--dry-run` when invoking `oaieval`. + """ + + def __init__(self, run_spec: RunSpec, log: bool = True): + super().__init__(run_spec) + self.log = log + + def record_event(self, type, data, sample_id=None): + from evals.registry import registry + + if self.run_spec is None: + return + + base_eval_spec = registry.get_base_eval(self.run_spec.base_eval) + if base_eval_spec and len(base_eval_spec.metrics) >= 1: + primary_metric = base_eval_spec.metrics[0] + else: + primary_metric = "accuracy" + + with self._event_lock: + event = self._create_event(type, data) + self._events.append(event) + + msg = f"Not recording event: {event}" + + if type == "match": + accuracy_good = ( + primary_metric == "accuracy" or primary_metric.startswith("pass@") + ) and (data.get("correct", False) or data.get("accuracy", 0) > 0.5) + f1_score_good = ( + primary_metric == "f1_score" and data.get("f1_score", 0) > 0.5 + ) + if accuracy_good or f1_score_good: + msg = _green(msg) + else: + msg = _red(msg) + + if self.log: + logging.info(msg) + + +class LocalRecorder(RecorderBase): + """ + A recorder which logs events to the specified JSON file. + This is the default recorder used by `oaieval`. + """ + + def __init__(self, log_path: Optional[str], run_spec: RunSpec): + super().__init__(run_spec) + self.event_file_path = log_path + if log_path is not None: + with bf.BlobFile(log_path, "wb") as f: + f.write( + (jsondumps({"spec": dataclasses.asdict(run_spec)}) + "\n").encode( + "utf-8" + ) + ) + + def _flush_events_internal(self, events_to_write: Sequence[Event]): + start = time.time() + try: + lines = [jsondumps(event) + "\n" for event in events_to_write] + except TypeError as e: + logger.error(f"Failed to serialize events: {events_to_write}") + raise e + + with bf.BlobFile(self.event_file_path, "ab") as f: + f.write(b"".join([l.encode("utf-8") for l in lines])) + + logger.info( + f"Logged {len(lines)} rows of events to {self.event_file_path}: insert_time={t(time.time()-start)}" + ) + + self._last_flush_time = time.time() + self._flushes_done += 1 + + def record_final_report(self, final_report: Any): + with bf.BlobFile(self.event_file_path, "ab") as f: + f.write((jsondumps({"final_report": final_report}) + "\n").encode("utf-8")) + + logging.info(f"Final report: {final_report}. Logged to {self.event_file_path}") + + +class Recorder(RecorderBase): + """ + A recorder which logs events to Snowflake. + Can be used by passing `--no-local-run` when invoking `oaieval`. + """ + + def __init__( + self, + log_path: Optional[str], + run_spec: evals.base.RunSpec, + snowflake_connection: Optional[SnowflakeConnection] = None, + ) -> None: + super().__init__(run_spec) + self.event_file_path = log_path + self._writing_lock = threading.Lock() + + if snowflake_connection is None: + snowflake_connection = SnowflakeConnection() + self._conn = snowflake_connection + + if log_path is not None: + with bf.BlobFile(log_path, "wb") as f: + f.write( + (jsondumps({"spec": dataclasses.asdict(run_spec)}) + "\n").encode( + "utf-8" + ) + ) + + query = """ + INSERT ALL INTO runs (run_id, model_name, eval_name, base_eval, split, run_config, settings, created_by, created_at) + VALUES (%(run_id)s, %(model_name)s, %(eval_name)s, %(base_eval)s, %(split)s, run_config, settings, %(created_by)s, %(created_at)s) + SELECT PARSE_JSON(%(run_config)s) AS run_config, PARSE_JSON(%(settings)s) AS settings + """ + self._conn.robust_query( + command=query, + params={ + "run_id": run_spec.run_id, + "model_name": jsondumps(run_spec.model_names), + "eval_name": run_spec.eval_name, + "base_eval": run_spec.base_eval, + "split": run_spec.split, + "run_config": jsondumps(run_spec.run_config), + "settings": jsondumps(run_spec.run_config.get("initial_settings", {})), + "created_by": run_spec.created_by, + "created_at": run_spec.created_at, + }, + ) + atexit.register(self.flush_events) + + def _flush_events_internal(self, events_to_write: Sequence[Event]): + with self._writing_lock: + try: + lines = [jsondumps(event) + "\n" for event in events_to_write] + except TypeError as e: + logger.error(f"Failed to serialize events: {events_to_write}") + raise e + idx_l = 0 + while idx_l < len(events_to_write): + total_bytes = 0 + idx_r = idx_l + while ( + idx_r < len(events_to_write) + and total_bytes + len(lines[idx_r]) < MAX_SNOWFLAKE_BYTES + ): + total_bytes += len(lines[idx_r]) + idx_r += 1 + assert idx_r > idx_l + start = time.time() + buffer = [ + ( + event.run_id, + event.event_id, + event.sample_id, + event.type, + jsondumps(event.data), + event.created_by, + event.created_at, + ) + for event in events_to_write[idx_l:idx_r] + ] + query = """ + INSERT INTO events (run_id, event_id, sample_id, type, data, created_by, created_at) + SELECT Column1 AS run_id, Column2 as event_id, Column3 AS sample_id, Column4 AS type, PARSE_JSON(Column5) AS data, Column6 AS created_by, Column7 AS created_at + FROM VALUES(%s, %s, %s, %s, %s, %s, %s) + """ + self._conn.robust_query(command=query, seqparams=buffer, many=True) + logger.info( + f"Logged {len(buffer)} rows of events to Snowflake: insert_time={t(time.time()-start)}" + ) + idx_l = idx_r + + with bf.BlobFile(self.event_file_path, "ab") as f: + f.write(b"".join([l.encode("utf-8") for l in lines])) + self._last_flush_time = time.time() + self._flushes_done += 1 + + def record_final_report(self, final_report: Any): + with self._writing_lock: + with bf.BlobFile(self.event_file_path, "ab") as f: + f.write( + (jsondumps({"final_report": final_report}) + "\n").encode("utf-8") + ) + query = """ + UPDATE runs + SET final_report = PARSE_JSON(%(final_report)s) + WHERE run_id = %(run_id)s + """ + self._conn.robust_query( + command=query, + params={ + "run_id": self.run_spec.run_id, + "final_report": jsondumps(final_report), + }, + ) + + def record_event(self, type, data=None, sample_id=None): + # try to serialize data so we fail early! + _ = jsondumps(data) + return super().record_event(type, data, sample_id) + + +######################################################################### +### Helper methods which use the thread local global default recorder ### +######################################################################### + + +def current_sample_id() -> str: + return default_recorder().current_sample_id + + +def record_match(correct: bool, *, expected=None, picked=None, **extra): + return default_recorder().record_match( + correct, expected=expected, picked=picked, **extra + ) + + +def record_embedding(prompt, embedding_type, **extra): + return default_recorder().record_embedding(prompt, embedding_type, **extra) + + +def record_sampling(prompt, sampled, **extra): + return default_recorder().record_sampling(prompt, sampled, **extra) + + +def record_cond_logp(prompt, completion, logp, **extra): + return default_recorder().record_cond_logp(prompt, completion, logp, **extra) + + +def record_pick_option(prompt, options, picked, **extra): + return default_recorder().record_pick_option(prompt, options, picked, **extra) + + +def record_raw(data): + return default_recorder().record_raw(data) + + +def record_metrics(**extra): + return default_recorder().record_metrics(**extra) + + +def record_error(msg: str, error: Exception = None, **extra): + return default_recorder().record_error(msg, error, **extra) + + +def record_extra(data): + return default_recorder().record_extra(data) diff --git a/src/xturing/evalutaion/registry.py b/src/xturing/evalutaion/registry.py new file mode 100644 index 0000000..e0ed543 --- /dev/null +++ b/src/xturing/evalutaion/registry.py @@ -0,0 +1,179 @@ +""" +Functions to handle registration of evals. To add a new eval to the registry, +add an entry in one of the YAML files in the `../registry` dir. +By convention, every eval name should start with {base_eval}.{split}. +""" + +import difflib +import functools +import logging +import os +import re +from functools import partial +from pathlib import Path +from typing import Any, Dict, Iterator, List, Sequence, Type, Union + +import yaml + +from .base import BaseEvalSpec, EvalSetSpec, EvalSpec +from .utils import make_object + +logger = logging.getLogger(__name__) + +DEFAULT_PATHS = [ + Path(__file__).parents[0].resolve() / "registry", + Path.home() / ".evals", +] + + +class Registry: + def __init__(self, registry_paths: Sequence[Union[str, Path]] = DEFAULT_PATHS): + self._registry_paths = [ + Path(p) if isinstance(p, str) else p for p in registry_paths + ] + + def make_callable(self, spec): + return partial(make_object(spec.cls).create_and_run, **(spec.args or {})) + + def get_class(self, spec: dict) -> Any: + return make_object(spec.cls, **(spec.args if spec.args else {})) + + def _dereference(self, name: str, d: dict, object: str, type: Type) -> dict: + if not name in d: + return None + + def get_alias(): + if isinstance(d[name], str): + return d[name] + if isinstance(d[name], dict) and "id" in d[name]: + return d[name]["id"] + return None + + logger.debug(f"Looking for {name}") + while True: + alias = get_alias() + + if alias is None: + break + name = alias + + spec = d[name] + + try: + return type(**spec) + except TypeError as e: + raise TypeError(f"Error while processing {object} {name}: {e}") + + def get_modelgraded_spec(self, name: str) -> Dict[str, Any]: + assert name in self._modelgraded_specs, ( + f"Modelgraded spec {name} not found. " + f"Closest matches: {difflib.get_close_matches(name, self._modelgraded_specs.keys(), n=5)}" + ) + return self._modelgraded_specs[name] + + def get_eval(self, name: str) -> EvalSpec: + return self._dereference(name, self._evals, "eval", EvalSpec) + + def get_eval_set(self, name: str) -> EvalSetSpec: + return self._dereference(name, self._eval_sets, "eval set", EvalSetSpec) + + def get_evals(self, patterns: Sequence[str]) -> Iterator[EvalSpec]: + # valid patterns: hello, hello.dev*, hello.dev.*-v1 + def get_regexp(pattern): + pattern = pattern.replace(".", "\\.") + pattern = pattern.replace("*", ".*") + return re.compile(f"^{pattern}$") + + regexps = list(map(get_regexp, patterns)) + for name in self._evals: + # if any regexps match, return the name + if any(map(lambda regexp: regexp.match(name), regexps)): + yield self.get_eval(name) + + def get_base_evals(self) -> List[BaseEvalSpec]: + base_evals = [] + for name, spec in self._evals.items(): + if name.count(".") == 0: + base_evals.append(self.get_base_eval(name)) + return base_evals + + def get_base_eval(self, name: str) -> BaseEvalSpec: + if not name in self._evals: + return None + + spec_or_alias = self._evals[name] + if isinstance(spec_or_alias, dict): + spec = spec_or_alias + try: + return BaseEvalSpec(**spec) + except TypeError as e: + raise TypeError(f"Error while processing base eval {name}: {e}") + + alias = spec_or_alias + return BaseEvalSpec(id=alias) + + def _process_file(self, registry, path): + with open(path, "r") as f: + d = yaml.safe_load(f) + + if d is None: + # no entries in the file + return + + for name, spec in d.items(): + assert name not in registry, f"duplicate entry: {name} from {path}" + if isinstance(spec, dict): + if "key" in spec: + raise ValueError( + f"key is a reserved keyword, but was used in {name} from {path}" + ) + if "group" in spec: + raise ValueError( + f"group is a reserved keyword, but was used in {name} from {path}" + ) + if "cls" in spec: + raise ValueError( + f"cls is a reserved keyword, but was used in {name} from {path}" + ) + + spec["key"] = name + spec["group"] = str(os.path.basename(path).split(".")[0]) + if "class" in spec: + spec["cls"] = spec["class"] + del spec["class"] + registry[name] = spec + + def _process_directory(self, registry, path): + files = Path(path).glob("*.yaml") + for file in files: + self._process_file(registry, file) + + def _load_registry(self, paths): + """Load registry from a list of paths. + + Each path or yaml specifies a dictionary of name -> spec. + """ + registry = {} + for path in paths: + logging.info(f"Loading registry from {path}") + if os.path.exists(path): + if os.path.isdir(path): + self._process_directory(registry, path) + else: + self._process_file(registry, path) + return registry + + @functools.cached_property + def _eval_sets(self): + return self._load_registry([p / "eval_sets" for p in self._registry_paths]) + + @functools.cached_property + def _evals(self): + return self._load_registry([p / "evals" for p in self._registry_paths]) + + @functools.cached_property + def _modelgraded_specs(self): + return self._load_registry([p / "modelgraded" for p in self._registry_paths]) + + +registry = Registry() diff --git a/src/xturing/evalutaion/utils.py b/src/xturing/evalutaion/utils.py new file mode 100644 index 0000000..11f7411 --- /dev/null +++ b/src/xturing/evalutaion/utils.py @@ -0,0 +1,129 @@ +# misc.py +""" +This file defines miscellanous utilities. +""" +import functools +import importlib +from typing import Any + + +def t(duration: float) -> str: + if duration is None: + return "n/a" + if duration < 1: + return f"{(1000*duration):0.3f}ms" + elif duration < 60: + return f"{duration:0.3f}s" + else: + return f"{duration//60}min{int(duration%60)}s" + + +def make_object(object_ref: Any, *args: Any, **kwargs: Any) -> Any: + modname, qualname_separator, qualname = object_ref.partition(":") + obj = importlib.import_module(modname) + if qualname_separator: + for attr in qualname.split("."): + obj = getattr(obj, attr) + return functools.partial(obj, *args, **kwargs) + + +# api_utils.py +""" +This file defines various helper functions for interacting with the OpenAI API. +""" +import logging + +import backoff +import openai + + +def generate_dummy_chat_completion(): + return { + "id": "dummy-id", + "object": "chat.completion", + "created": 12345, + "model": "dummy-chat", + "usage": {"prompt_tokens": 56, "completion_tokens": 6, "total_tokens": 62}, + "choices": [ + { + "message": { + "role": "assistant", + "content": "This is a dummy response.", + }, + "finish_reason": "stop", + "index": 0, + } + ], + } + + +def generate_dummy_completion(): + return { + "id": "dummy-id", + "object": "text_completion", + "created": 12345, + "model": "dummy-completion", + "choices": [ + { + "text": "This is a dummy response.", + "index": 0, + "logprobs": None, + "finish_reason": "stop", + } + ], + "usage": {"prompt_tokens": 5, "completion_tokens": 6, "total_tokens": 11}, + } + + +@backoff.on_exception( + wait_gen=backoff.expo, + exception=( + openai.error.ServiceUnavailableError, + openai.error.APIError, + openai.error.RateLimitError, + openai.error.APIConnectionError, + openai.error.Timeout, + ), + max_value=60, + factor=1.5, +) +def openai_completion_create_retrying(*args, **kwargs): + """ + Helper function for creating a completion. + `args` and `kwargs` match what is accepted by `openai.Completion.create`. + """ + if kwargs["model"] == "dummy-completion": + return generate_dummy_completion() + + result = openai.Completion.create(*args, **kwargs) + if "error" in result: + logging.warning(result) + raise openai.error.APIError(result["error"]) + return result + + +@backoff.on_exception( + wait_gen=backoff.expo, + exception=( + openai.error.ServiceUnavailableError, + openai.error.APIError, + openai.error.RateLimitError, + openai.error.APIConnectionError, + openai.error.Timeout, + ), + max_value=60, + factor=1.5, +) +def openai_chat_completion_create_retrying(*args, **kwargs): + """ + Helper function for creating a chat completion. + `args` and `kwargs` match what is accepted by `openai.ChatCompletion.create`. + """ + if kwargs["model"] == "dummy-chat": + return generate_dummy_chat_completion() + + result = openai.ChatCompletion.create(*args, **kwargs) + if "error" in result: + logging.warning(result) + raise openai.error.APIError(result["error"]) + return result diff --git a/src/xturing/models/causal.py b/src/xturing/models/causal.py index 40457d2..9b6c6d3 100644 --- a/src/xturing/models/causal.py +++ b/src/xturing/models/causal.py @@ -1,6 +1,6 @@ import json from pathlib import Path -from typing import Iterable, List, Optional, Type, Union +from typing import Iterable, List, Optional, Tuple, Type, Union import torch from pytorch_lightning.loggers import Logger @@ -19,7 +19,16 @@ from xturing.trainers.base import BaseTrainer from xturing.trainers.lightning_trainer import LightningTrainer from xturing.utils.logging import configure_logger -from xturing.utils.utils import _filter_args +from xturing.utils.metrics import get_accuracy +from xturing.utils.prompt import ( + OpenAIChatMessage, + OpenAICreateChatPrompt, + OpenAICreatePrompt, + Prompt, + chat_prompt_to_text, + is_chat_prompt, +) +from xturing.utils.utils import _filter_args, _index_samples logger = configure_logger(__name__) @@ -112,9 +121,6 @@ def finetune( trainer = self._make_trainer(dataset, logger) trainer.fit() - def evaluate(self, dataset: Union[TextDataset, InstructionDataset]): - pass - def _generate_from_iterable( self, data_iterator: Iterable, do_tokenization=False, show_tqdm_bar=True ): @@ -206,6 +212,82 @@ def save(self, path: Union[str, Path]): self.engine.save(path) self._save_config(path=path) + def completion_query( + self, prompt: Union[OpenAICreatePrompt, OpenAICreateChatPrompt, Prompt] + ): + actual_prompt = chat_prompt_to_text(prompt) + + text_out = self.model.generate(texts=[actual_prompt]) + + # parse results + # result = { + # "text": text_out, + # "tokens": None, + # "logprobs": None, + # } + + return text_out, actual_prompt + + def check_sampled_text( + self, + prompt: Union[OpenAICreatePrompt, OpenAICreateChatPrompt, Prompt], + expected: Union[str, List[str], Tuple[str]], + *, + options: Optional[List[str]] = None, + ) -> Optional[str]: + if isinstance(expected, tuple): + expected = list(expected) + elif not isinstance(expected, list): + expected = [expected] + if options is None: + options = expected + + output, actual_prompt = self.completion_query(prompt=prompt) + + choice = output[0] + + picked = sampled = choice.strip() + + result = { + "prompt": actual_prompt, + "sampled": sampled, + "options": options, + "picked": picked, + } + result["expected"] = expected + result["match"] = picked in expected + return result + + def eval_sample(self, sample, *args): + prompt = sample["input"] + return self.check_sampled_text(prompt, expected=sample["ideal"]) + + def eval_all_samples( + self, + samples, + show_progress=True, + ): + """ + Evaluate all provided samples in parallel. + """ + work_items = _index_samples(samples, logger) + show_progress = show_progress + + def eval_sample(args): + sample, idx = args + return idx, self.eval_sample(sample) + + logger.info(f"Running in sequential mode!") + iter = map(eval_sample, work_items) + idx_and_result = list( + tqdm(iter, total=len(work_items), disable=not show_progress) + ) + return [r for _, r in sorted(idx_and_result)] + + def evaluate(self, dataset: Union[TextDataset, InstructionDataset]): + outputs = self.eval_all_samples(dataset) + return get_accuracy(outputs) + class CausalInt8Model(CausalModel): def __init__( @@ -275,7 +357,7 @@ def __init__( target_modules=target_modules, **kwargs, ) - + class CausalLoraKbitModel(CausalLoraModel): def __init__(self, engine: str, weights_path: Optional[str] = None): diff --git a/src/xturing/utils/metrics.py b/src/xturing/utils/metrics.py new file mode 100644 index 0000000..d5c5b8b --- /dev/null +++ b/src/xturing/utils/metrics.py @@ -0,0 +1,54 @@ +from typing import Dict, Optional, Sequence, Set + +import numpy as np + + +def get_accuracy(outputs) -> float: + num_correct = 0 + num_total = 0 + for output in outputs: + num_total += 1 + num_correct += int(output["match"]) + if num_total == 0: + return float("nan") + else: + return num_correct / num_total + + +def get_confusion_matrix(outputs: Sequence[Dict], class_labels: Optional[Set] = None): + labels = set() + for r in outputs: + labels.add(r["expected"]) + if class_labels is None: + labels = {label: i for i, label in enumerate(sorted(labels))} + else: + assert labels.issubset(class_labels) + labels = {label: i for i, label in enumerate(class_labels)} + result = np.zeros((len(labels), len(labels) + 1), dtype=int) + for r in outputs: + i = labels[r["expected"]] + j = labels.get(r["picked"], len(labels)) + result[i, j] += 1 + return result + + +def compute_precision(confusion_matrix, idx=0): + return confusion_matrix[idx, idx] / confusion_matrix[:, idx].sum() + + +def compute_recall(confusion_matrix, idx=0): + return confusion_matrix[idx, idx] / confusion_matrix[idx, :].sum() + + +def compute_f_score(confusion_matrix, idx=0, beta=1.0): + precision = compute_precision(confusion_matrix, idx=idx) + recall = compute_recall(confusion_matrix, idx=idx) + return (1 + beta**2) * (precision * recall) / (beta**2 * precision + recall) + + +def compute_averaged_f_score(confusion_matrix, beta=1.0, average="macro"): + assert average in ["macro"] + f_scores = [] + for i in range(confusion_matrix.shape[0]): + f_scores.append(compute_f_score(confusion_matrix, idx=i, beta=beta)) + return np.array(f_scores).mean() diff --git a/src/xturing/utils/prompt.py b/src/xturing/utils/prompt.py new file mode 100644 index 0000000..429663c --- /dev/null +++ b/src/xturing/utils/prompt.py @@ -0,0 +1,65 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Dict, List, Union + +OpenAICreatePrompt = Union[str, List[str], List[int], List[List[int]]] +OpenAIChatMessage = Dict[ + str, str +] # A message is a dictionary with "role" and "content" keys +OpenAICreateChatPrompt = List[OpenAIChatMessage] # A chat log is a list of messages + + +@dataclass +class Prompt(ABC): + """ + A `Prompt` encapsulates everything required to present the `raw_prompt` in different formats, + e.g., a normal unadorned format vs. a chat format. + """ + + @abstractmethod + def to_openai_create_prompt(self): + """ + Return the actual data to be passed as the `prompt` field to either `openai.ChatCompletion.create`, + if the model is a chat model, or `openai.Completion.create` otherwise. + See the above types to see what each API call is able to handle. + """ + + +def chat_prompt_to_text_prompt(prompt: OpenAICreateChatPrompt) -> str: + """ + Render a chat prompt as a text prompt. User and assistant messages are separated by newlines + and prefixed with "User: " and "Assistant: ", respectively, unless there is only one message. + System messages have no prefix. + """ + assert is_chat_prompt(prompt), f"Expected a chat prompt, got {prompt}" + chat_to_prefixes = { + # roles + "system": "", + # names + "example_user": "User: ", + "example_assistant": "Assistant: ", + } + + # For a single message, be it system, user, or assistant, just return the message + if len(prompt) == 1: + return prompt[0]["content"] + + text = "" + for msg in prompt: + role = msg["name"] if "name" in msg else msg["role"] + prefix = chat_to_prefixes.get(role, role.capitalize() + ": ") + content = msg["content"] + text += f"{prefix}{content}\n" + text += "Assistant: " + return text.lstrip() + + +def text_prompt_to_chat_prompt(prompt: str) -> OpenAICreateChatPrompt: + assert isinstance(prompt, str), f"Expected a text prompt, got {prompt}" + return [ + {"role": "system", "content": prompt}, + ] + + +def is_chat_prompt(prompt: Prompt) -> bool: + return isinstance(prompt, list) and all(isinstance(msg, dict) for msg in prompt) diff --git a/src/xturing/utils/utils.py b/src/xturing/utils/utils.py index da09721..a1bc22a 100644 --- a/src/xturing/utils/utils.py +++ b/src/xturing/utils/utils.py @@ -1,14 +1,19 @@ import contextlib import io +import logging import os import random import string import sys import tempfile from pathlib import Path +from typing import Any, List import yaml +SHUFFLE_SEED = 123 +_MAX_SAMPLES = None + def read_yamls(config_path): conf = {} @@ -134,3 +139,14 @@ def _filter_args(arguments: dict): for key in to_delete: del arguments[key] return arguments + + +def _index_samples(samples: List[Any], logger: logging.Logger): + """Shuffle `samples` and pair each sample with its index.""" + indices = list(range(len(samples))) + random.Random(SHUFFLE_SEED).shuffle(indices) + if _MAX_SAMPLES is not None: + indices = indices[:_MAX_SAMPLES] + logger.info(f"Evaluating {len(indices)} samples") + work_items = [(samples[i], i) for i in indices] + return work_items From f306cb9457b3fd7eec5ee1c98a0a7c2d88bd231d Mon Sep 17 00:00:00 2001 From: Tushar Date: Tue, 11 Jul 2023 11:00:51 +0000 Subject: [PATCH 07/18] fix: fixed .evaluate errors The BaseModel.evalute(dataset) is working just fine --- src/xturing/models/causal.py | 11 ++++++----- src/xturing/utils/prompt.py | 2 +- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/xturing/models/causal.py b/src/xturing/models/causal.py index 9b6c6d3..90b6462 100644 --- a/src/xturing/models/causal.py +++ b/src/xturing/models/causal.py @@ -215,9 +215,10 @@ def save(self, path: Union[str, Path]): def completion_query( self, prompt: Union[OpenAICreatePrompt, OpenAICreateChatPrompt, Prompt] ): - actual_prompt = chat_prompt_to_text(prompt) - - text_out = self.model.generate(texts=[actual_prompt]) + # actual_prompt = chat_prompt_to_text(prompt) + actual_prompt = prompt + logger.info(prompt) + text_out = self.generate(texts=[actual_prompt]) # parse results # result = { @@ -259,8 +260,8 @@ def check_sampled_text( return result def eval_sample(self, sample, *args): - prompt = sample["input"] - return self.check_sampled_text(prompt, expected=sample["ideal"]) + prompt = f"{sample.get('instruction', '')} {sample.get('text', ' ')}".strip() + return self.check_sampled_text(prompt, expected=sample["target"]) def eval_all_samples( self, diff --git a/src/xturing/utils/prompt.py b/src/xturing/utils/prompt.py index 429663c..9840e08 100644 --- a/src/xturing/utils/prompt.py +++ b/src/xturing/utils/prompt.py @@ -25,7 +25,7 @@ def to_openai_create_prompt(self): """ -def chat_prompt_to_text_prompt(prompt: OpenAICreateChatPrompt) -> str: +def chat_prompt_to_text(prompt: OpenAICreateChatPrompt) -> str: """ Render a chat prompt as a text prompt. User and assistant messages are separated by newlines and prefixed with "User: " and "Assistant: ", respectively, unless there is only one message. From e962622f903a2197a6ea75bceda277509042db13 Mon Sep 17 00:00:00 2001 From: Tushar Date: Thu, 13 Jul 2023 14:31:56 +0000 Subject: [PATCH 08/18] feat: added perplexity evaluation fixed length perplexity implemented --- examples/opt/opt_evaluate.py | 10 +++++++++ src/xturing/models/causal.py | 39 +++++++++++++++++++++++++++++++++--- 2 files changed, 46 insertions(+), 3 deletions(-) create mode 100644 examples/opt/opt_evaluate.py diff --git a/examples/opt/opt_evaluate.py b/examples/opt/opt_evaluate.py new file mode 100644 index 0000000..f180e50 --- /dev/null +++ b/examples/opt/opt_evaluate.py @@ -0,0 +1,10 @@ +from xturing.datasets.instruction_dataset import InstructionDataset +from xturing.models import BaseModel + +instruction_dataset = InstructionDataset("../examples/llama/alpaca_data") +# Initializes the model +model = BaseModel.create("opt") +# Call the evaluate function +perplexity = model.evaluate(instruction_dataset) + +print(perplexity) diff --git a/src/xturing/models/causal.py b/src/xturing/models/causal.py index 90b6462..e8e4cee 100644 --- a/src/xturing/models/causal.py +++ b/src/xturing/models/causal.py @@ -3,9 +3,11 @@ from typing import Iterable, List, Optional, Tuple, Type, Union import torch +import torch.nn.functional as F from pytorch_lightning.loggers import Logger from torch.utils.data import DataLoader from tqdm import tqdm +from transformers import BatchEncoding from xturing.config import DEFAULT_DEVICE, assert_not_cpu_int8 from xturing.config.config_data_classes import FinetuningConfig, GenerationConfig @@ -30,6 +32,8 @@ ) from xturing.utils.utils import _filter_args, _index_samples +TokenSequence = Union[List[int], torch.LongTensor, torch.Tensor, BatchEncoding] + logger = configure_logger(__name__) @@ -212,6 +216,25 @@ def save(self, path: Union[str, Path]): self.engine.save(path) self._save_config(path=path) + def _loglikelihood_tokens( + self, + data_iterator: Iterable, + disable_tqdm: Optional[bool] = False, + ) -> List[Tuple[float, bool]]: + results = [] + for chunk in tqdm(data_iterator, disable=disable_tqdm): + del input_tokens["label_masks"], input_tokens["targets"] + input_tokens = chunk.to(DEFAULT_DEVICE) + outputs = self._model_call(inputs=input_tokens, labels=input_tokens) + results.append(outputs.loss) + return results + + def _model_call( + self, inputs: TokenSequence, labels: Optional[TokenSequence] = None + ) -> TokenSequence: + self.engine.model = self.engine.model.to(DEFAULT_DEVICE) + return self.engine.model(**inputs, labels=labels["input_ids"]) + def completion_query( self, prompt: Union[OpenAICreatePrompt, OpenAICreateChatPrompt, Prompt] ): @@ -271,7 +294,7 @@ def eval_all_samples( """ Evaluate all provided samples in parallel. """ - work_items = _index_samples(samples, logger) + work_items = _index_samples([samples[i] for i in range(10)], logger) show_progress = show_progress def eval_sample(args): @@ -286,8 +309,18 @@ def eval_sample(args): return [r for _, r in sorted(idx_and_result)] def evaluate(self, dataset: Union[TextDataset, InstructionDataset]): - outputs = self.eval_all_samples(dataset) - return get_accuracy(outputs) + # outputs = self.eval_all_samples(dataset) + collate_fn = self._make_collate_fn(dataset) + dataloader = DataLoader( + dataset, + batch_size=1, + shuffle=False, + drop_last=False, + collate_fn=collate_fn, + ) + results = self._loglikelihood_tokens(dataloader) + return torch.exp(torch.stack(results).mean()) + # return get_accuracy(outputs) class CausalInt8Model(CausalModel): From f06bf1d141c41c91e7ea9ac2c3acccaba15e4523 Mon Sep 17 00:00:00 2001 From: Tushar Date: Mon, 17 Jul 2023 11:57:50 +0000 Subject: [PATCH 09/18] feat: added batch_size to evaluate function Also, removed the evaluation unnecessary files --- examples/opt/opt_evaluate.py | 2 +- src/xturing/evalutaion/api.py | 267 --------------- src/xturing/evalutaion/base.py | 150 --------- src/xturing/evalutaion/data.py | 195 ----------- src/xturing/evalutaion/eval.py | 151 --------- src/xturing/evalutaion/evaluate.py | 197 ------------ src/xturing/evalutaion/match.py | 42 --- src/xturing/evalutaion/metrics.py | 76 ----- src/xturing/evalutaion/models.py | 268 --------------- src/xturing/evalutaion/prompt.py | 119 ------- src/xturing/evalutaion/record.py | 501 ----------------------------- src/xturing/evalutaion/registry.py | 179 ----------- src/xturing/evalutaion/utils.py | 129 -------- src/xturing/models/causal.py | 14 +- 14 files changed, 10 insertions(+), 2280 deletions(-) delete mode 100644 src/xturing/evalutaion/api.py delete mode 100644 src/xturing/evalutaion/base.py delete mode 100644 src/xturing/evalutaion/data.py delete mode 100644 src/xturing/evalutaion/eval.py delete mode 100644 src/xturing/evalutaion/evaluate.py delete mode 100644 src/xturing/evalutaion/match.py delete mode 100644 src/xturing/evalutaion/metrics.py delete mode 100644 src/xturing/evalutaion/models.py delete mode 100644 src/xturing/evalutaion/prompt.py delete mode 100644 src/xturing/evalutaion/record.py delete mode 100644 src/xturing/evalutaion/registry.py delete mode 100644 src/xturing/evalutaion/utils.py diff --git a/examples/opt/opt_evaluate.py b/examples/opt/opt_evaluate.py index f180e50..7ede302 100644 --- a/examples/opt/opt_evaluate.py +++ b/examples/opt/opt_evaluate.py @@ -5,6 +5,6 @@ # Initializes the model model = BaseModel.create("opt") # Call the evaluate function -perplexity = model.evaluate(instruction_dataset) +perplexity = model.evaluate(instruction_dataset, batch_size=5) print(perplexity) diff --git a/src/xturing/evalutaion/api.py b/src/xturing/evalutaion/api.py deleted file mode 100644 index dd24c66..0000000 --- a/src/xturing/evalutaion/api.py +++ /dev/null @@ -1,267 +0,0 @@ -""" -This file provides common interfaces and utilities used by eval creators to -sample from models and process the results. -""" - -import logging -from typing import Callable, Dict, List, Optional, Tuple, Union - -from .base import ModelSpec -from .prompt import ( - ChatCompletionPrompt, - CompletionPrompt, - OpenAICreateChatPrompt, - OpenAICreatePrompt, - Prompt, -) -from .record import record_match, record_sampling -from .utils import ( - openai_chat_completion_create_retrying, - openai_completion_create_retrying, -) - -logger = logging.getLogger(__name__) - - -def completion_query( - model_spec: ModelSpec, - prompt: Union[OpenAICreatePrompt, OpenAICreateChatPrompt, Prompt], - **kwargs, -) -> Tuple[dict, Union[OpenAICreatePrompt, OpenAICreateChatPrompt], dict]: - """ - Query the API for a completion. - - ARGS - ==== - `model_spec`: `ModelSpec` containing model details to use in the query. - This should be the dict returned by `registry.get_model()`. - If `model_spec` is not provided, we use the default model that was - intialized at the beginning of the run. - `prompt`: Either a `Prompt` object or a raw prompt that will get wrapped in - the approriate `Prompt` class. - `kwargs`: Other arguments passed to the API. - - RETURNS - ======= - The result of the API call. - The prompt that was fed into the API call as a str. - A dict containing metadata about the query. - """ - if not isinstance(prompt, Prompt): - assert ( - isinstance(prompt, str) - or ( - isinstance(prompt, list) - and all(isinstance(token, int) for token in prompt) - ) - or ( - isinstance(prompt, list) - and all(isinstance(token, str) for token in prompt) - ) - or ( - isinstance(prompt, list) - and all(isinstance(msg, dict) for msg in prompt) - ) - ), f"Got type {type(prompt)}, with val {type(prompt[0])} for prompt, expected str or list[int] or list[str] or list[dict[str, str]]" - - if model_spec.is_chat: - prompt = ChatCompletionPrompt( - raw_prompt=prompt, - ) - else: - prompt = CompletionPrompt( - raw_prompt=prompt, - ) - - openai_create_prompt: Union[ - OpenAICreatePrompt, OpenAICreateChatPrompt - ] = prompt.to_openai_create_prompt() - - if model_spec.is_chat: - result = openai_chat_completion_create_retrying( - model=model_spec.model, - api_key=model_spec.api_key, - messages=openai_create_prompt, - **{**kwargs, **model_spec.extra_options}, - ) - else: - result = openai_completion_create_retrying( - model=model_spec.model, - api_key=model_spec.api_key, - prompt=openai_create_prompt, - **{**kwargs, **model_spec.extra_options}, - ) - - metadata = {} - if result: - metadata["completion_id"] = result.get("id", None) - metadata["model"] = result.get("model", None) - - if model_spec.is_chat: - for choice in result["choices"]: - choice["text"] = choice["message"]["content"] - - return result, openai_create_prompt, metadata - - -def check_sampled_text( - model_spec: ModelSpec, - prompt: Union[OpenAICreatePrompt, OpenAICreateChatPrompt, Prompt], - expected: Union[str, List[str], Tuple[str]], - *, - options: Optional[List[str]] = None, - separator: Callable[[str], bool] = None, -) -> Optional[str]: - """ - Generates a completion using the prompt, checks whether the completion is - one of the expected completions, and then records the result. - - ARGS - ==== - `model_spec`: See `completion_query`. - `prompt`: See `completion_query`. - `options`: The list of canonical options, defaults to `expected` if None. - The completion will be converted to one of these options. - `expected`: The desired completion or the list of desired completions. - `separator`: A callable which check the character sampled after the option - to see if it is a valid separator. - - RETURNS - ======= - The option that was picked, i.e., matched the completion, or None. - """ - if isinstance(expected, tuple): - expected = list(expected) - elif not isinstance(expected, list): - expected = [expected] - if options is None: - options = expected - - result, actual_prompt, metadata = completion_query( - prompt=prompt, - temperature=0.0, - model_spec=model_spec, - ) - choice = result["choices"][0] - - sampled = choice["text"].strip() if model_spec.strip_completion else choice["text"] - - picked = None - for option in options: - if not sampled.startswith(option): - continue - if ( - separator is not None - and len(sampled) > len(option) - and not separator(sampled[len(option)]) - ): - continue - picked = option - break - - result = { - "prompt": actual_prompt, - "sampled": sampled, - "options": options, - "picked": picked, - } - match = picked in expected - result["expected"] = expected - result["match"] = match - result["metadata"] = metadata - record_sampling(**result) - record_match(match, expected=expected, picked=picked, sampled=sampled) - return picked - - -def sample_freeform( - model_spec: ModelSpec, - prompt: Union[OpenAICreatePrompt, OpenAICreateChatPrompt, Prompt], - *, - temperature: float = 1.0, - top_p: float = 0.9, - max_tokens: int = 512, - stop: Optional[str] = None, - n_samples: int = None, - return_logprobs: bool = False, - **kwargs, -) -> Union[str, List[str], dict]: - """ - Samples a freeform response from the specified model, records the sampling, - and returns the sampled text. - - ARGS - ==== - `model_spec`: See `completion_query`. - `prompt`: See `completion_query`. - `temperature`: Passed to `openai.Completion.create`. - `top_p`: Passed to `openai.Completion.create`. - `max_tokens`: Passed to `openai.Completion.create`. - `stop`: Passed to `openai.Completion.create`. - `n_samples`: The number of samples to generate (1 if None). - `return_logprobs`: If True, returns the tokens and corresponding logprobs - in addition to the sampled text. - `kwargs`: See `completion_query`. - - RETURNS - ======= - If `return_logprobs` is True, returns a dict with the sampled text, tokens, - and corresponding logprobs. If `n_samples` is None, the outer list is - removed from all values. - Otherwise, returns the sampled text, or a list of sampled texts if - `n_samples` is not None. - """ - response, actual_prompt, metadata = completion_query( - prompt=prompt, - temperature=temperature, - top_p=top_p, - max_tokens=max_tokens, - stop=stop, - n=(1 if n_samples is None else n_samples), - model_spec=model_spec, - headers={}, - **kwargs, - ) - sampled = [choice["text"] for choice in response["choices"]] - if n_samples is None: - sampled = sampled[0] - record_sampling(prompt=actual_prompt, sampled=sampled, metadata=metadata) - - if return_logprobs: - assert not model_spec.is_chat, "logprobs only works for non-chat models" - assert not kwargs.get("logprobs") is None - - def _maybe_tokens(logprobs: Optional[dict]) -> Optional[List[str]]: - return logprobs["tokens"] if logprobs is not None else None - - def _maybe_logprobs(logprobs: Optional[dict]) -> Optional[List[float]]: - return logprobs["token_logprobs"] if logprobs is not None else None - - def _maybe_top_logprobs( - logprobs: Optional[dict], - ) -> Optional[List[Dict[str, float]]]: - return ( - [dict(x) for x in logprobs["top_logprobs"]] - if logprobs is not None - else None - ) - - tokens = [_maybe_tokens(choice["logprobs"]) for choice in response["choices"]] - logprobs = [ - _maybe_logprobs(choice["logprobs"]) for choice in response["choices"] - ] - top_logprobs = [ - _maybe_top_logprobs(choice["logprobs"]) for choice in response["choices"] - ] - if n_samples is None: - tokens = tokens[0] - logprobs = logprobs[0] - top_logprobs = top_logprobs[0] - return { - "text": sampled, - "tokens": tokens, - "logprobs": logprobs, - "top_logprobs": top_logprobs, - } - - return sampled diff --git a/src/xturing/evalutaion/base.py b/src/xturing/evalutaion/base.py deleted file mode 100644 index 97bb60e..0000000 --- a/src/xturing/evalutaion/base.py +++ /dev/null @@ -1,150 +0,0 @@ -""" -This file defines the base specifications for models, evals, and runs. Running -evals and most development work should not require familiarity with this file. -""" -import base64 -import datetime -import os -from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional, Sequence - -if TYPE_CHECKING: - from dataclasses import dataclass -else: - from pydantic.dataclasses import dataclass - - -@dataclass -# class ModelSpec: -# """ -# Specification for a model. -# """ - -# name: str -# model: Optional[str] = None - -# is_chat: bool = False - -# encoding: Optional[str] = None -# organization: Optional[str] = None -# api_key: Optional[str] = None -# extra_options: Optional[Mapping[str, Any]] = None -# headers: Optional[Mapping[str, Any]] = None -# strip_completion: bool = True -# n_ctx: Optional[int] = None -# format: Optional[str] = None -# key: Optional[str] = None -# group: Optional[str] = None - -# def __post_init__(self): -# if self.extra_options is None: -# self.extra_options = {} -# if self.headers is None: -# self.headers = {} - -# if self.model is None: -# raise ValueError(f"Must specify a model") - - -@dataclass -# class BaseEvalSpec: -# """ -# Specification for a base eval. -# """ - -# id: Optional[str] = None -# metrics: Optional[Sequence[str]] = None -# description: Optional[str] = None -# disclaimer: Optional[str] = None - -# """ -# True if higher values are better, False if lower values are better. -# This should really be part of a metric, but it's easier to put it here. -# """ -# higher_is_better: bool = True - -# key: Optional[str] = None -# group: Optional[str] = None - - -@dataclass -class EvalSpec: - """ - Specification for an eval. - """ - - cls: str - args: Optional[Dict[str, Any]] = None - key: Optional[str] = None - group: Optional[str] = None - - -@dataclass -class EvalSetSpec: - """ - Specification for an eval set. - """ - - evals: Sequence[str] - key: Optional[str] = None - group: Optional[str] = None - - -@dataclass -# class ModelSpecs: -# completions_: Optional[Sequence[ModelSpec]] = None -# embedding_: Optional[ModelSpec] = None -# ranking_: Optional[ModelSpec] = None - -# @property -# def embedding(self) -> ModelSpec: -# if self.embedding_ is None: -# raise ValueError("Embedding model was not specified") -# return self.embedding_ - -# @property -# def ranking(self) -> ModelSpec: -# if self.ranking_ is None: -# raise ValueError("Ranking model was not specified") -# return self.ranking_ - -# @property -# def completion(self) -> ModelSpec: -# if self.completions_ is None: -# raise ValueError("Completion model was not specified") -# return self.completions_[0] - -# @property -# def completions(self) -> Sequence[ModelSpec]: -# if self.completions_ is None: -# raise ValueError("Completion model was not specified") -# return self.completions_ - -# @property -# def names(self) -> Dict[str, Sequence[str]]: -# dict = {} -# if self.completions_ is not None: -# dict["completions"] = [model.name for model in self.completions_] -# if self.embedding_ is not None: -# dict["embedding"] = [self.embedding_.name] -# if self.ranking_ is not None: -# dict["ranking"] = [self.ranking_.name] -# return dict - - -@dataclass -class RunSpec: - model_name: str - model_names: Dict[str, Sequence[str]] - eval_name: str - base_eval: str - split: str - run_config: Dict[str, Any] - created_by: str - run_id: str = None - created_at: str = None - - def __post_init__(self): - now = datetime.datetime.utcnow() - rand_suffix = base64.b32encode(os.urandom(5)).decode("ascii") - self.run_id = now.strftime("%y%m%d%H%M%S") + rand_suffix - self.created_at = str(now) diff --git a/src/xturing/evalutaion/data.py b/src/xturing/evalutaion/data.py deleted file mode 100644 index 5def56a..0000000 --- a/src/xturing/evalutaion/data.py +++ /dev/null @@ -1,195 +0,0 @@ -""" -This file defines utilities for working with data and files of various types. -""" -import csv -import dataclasses -import gzip -import itertools -import json -import logging -import os -import urllib - -# from collections.abc import Iterator -from functools import partial -from typing import Any, Dict, Iterator, List, Sequence, Union - -import blobfile as bf -import lz4.frame -import pydantic -import pyzstd - -logger = logging.getLogger(__name__) - - -def gzip_open(filename: str, mode: str = "rb", openhook: Any = open) -> gzip.GzipFile: - """Wrap the given openhook in gzip.""" - if mode and "b" not in mode: - mode += "b" - - return gzip.GzipFile(fileobj=openhook(filename, mode), mode=mode) - - -def lz4_open( - filename: str, mode: str = "rb", openhook: Any = open -) -> lz4.frame.LZ4FrameFile: - if mode and "b" not in mode: - mode += "b" - - return lz4.frame.LZ4FrameFile(openhook(filename, mode), mode=mode) - - -def zstd_open(filename: str, mode: str = "rb", openhook: Any = open) -> pyzstd.ZstdFile: - if mode and "b" not in mode: - mode += "b" - - return pyzstd.ZstdFile(openhook(filename, mode), mode=mode) - - -def open_by_file_pattern(filename: str, mode: str = "r", **kwargs: Any) -> Any: - """Can read/write to files on gcs/local with or without gzipping. If file - is stored on gcs, streams with blobfile. Otherwise use vanilla python open. If - filename endswith gz, then zip/unzip contents on the fly (note that gcs paths and - gzip are compatible)""" - open_fn = partial(bf.BlobFile, **kwargs) - try: - if filename.endswith(".gz"): - return gzip_open(filename, openhook=open_fn, mode=mode) - elif filename.endswith(".lz4"): - return lz4_open(filename, openhook=open_fn, mode=mode) - elif filename.endswith(".zst"): - return zstd_open(filename, openhook=open_fn, mode=mode) - else: - scheme = urllib.parse.urlparse(filename).scheme - if scheme == "" or scheme == "file": - return open_fn( - os.path.join( - os.path.dirname(os.path.abspath(__file__)), - "registry", - "data", - filename, - ), - mode=mode, - ) - else: - return open_fn(filename, mode=mode) - except Exception as e: - raise RuntimeError(f"Failed to open: {filename}") from e - - -def _get_jsonl_file(path): - logger.info(f"Fetching {path}") - with open_by_file_pattern(path, mode="r") as f: - return list(map(json.loads, f.readlines())) - - -def _get_json_file(path): - logger.info(f"Fetching {path}") - with open_by_file_pattern(path, mode="r") as f: - return json.loads(f.read()) - - -def _stream_jsonl_file(path) -> Iterator: - logger.info(f"Streaming {path}") - with bf.BlobFile(path, "r", streaming=True) as f: - for line in f: - yield json.loads(line) - - -def get_lines(path) -> List[dict]: - """ - Get a list of lines from a file. - """ - with open_by_file_pattern(path, mode="r") as f: - return f.readlines() - - -def get_jsonl(path: str) -> List[dict]: - """ - Extract json lines from the given path. - If the path is a directory, look in subpaths recursively. - - Return all lines from all jsonl files as a single list. - """ - if bf.isdir(path): - result = [] - for filename in bf.listdir(path): - if filename.endswith(".jsonl"): - result += get_jsonl(os.path.join(path, filename)) - return result - return _get_jsonl_file(path) - - -def get_jsonls(paths: Sequence[str], line_limit=None) -> List[dict]: - return list(iter_jsonls(paths, line_limit)) - - -def get_json(path) -> dict: - if bf.isdir(path): - raise ValueError("Path is a directory, only files are supported") - return _get_json_file(path) - - -def iter_jsonls(paths: Union[str, List[str]], line_limit=None) -> Iterator[dict]: - """ - For each path in the input, iterate over the jsonl files in that path. - Look in subdirectories recursively. - - Use an iterator to conserve memory. - """ - if type(paths) == str: - paths = [paths] - - def _iter(): - for path in paths: - if bf.isdir(path): - for filename in bf.listdir(path): - if filename.endswith(".jsonl"): - yield from iter_jsonls([os.path.join(path, filename)]) - else: - yield from _stream_jsonl_file(path) - - return itertools.islice(_iter(), line_limit) - - -def get_csv(path, fieldnames=None): - with bf.BlobFile(path, "r", cache_dir="/tmp/bf_cache", streaming=False) as f: - reader = csv.DictReader(f, fieldnames=fieldnames) - return [row for row in reader] - - -def _to_py_types(o: Any) -> Any: - if isinstance(o, dict): - return {k: _to_py_types(v) for k, v in o.items()} - if isinstance(o, list): - return [_to_py_types(v) for v in o] - - if dataclasses.is_dataclass(o): - return _to_py_types(dataclasses.asdict(o)) - - # pydantic data classes - if isinstance(o, pydantic.BaseModel): - return json.loads(o.json()) - - return o - - -class EnhancedJSONEncoder(json.JSONEncoder): - def default(self, o: Any) -> str: - return _to_py_types(o) - - -def jsondumps(o: Any, ensure_ascii: bool = False, **kwargs: Any) -> str: - return json.dumps(o, cls=EnhancedJSONEncoder, ensure_ascii=ensure_ascii, **kwargs) - - -def jsondump(o: Any, fp: Any, ensure_ascii: bool = False, **kwargs: Any) -> None: - json.dump(o, fp, cls=EnhancedJSONEncoder, ensure_ascii=ensure_ascii, **kwargs) - - -def jsonloads(s: str, **kwargs: Any) -> Any: - return json.loads(s, **kwargs) - - -def jsonload(fp: Any, **kwargs: Any) -> Any: - return json.load(fp, **kwargs) diff --git a/src/xturing/evalutaion/eval.py b/src/xturing/evalutaion/eval.py deleted file mode 100644 index 7c46e1b..0000000 --- a/src/xturing/evalutaion/eval.py +++ /dev/null @@ -1,151 +0,0 @@ -""" -This file defines the base class for evals. -""" -import abc -import asyncio -import concurrent.futures -import logging -import os -import random -from multiprocessing.pool import ThreadPool -from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple - -from tqdm import tqdm - -from .base import ModelSpec, ModelSpecs -from .record import RecorderBase -from .registry import Registry - -logger = logging.getLogger(__name__) - - -SHUFFLE_SEED = 123 -_MAX_SAMPLES = None - - -def _index_samples(samples: List[Any]) -> List[Tuple[Any, int]]: - """Shuffle `samples` and pair each sample with its index.""" - indices = list(range(len(samples))) - random.Random(SHUFFLE_SEED).shuffle(indices) - if _MAX_SAMPLES is not None: - indices = indices[:_MAX_SAMPLES] - logger.info(f"Evaluating {len(indices)} samples") - work_items = [(samples[i], i) for i in indices] - return work_items - - -def set_max_samples(max_samples: int): - global _MAX_SAMPLES - _MAX_SAMPLES = max_samples - - -class Eval(abc.ABC): - """ - Evaluation classes generally should override two methods: - `eval_sample`: Takes in a test sample and a random number generator and - records the metrics of interest. - `run`: Takes in a recorder and runs the evaluation. Generally, most `run` - methods will follow this same pattern: loading the data, calling - `eval_all_samples`, and aggregating the recorded results. - """ - - def __init__( - self, - model_specs, - seed: int = 20220722, - name: str = "no_name_eval.default", - registry: Optional[Registry] = None, - ): - splits = name.split(".") - if len(splits) < 2: - raise ValueError( - f"Eval name must at least have .. Got name {name}" - ) - - self.model_specs = model_specs - self.seed = seed - self.name = name - self.registry = registry or Registry() - - def eval_sample(self, sample: Any, rng: random.Random): - raise NotImplementedError() - - @abc.abstractmethod - def run(self, recorder: RecorderBase) -> Dict[str, float]: - """Run the evaluation with the corresponding recorder.""" - raise NotImplementedError() - - async def async_eval_all_samples( - self, - eval_fn: Callable[[Tuple[Any, int]], Awaitable[Tuple[int, Any]]], - samples: List[Any], - concurrency: int = 32, - show_progress: bool = True, - ): - work_items = _index_samples(samples) - semaphore = asyncio.Semaphore(concurrency) - - async def eval_fn_with_semaphore(args): - async with semaphore: - return await eval_fn(args) - - futures = [ - asyncio.ensure_future(eval_fn_with_semaphore(args)) for args in work_items - ] - - for future in tqdm( - asyncio.as_completed(futures), total=len(samples), disable=not show_progress - ): - await future - - def eval_all_samples( - self, - recorder: RecorderBase, - samples, - show_progress=True, - ): - """ - Evaluate all provided samples in parallel. - """ - work_items = _index_samples(samples) - threads = int(os.environ.get("EVALS_THREADS", "10")) - show_progress = bool(os.environ.get("EVALS_SHOW_EVAL_PROGRESS", show_progress)) - timeout = float(os.environ.get("EVALS_THREAD_TIMEOUT", "40")) - - def eval_sample(args): - """ - Evaluate a single sample. - """ - sample, idx = args - base_name, split = self.name.split(".")[0:2] - sample_id = f"{base_name}.{split}.{idx}" - with recorder.as_default_recorder(sample_id): - recorder.record_raw(sample) - seed = f"{sample_id}:{self.seed}".encode("utf-8") - rng = random.Random(seed) - return idx, self.eval_sample(sample, rng) - - def worker_thread(args): - """ - Worker thread for evaluating a single sample. - """ - while True: - executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) - future = executor.submit(eval_sample, args=args) - try: - result = future.result(timeout=timeout) - return result - except concurrent.futures.TimeoutError as e: - executor.shutdown(wait=False) - - with ThreadPool(threads) as pool: - if os.environ.get("EVALS_SEQUENTIAL", "0") in {"1", "true", "yes"}: - logger.info(f"Running in sequential mode!") - iter = map(eval_sample, work_items) - else: - logger.info(f"Running in threaded mode with {threads} threads!") - iter = pool.imap_unordered(worker_thread, work_items) - idx_and_result = list( - tqdm(iter, total=len(work_items), disable=not show_progress) - ) - return [r for _, r in sorted(idx_and_result)] diff --git a/src/xturing/evalutaion/evaluate.py b/src/xturing/evalutaion/evaluate.py deleted file mode 100644 index 4b27253..0000000 --- a/src/xturing/evalutaion/evaluate.py +++ /dev/null @@ -1,197 +0,0 @@ -import argparse -import logging -import shlex -import sys -from functools import cached_property -from typing import Any, Mapping, Optional - -import openai - -from .base import EvalSpec, RunSpec -from .record import DummyRecorder, LocalRecorder, Recorder -from .registry import Registry, registry - -logger = logging.getLogger(__name__) - - -def _purple(str): - return f"\033[1;35m{str}\033[0m" - - -def run_evaluation(args): - if args.debug: - logging.getLogger().setLevel(logging.DEBUG) - - model = args.model - - run_config = { - "model": model, - "eval": args.eval_spec, - "seed": args.seed, - } - - model_name = model.config_name - eval_name = args.eval_spec.key - - run_spec = RunSpec( - model_name=model_name, - eval_name=eval_name, - base_eval=eval_name.split(".")[0], - split=eval_name.split(".")[1], - run_config=run_config, - created_by=args.user, - run_id="something", - ) - if args.record_path is None: - record_path = f"/tmp/evallogs/{run_spec.run_id}_{args.model}_{args.eval}.jsonl" - else: - record_path = args.record_path - - # Recording progress - if args.dry_run: - recorder = DummyRecorder(run_spec=run_spec, log=args.dry_run_logging) - elif args.local_run: - recorder = LocalRecorder(record_path, run_spec=run_spec) - else: - recorder = Recorder(record_path, run_spec=run_spec) - - api_extra_options = {} - if not args.cache: - api_extra_options["cache_level"] = 0 - - run_url = f"{run_spec.run_id}" - logger.info(_purple(f"Run started: {run_url}")) - - def parse_extra_eval_params(param_str: Optional[str]) -> Mapping[str, Any]: - """Parse a string of the form "key1=value1,key2=value2" into a dict.""" - if not param_str: - return {} - - def to_number(x): - try: - return int(x) - except: - pass - try: - return float(x) - except: - pass - return x - - str_dict = dict(kv.split("=") for kv in param_str.split(",")) - return {k: to_number(v) for k, v in str_dict.items()} - - extra_eval_params = parse_extra_eval_params(args.extra_eval_params) - - eval_class = registry.get_class(args.eval_spec) - eval = eval_class( - model_specs=model, - seed=args.seed, - name=eval_name, - registry=registry, - **extra_eval_params, - ) - result = eval.run(recorder) - recorder.record_final_report(result) - - if not (args.dry_run or args.local_run): - logger.info(_purple(f"Run completed: {run_url}")) - - logger.info("Final report:") - for key, value in result.items(): - logger.info(f"{key}: {value}") - return run_spec.run_id - - -def evaluate( - model: str, - eval: str, - embedding_model: str = "", - ranking_model: str = "", - extra_eval_params: str = "", - max_samples: Optional[int] = None, - cache: bool = True, - visible: Optional[bool] = None, - seed: int = 20220722, - user: str = "", - record_path: Optional[str] = None, - log_to_file: Optional[str] = None, - debug: bool = False, - local_run: bool = True, - dry_run: bool = False, - dry_run_logging: bool = True, -) -> Any: - parser = argparse.ArgumentParser(description="Run evals through the API") - parser.add_argument("model", type=str, help="Name of a completion model.") - parser.add_argument("eval", type=str, help="Name of an eval. See registry.") - parser.add_argument("--embedding_model", type=str, default="") - parser.add_argument("--ranking_model", type=str, default="") - parser.add_argument("--extra_eval_params", type=str, default="") - parser.add_argument("--max_samples", type=int, default=None) - parser.add_argument("--cache", action=argparse.BooleanOptionalAction, default=True) - parser.add_argument( - "--visible", action=argparse.BooleanOptionalAction, default=None - ) - parser.add_argument("--seed", type=int, default=20220722) - parser.add_argument("--user", type=str, default="") - parser.add_argument("--record_path", type=str, default=None) - parser.add_argument( - "--log_to_file", type=str, default=None, help="Log to a file instead of stdout" - ) - parser.add_argument("--debug", action=argparse.BooleanOptionalAction, default=False) - parser.add_argument( - "--local-run", action=argparse.BooleanOptionalAction, default=True - ) - parser.add_argument( - "--dry-run", action=argparse.BooleanOptionalAction, default=False - ) - parser.add_argument( - "--dry-run-logging", action=argparse.BooleanOptionalAction, default=True - ) - - args = argparse.Namespace( - model=model, - eval=eval, - embedding_model=embedding_model, - ranking_model=ranking_model, - extra_eval_params=extra_eval_params, - max_samples=max_samples, - cache=cache, - visible=visible, - seed=seed, - user=user, - record_path=record_path, - log_to_file=log_to_file, - debug=debug, - local_run=local_run, - dry_run=dry_run, - dry_run_logging=dry_run_logging, - ) - - # args_parsed = parser.parse_args() - - # Running evaluation code - logging.basicConfig( - format="[%(asctime)s] [%(filename)s:%(lineno)d] %(message)s", - level=logging.INFO, - filename=args.log_to_file if args.log_to_file else None, - ) - - logging.getLogger("openai").setLevel(logging.WARN) - if hasattr(openai.error, "set_display_cause"): - openai.error.set_display_cause() - - run_evaluation(args) - - -#################################### -# EXAMPLE USAGE: - -# evaluate( -# model_name="davinci", -# eval="test", -# embedding_model="", -# ranking_model="", -# extra_eval_params="", -# max_samples=None, -# ) diff --git a/src/xturing/evalutaion/match.py b/src/xturing/evalutaion/match.py deleted file mode 100644 index 7dda0a1..0000000 --- a/src/xturing/evalutaion/match.py +++ /dev/null @@ -1,42 +0,0 @@ -from typing import Any - -from .data import get_jsonl -from .eval import Eval -from .metrics import get_accuracy -from .models import check_sampled_text -from .prompt import is_chat_prompt - - -class Match(Eval): - def __init__( - self, - model_specs, - samples_jsonl: str, - *args, - max_tokens: int = 500, - num_few_shot: int = 0, - few_shot_jsonl: str = None, - **kwargs, - ): - super().__init__(model_specs, *args, **kwargs) - self.max_tokens = max_tokens - self.samples_jsonl = samples_jsonl - - def eval_sample(self, sample: Any, *_): - prompt = sample["input"] - if self.num_few_shot > 0: - assert is_chat_prompt(sample["input"]), "few shot requires chat prompt" - prompt = sample["input"][:-1] - for s in self.few_shot[: self.num_few_shot]: - prompt += s["sample"] - prompt += sample["input"][-1:] - - return check_sampled_text(self.model_spec, prompt, expected=sample["ideal"]) - - def run(self, recorder): - samples = get_jsonl(self.samples_jsonl) - self.eval_all_samples(recorder, samples) - events = recorder.get_events("match") - return { - "accuracy": get_accuracy(events), - } diff --git a/src/xturing/evalutaion/metrics.py b/src/xturing/evalutaion/metrics.py deleted file mode 100644 index 6f46144..0000000 --- a/src/xturing/evalutaion/metrics.py +++ /dev/null @@ -1,76 +0,0 @@ -""" -This file defines various common metrics of interest. -""" -import random -from typing import Optional, Sequence, Set - -import numpy as np - -from .record import Event - - -def get_accuracy(events: Sequence[Event]) -> float: - num_correct = 0 - num_total = 0 - for event in events: - num_total += 1 - num_correct += int(event.data["correct"]) - if num_total == 0: - return float("nan") - else: - return num_correct / num_total - - -def get_bootstrap_accuracy_std(events: Sequence[Event], num_samples: int = 1000): - vals = [m.data["correct"] for m in events] - return np.std([np.mean(random.sample(vals, len(vals) // 2)) for _ in range(1000)]) - - -def get_confusion_matrix( - matches: Sequence[Event], class_labels: Optional[Set] = None -) -> np.ndarray: - labels = set() - for match in matches: - labels.add(match.data["expected"]) - if class_labels is None: - labels = {label: i for i, label in enumerate(sorted(labels))} - else: - assert labels.issubset(class_labels) - labels = {label: i for i, label in enumerate(class_labels)} - result = np.zeros((len(labels), len(labels) + 1), dtype=int) - for match in matches: - i = labels[match.data["expected"]] - j = labels.get(match.data["picked"], len(labels)) - result[i, j] += 1 - return result - - -def compute_matthew_corr(confusion_matrix): - assert confusion_matrix.shape == (2, 3), f"Got shape: {confusion_matrix.shape}" - r = confusion_matrix[:, :2] - r[:, 0] += confusion_matrix[:, 2] - return (r[1, 1] * r[0, 0] - r[1, 0] * r[0, 1]) / np.sqrt( - r[1, :].sum() * r[0, :].sum() * r[:, 0].sum() * r[:, 1].sum() - ) - - -def compute_precision(confusion_matrix, idx=0): - return confusion_matrix[idx, idx] / confusion_matrix[:, idx].sum() - - -def compute_recall(confusion_matrix, idx=0): - return confusion_matrix[idx, idx] / confusion_matrix[idx, :].sum() - - -def compute_f_score(confusion_matrix, idx=0, beta=1.0): - precision = compute_precision(confusion_matrix, idx=idx) - recall = compute_recall(confusion_matrix, idx=idx) - return (1 + beta**2) * (precision * recall) / (beta**2 * precision + recall) - - -def compute_averaged_f_score(confusion_matrix, beta=1.0, average="macro"): - assert average in ["macro"] - f_scores = [] - for i in range(confusion_matrix.shape[0]): - f_scores.append(compute_f_score(confusion_matrix, idx=i, beta=beta)) - return np.array(f_scores).mean() diff --git a/src/xturing/evalutaion/models.py b/src/xturing/evalutaion/models.py deleted file mode 100644 index d17f75d..0000000 --- a/src/xturing/evalutaion/models.py +++ /dev/null @@ -1,268 +0,0 @@ -# import os - -# from dotenv import load_dotenv, find_dotenv - -""" -This file provides common interfaces and utilities used by eval creators to -sample from models and process the results. -""" - -import logging -import os -from typing import Callable, List, Optional, Tuple, Union - -from transformers import AutoModelForCausalLM, AutoTokenizer - -from xturing.datasets.instruction_dataset import InstructionDataset -from xturing.models.base import BaseModel - -from .base import ModelSpec -from .prompt import ( - ChatCompletionPrompt, - CompletionPrompt, - OpenAICreateChatPrompt, - OpenAICreatePrompt, - Prompt, -) -from .record import record_match, record_sampling - -logger = logging.getLogger(__name__) - -# # load openai key -# load_dotenv(find_dotenv()) -# OPENAI_KEY = os.environ["OPENAI_KEY"] - -# HELPER FUNCTIONS - - -def chat_prompt_to_text(prompt): - if type(prompt) == str: - return prompt - else: - return " ".join([message["content"] for message in prompt]) - - -def load_model(model_name): - if not os.path.exists(f"./{model_name}"): - print(f"LOADING MODEL: {model_name}") - model = BaseModel.create(model_name) - model.save(f"./{model_name}") - - return BaseModel.load(f"./{model_name}") - - -def completion_query( - model_spec: ModelSpec, - prompt: Union[OpenAICreatePrompt, OpenAICreateChatPrompt, Prompt], - **kwargs, -) -> Tuple[dict, Union[OpenAICreatePrompt, OpenAICreateChatPrompt], dict]: - """ - Query the API for a completion. - - ARGS - ==== - `model_spec`: `ModelSpec` containing model details to use in the query. - This should be the dict returned by `registry.get_model()`. - If `model_spec` is not provided, we use the default model that was - intialized at the beginning of the run. - `prompt`: Either a `Prompt` object or a raw prompt that will get wrapped in - the approriate `Prompt` class. - `kwargs`: Other arguments passed to the API. - - RETURNS - ======= - The result of the API call. - The prompt that was fed into the API call as a str. - A dict containing metadata about the query. - """ - - # parse prompt - - # Initialize model - # TODO: pass kwargs to model! - - # model = AutoModelForCausalLM.from_pretrained(model_spec.name) - - # huggingface_models = ["gpt2"] - - # if model_spec.name in huggingface_models: - # model = AutoModelForCausalLM.from_pretrained(model_spec.name) - # else: - # model = BaseModel.load(model_spec.name) - # tokenizer = AutoTokenizer.from_pretrained(model_spec.name, return_tensors="pt") - - # TODO: is concatenating the contents a good solution to transform chat-style inputs to one string? - - # inputs = tokenizer(actual_prompt, return_tensors="pt").input_ids - - # Run completion - # outputs = model.generate( - # input_ids=inputs, return_dict_in_generate=True, output_scores=True, **kwargs - # ) - - actual_prompt = chat_prompt_to_text(prompt) - - # TODO add config - - model = load_model(model_spec.name) - - text_out = model.generate(texts=[actual_prompt]) - - # parse results - result = { - "text": text_out, - "tokens": None, - "logprobs": None, - } - # TODO: change metadata based on model - metadata = {"model": model_spec.name} - - return result, actual_prompt, metadata - - -def check_sampled_text( - model_spec: ModelSpec, - prompt: Union[OpenAICreatePrompt, OpenAICreateChatPrompt, Prompt], - expected: Union[str, List[str], Tuple[str]], - *, - options: Optional[List[str]] = None, - separator: Callable[[str], bool] = None, -) -> Optional[str]: - """ - Generates a completion using the prompt, checks whether the completion is - one of the expected completions, and then records the result. - - ARGS - ==== - `model_spec`: See `completion_query`. - `prompt`: See `completion_query`. - `options`: The list of canonical options, defaults to `expected` if None. - The completion will be converted to one of these options. - `expected`: The desired completion or the list of desired completions. - `separator`: A callable which check the character sampled after the option - to see if it is a valid separator. - - RETURNS - ======= - The option that was picked, i.e., matched the completion, or None. - """ - if isinstance(expected, tuple): - expected = list(expected) - elif not isinstance(expected, list): - expected = [expected] - if options is None: - options = expected - - result, actual_prompt, metadata = completion_query( - prompt=prompt, - model_spec=model_spec, - ) - - choice = result["text"][0] - - # TODO: check what result is supposed to look like [from OPENAI API] - sampled = choice.strip() if model_spec.strip_completion else choice - - picked = None - for option in options: - if not sampled.startswith(option): - continue - if ( - separator is not None - and len(sampled) > len(option) - and not separator(sampled[len(option)]) - ): - continue - picked = option - break - - result = { - "prompt": actual_prompt, - "sampled": sampled, - "options": options, - "picked": picked, - } - match = picked in expected - result["expected"] = expected - result["match"] = match - result["metadata"] = metadata - print("result", result) - record_sampling(**result) - record_match(match, expected=expected, picked=picked, sampled=sampled) - return picked - - -def sample_freeform( - model_spec: ModelSpec, - prompt: Union[OpenAICreatePrompt, OpenAICreateChatPrompt, Prompt], - *, - temperature: float = 1.0, - top_p: float = 0.9, - max_tokens: int = 512, - stop: Optional[str] = None, - n_samples: int = None, - return_logprobs: bool = False, - **kwargs, -) -> Union[str, List[str], dict]: - """ - Samples a freeform response from the specified model, records the sampling, - and returns the sampled text. - - ARGS - ==== - `model_spec`: See `completion_query`. - `prompt`: See `completion_query`. - `temperature`: Passed to `openai.Completion.create`. - `top_p`: Passed to `openai.Completion.create`. - `max_tokens`: Passed to `openai.Completion.create`. - `stop`: Passed to `openai.Completion.create`. - `n_samples`: The number of samples to generate (1 if None). - `return_logprobs`: If True, returns the tokens and corresponding logprobs - in addition to the sampled text. - `kwargs`: See `completion_query`. - - RETURNS - ======= - If `return_logprobs` is True, returns a dict with the sampled text, tokens, - and corresponding logprobs. If `n_samples` is None, the outer list is - removed from all values. - Otherwise, returns the sampled text, or a list of sampled texts if - `n_samples` is not None. - """ - - # TODO: add kwargs to completion query (see api.py for reference) - result, actual_prompt, metadata = completion_query( - prompt=prompt, - model_spec=model_spec, - do_sample=True, - num_return_sequences=n_samples if n_samples else 1, - max_new_tokens=max_tokens, - top_p=top_p, - ) - - if n_samples is None: - sampled = result["text"][0] - else: - sampled = result["text"] - - record_sampling(prompt=actual_prompt, sampled=sampled, metadata=metadata) - - if return_logprobs: - # assert not model_spec.is_chat, "logprobs only works for non-chat models" - # assert not kwargs.get("logprobs") is None - - tokens = result["tokens"] - logprobs = result["logprobs"] - top_logprobs = logprobs # TODO: check how to get top logprobs, for now I return all logprobs - if n_samples is None: - tokens = tokens[0] - logprobs = logprobs[0] - top_logprobs = top_logprobs[0] - return { - "text": sampled, - "tokens": tokens, - "logprobs": logprobs, - "top_logprobs": top_logprobs, - } - - return sampled diff --git a/src/xturing/evalutaion/prompt.py b/src/xturing/evalutaion/prompt.py deleted file mode 100644 index 438ebdb..0000000 --- a/src/xturing/evalutaion/prompt.py +++ /dev/null @@ -1,119 +0,0 @@ -""" -This file defines the classes for how to manage prompts for different types of -models, i.e., "chat models" vs. "non chat models". -""" -import logging -import threading -from abc import ABC, abstractmethod -from dataclasses import dataclass -from typing import Dict, List, Union - -logger = logging.getLogger(__name__) -ENCODER_LOCK = threading.Lock() - -# This is an approximation to the type accepted as the `prompt` field to `openai.Completion.create` calls -OpenAICreatePrompt = Union[str, List[str], List[int], List[List[int]]] - -# This is the type accepted as the `prompt` field to `openai.ChatCompletion.create` calls -OpenAIChatMessage = Dict[ - str, str -] # A message is a dictionary with "role" and "content" keys -OpenAICreateChatPrompt = List[OpenAIChatMessage] # A chat log is a list of messages - - -def chat_prompt_to_text_prompt(prompt: OpenAICreateChatPrompt) -> str: - """ - Render a chat prompt as a text prompt. User and assistant messages are separated by newlines - and prefixed with "User: " and "Assistant: ", respectively, unless there is only one message. - System messages have no prefix. - """ - assert is_chat_prompt(prompt), f"Expected a chat prompt, got {prompt}" - chat_to_prefixes = { - # roles - "system": "", - # names - "example_user": "User: ", - "example_assistant": "Assistant: ", - } - - # For a single message, be it system, user, or assistant, just return the message - if len(prompt) == 1: - return prompt[0]["content"] - - text = "" - for msg in prompt: - role = msg["name"] if "name" in msg else msg["role"] - prefix = chat_to_prefixes.get(role, role.capitalize() + ": ") - content = msg["content"] - text += f"{prefix}{content}\n" - text += "Assistant: " - return text.lstrip() - - -def text_prompt_to_chat_prompt(prompt: str) -> OpenAICreateChatPrompt: - assert isinstance(prompt, str), f"Expected a text prompt, got {prompt}" - return [ - {"role": "system", "content": prompt}, - ] - - -@dataclass -class Prompt(ABC): - """ - A `Prompt` encapsulates everything required to present the `raw_prompt` in different formats, - e.g., a normal unadorned format vs. a chat format. - """ - - @abstractmethod - def to_openai_create_prompt(self): - """ - Return the actual data to be passed as the `prompt` field to either `openai.ChatCompletion.create`, - if the model is a chat model, or `openai.Completion.create` otherwise. - See the above types to see what each API call is able to handle. - """ - - -def is_chat_prompt(prompt: Prompt) -> bool: - return isinstance(prompt, list) and all(isinstance(msg, dict) for msg in prompt) - - -@dataclass -class CompletionPrompt(Prompt): - """ - A `Prompt` object that wraps prompts to be compatible with non chat models, which use `openai.Completion.create`. - """ - - raw_prompt: Union[OpenAICreatePrompt, OpenAICreateChatPrompt] - - def _render_chat_prompt_as_text( - self, prompt: OpenAICreateChatPrompt - ) -> OpenAICreatePrompt: - return chat_prompt_to_text_prompt(prompt) - - def to_openai_create_prompt(self) -> OpenAICreatePrompt: - if is_chat_prompt(self.raw_prompt): - return self._render_chat_prompt_as_text(self.raw_prompt) - return self.raw_prompt - - -@dataclass -class ChatCompletionPrompt(Prompt): - """ - A `Prompt` object that wraps prompts to be compatible with chat models, which use `openai.ChatCompletion.create`. - - The format expected by chat models is a list of messages, where each message is a dict with "role" and "content" keys. - """ - - raw_prompt: Union[OpenAICreatePrompt, OpenAICreateChatPrompt] - - def _render_text_as_chat_prompt(self, prompt: str) -> OpenAICreateChatPrompt: - """ - Render a text string as a chat prompt. The default option we adopt here is to simply take the full prompt - and treat it as a system message. - """ - return text_prompt_to_chat_prompt(prompt) - - def to_openai_create_prompt(self) -> OpenAICreateChatPrompt: - if is_chat_prompt(self.raw_prompt): - return self.raw_prompt - return self._render_text_as_chat_prompt(self.raw_prompt) diff --git a/src/xturing/evalutaion/record.py b/src/xturing/evalutaion/record.py deleted file mode 100644 index 8a07e74..0000000 --- a/src/xturing/evalutaion/record.py +++ /dev/null @@ -1,501 +0,0 @@ -""" -This file defines the recorder classes which log eval results in different ways, -such as to a local JSON file or to a remote Snowflake database. - -If you would like to implement a custom recorder, you can see how the -`LocalRecorder` and `Recorder` classes inherit from the `RecorderBase` class and -override certain methods. -""" -import atexit -import contextlib -import dataclasses -import logging -import threading -import time -from contextvars import ContextVar -from datetime import datetime, timezone -from typing import Any, List, Optional, Sequence - -import blobfile as bf -import evals -from evals.base import RunSpec -from evals.data import jsondumps -from evals.utils.misc import t -from evals.utils.snowflake import SnowflakeConnection - -logger = logging.getLogger(__name__) - -MIN_FLUSH_EVENTS = 100 -MAX_SNOWFLAKE_BYTES = 16 * 10**6 -MIN_FLUSH_SECONDS = 10 - -_default_recorder: ContextVar[Optional["RecorderBase"]] = ContextVar( - "default_recorder", default=None -) - - -def default_recorder() -> Optional["RecorderBase"]: - return _default_recorder.get() - - -@dataclasses.dataclass -class Event: - run_id: str - event_id: int - sample_id: Optional[str] - type: str - data: dict - created_by: str - created_at: str - - -class RecorderBase: - """ - The standard events for which recording methods are provided are: - - `match`: A match or non match, as specified by the `correct` bool, between - the `expected` and `picked` results. - - `embedding`: An embedding of the `prompt` of type `embedding_type`. - - `sampling`: What was `sampled` from the model given the input `prompt`. - - `cond_logp`: The conditional log probability, as `logp`, of the - `completion` from the model given the input `prompt`. - - `pick_option`: The option `picked` by the model out of the valid `options` - given the input `prompt`. - - `raw`: A raw sample specified by the `data`. - - `metrics`: A set of metrics specified by the `kwargs`. - - `error`: An `error` along with an accompanying `msg`. - - `extra`: Any extra `data` of interest to be recorded. - For these events, helper methods are defined at the bottom of this file. - More generally, you can record any event by calling `record_event` with the - event `type` and `data`. - Finally, you can also record a final report using `record_final_report`. - """ - - def __init__( - self, - run_spec: evals.base.RunSpec, - ) -> None: - self._sample_id: ContextVar[Optional[int]] = ContextVar( - "_sample_id", default=None - ) - self.run_spec = run_spec - self._events: List[Event] = [] - self._last_flush_time = time.time() - self._flushes_done = 0 - self._written_events = 0 - self._flushes_started = 0 - self._event_lock = threading.Lock() - atexit.register(self.flush_events) - - @contextlib.contextmanager - def as_default_recorder(self, sample_id: str): - sample_id_token = self._sample_id.set(sample_id) - default_recorder_token = _default_recorder.set(self) - yield - _default_recorder.reset(default_recorder_token) - self._sample_id.reset(sample_id_token) - - def current_sample_id(self) -> Optional[str]: - return self._sample_id.get() - - def get_events(self, type: str) -> Sequence[Event]: - with self._event_lock: - return [event for event in self._events if event.type == type] - - def get_metrics(self): - return list(map(lambda x: x.data, self.get_events("metrics"))) - - def get_scores(self, key: str): - return list(map(lambda e: e.data[key], self.get_events("metrics"))) - - def _create_event(self, type, data=None, sample_id=None): - if sample_id is None: - sample_id = self.current_sample_id() - if sample_id is None: - raise ValueError( - "No sample_id set! Either pass it in or use as_default_recorder!" - ) - - return Event( - run_id=self.run_spec.run_id, - event_id=len(self._events), - type=type, - sample_id=sample_id, - data=data, - created_by=self.run_spec.created_by, - created_at=str(datetime.now(timezone.utc)), - ) - - def _flush_events_internal(self, events_to_write: Sequence[Event]): - pass - - def flush_events(self): - with self._event_lock: - if len(self._events) == self._written_events: - return - events_to_write = self._events[self._written_events :] - self._written_events = len(self._events) - self._flushes_started += 1 - self._flush_events_internal(events_to_write) - - def record_event(self, type, data=None, sample_id=None): - if sample_id is None: - sample_id = self.current_sample_id() - if sample_id is None: - raise ValueError( - "No sample_id set! Either pass it in or use as_default_recorder!" - ) - - with self._event_lock: - event = Event( - run_id=self.run_spec.run_id, - event_id=len(self._events), - type=type, - sample_id=sample_id, - data=data, - created_by=self.run_spec.created_by, - created_at=str(datetime.now(timezone.utc)), - ) - self._events.append(event) - if ( - self._flushes_done < self._flushes_started - or len(self._events) < self._written_events + MIN_FLUSH_EVENTS - or time.time() < self._last_flush_time + MIN_FLUSH_SECONDS - ): - return - events_to_write = self._events[self._written_events :] - self._written_events = len(self._events) - self._flushes_started += 1 - self._flush_events_internal(events_to_write) - - def record_match( - self, correct: bool, *, expected=None, picked=None, sample_id=None, **extra - ): - assert isinstance( - correct, bool - ), f"correct must be a bool, but was a {type(correct)}: {correct}" - - if isinstance(expected, list) and len(expected) == 1: - expected = expected[0] - data = { - "correct": bool(correct), - "expected": expected, - "picked": picked, - **extra, - } - self.record_event("match", data, sample_id=sample_id) - - def record_embedding(self, prompt, embedding_type, sample_id=None, **extra): - data = { - "prompt": prompt, - "embedding_type": embedding_type, - **extra, - } - self.record_event("embedding", data, sample_id=sample_id) - - def record_sampling(self, prompt, sampled, sample_id=None, **extra): - data = { - "prompt": prompt, - "sampled": sampled, - **extra, - } - self.record_event("sampling", data, sample_id=sample_id) - - def record_cond_logp(self, prompt, completion, logp, sample_id=None, **extra): - data = { - "prompt": prompt, - "completion": completion, - "logp": logp, - **extra, - } - self.record_event("cond_logp", data, sample_id=sample_id) - - def record_pick_option(self, prompt, options, picked, sample_id=None, **extra): - data = { - "prompt": prompt, - "options": options, - "picked": picked, - **extra, - } - self.record_event("pick_option", data, sample_id=sample_id) - - def record_raw(self, data): - self.record_event("raw_sample", data) - - def record_metrics(self, **kwargs): - self.record_event("metrics", kwargs) - - def record_error(self, msg: str, error: Exception, **kwargs): - data = { - "type": type(error).__name__, - "message": str(error), - } - data.update(kwargs) - self.record_event("error", data) - - def record_extra(self, data, sample_id=None): - self.record_event("extra", data, sample_id=sample_id) - - def record_final_report(self, final_report: Any): - logging.info(f"Final report: {final_report}. Not writing anywhere.") - - -def _green(str): - return f"\033[1;32m{str}\033[0m" - - -def _red(str): - return f"\033[1;31m{str}\033[0m" - - -class DummyRecorder(RecorderBase): - """ - A "recorder" which only logs certain events to the console. - Can be used by passing `--dry-run` when invoking `oaieval`. - """ - - def __init__(self, run_spec: RunSpec, log: bool = True): - super().__init__(run_spec) - self.log = log - - def record_event(self, type, data, sample_id=None): - from evals.registry import registry - - if self.run_spec is None: - return - - base_eval_spec = registry.get_base_eval(self.run_spec.base_eval) - if base_eval_spec and len(base_eval_spec.metrics) >= 1: - primary_metric = base_eval_spec.metrics[0] - else: - primary_metric = "accuracy" - - with self._event_lock: - event = self._create_event(type, data) - self._events.append(event) - - msg = f"Not recording event: {event}" - - if type == "match": - accuracy_good = ( - primary_metric == "accuracy" or primary_metric.startswith("pass@") - ) and (data.get("correct", False) or data.get("accuracy", 0) > 0.5) - f1_score_good = ( - primary_metric == "f1_score" and data.get("f1_score", 0) > 0.5 - ) - if accuracy_good or f1_score_good: - msg = _green(msg) - else: - msg = _red(msg) - - if self.log: - logging.info(msg) - - -class LocalRecorder(RecorderBase): - """ - A recorder which logs events to the specified JSON file. - This is the default recorder used by `oaieval`. - """ - - def __init__(self, log_path: Optional[str], run_spec: RunSpec): - super().__init__(run_spec) - self.event_file_path = log_path - if log_path is not None: - with bf.BlobFile(log_path, "wb") as f: - f.write( - (jsondumps({"spec": dataclasses.asdict(run_spec)}) + "\n").encode( - "utf-8" - ) - ) - - def _flush_events_internal(self, events_to_write: Sequence[Event]): - start = time.time() - try: - lines = [jsondumps(event) + "\n" for event in events_to_write] - except TypeError as e: - logger.error(f"Failed to serialize events: {events_to_write}") - raise e - - with bf.BlobFile(self.event_file_path, "ab") as f: - f.write(b"".join([l.encode("utf-8") for l in lines])) - - logger.info( - f"Logged {len(lines)} rows of events to {self.event_file_path}: insert_time={t(time.time()-start)}" - ) - - self._last_flush_time = time.time() - self._flushes_done += 1 - - def record_final_report(self, final_report: Any): - with bf.BlobFile(self.event_file_path, "ab") as f: - f.write((jsondumps({"final_report": final_report}) + "\n").encode("utf-8")) - - logging.info(f"Final report: {final_report}. Logged to {self.event_file_path}") - - -class Recorder(RecorderBase): - """ - A recorder which logs events to Snowflake. - Can be used by passing `--no-local-run` when invoking `oaieval`. - """ - - def __init__( - self, - log_path: Optional[str], - run_spec: evals.base.RunSpec, - snowflake_connection: Optional[SnowflakeConnection] = None, - ) -> None: - super().__init__(run_spec) - self.event_file_path = log_path - self._writing_lock = threading.Lock() - - if snowflake_connection is None: - snowflake_connection = SnowflakeConnection() - self._conn = snowflake_connection - - if log_path is not None: - with bf.BlobFile(log_path, "wb") as f: - f.write( - (jsondumps({"spec": dataclasses.asdict(run_spec)}) + "\n").encode( - "utf-8" - ) - ) - - query = """ - INSERT ALL INTO runs (run_id, model_name, eval_name, base_eval, split, run_config, settings, created_by, created_at) - VALUES (%(run_id)s, %(model_name)s, %(eval_name)s, %(base_eval)s, %(split)s, run_config, settings, %(created_by)s, %(created_at)s) - SELECT PARSE_JSON(%(run_config)s) AS run_config, PARSE_JSON(%(settings)s) AS settings - """ - self._conn.robust_query( - command=query, - params={ - "run_id": run_spec.run_id, - "model_name": jsondumps(run_spec.model_names), - "eval_name": run_spec.eval_name, - "base_eval": run_spec.base_eval, - "split": run_spec.split, - "run_config": jsondumps(run_spec.run_config), - "settings": jsondumps(run_spec.run_config.get("initial_settings", {})), - "created_by": run_spec.created_by, - "created_at": run_spec.created_at, - }, - ) - atexit.register(self.flush_events) - - def _flush_events_internal(self, events_to_write: Sequence[Event]): - with self._writing_lock: - try: - lines = [jsondumps(event) + "\n" for event in events_to_write] - except TypeError as e: - logger.error(f"Failed to serialize events: {events_to_write}") - raise e - idx_l = 0 - while idx_l < len(events_to_write): - total_bytes = 0 - idx_r = idx_l - while ( - idx_r < len(events_to_write) - and total_bytes + len(lines[idx_r]) < MAX_SNOWFLAKE_BYTES - ): - total_bytes += len(lines[idx_r]) - idx_r += 1 - assert idx_r > idx_l - start = time.time() - buffer = [ - ( - event.run_id, - event.event_id, - event.sample_id, - event.type, - jsondumps(event.data), - event.created_by, - event.created_at, - ) - for event in events_to_write[idx_l:idx_r] - ] - query = """ - INSERT INTO events (run_id, event_id, sample_id, type, data, created_by, created_at) - SELECT Column1 AS run_id, Column2 as event_id, Column3 AS sample_id, Column4 AS type, PARSE_JSON(Column5) AS data, Column6 AS created_by, Column7 AS created_at - FROM VALUES(%s, %s, %s, %s, %s, %s, %s) - """ - self._conn.robust_query(command=query, seqparams=buffer, many=True) - logger.info( - f"Logged {len(buffer)} rows of events to Snowflake: insert_time={t(time.time()-start)}" - ) - idx_l = idx_r - - with bf.BlobFile(self.event_file_path, "ab") as f: - f.write(b"".join([l.encode("utf-8") for l in lines])) - self._last_flush_time = time.time() - self._flushes_done += 1 - - def record_final_report(self, final_report: Any): - with self._writing_lock: - with bf.BlobFile(self.event_file_path, "ab") as f: - f.write( - (jsondumps({"final_report": final_report}) + "\n").encode("utf-8") - ) - query = """ - UPDATE runs - SET final_report = PARSE_JSON(%(final_report)s) - WHERE run_id = %(run_id)s - """ - self._conn.robust_query( - command=query, - params={ - "run_id": self.run_spec.run_id, - "final_report": jsondumps(final_report), - }, - ) - - def record_event(self, type, data=None, sample_id=None): - # try to serialize data so we fail early! - _ = jsondumps(data) - return super().record_event(type, data, sample_id) - - -######################################################################### -### Helper methods which use the thread local global default recorder ### -######################################################################### - - -def current_sample_id() -> str: - return default_recorder().current_sample_id - - -def record_match(correct: bool, *, expected=None, picked=None, **extra): - return default_recorder().record_match( - correct, expected=expected, picked=picked, **extra - ) - - -def record_embedding(prompt, embedding_type, **extra): - return default_recorder().record_embedding(prompt, embedding_type, **extra) - - -def record_sampling(prompt, sampled, **extra): - return default_recorder().record_sampling(prompt, sampled, **extra) - - -def record_cond_logp(prompt, completion, logp, **extra): - return default_recorder().record_cond_logp(prompt, completion, logp, **extra) - - -def record_pick_option(prompt, options, picked, **extra): - return default_recorder().record_pick_option(prompt, options, picked, **extra) - - -def record_raw(data): - return default_recorder().record_raw(data) - - -def record_metrics(**extra): - return default_recorder().record_metrics(**extra) - - -def record_error(msg: str, error: Exception = None, **extra): - return default_recorder().record_error(msg, error, **extra) - - -def record_extra(data): - return default_recorder().record_extra(data) diff --git a/src/xturing/evalutaion/registry.py b/src/xturing/evalutaion/registry.py deleted file mode 100644 index e0ed543..0000000 --- a/src/xturing/evalutaion/registry.py +++ /dev/null @@ -1,179 +0,0 @@ -""" -Functions to handle registration of evals. To add a new eval to the registry, -add an entry in one of the YAML files in the `../registry` dir. -By convention, every eval name should start with {base_eval}.{split}. -""" - -import difflib -import functools -import logging -import os -import re -from functools import partial -from pathlib import Path -from typing import Any, Dict, Iterator, List, Sequence, Type, Union - -import yaml - -from .base import BaseEvalSpec, EvalSetSpec, EvalSpec -from .utils import make_object - -logger = logging.getLogger(__name__) - -DEFAULT_PATHS = [ - Path(__file__).parents[0].resolve() / "registry", - Path.home() / ".evals", -] - - -class Registry: - def __init__(self, registry_paths: Sequence[Union[str, Path]] = DEFAULT_PATHS): - self._registry_paths = [ - Path(p) if isinstance(p, str) else p for p in registry_paths - ] - - def make_callable(self, spec): - return partial(make_object(spec.cls).create_and_run, **(spec.args or {})) - - def get_class(self, spec: dict) -> Any: - return make_object(spec.cls, **(spec.args if spec.args else {})) - - def _dereference(self, name: str, d: dict, object: str, type: Type) -> dict: - if not name in d: - return None - - def get_alias(): - if isinstance(d[name], str): - return d[name] - if isinstance(d[name], dict) and "id" in d[name]: - return d[name]["id"] - return None - - logger.debug(f"Looking for {name}") - while True: - alias = get_alias() - - if alias is None: - break - name = alias - - spec = d[name] - - try: - return type(**spec) - except TypeError as e: - raise TypeError(f"Error while processing {object} {name}: {e}") - - def get_modelgraded_spec(self, name: str) -> Dict[str, Any]: - assert name in self._modelgraded_specs, ( - f"Modelgraded spec {name} not found. " - f"Closest matches: {difflib.get_close_matches(name, self._modelgraded_specs.keys(), n=5)}" - ) - return self._modelgraded_specs[name] - - def get_eval(self, name: str) -> EvalSpec: - return self._dereference(name, self._evals, "eval", EvalSpec) - - def get_eval_set(self, name: str) -> EvalSetSpec: - return self._dereference(name, self._eval_sets, "eval set", EvalSetSpec) - - def get_evals(self, patterns: Sequence[str]) -> Iterator[EvalSpec]: - # valid patterns: hello, hello.dev*, hello.dev.*-v1 - def get_regexp(pattern): - pattern = pattern.replace(".", "\\.") - pattern = pattern.replace("*", ".*") - return re.compile(f"^{pattern}$") - - regexps = list(map(get_regexp, patterns)) - for name in self._evals: - # if any regexps match, return the name - if any(map(lambda regexp: regexp.match(name), regexps)): - yield self.get_eval(name) - - def get_base_evals(self) -> List[BaseEvalSpec]: - base_evals = [] - for name, spec in self._evals.items(): - if name.count(".") == 0: - base_evals.append(self.get_base_eval(name)) - return base_evals - - def get_base_eval(self, name: str) -> BaseEvalSpec: - if not name in self._evals: - return None - - spec_or_alias = self._evals[name] - if isinstance(spec_or_alias, dict): - spec = spec_or_alias - try: - return BaseEvalSpec(**spec) - except TypeError as e: - raise TypeError(f"Error while processing base eval {name}: {e}") - - alias = spec_or_alias - return BaseEvalSpec(id=alias) - - def _process_file(self, registry, path): - with open(path, "r") as f: - d = yaml.safe_load(f) - - if d is None: - # no entries in the file - return - - for name, spec in d.items(): - assert name not in registry, f"duplicate entry: {name} from {path}" - if isinstance(spec, dict): - if "key" in spec: - raise ValueError( - f"key is a reserved keyword, but was used in {name} from {path}" - ) - if "group" in spec: - raise ValueError( - f"group is a reserved keyword, but was used in {name} from {path}" - ) - if "cls" in spec: - raise ValueError( - f"cls is a reserved keyword, but was used in {name} from {path}" - ) - - spec["key"] = name - spec["group"] = str(os.path.basename(path).split(".")[0]) - if "class" in spec: - spec["cls"] = spec["class"] - del spec["class"] - registry[name] = spec - - def _process_directory(self, registry, path): - files = Path(path).glob("*.yaml") - for file in files: - self._process_file(registry, file) - - def _load_registry(self, paths): - """Load registry from a list of paths. - - Each path or yaml specifies a dictionary of name -> spec. - """ - registry = {} - for path in paths: - logging.info(f"Loading registry from {path}") - if os.path.exists(path): - if os.path.isdir(path): - self._process_directory(registry, path) - else: - self._process_file(registry, path) - return registry - - @functools.cached_property - def _eval_sets(self): - return self._load_registry([p / "eval_sets" for p in self._registry_paths]) - - @functools.cached_property - def _evals(self): - return self._load_registry([p / "evals" for p in self._registry_paths]) - - @functools.cached_property - def _modelgraded_specs(self): - return self._load_registry([p / "modelgraded" for p in self._registry_paths]) - - -registry = Registry() diff --git a/src/xturing/evalutaion/utils.py b/src/xturing/evalutaion/utils.py deleted file mode 100644 index 11f7411..0000000 --- a/src/xturing/evalutaion/utils.py +++ /dev/null @@ -1,129 +0,0 @@ -# misc.py -""" -This file defines miscellanous utilities. -""" -import functools -import importlib -from typing import Any - - -def t(duration: float) -> str: - if duration is None: - return "n/a" - if duration < 1: - return f"{(1000*duration):0.3f}ms" - elif duration < 60: - return f"{duration:0.3f}s" - else: - return f"{duration//60}min{int(duration%60)}s" - - -def make_object(object_ref: Any, *args: Any, **kwargs: Any) -> Any: - modname, qualname_separator, qualname = object_ref.partition(":") - obj = importlib.import_module(modname) - if qualname_separator: - for attr in qualname.split("."): - obj = getattr(obj, attr) - return functools.partial(obj, *args, **kwargs) - - -# api_utils.py -""" -This file defines various helper functions for interacting with the OpenAI API. -""" -import logging - -import backoff -import openai - - -def generate_dummy_chat_completion(): - return { - "id": "dummy-id", - "object": "chat.completion", - "created": 12345, - "model": "dummy-chat", - "usage": {"prompt_tokens": 56, "completion_tokens": 6, "total_tokens": 62}, - "choices": [ - { - "message": { - "role": "assistant", - "content": "This is a dummy response.", - }, - "finish_reason": "stop", - "index": 0, - } - ], - } - - -def generate_dummy_completion(): - return { - "id": "dummy-id", - "object": "text_completion", - "created": 12345, - "model": "dummy-completion", - "choices": [ - { - "text": "This is a dummy response.", - "index": 0, - "logprobs": None, - "finish_reason": "stop", - } - ], - "usage": {"prompt_tokens": 5, "completion_tokens": 6, "total_tokens": 11}, - } - - -@backoff.on_exception( - wait_gen=backoff.expo, - exception=( - openai.error.ServiceUnavailableError, - openai.error.APIError, - openai.error.RateLimitError, - openai.error.APIConnectionError, - openai.error.Timeout, - ), - max_value=60, - factor=1.5, -) -def openai_completion_create_retrying(*args, **kwargs): - """ - Helper function for creating a completion. - `args` and `kwargs` match what is accepted by `openai.Completion.create`. - """ - if kwargs["model"] == "dummy-completion": - return generate_dummy_completion() - - result = openai.Completion.create(*args, **kwargs) - if "error" in result: - logging.warning(result) - raise openai.error.APIError(result["error"]) - return result - - -@backoff.on_exception( - wait_gen=backoff.expo, - exception=( - openai.error.ServiceUnavailableError, - openai.error.APIError, - openai.error.RateLimitError, - openai.error.APIConnectionError, - openai.error.Timeout, - ), - max_value=60, - factor=1.5, -) -def openai_chat_completion_create_retrying(*args, **kwargs): - """ - Helper function for creating a chat completion. - `args` and `kwargs` match what is accepted by `openai.ChatCompletion.create`. - """ - if kwargs["model"] == "dummy-chat": - return generate_dummy_chat_completion() - - result = openai.ChatCompletion.create(*args, **kwargs) - if "error" in result: - logging.warning(result) - raise openai.error.APIError(result["error"]) - return result diff --git a/src/xturing/models/causal.py b/src/xturing/models/causal.py index e8e4cee..5d20008 100644 --- a/src/xturing/models/causal.py +++ b/src/xturing/models/causal.py @@ -223,8 +223,8 @@ def _loglikelihood_tokens( ) -> List[Tuple[float, bool]]: results = [] for chunk in tqdm(data_iterator, disable=disable_tqdm): - del input_tokens["label_masks"], input_tokens["targets"] input_tokens = chunk.to(DEFAULT_DEVICE) + del input_tokens["label_masks"], input_tokens["targets"] outputs = self._model_call(inputs=input_tokens, labels=input_tokens) results.append(outputs.loss) return results @@ -308,19 +308,23 @@ def eval_sample(args): ) return [r for _, r in sorted(idx_and_result)] - def evaluate(self, dataset: Union[TextDataset, InstructionDataset]): + def evaluate( + self, + dataset: Union[TextDataset, InstructionDataset], + batch_size: Optional[int] = 1, + ): # outputs = self.eval_all_samples(dataset) + # return get_accuracy(outputs) collate_fn = self._make_collate_fn(dataset) dataloader = DataLoader( dataset, - batch_size=1, + batch_size=batch_size, shuffle=False, drop_last=False, collate_fn=collate_fn, ) results = self._loglikelihood_tokens(dataloader) - return torch.exp(torch.stack(results).mean()) - # return get_accuracy(outputs) + return torch.exp(torch.stack(results).sum() / len(dataset)) class CausalInt8Model(CausalModel): From 0f0248c28f780a840a5907e2eeb7adae7bca7a50 Mon Sep 17 00:00:00 2001 From: Tushar Date: Mon, 17 Jul 2023 12:31:55 +0000 Subject: [PATCH 10/18] fix: loading llama in cpu the dtype was hardcoded to fp16 in llama engine file which is now being rendered dynamically --- src/xturing/engines/causal.py | 5 ++- src/xturing/engines/llama_engine.py | 54 ++++------------------------- 2 files changed, 11 insertions(+), 48 deletions(-) diff --git a/src/xturing/engines/causal.py b/src/xturing/engines/causal.py index 7944463..6003231 100644 --- a/src/xturing/engines/causal.py +++ b/src/xturing/engines/causal.py @@ -35,7 +35,9 @@ def __init__( **kwargs, ): self.model_name = model_name - + print(weights_path) + print(model) + print(model_name) if weights_path is not None: assert Path( weights_path @@ -60,6 +62,7 @@ def __init__( self.model = model self.tokenizer = tokenizer elif model_name is not None: + print("here") if load_8bit: device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} self.model = AutoModelForCausalLM.from_pretrained( diff --git a/src/xturing/engines/llama_engine.py b/src/xturing/engines/llama_engine.py index 35669e6..d74ccc0 100644 --- a/src/xturing/engines/llama_engine.py +++ b/src/xturing/engines/llama_engine.py @@ -1,20 +1,13 @@ -import argparse import os from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Optional, Union -import torch -import transformers from torch import nn -from xturing.config.config_data_classes import FinetuningConfig, GenerationConfig -from xturing.config.read_config import load_config, read_yaml +from xturing.config import DEFAULT_DTYPE from xturing.engines.causal import CausalEngine, CausalLoraEngine, CausalLoraKbitEngine -from xturing.engines.llama_utils import LlamaConfig, LlamaForCausalLM, LlamaTokenizer +from xturing.engines.llama_utils import LlamaForCausalLM, LlamaTokenizer from xturing.engines.lora_engine import prepare_model_for_int8_training -from xturing.engines.quant_utils import autotune_warmup, make_quant -from xturing.engines.quant_utils.lrec import get_c4, prepare_models, train_model -from xturing.utils.hub import ModelHub class LLamaEngine(CausalEngine): @@ -22,7 +15,7 @@ class LLamaEngine(CausalEngine): def __init__(self, weights_path: Optional[Union[str, Path]] = None): model_name = "aleksickx/llama-7b-hf" - model = LlamaForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16) + model = LlamaForCausalLM.from_pretrained(model_name, torch_dtype=DEFAULT_DTYPE) tokenizer = LlamaTokenizer.from_pretrained(model_name, add_bos_token=False) tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token_id = tokenizer.eos_token_id @@ -41,7 +34,7 @@ def __init__(self, weights_path: Optional[Union[str, Path]] = None): model_name = "aleksickx/llama-7b-hf" model = LlamaForCausalLM.from_pretrained( model_name, - torch_dtype=torch.float16, + torch_dtype=DEFAULT_DTYPE, ) tokenizer = LlamaTokenizer.from_pretrained(model_name, add_bos_token=False) tokenizer.pad_token = tokenizer.eos_token @@ -63,7 +56,7 @@ def __init__(self, weights_path: Optional[Union[str, Path]] = None): device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} model = LlamaForCausalLM.from_pretrained( model_name, - torch_dtype=torch.float16, + torch_dtype=DEFAULT_DTYPE, load_in_8bit=True, device_map=device_map, ) @@ -89,7 +82,7 @@ def __init__(self, weights_path: Optional[Union[str, Path]] = None): device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} model = LlamaForCausalLM.from_pretrained( model_name, - torch_dtype=torch.float16, + torch_dtype=DEFAULT_DTYPE, load_in_8bit=True, device_map=device_map, ) @@ -126,39 +119,6 @@ class LlamaLoraKbitEngine(CausalLoraKbitEngine): def __init__(self, weights_path: Optional[Union[str, Path]] = None): model_name = "decapoda-research/llama-7b-hf" - # lrec_config = { - # "base_model": model_name, - # "intq_checkpoint": str( - # Path(__file__).parent / "llama7b-2bit-128g.pt" - # ), ## how to do this - # "wbits": wbits, - # "lora_target_modules": [ - # "q_proj", - # "v_proj", - # "k_proj", - # "o_proj", - # "up_proj", - # "down_proj", - # "gate_proj", - # ], - # # "n_samples": 100, - # # "train_cache_dir": "./train_cache/", - # # "val_cache_dir": "./val_cache/", - # # "ckpt_dir": "./ckpts/", - # # "save_dir": "./save/", - # } - - # # Finetuning config - # yml_content = read_yaml( - # Path(__file__).parent.parent / "config" / "finetuning_config.yaml", - # ) - # lrec_config.update(yml_content["defaults"]) - # lrec_config.update(yml_content[self.config_name.replace("_engine", "")]) - - # model, fp_model = prepare_models(argparse.Namespace(**lrec_config)) - - # # The model before applying LoRA - # self.base_model = fp_model tokenizer = LlamaTokenizer.from_pretrained(model_name, add_bos_token=False) tokenizer.pad_token = tokenizer.eos_token From ec27e7e0bb515c15419f7181df2923a160b03cde Mon Sep 17 00:00:00 2001 From: Tushar Date: Tue, 18 Jul 2023 14:27:47 +0000 Subject: [PATCH 11/18] feat: batch in model generation added batch parameter to generate function and altered its functionality --- src/xturing/engines/causal.py | 4 ---- src/xturing/models/causal.py | 13 +++++++------ 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/src/xturing/engines/causal.py b/src/xturing/engines/causal.py index 6003231..64d068e 100644 --- a/src/xturing/engines/causal.py +++ b/src/xturing/engines/causal.py @@ -35,9 +35,6 @@ def __init__( **kwargs, ): self.model_name = model_name - print(weights_path) - print(model) - print(model_name) if weights_path is not None: assert Path( weights_path @@ -62,7 +59,6 @@ def __init__( self.model = model self.tokenizer = tokenizer elif model_name is not None: - print("here") if load_8bit: device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} self.model = AutoModelForCausalLM.from_pretrained( diff --git a/src/xturing/models/causal.py b/src/xturing/models/causal.py index 5d20008..5168fdf 100644 --- a/src/xturing/models/causal.py +++ b/src/xturing/models/causal.py @@ -135,7 +135,7 @@ def _generate_from_iterable( else: enumeration = enumerate(data_iterator) - for i, batch in enumeration: + for _, batch in enumeration: if do_tokenization: inputs = self.engine.tokenizer(batch, return_tensors="pt") input_ids = inputs.input_ids.to(DEFAULT_DEVICE) @@ -148,11 +148,11 @@ def _generate_from_iterable( input_ids=input_ids, **self.generation_args.dict() ) - output = self.engine.tokenizer.decode( - output[0][len_input:], skip_special_tokens=True + output = self.engine.tokenizer.batch_decode( + torch.stack([output[i][len_input:] for i in range(output.shape[0])]), + skip_special_tokens=True, ) - outputs.append(output) - + outputs.extend(output) return outputs def generate( @@ -160,6 +160,7 @@ def generate( *, texts: Optional[Union[List[str], str]] = None, dataset: Optional[Union[TextDataset, InstructionDataset]] = None, + batch_size: Optional[int] = 1, ): self.engine.model.eval() self.engine.model = self.engine.model.to(DEFAULT_DEVICE) @@ -179,7 +180,7 @@ def generate( collate_fn = self._make_collate_fn(dataset) dataloader = DataLoader( dataset, - batch_size=1, + batch_size=batch_size, shuffle=False, drop_last=False, collate_fn=collate_fn, From 069da0576a36d59ba4c0ceadc96549382dd98fdf Mon Sep 17 00:00:00 2001 From: Tushar Date: Thu, 20 Jul 2023 11:18:38 +0000 Subject: [PATCH 12/18] docs: updated models supported edited intro.md in docs/ --- docs/docs/intro.md | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/docs/docs/intro.md b/docs/docs/intro.md index 9d880d8..ce83d7b 100644 --- a/docs/docs/intro.md +++ b/docs/docs/intro.md @@ -39,13 +39,15 @@ You can quickly get started with xTuring by following the [Quickstart](/quicksta | Model | Examples | | --- | --- | -| LLaMA | [LLaMA 7B fine-tuning on Alpaca dataset with/without LoRA and with/without INT8](https://github.com/stochasticai/xturing/tree/main/examples/llama) | +| Bloom | [Bloom fine-tuning on Alpaca dataset with/without LoRA and with/without INT8](https://github.com/stochasticai/xturing/tree/main/examples/bloom) | +| Cerebras-GPT | [Cerebras-GPT fine-tuning on Alpaca dataset with/without LoRA and with/without INT8](https://github.com/stochasticai/xturing/tree/main/examples/cerebras) | +| Falcon | [Falcon 7B fine-tuning on Alpaca dataset with/without LoRA and with/without INT8](https://github.com/stochasticai/xturing/tree/main/examples/falcon) | +| Galactica | [Galactica fine-tuning on Alpaca dataset with/without LoRA and with/without INT8](https://github.com/stochasticai/xturing/tree/main/examples/galactica) | +| Generic Wrapper | [Any large language model fine-tuning on Alpaca dataset with/without LoRA and with/without INT8](https://github.com/stochasticai/xturing/tree/main/examples/generic) | | GPT-J | [GPT-J 6B LoRA fine-tuning with/without INT8 ](https://github.com/stochasticai/xturing/tree/main/examples/gptj) | | GPT-2 | [GPT-2 fine-tuning on Alpaca dataset with/without LoRA and with/without INT8](https://github.com/stochasticai/xturing/tree/main/examples/gpt2) | +| LLaMA | [LLaMA 7B fine-tuning on Alpaca dataset with/without LoRA and with/without INT8](https://github.com/stochasticai/xturing/tree/main/examples/llama) | | OPT | [OPT fine-tuning on Alpaca dataset with/without LoRA and with/without INT8](https://github.com/stochasticai/xturing/tree/main/examples/opt) | -| Cerebras-GPT | [Cerebras-GPT fine-tuning on Alpaca dataset with/without LoRA and with/without INT8](https://github.com/stochasticai/xturing/tree/main/examples/cerebras) | -| Galactica | [Galactica fine-tuning on Alpaca dataset with/without LoRA and with/without INT8](https://github.com/stochasticai/xturing/tree/main/examples/galactica) | -| Bloom | [Bloom fine-tuning on Alpaca dataset with/without LoRA and with/without INT8](https://github.com/stochasticai/xturing/tree/main/examples/bloom) | xTuring is licensed under [Apache 2.0](https://github.com/stochasticai/xturing/blob/main/LICENSE) From 91d41bdffd58559af52f8923ca2a581853ccfb02 Mon Sep 17 00:00:00 2001 From: Tushar Date: Fri, 21 Jul 2023 13:35:29 +0000 Subject: [PATCH 13/18] docs: Updated the readme with roadmap and whats new Updated README.md for the next release --- README.md | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 6d1a383..687030e 100644 --- a/README.md +++ b/README.md @@ -35,10 +35,14 @@ With `xTuring` you can, ## 🌟 What's new? We are excited to announce the latest enhancements to our `xTuring` library: -1. __`Falcon LLM` integration__ - You can use and fine-tune the _`Falcon-7B`_ model in different configurations: _off-the-shelf_, _off-the-shelf with INT8 precision_, _LoRA fine-tuning_, and _LoRA fine-tuning with INT8 precision_. -2. __`GenericModel` wrapper__ - This new integration allows you to test and fine-tune any new model on `xTuring` without waiting for it to be integrated using class _`GenericModel`_. +1. __`LLaMA 2` integration__ - You can use and fine-tune the _`LLaMA 2`_ model in different configurations: _off-the-shelf_, _off-the-shelf with INT8 precision_, _LoRA fine-tuning_, _LoRA fine-tuning with INT8 precision_ and _LoRA fine-tuning with INT4 precision_ using the `GenericModel` wrapper. +2. __`Evaluation`__ - Now you can evaluate any `Causal Language Model` on any dataset. The metrics currently supported is `perplexity`. +3. __`INT4` Precision__ - You can now use and fine-tune any LLM with `INT4 Precision` using `GenericKbitModel`. +4. __CPU inference__ - Now you can use just your CPU for inference of any LLM. _CAUTION : The inference will be very slow as CPUs are extremely slow for the amount of computation needed for inference_. +5. __Batch integration__ - Now you play around with `batch_size` in `.generate()` and `.evaluate()` functions. This will lead to faster results with `batch_size>1`. + +You can check the [Llama LoRA INT4 working example](examples/int4_finetuning/LLaMA_lora_int4.ipynb) file to see how it works. -You can check the [Falcon LoRA INT8 working example](examples/falcon/falcon_lora_int8.py) repository to see how it works. Also, you can check the [GenericModel working example](examples/generic/generic_model.py) repository to see how it works.
@@ -170,8 +174,8 @@ model = BaseModel.load("x/distilgpt2_lora_finetuned_alpaca") - [x] INT4 LLaMA LoRA fine-tuning with INT4 generation - [x] Support for a `Generic model` wrapper - [x] Support for `Falcon-7B` model -- [ ] Evaluation of LLM models -- [ ] INT4 low-precision fine-tuning support +- [x] INT4 low-precision fine-tuning support +- [x] Evaluation of LLM models - [ ] INT3, INT2, INT1 low-precision fine-tuning support - [ ] Support for Stable Diffusion From c573fb92b72b856da477bb5718ae9d23830cd6b7 Mon Sep 17 00:00:00 2001 From: Tushar Date: Tue, 25 Jul 2023 11:53:40 +0000 Subject: [PATCH 14/18] feat: Added llama2 class Added llama2 class, updated the supported models' section in documentation, Updated README.md and fixed the int4 example --- .pre-commit-config.yaml | 2 +- README.md | 60 ++++++++++++++++++- docs/docs/intro.md | 1 + examples/evaluation/evaluation.py | 15 +++++ .../int4_finetuning/LLaMA_lora_int4.ipynb | 7 +-- examples/llama2/llama2.py | 21 +++++++ src/xturing/config/finetuning_config.yaml | 7 +++ src/xturing/config/generation_config.yaml | 7 +++ src/xturing/engines/__init__.py | 2 + src/xturing/engines/llama2_engine.py | 18 ++++++ src/xturing/models/__init__.py | 2 + src/xturing/models/llama2.py | 11 ++++ 12 files changed, 146 insertions(+), 7 deletions(-) create mode 100644 examples/evaluation/evaluation.py create mode 100644 examples/llama2/llama2.py create mode 100644 src/xturing/engines/llama2_engine.py create mode 100644 src/xturing/models/llama2.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 64bee29..bae902a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -30,4 +30,4 @@ repos: rev: v0.3.1 hooks: - id: absolufy-imports - args: ["--application-directories=src"] + args: ["--application-directories=.:src"] diff --git a/README.md b/README.md index 687030e..638056a 100644 --- a/README.md +++ b/README.md @@ -35,11 +35,67 @@ With `xTuring` you can, ## 🌟 What's new? We are excited to announce the latest enhancements to our `xTuring` library: -1. __`LLaMA 2` integration__ - You can use and fine-tune the _`LLaMA 2`_ model in different configurations: _off-the-shelf_, _off-the-shelf with INT8 precision_, _LoRA fine-tuning_, _LoRA fine-tuning with INT8 precision_ and _LoRA fine-tuning with INT4 precision_ using the `GenericModel` wrapper. -2. __`Evaluation`__ - Now you can evaluate any `Causal Language Model` on any dataset. The metrics currently supported is `perplexity`. +1. __`LLaMA 2` integration__ - You can use and fine-tune the _`LLaMA 2`_ model in different configurations: _off-the-shelf_, _off-the-shelf with INT8 precision_, _LoRA fine-tuning_, _LoRA fine-tuning with INT8 precision_ and _LoRA fine-tuning with INT4 precision_ using the `GenericModel` wrapper and/or you can use the `Llama2` class from `xturing.models` to test and finetune the model. +```python +from xturing.models import Llama2 +model = Llama2() + +## or +from xturing.models import BaseModel +model = BaseModel.create('llama2') + +``` +2. __`Evaluation`__ - Now you can evaluate any `Causal Language Model` on any dataset. The metrics currently supported is [`perplexity`](https://towardsdatascience.com/perplexity-in-language-models-87a196019a94). +```python +# Make the necessary imports +from xturing.datasets import InstructionDataset +from xturing.models import BaseModel + +# Load the desired dataset +dataset = InstructionDataset('../llama/alpaca_data') + +# Load the desired model +model = BaseModel.create('gpt2') + +# Run the Evaluation of the model on the dataset +result = model.evaluate(dataset) + +# Print the result +print(f"Perplexity of the evalution: {result}") + +``` 3. __`INT4` Precision__ - You can now use and fine-tune any LLM with `INT4 Precision` using `GenericKbitModel`. +```python +# Make the necessary imports +from xturing.datasets import InstructionDataset +from xturing.models import GenericKbitModel + +# Load the desired dataset +dataset = InstructionDataset('../llama/alpaca_data') + +# Load the desired model for INT4 bit fine-tuning +model = GenericKbitModel('tiiuae/falcon-7b') + +# Run the fine-tuning +model.finetune(dataset) +``` 4. __CPU inference__ - Now you can use just your CPU for inference of any LLM. _CAUTION : The inference will be very slow as CPUs are extremely slow for the amount of computation needed for inference_. 5. __Batch integration__ - Now you play around with `batch_size` in `.generate()` and `.evaluate()` functions. This will lead to faster results with `batch_size>1`. +```python +# Make the necessary imports +from xturing.datasets import InstructionDataset +from xturing.models import GenericKbitModel + +# Load the desired dataset +dataset = InstructionDataset('../llama/alpaca_data') + +# Load the desired model for INT4 bit fine-tuning +model = GenericKbitModel('tiiuae/falcon-7b') + +# Generate outputs on desired prompts +outputs = model.generate(dataset = dataset, batch_size=10) + +``` You can check the [Llama LoRA INT4 working example](examples/int4_finetuning/LLaMA_lora_int4.ipynb) file to see how it works. diff --git a/docs/docs/intro.md b/docs/docs/intro.md index ce83d7b..2acba92 100644 --- a/docs/docs/intro.md +++ b/docs/docs/intro.md @@ -47,6 +47,7 @@ You can quickly get started with xTuring by following the [Quickstart](/quicksta | GPT-J | [GPT-J 6B LoRA fine-tuning with/without INT8 ](https://github.com/stochasticai/xturing/tree/main/examples/gptj) | | GPT-2 | [GPT-2 fine-tuning on Alpaca dataset with/without LoRA and with/without INT8](https://github.com/stochasticai/xturing/tree/main/examples/gpt2) | | LLaMA | [LLaMA 7B fine-tuning on Alpaca dataset with/without LoRA and with/without INT8](https://github.com/stochasticai/xturing/tree/main/examples/llama) | +| LLaMA 2 | [LLaMA 2 7B fine-tuning on Alpaca dataset with/without LoRA and with/without INT8](https://github.com/stochasticai/xturing/tree/main/examples/llama2) | | OPT | [OPT fine-tuning on Alpaca dataset with/without LoRA and with/without INT8](https://github.com/stochasticai/xturing/tree/main/examples/opt) | xTuring is licensed under [Apache 2.0](https://github.com/stochasticai/xturing/blob/main/LICENSE) diff --git a/examples/evaluation/evaluation.py b/examples/evaluation/evaluation.py new file mode 100644 index 0000000..9d1862b --- /dev/null +++ b/examples/evaluation/evaluation.py @@ -0,0 +1,15 @@ +# Make the necessary imports +from xturing.datasets import InstructionDataset +from xturing.models import BaseModel + +# Load the desired dataset +dataset = InstructionDataset("../llama/alpaca_data") + +# Load the desired model +model = BaseModel.create("gpt2") + +# Run the Evaluation of the model on the dataset +result = model.evaluate(dataset) + +# Print the result +print(f"Perplexity of the evalution: {result}") diff --git a/examples/int4_finetuning/LLaMA_lora_int4.ipynb b/examples/int4_finetuning/LLaMA_lora_int4.ipynb index f331692..ffa29ff 100644 --- a/examples/int4_finetuning/LLaMA_lora_int4.ipynb +++ b/examples/int4_finetuning/LLaMA_lora_int4.ipynb @@ -31,8 +31,7 @@ }, "outputs": [], "source": [ - "!pip install xturing --upgrade\n", - "!pip install xturing[int4] --upgrade" + "!pip install xturing --upgrade" ] }, { @@ -56,7 +55,7 @@ "outputs": [], "source": [ "from xturing.datasets.instruction_dataset import InstructionDataset\n", - "from xturing.models import BaseModel\n", + "from xturing.models import GenericLoraKbitModel\n", "from pytorch_lightning.loggers import WandbLogger\n", "\n", "# Initializes WandB integration \n", @@ -64,7 +63,7 @@ "\n", "instruction_dataset = InstructionDataset(\"../llama/alpaca_data\")\n", "# Initializes the model\n", - "model = BaseModel.create(\"llama_lora_int4\")" + "model = GenericLoraKbitModel('aleksickx/llama-7b-hf')" ] }, { diff --git a/examples/llama2/llama2.py b/examples/llama2/llama2.py new file mode 100644 index 0000000..30df48f --- /dev/null +++ b/examples/llama2/llama2.py @@ -0,0 +1,21 @@ +# Make the necessary imports +from xturing.models import Llama2 + +# Load the model +model = Llama2() +# Generate ouputs from the model +outputs = model.generate(texts=["How are you?"]) +# Print the generated outputs +print(outputs) + +## or + +# Make the necessary imports +from xturing.models import BaseModel + +# Load the model +model = BaseModel.create("llama2") +# Generate ouputs from the model +outputs = model.generate(texts=["How are you?"]) +# Print the generated outputs +print(outputs) diff --git a/src/xturing/config/finetuning_config.yaml b/src/xturing/config/finetuning_config.yaml index a3c5d5d..47d491b 100644 --- a/src/xturing/config/finetuning_config.yaml +++ b/src/xturing/config/finetuning_config.yaml @@ -193,6 +193,7 @@ llama: num_train_epochs: 3 optimizer_name: cpu_adam + llama_lora: learning_rate: 1e-4 weight_decay: 0.01 @@ -227,6 +228,12 @@ llama_lora_kbit: intra_save_freq: 200 groupsize: 128 +llama2: + learning_rate: 5e-5 + weight_decay: 0.01 + num_train_epochs: 3 + optimizer_name: cpu_adam + opt: learning_rate: 5e-5 weight_decay: 0.01 diff --git a/src/xturing/config/generation_config.yaml b/src/xturing/config/generation_config.yaml index 8163bb7..38dea3c 100644 --- a/src/xturing/config/generation_config.yaml +++ b/src/xturing/config/generation_config.yaml @@ -191,6 +191,13 @@ llama_lora_kbit: max_new_tokens: 256 do_sample: false +# Contrastive search +llama2: + penalty_alpha: 0.6 + top_k: 4 + max_new_tokens: 256 + do_sample: false + # Contrastive search opt: penalty_alpha: 0.6 diff --git a/src/xturing/engines/__init__.py b/src/xturing/engines/__init__.py index b034200..701db90 100644 --- a/src/xturing/engines/__init__.py +++ b/src/xturing/engines/__init__.py @@ -44,6 +44,7 @@ GPTJLoraEngine, GPTJLoraInt8Engine, ) +from xturing.engines.llama2_engine import LLama2Engine from xturing.engines.llama_engine import ( LLamaEngine, LLamaInt8Engine, @@ -95,6 +96,7 @@ BaseEngine.add_to_registry(LlamaLoraEngine.config_name, LlamaLoraEngine) BaseEngine.add_to_registry(LlamaLoraInt8Engine.config_name, LlamaLoraInt8Engine) BaseEngine.add_to_registry(LlamaLoraKbitEngine.config_name, LlamaLoraKbitEngine) +BaseEngine.add_to_registry(LLama2Engine.config_name, LLama2Engine) BaseEngine.add_to_registry(OPTEngine.config_name, OPTEngine) BaseEngine.add_to_registry(OPTInt8Engine.config_name, OPTInt8Engine) BaseEngine.add_to_registry(OPTLoraEngine.config_name, OPTLoraEngine) diff --git a/src/xturing/engines/llama2_engine.py b/src/xturing/engines/llama2_engine.py new file mode 100644 index 0000000..851966d --- /dev/null +++ b/src/xturing/engines/llama2_engine.py @@ -0,0 +1,18 @@ +from pathlib import Path +from typing import Optional, Union + +from xturing.engines.causal import CausalEngine + + +class LLama2Engine(CausalEngine): + config_name: str = "llama2_engine" + + def __init__(self, weights_path: Optional[Union[str, Path]] = None): + super().__init__( + model_name="daryl149/llama-2-7b-chat-hf", + weights_path=weights_path, + trust_remote_code=True, + ) + + self.tokenizer.pad_token = self.tokenizer.eos_token + self.tokenizer.pad_token_id = self.tokenizer.eos_token_id diff --git a/src/xturing/models/__init__.py b/src/xturing/models/__init__.py index 496e22a..5be4f12 100644 --- a/src/xturing/models/__init__.py +++ b/src/xturing/models/__init__.py @@ -36,6 +36,7 @@ LlamaLoraInt8, LlamaLoraKbit, ) +from xturing.models.llama2 import Llama2 from xturing.models.opt import OPT, OPTInt8, OPTLora, OPTLoraInt8 from xturing.models.stable_diffusion import StableDiffusion @@ -76,6 +77,7 @@ BaseModel.add_to_registry(LlamaLora.config_name, LlamaLora) BaseModel.add_to_registry(LlamaLoraInt8.config_name, LlamaLoraInt8) BaseModel.add_to_registry(LlamaLoraKbit.config_name, LlamaLoraKbit) +BaseModel.add_to_registry(Llama2.config_name, Llama2) BaseModel.add_to_registry(OPT.config_name, OPT) BaseModel.add_to_registry(OPTInt8.config_name, OPTInt8) BaseModel.add_to_registry(OPTLora.config_name, OPTLora) diff --git a/src/xturing/models/llama2.py b/src/xturing/models/llama2.py new file mode 100644 index 0000000..7de0b35 --- /dev/null +++ b/src/xturing/models/llama2.py @@ -0,0 +1,11 @@ +from typing import Optional + +from xturing.engines.llama_engine import LLama2Engine +from xturing.models.causal import CausalModel + + +class Llama2(CausalModel): + config_name: str = "llama2" + + def __init__(self, weights_path: Optional[str] = None): + super().__init__(LLama2Engine.config_name, weights_path) From e6ecabc2bb007830561afd79da9d54db24146648 Mon Sep 17 00:00:00 2001 From: Tushar Date: Tue, 25 Jul 2023 12:22:29 +0000 Subject: [PATCH 15/18] feat: added all variations all variations added --- src/xturing/config/finetuning_config.yaml | 24 ++++++++ src/xturing/config/generation_config.yaml | 22 +++++++ src/xturing/engines/llama2_engine.py | 70 ++++++++++++++++++++++- src/xturing/models/llama2.py | 44 +++++++++++++- 4 files changed, 157 insertions(+), 3 deletions(-) diff --git a/src/xturing/config/finetuning_config.yaml b/src/xturing/config/finetuning_config.yaml index 47d491b..3f12670 100644 --- a/src/xturing/config/finetuning_config.yaml +++ b/src/xturing/config/finetuning_config.yaml @@ -234,6 +234,30 @@ llama2: num_train_epochs: 3 optimizer_name: cpu_adam +llama2_lora: + learning_rate: 5e-5 + weight_decay: 0.01 + num_train_epochs: 3 + optimizer_name: cpu_adam + +llama2_lora_int8: + learning_rate: 5e-5 + weight_decay: 0.01 + num_train_epochs: 3 + optimizer_name: cpu_adam + +llama2_int8: + learning_rate: 5e-5 + weight_decay: 0.01 + num_train_epochs: 3 + optimizer_name: cpu_adam + +llama2_lora_kbit: + learning_rate: 5e-5 + weight_decay: 0.01 + num_train_epochs: 3 + optimizer_name: cpu_adam + opt: learning_rate: 5e-5 weight_decay: 0.01 diff --git a/src/xturing/config/generation_config.yaml b/src/xturing/config/generation_config.yaml index 38dea3c..f844cf0 100644 --- a/src/xturing/config/generation_config.yaml +++ b/src/xturing/config/generation_config.yaml @@ -198,6 +198,28 @@ llama2: max_new_tokens: 256 do_sample: false +# Contrastive search +llama2_lora: + penalty_alpha: 0.6 + top_k: 4 + max_new_tokens: 256 + do_sample: false + +# Greedy search +llama2_int8: + max_new_tokens: 256 + do_sample: false + +# Greedy search +llama2_lora_int8: + max_new_tokens: 256 + do_sample: false + +# Greedy search +llama2_lora_kbit: + max_new_tokens: 256 + do_sample: false + # Contrastive search opt: penalty_alpha: 0.6 diff --git a/src/xturing/engines/llama2_engine.py b/src/xturing/engines/llama2_engine.py index 851966d..78dea48 100644 --- a/src/xturing/engines/llama2_engine.py +++ b/src/xturing/engines/llama2_engine.py @@ -1,7 +1,7 @@ from pathlib import Path from typing import Optional, Union -from xturing.engines.causal import CausalEngine +from xturing.engines.causal import CausalEngine, CausalLoraEngine, CausalLoraKbitEngine class LLama2Engine(CausalEngine): @@ -16,3 +16,71 @@ def __init__(self, weights_path: Optional[Union[str, Path]] = None): self.tokenizer.pad_token = self.tokenizer.eos_token self.tokenizer.pad_token_id = self.tokenizer.eos_token_id + + +class LLama2LoraEngine(CausalLoraEngine): + config_name: str = "llama2_lora_engine" + + def __init__(self, weights_path: Optional[Union[str, Path]] = None): + super().__init__( + model_name="daryl149/llama-2-7b-chat-hf", + weights_path=weights_path, + target_modules=[ + "q_proj", + "v_proj", + ], + trust_remote_code=True, + ) + + self.tokenizer.pad_token = self.tokenizer.eos_token + self.tokenizer.pad_token_id = self.tokenizer.eos_token_id + + +class LLama2Int8Engine(CausalEngine): + config_name: str = "llama2_int8_engine" + + def __init__(self, weights_path: Optional[Union[str, Path]] = None): + super().__init__( + model_name="daryl149/llama-2-7b-chat-hf", + weights_path=weights_path, + load_8bit=True, + trust_remote_code=True, + ) + self.tokenizer.pad_token = self.tokenizer.eos_token + self.tokenizer.pad_token_id = self.tokenizer.eos_token_id + + +class LLama2LoraInt8Engine(CausalLoraEngine): + config_name: str = "llama2_lora_int8_engine" + + def __init__(self, weights_path: Optional[Union[str, Path]] = None): + super().__init__( + model_name="daryl149/llama-2-7b-chat-hf", + weights_path=weights_path, + load_8bit=True, + target_modules=[ + "q_proj", + "v_proj", + ], + trust_remote_code=True, + ) + + self.tokenizer.pad_token = self.tokenizer.eos_token + self.tokenizer.pad_token_id = self.tokenizer.eos_token_id + + +class LLama2LoraKbitEngine(CausalLoraKbitEngine): + config_name: str = "llama2_lora_kbit_engine" + + def __init__(self, weights_path: Optional[Union[str, Path]] = None): + model_name = "daryl149/llama-2-7b-chat-hf" + super().__init__( + model_name=model_name, + weights_path=None, + target_modules=["q_proj", "v_proj"], + trust_remote_code=True, + load_4bit=True, + ) + + self.tokenizer.pad_token = self.tokenizer.eos_token + self.tokenizer.pad_token_id = self.tokenizer.eos_token_id diff --git a/src/xturing/models/llama2.py b/src/xturing/models/llama2.py index 7de0b35..202dda1 100644 --- a/src/xturing/models/llama2.py +++ b/src/xturing/models/llama2.py @@ -1,7 +1,19 @@ from typing import Optional -from xturing.engines.llama_engine import LLama2Engine -from xturing.models.causal import CausalModel +from xturing.engines.llama2_engine import ( + LLama2Engine, + LLama2Int8Engine, + LLama2LoraEngine, + LLama2LoraInt8Engine, + LLama2LoraKbitEngine, +) +from xturing.models.causal import ( + CausalInt8Model, + CausalLoraInt8Model, + CausalLoraKbitModel, + CausalLoraModel, + CausalModel, +) class Llama2(CausalModel): @@ -9,3 +21,31 @@ class Llama2(CausalModel): def __init__(self, weights_path: Optional[str] = None): super().__init__(LLama2Engine.config_name, weights_path) + + +class Llama2Lora(CausalLoraModel): + config_name: str = "llama2_lora" + + def __init__(self, weights_path: Optional[str] = None): + super().__init__(LLama2LoraEngine.config_name, weights_path) + + +class Llama2Int8(CausalInt8Model): + config_name: str = "llama2_int8" + + def __init__(self, weights_path: Optional[str] = None): + super().__init__(LLama2Int8Engine.config_name, weights_path) + + +class Llama2LoraInt8(CausalLoraInt8Model): + config_name: str = "llama2_lora_int8" + + def __init__(self, weights_path: Optional[str] = None): + super().__init__(LLama2LoraInt8Engine.config_name, weights_path) + + +class Llama2LoraKbit(CausalLoraKbitModel): + config_name: str = "llama2_lora_kbit" + + def __init__(self, weights_path: Optional[str] = None): + super().__init__(LLama2LoraKbitEngine.config_name, weights_path) From 842ddd190a4f0387455b80f0799c3a05fc102097 Mon Sep 17 00:00:00 2001 From: Roman Ageev <112644287+StochasticRomanAgeev@users.noreply.github.com> Date: Wed, 26 Jul 2023 19:45:16 +0800 Subject: [PATCH 16/18] feat: update README.md Clarification on CPU's --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 638056a..f85fd8b 100644 --- a/README.md +++ b/README.md @@ -79,7 +79,7 @@ model = GenericKbitModel('tiiuae/falcon-7b') # Run the fine-tuning model.finetune(dataset) ``` -4. __CPU inference__ - Now you can use just your CPU for inference of any LLM. _CAUTION : The inference will be very slow as CPUs are extremely slow for the amount of computation needed for inference_. +4. __CPU inference__ - Now you can use just your CPU for inference of any LLM. _CAUTION : The inference process may be sluggish because CPUs lack the required computational capacity for efficient inference_. 5. __Batch integration__ - Now you play around with `batch_size` in `.generate()` and `.evaluate()` functions. This will lead to faster results with `batch_size>1`. ```python # Make the necessary imports From 8cba9d66a293e138bc067d90a27f0e15df259cce Mon Sep 17 00:00:00 2001 From: Roman Ageev <112644287+StochasticRomanAgeev@users.noreply.github.com> Date: Wed, 26 Jul 2023 19:47:12 +0800 Subject: [PATCH 17/18] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index f85fd8b..d61e38a 100644 --- a/README.md +++ b/README.md @@ -80,7 +80,7 @@ model = GenericKbitModel('tiiuae/falcon-7b') model.finetune(dataset) ``` 4. __CPU inference__ - Now you can use just your CPU for inference of any LLM. _CAUTION : The inference process may be sluggish because CPUs lack the required computational capacity for efficient inference_. -5. __Batch integration__ - Now you play around with `batch_size` in `.generate()` and `.evaluate()` functions. This will lead to faster results with `batch_size>1`. +5. __Batch integration__ - By tweaking the 'batch_size' in the .generate() and .evaluate() functions, you can expedite results. Using a 'batch_size' greater than 1 typically enhances processing efficiency. ```python # Make the necessary imports from xturing.datasets import InstructionDataset From 45871574e6f02a085018675f40791d78b12c055a Mon Sep 17 00:00:00 2001 From: Roman Ageev <112644287+StochasticRomanAgeev@users.noreply.github.com> Date: Wed, 26 Jul 2023 19:50:52 +0800 Subject: [PATCH 18/18] feat: update README.md Added refactored text --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index d61e38a..465a431 100644 --- a/README.md +++ b/README.md @@ -97,9 +97,9 @@ outputs = model.generate(dataset = dataset, batch_size=10) ``` -You can check the [Llama LoRA INT4 working example](examples/int4_finetuning/LLaMA_lora_int4.ipynb) file to see how it works. +An exploration of the [Llama LoRA INT4 working example](examples/int4_finetuning/LLaMA_lora_int4.ipynb) is recommended for an understanding of its application. -Also, you can check the [GenericModel working example](examples/generic/generic_model.py) repository to see how it works. +For an extended insight, consider examining the [GenericModel working example](examples/generic/generic_model.py) available in the repository.