Skip to content

Commit

Permalink
[refactor] warn vllm returning logprobs with prefix caching
Browse files Browse the repository at this point in the history
  • Loading branch information
huyiwen committed Jun 21, 2024
1 parent f323353 commit 2ed7e1c
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 44 deletions.
6 changes: 3 additions & 3 deletions tests/utilization/utils/test_parse_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
def test_default_vllm():
model_args, dataset_args, evaluation_args = parse_argument(['-m', 'a-random-fake-model', '-d', 'nq', 'quac'])
assert model_args.model_backend == "vllm"
assert model_args.prefix_caching is False
assert model_args.prefix_caching is None # vllm default is False


def test_no_prefix_caching():
# currently vllm doesn't support returning logprob for prefix caching
# batch size is 1, so prefix caching is not used
model_args, dataset_args, evaluation_args = parse_argument([
'-m', 'a-random-fake-model', '-d', 'nq', 'mmlu', '-b', '1'
])
Expand All @@ -27,7 +27,7 @@ def test_default_prefix_caching():
'-m', 'a-random-fake-model', '-d', 'nq', 'mmlu', '-b', '16'
])
assert model_args.model_backend == "huggingface"
assert model_args.prefix_caching is True
assert model_args.prefix_caching is None # huggingface default is True


def test_default_no_efficient():
Expand Down
2 changes: 2 additions & 0 deletions utilization/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,8 @@ def __init__(
self.ranking_type = args.ranking_type
self.model_type = model.model_type
self.prefix_caching = model.args.prefix_caching
if self.prefix_caching is None:
self.prefix_caching = True
self.instance_format = "{source}{target}"
if args.instruction:
self.instruction = args.instruction
Expand Down
6 changes: 0 additions & 6 deletions utilization/load_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,12 +369,6 @@ def load_datasets(
args.auto_batch_size = False
logger.info("Setting batch_size to -1, since vllm can automatically planning the optimal batch and order.")

if model.args.prefix_caching and model.model_backend != "huggingface":
logger.warning(
"Prefix caching is only available for HuggingFaceModel. Automatically set prefix_caching to False"
)
model.args.prefix_caching = False

# get all the dataset classes
datasets = []
for d in args.dataset_names:
Expand Down
32 changes: 21 additions & 11 deletions utilization/model/vllm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,8 @@ def __init__(self, args: "ModelArguments", **kwargs):
self.args = args

logger.info(f"Trying to load {args.model_name_or_path} using vllm...")
self.vllm_version = version.parse(vllm.__version__)
if args.prefix_caching:
if self.is_legacy_vllm():
logger.warning(
f"vllm version ({vllm.__version__}) is lower than 0.4.0, prefix_caching is not supported."
)
else:
kwargs["enable_prefix_caching"] = True
self.use_cache = True
if args.prefix_caching is not None:
kwargs["enable_prefix_caching"] = args.prefix_caching

self.model = LLM(
model=args.model_name_or_path,
Expand All @@ -77,10 +70,21 @@ def __init__(self, args: "ModelArguments", **kwargs):
)
self.tokenizer.chat_template = args.chat_template

def is_legacy_vllm(self):
return self.vllm_version < version.parse("0.4.0")
@property
def use_cache(self):
return self.model.llm_engine.cache_config.enable_prefix_caching

@use_cache.setter
def use_cache(self, value):
self.model.llm_engine.cache_config.enable_prefix_caching = value

def set_ppl_args(self, **extra_model_args):
if self.use_cache:
logger.warning(
"Prefix caching is enabled for vllm. However, it is a known issue for vllm to return logprobs with prefix caching enabled. See https://github.com/vllm-project/vllm/issues/3914 for details."
)
self.use_cache = False

self.ppl_kwargs = SamplingParams(max_tokens=1, prompt_logprobs=0)
if len(extra_model_args) > 0:
logger.warning(f"Unused generation arguments: {extra_model_args}")
Expand Down Expand Up @@ -144,6 +148,12 @@ def generation(self, batched_inputs: List[Conversation]) -> List[str]:
return [c.get_generation_results() for c in batched_inputs]

def set_prob_args(self, **extra_model_args):
if self.use_cache:
logger.warning(
"Prefix caching is enabled for vllm. However, it is a known issue for vllm to return logprobs with prefix caching enabled. See https://github.com/vllm-project/vllm/issues/3914 for details."
)
self.use_cache = False

self.prob_kwargs = SamplingParams(max_tokens=1, temperature=0)
self.candidate_ids = extra_model_args.pop("candidate_ids", None)

Expand Down
30 changes: 6 additions & 24 deletions utilization/utils/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ class ModelArguments(ModelBackendMixin):
default="auto",
help="The device map for model and data",
)
prefix_caching: bool = HfArg(
prefix_caching: Optional[bool] = HfArg(
default=None,
help="Whether to cache prefix in get_ppl mode",
)
Expand Down Expand Up @@ -369,13 +369,6 @@ def __post_init__(self):

if self.is_vllm_model():
self.vllm_gpu_memory_utilization = 0.9
if self.prefix_caching is None:
# prefix_caching is still experimental
self.prefix_caching = False

elif self.is_huggingface_model():
if self.prefix_caching is None:
self.prefix_caching = True

# argparse encodes string with unicode_escape, decode it to normal string, e.g., "\\n" -> "\n"
if self.stop is not None:
Expand Down Expand Up @@ -626,16 +619,13 @@ def check_args(model_args: ModelArguments, dataset_args: DatasetArguments, evalu
d not in DEFAULT_VLLM_DATASETS for d in dataset_args.dataset_names
):
model_args.model_backend = "huggingface"
if not model_args.passed_in_commandline("prefix_caching"):
model_args.prefix_caching = True

model_args.seed = int(evaluation_args.seed)

if dataset_args.batch_size == 1 and model_args.prefix_caching:
if model_args.is_local_model():
logger.warning(
"Prefix caching is not supported for batch_size=1, automatically set prefix_caching to False."
)
if dataset_args.batch_size == 1 and model_args.prefix_caching is None and model_args.is_huggingface_model():
logger.warning(
"Prefix caching is not supported for batch_size=1, automatically set prefix_caching to False."
)
model_args.prefix_caching = False

# check models
Expand All @@ -646,14 +636,6 @@ def check_args(model_args: ModelArguments, dataset_args: DatasetArguments, evalu
f"chat/completions endpoint model {model_args.model_name_or_path} doesn't support batch_size > 1, automatically set batch_size to 1."
)

# vllm has its own prefix caching mechanism
if model_args.prefix_caching and "expandable_segments" not in os.environ.get(
"PYTORCH_CUDA_ALLOC_CONF", ""
) and model_args.is_huggingface_model():
logger.warning(
f"Prefix caching might results in cuda memory fragmentation, which can be mitigated by setting `PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True`. See https://pytorch.org/docs/stable/notes/cuda.html#environment-variables for details."
)

# check dataset
if "vicuna_bench" in dataset_args.dataset_names and model_args.openai_api_key is None:
raise ValueError(
Expand All @@ -675,7 +657,7 @@ def check_args(model_args: ModelArguments, dataset_args: DatasetArguments, evalu
"Instruction does not include any variable, so the input remains unchanged across the insatnces. Try to use f-string or jinja2 format to include variables like `{source}` or `{problem}`. See dataset documentation for details."
)

if evaluation_args.dry_run and model_args.prefix_caching:
if evaluation_args.dry_run:
model_args.prefix_caching = False

args_ignored = set()
Expand Down

0 comments on commit 2ed7e1c

Please sign in to comment.