Skip to content

Commit

Permalink
WIP: Support table logging for mlflow, too (#1506)
Browse files Browse the repository at this point in the history
* WIP: Support table logging for mlflow, too

Create a `LogPredictionCallback` for both "wandb" and "mlflow" if
specified.

In `log_prediction_callback_factory`, create a generic table and make it
specific only if the newly added `logger` argument is set to "wandb"
resp. "mlflow".

See #1505

* chore: lint

* add additional clause for mlflow as it's optional

* Fix circular imports

---------

Co-authored-by: Dave Farago <[email protected]>
Co-authored-by: Wing Lian <[email protected]>
  • Loading branch information
3 people authored Apr 9, 2024
1 parent 8fa0785 commit 057fa44
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 24 deletions.
16 changes: 11 additions & 5 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from axolotl.loraplus import create_loraplus_optimizer
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
from axolotl.utils import is_mlflow_available
from axolotl.utils.callbacks import (
EvalFirstStepCallback,
GPUStatsCallback,
Expand Down Expand Up @@ -71,10 +72,6 @@
LOG = logging.getLogger("axolotl.core.trainer_builder")


def is_mlflow_available():
return importlib.util.find_spec("mlflow") is not None


def _sanitize_kwargs_for_tagging(tag_names, kwargs=None):
if isinstance(tag_names, str):
tag_names = [tag_names]
Expand Down Expand Up @@ -943,7 +940,16 @@ def get_post_trainer_create_callbacks(self, trainer):
callbacks = []
if self.cfg.use_wandb and self.cfg.eval_table_size > 0:
LogPredictionCallback = log_prediction_callback_factory(
trainer, self.tokenizer
trainer, self.tokenizer, "wandb"
)
callbacks.append(LogPredictionCallback(self.cfg))
if (
self.cfg.use_mlflow
and is_mlflow_available()
and self.cfg.eval_table_size > 0
):
LogPredictionCallback = log_prediction_callback_factory(
trainer, self.tokenizer, "mlflow"
)
callbacks.append(LogPredictionCallback(self.cfg))

Expand Down
8 changes: 8 additions & 0 deletions src/axolotl/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""
Basic utils for Axolotl
"""
import importlib


def is_mlflow_available():
return importlib.util.find_spec("mlflow") is not None
51 changes: 32 additions & 19 deletions src/axolotl/utils/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import os
from shutil import copyfile
from tempfile import NamedTemporaryFile
from typing import TYPE_CHECKING, Dict, List
from typing import TYPE_CHECKING, Any, Dict, List

import evaluate
import numpy as np
Expand All @@ -27,7 +27,9 @@
)
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy

from axolotl.utils import is_mlflow_available
from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
from axolotl.utils.distributed import (
barrier,
broadcast_dict,
Expand Down Expand Up @@ -540,7 +542,7 @@ def predict_with_generate():
return CausalLMBenchEvalCallback


def log_prediction_callback_factory(trainer: Trainer, tokenizer):
def log_prediction_callback_factory(trainer: Trainer, tokenizer, logger: str):
class LogPredictionCallback(TrainerCallback):
"""Callback to log prediction values during each evaluation"""

Expand Down Expand Up @@ -597,15 +599,13 @@ def find_ranges(lst):
return ranges

def log_table_from_dataloader(name: str, table_dataloader):
table = wandb.Table( # type: ignore[attr-defined]
columns=[
"id",
"Prompt",
"Correct Completion",
"Predicted Completion (model.generate)",
"Predicted Completion (trainer.prediction_step)",
]
)
table_data: Dict[str, List[Any]] = {
"id": [],
"Prompt": [],
"Correct Completion": [],
"Predicted Completion (model.generate)": [],
"Predicted Completion (trainer.prediction_step)": [],
}
row_index = 0

for batch in tqdm(table_dataloader):
Expand Down Expand Up @@ -709,16 +709,29 @@ def log_table_from_dataloader(name: str, table_dataloader):
) in zip(
prompt_texts, completion_texts, predicted_texts, pred_step_texts
):
table.add_data(
row_index,
prompt_text,
completion_text,
prediction_text,
pred_step_text,
table_data["id"].append(row_index)
table_data["Prompt"].append(prompt_text)
table_data["Correct Completion"].append(completion_text)
table_data["Predicted Completion (model.generate)"].append(
prediction_text
)
table_data[
"Predicted Completion (trainer.prediction_step)"
].append(pred_step_text)
row_index += 1

wandb.run.log({f"{name} - Predictions vs Ground Truth": table}) # type: ignore[attr-defined]
if logger == "wandb":
wandb.run.log({f"{name} - Predictions vs Ground Truth": pd.DataFrame(table_data)}) # type: ignore[attr-defined]
elif logger == "mlflow" and is_mlflow_available():
import mlflow

tracking_uri = AxolotlInputConfig(
**self.cfg.to_dict()
).mlflow_tracking_uri
mlflow.log_table(
data=table_data,
artifact_file="PredictionsVsGroundTruth.json",
tracking_uri=tracking_uri,
)

if is_main_process():
log_table_from_dataloader("Eval", eval_dataloader)
Expand Down

0 comments on commit 057fa44

Please sign in to comment.