Skip to content

Commit

Permalink
Add pre-commit hook to stop linting errors being pushed. (#2632)
Browse files Browse the repository at this point in the history
Summary:
## Motivation

I am so used to relying on `pre-commit` that I have forgotten to run manual linting when making PRs to `botorch`. This adds a simple `pre-commit` config to automate linting.

### Have you read the [Contributing Guidelines on pull requests](https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md#pull-requests)?

Yes

Pull Request resolved: #2632

Test Plan:
This PR should not affect anything. There are tools to run pre-commit in ci that maintainers could add to the package: https://pre-commit.ci/

## Related Issues:
omnilib/ufmt#251 The pre-commit hook only seems to work for v2.3.0 of ufmt. This issue tracks the error following from ruff_api.

Reviewed By: saitcakmak

Differential Revision: D66169817

Pulled By: Balandat

fbshipit-source-id: a1034d8a882749ee7ff08d4cc92188957073b3c8
  • Loading branch information
CompRhys authored and facebook-github-bot committed Nov 22, 2024
1 parent 5d37606 commit 3f2e2c7
Show file tree
Hide file tree
Showing 13 changed files with 165 additions and 48 deletions.
45 changes: 12 additions & 33 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,39 +10,18 @@ on:

jobs:

ufmt:
name: Code formatting and sorting with ufmt
lint:
runs-on: ubuntu-latest
strategy:
fail-fast: false
steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.10"
- name: Install dependencies
run: |
# pin dependencies to match Meta-internal versions
pip install -r requirements-fmt.txt
- name: ufmt
run: |
ufmt diff .
- uses: actions/checkout@v4

flake8:
name: Lint with flake8
runs-on: ubuntu-latest
strategy:
fail-fast: false
steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.10"
- name: Install dependencies
run: |
pip install flake8 flake8-docstrings
- name: Flake8
run: |
flake8
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.10"

- name: Install dependencies
run: pip install pre-commit

- name: Run pre-commit
run: pre-commit run --all-files --show-diff-on-failure
29 changes: 29 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
repos:
- repo: local
hooks:
- id: check-requirements-versions
name: Check pre-commit formatting versions
entry: python scripts/check_pre_commit_reqs.py
language: python
always_run: true
pass_filenames: false
additional_dependencies:
- PyYAML

- repo: https://github.com/omnilib/ufmt
rev: v2.8.0
hooks:
- id: ufmt
additional_dependencies:
- black==24.4.2
- usort==1.0.8.post1
- ruff-api==0.1.0
- stdlibs==2024.1.28
args: [format]

- repo: https://github.com/pycqa/flake8
rev: 7.0.0
hooks:
- id: flake8
additional_dependencies:
- flake8-docstrings
7 changes: 7 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,13 @@ flake8 .

from the repository root.

#### Pre-commit hooks

Contributors can use [pre-commit](https://pre-commit.com/) to run `ufmt` and
`flake8` as part of the commit process. To install the hooks, install `pre-commit`
via `pip install pre-commit` and run `pre-commit install` from the repository
root.

#### Docstring formatting

BoTorch uses
Expand Down
2 changes: 1 addition & 1 deletion botorch/acquisition/analytic.py
Original file line number Diff line number Diff line change
Expand Up @@ -982,7 +982,7 @@ def _log_ei_helper(u: Tensor) -> Tensor:
if not (u.dtype == torch.float32 or u.dtype == torch.float64):
raise TypeError(
f"LogExpectedImprovement only supports torch.float32 and torch.float64 "
f"dtypes, but received {u.dtype = }."
f"dtypes, but received {u.dtype=}."
)
# The function has two branching decisions. The first is u < bound, and in this
# case, just taking the logarithm of the naive _ei_helper implementation works.
Expand Down
4 changes: 2 additions & 2 deletions botorch/acquisition/logei.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,7 +540,7 @@ def check_tau(tau: FloatOrTensor, name: str) -> FloatOrTensor:
"""Checks the validity of the tau arguments of the functions below, and returns
`tau` if it is valid."""
if isinstance(tau, Tensor) and tau.numel() != 1:
raise ValueError(name + f" is not a scalar: {tau.numel() = }.")
raise ValueError(f"{name} is not a scalar: {tau.numel()=}.")
if not (tau > 0):
raise ValueError(name + f" is non-positive: {tau = }.")
raise ValueError(f"{name} is non-positive: {tau=}.")
return tau
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,7 @@ def _split_hvkg_fantasy_points(
"""
if n_f * num_pareto > X.size(-2):
raise ValueError(
f"`n_f*num_pareto` ({n_f*num_pareto}) must be less than"
f"`n_f*num_pareto` ({n_f * num_pareto}) must be less than"
f" the `q`-batch dimension of `X` ({X.size(-2)})."
)
split_sizes = [X.size(-2) - n_f * num_pareto, n_f * num_pareto]
Expand Down
4 changes: 2 additions & 2 deletions botorch/optim/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def _optimize_acqf_sequential_q(
if base_X_pending is not None
else candidates
)
logger.info(f"Generated sequential candidate {i+1} of {opt_inputs.q}")
logger.info(f"Generated sequential candidate {i + 1} of {opt_inputs.q}")
opt_inputs.acq_function.set_X_pending(base_X_pending)
return candidates, torch.stack(acq_value_list)

Expand Down Expand Up @@ -325,7 +325,7 @@ def _optimize_batch_candidates() -> tuple[Tensor, Tensor, list[Warning]]:
opt_warnings += ws
batch_candidates_list.append(batch_candidates_curr)
batch_acq_values_list.append(batch_acq_values_curr)
logger.info(f"Generated candidate batch {i+1} of {len(batched_ics)}.")
logger.info(f"Generated candidate batch {i + 1} of {len(batched_ics)}.")

batch_candidates = torch.cat(batch_candidates_list)
has_scalars = batch_acq_values_list[0].ndim == 0
Expand Down
8 changes: 5 additions & 3 deletions botorch/posteriors/posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from __future__ import annotations

from abc import ABC, abstractmethod, abstractproperty
from abc import ABC, abstractmethod

import torch
from torch import Tensor
Expand Down Expand Up @@ -77,12 +77,14 @@ def sample(self, sample_shape: torch.Size | None = None) -> Tensor:
with torch.no_grad():
return self.rsample(sample_shape=sample_shape)

@abstractproperty
@property
@abstractmethod
def device(self) -> torch.device:
r"""The torch device of the distribution."""
pass # pragma: no cover

@abstractproperty
@property
@abstractmethod
def dtype(self) -> torch.dtype:
r"""The torch dtype of the distribution."""
pass # pragma: no cover
Expand Down
4 changes: 2 additions & 2 deletions botorch/utils/probability/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def log_ndtr(x: Tensor) -> Tensor:
if not (x.dtype == torch.float32 or x.dtype == torch.float64):
raise TypeError(
f"log_Phi only supports torch.float32 and torch.float64 "
f"dtypes, but received {x.dtype = }."
f"dtypes, but received {x.dtype=}."
)
neg_inv_sqrt_2, log_2 = get_constants_like((_neg_inv_sqrt_2, _log_2), x)
return log_erfc(neg_inv_sqrt_2 * x) - log_2
Expand All @@ -181,7 +181,7 @@ def log_erfc(x: Tensor) -> Tensor:
if not (x.dtype == torch.float32 or x.dtype == torch.float64):
raise TypeError(
f"log_erfc only supports torch.float32 and torch.float64 "
f"dtypes, but received {x.dtype = }."
f"dtypes, but received {x.dtype=}."
)
is_pos = x > 0
x_pos = x.masked_fill(~is_pos, 0)
Expand Down
5 changes: 3 additions & 2 deletions botorch/utils/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import math
import warnings
from abc import abstractproperty
from abc import abstractmethod
from collections import OrderedDict
from collections.abc import Sequence
from itertools import product
Expand Down Expand Up @@ -138,7 +138,8 @@ def test_forward_and_evaluate_true(self):
)
self.assertEqual(res.shape, batch_shape + tail_shape)

@abstractproperty
@property
@abstractmethod
def functions(self) -> Sequence[BaseTestProblem]:
# The functions that should be tested. Typically defined as a class
# attribute on the test case subclassing this class.
Expand Down
99 changes: 99 additions & 0 deletions scripts/check_pre_commit_reqs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import sys
from pathlib import Path

import yaml


def parse_requirements(filepath):
"""Parse requirements file and return a dict of package versions."""
versions = {}
with open(filepath) as f:
for line in f:
line = line.strip()
if line and not line.startswith("#"):
# Handle different requirement formats
if "==" in line:
pkg, version = line.split("==")
versions[pkg.strip().lower()] = version.strip()
return versions


def parse_precommit_config(filepath):
"""Parse pre-commit config and extract ufmt repo rev and hook dependencies."""
with open(filepath) as f:
config = yaml.safe_load(f)

versions = {}
for repo in config["repos"]:
if "https://github.com/omnilib/ufmt" in repo.get("repo", ""):
# Get ufmt version from rev - assumes fixed format: vX.Y.Z
versions["ufmt"] = repo.get("rev", "").replace("v", "")

# Get dependency versions
for hook in repo["hooks"]:
if hook["id"] == "ufmt":
for dep in hook.get("additional_dependencies", []):
if "==" in dep:
pkg, version = dep.split("==")
versions[pkg.strip().lower()] = version.strip()
break
return versions


def main():
# Find the pre-commit config and requirements files
config_file = Path(".pre-commit-config.yaml")
requirements_file = Path("requirements-fmt.txt")

if not config_file.exists():
print(f"Error: Could not find {config_file}")
sys.exit(1)

if not requirements_file.exists():
print(f"Error: Could not find {requirements_file}")
sys.exit(1)

# Parse both files
req_versions = parse_requirements(requirements_file)
config_versions = parse_precommit_config(config_file)

# Check versions
mismatches = []
for pkg, req_ver in req_versions.items():
req_ver = req_versions.get(pkg, None)
config_ver = config_versions.get(pkg, None)

if req_ver != config_ver:
found_version_str = f"{pkg}: {requirements_file} has {req_ver},"
if pkg == "ufmt":
mismatches.append(
f"{found_version_str} pre-commit config rev has v{config_ver}"
)
else:
mismatches.append(
f"{found_version_str} pre-commit config has {config_ver}"
)

# Report results
if mismatches:
msg_str = "".join("\n\t" + msg for msg in mismatches)
print(
f"Version mismatches found:{msg_str}"
"\nPlease update the versions in `.pre-commit-config.yaml` to be "
"consistent with those in `requirements-fmt.txt` (source of truth)."
"\nNote: all versions must be pinned exactly ('==X.Y.Z') in both files."
)
sys.exit(1)
else:
print("All versions match!")
sys.exit(0)


if __name__ == "__main__":
main()
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ def test_split_hvkg_fantasy_points(self):
n_f = 100
num_pareto = 3
msg = (
rf".*\({n_f*num_pareto}\) must be less than"
rf".*\({n_f * num_pareto}\) must be less than"
rf" the `q`-batch dimension of `X` \({X.size(-2)}\)\."
)
with self.assertRaisesRegex(ValueError, msg):
Expand Down
2 changes: 1 addition & 1 deletion test/utils/probability/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ def test_gaussian_probabilities(self) -> None:

float16_msg = (
"only supports torch.float32 and torch.float64 dtypes, but received "
"x.dtype = torch.float16."
"x.dtype=torch.float16."
)
with self.assertRaisesRegex(TypeError, expected_regex=float16_msg):
log_erfc(torch.tensor(1.0, dtype=torch.float16, device=self.device))
Expand Down

0 comments on commit 3f2e2c7

Please sign in to comment.