From 7a564de2005b6853992bf827ebe32f2e648d8648 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Thu, 18 Apr 2024 00:38:18 -0500 Subject: [PATCH 1/7] =?UTF-8?q?=E2=9C=A8=20Implement=20add=5Fannotations?= =?UTF-8?q?=20codemod?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- bump_pydantic/codemods/add_annotations.py | 156 ++++++++++++++++++++++ 1 file changed, 156 insertions(+) create mode 100644 bump_pydantic/codemods/add_annotations.py diff --git a/bump_pydantic/codemods/add_annotations.py b/bump_pydantic/codemods/add_annotations.py new file mode 100644 index 0000000..634f24b --- /dev/null +++ b/bump_pydantic/codemods/add_annotations.py @@ -0,0 +1,156 @@ +from __future__ import annotations + +import libcst as cst +import libcst.matchers as m +from libcst.codemod import CodemodContext, VisitorBasedCodemodCommand +from libcst.metadata import FullyQualifiedNameProvider, QualifiedName + +from bump_pydantic.codemods.class_def_visitor import ClassDefVisitor + +COMMENT = "# TODO[pydantic]: add type annotation" + + +class AddAnnotationsCommand(VisitorBasedCodemodCommand): + """This codemod adds a type annotation or TODO comment to pydantic fields without + a type annotation. + + Example:: + # Before + ```py + from pydantic import BaseModel, Field + + class Foo(BaseModel): + name: str + is_sale = True + tags = ["tag1", "tag2"] + price = 10.5 + description = "Some item" + active = Field(default=True) + ready = Field(True) + age = Field(10, title="Age") + ``` + + # After + ```py + from pydantic import BaseModel, Field + + class Foo(BaseModel): + name: str + is_sale: bool = True + # TODO[pydantic]: add type annotation + tags = ["tag1", "tag2"] + price: float = 10.5 + description: str = "Some item" + active: bool = Field(default=True) + ready: bool = Field(True) + age: int = Field(10, title="Age") + ``` + """ + + METADATA_DEPENDENCIES = (FullyQualifiedNameProvider,) + + def __init__(self, context: CodemodContext) -> None: + super().__init__(context) + + self.inside_base_model = False + self.base_model_fields: set[cst.Assign | cst.AnnAssign | cst.SimpleStatementLine] = set() + self.statement: cst.SimpleStatementLine | None = None + self.needs_comment = False + self.has_comment = False + self.in_field = False + + def visit_ClassDef(self, node: cst.ClassDef) -> None: + fqn_set = self.get_metadata(FullyQualifiedNameProvider, node) + + if not fqn_set: + return None + + fqn: QualifiedName = next(iter(fqn_set)) # type: ignore + if fqn.name in self.context.scratch[ClassDefVisitor.BASE_MODEL_CONTEXT_KEY]: + self.inside_base_model = True + self.base_model_fields = { + child for child in node.body.children if isinstance(child, cst.SimpleStatementLine) + } + return + + def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef: + self.base_model_fields = set() + return updated_node + + def visit_SimpleStatementLine(self, node: cst.SimpleStatementLine) -> None: + if node not in self.base_model_fields: + return + if not self.inside_base_model: + return + self.statement = node + self.in_field = True + for line in node.leading_lines: + if m.matches(line, m.EmptyLine(comment=m.Comment(value=COMMENT))): + self.has_comment = True + + def leave_SimpleStatementLine( + self, original_node: cst.SimpleStatementLine, updated_node: cst.SimpleStatementLine + ) -> cst.SimpleStatementLine: + if original_node not in self.base_model_fields: + return updated_node + if self.needs_comment and not self.has_comment: + updated_node = updated_node.with_changes( + leading_lines=[ + *updated_node.leading_lines, + cst.EmptyLine(comment=cst.Comment(value=(COMMENT))), + ], + body=[ + *updated_node.body, + ], + ) + self.statement = None + self.needs_comment = False + self.has_comment = False + self.in_field = False + return updated_node + + def leave_Assign(self, original_node: cst.Assign, updated_node: cst.Assign) -> cst.Assign | cst.AnnAssign: + if not self.in_field: + return updated_node + if self.inside_base_model: + Undefined = object() + value: cst.BaseExpression | object = Undefined + if m.matches(updated_node.value, m.Call(func=m.Name("Field"))): + assert isinstance(updated_node.value, cst.Call) + args = updated_node.value.args + if args: + default_keywords = [arg.value for arg in args if arg.keyword and arg.keyword.value == "default"] + # NOTE: It has a "default" value as positional argument. + if args[0].keyword is None: + value = args[0].value + # NOTE: It has a "default" keyword argument. + elif default_keywords: + value = default_keywords[0] + else: + value = updated_node.value + if value is Undefined: + self.needs_comment = True + return updated_node + + # Infer simple type annotations + ann_type = None + assert isinstance(value, cst.BaseExpression) + if m.matches(value, m.Name("True") | m.Name("False")): + ann_type = "bool" + elif m.matches(value, m.SimpleString()): + ann_type = "str" + elif m.matches(value, m.Integer()): + ann_type = "int" + elif m.matches(value, m.Float()): + ann_type = "float" + + # If there's a simple inferred type annotation, return that + if ann_type: + return cst.AnnAssign( + target=updated_node.targets[0].target, + annotation=cst.Annotation(cst.Name(ann_type)), + value=updated_node.value, + ) + else: + self.needs_comment = True + return updated_node From ff83f7821664642c2dafdc174e0b58178719d634 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Thu, 18 Apr 2024 00:38:57 -0500 Subject: [PATCH 2/7] =?UTF-8?q?=E2=9C=A8=20Add=20AddAnnotations=20to=20CLI?= =?UTF-8?q?=20default=20codemods=20and=20new=20rule=20BP010?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- bump_pydantic/codemods/__init__.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/bump_pydantic/codemods/__init__.py b/bump_pydantic/codemods/__init__.py index 262f8c3..c56e02d 100644 --- a/bump_pydantic/codemods/__init__.py +++ b/bump_pydantic/codemods/__init__.py @@ -4,6 +4,7 @@ from libcst.codemod import ContextAwareTransformer from libcst.codemod.visitors import AddImportsVisitor, RemoveImportsVisitor +from bump_pydantic.codemods.add_annotations import AddAnnotationsCommand from bump_pydantic.codemods.add_default_none import AddDefaultNoneCommand from bump_pydantic.codemods.con_func import ConFuncCallCommand from bump_pydantic.codemods.custom_types import CustomTypeCodemod @@ -34,6 +35,8 @@ class Rule(str, Enum): """Replace `con*` functions by `Annotated` versions.""" BP009 = "BP009" """Mark Pydantic "protocol" functions in custom types with proper TODOs.""" + BP010 = "BP010" + """Add type annotations or TODOs to fields without them.""" def gather_codemods(disabled: List[Rule]) -> List[Type[ContextAwareTransformer]]: @@ -67,6 +70,9 @@ def gather_codemods(disabled: List[Rule]) -> List[Type[ContextAwareTransformer]] if Rule.BP009 not in disabled: codemods.append(CustomTypeCodemod) + if Rule.BP010 not in disabled: + codemods.append(AddAnnotationsCommand) + # Those codemods need to be the last ones. codemods.extend([RemoveImportsVisitor, AddImportsVisitor]) return codemods From c1f0ed67ab1f513d63e4e963f8174b24a88b8035 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Thu, 18 Apr 2024 00:40:06 -0500 Subject: [PATCH 3/7] =?UTF-8?q?=E2=9C=85=20Add=20test=20for=20AddAnnotatio?= =?UTF-8?q?nsCommand?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/unit/test_add_annotations.py | 169 +++++++++++++++++++++++++++++ 1 file changed, 169 insertions(+) create mode 100644 tests/unit/test_add_annotations.py diff --git a/tests/unit/test_add_annotations.py b/tests/unit/test_add_annotations.py new file mode 100644 index 0000000..f82df7a --- /dev/null +++ b/tests/unit/test_add_annotations.py @@ -0,0 +1,169 @@ +import textwrap +from pathlib import Path + +import libcst as cst +from libcst import MetadataWrapper, parse_module +from libcst.codemod import CodemodContext, CodemodTest +from libcst.metadata import FullyQualifiedNameProvider +from libcst.testing.utils import UnitTest + +from bump_pydantic.codemods.add_annotations import AddAnnotationsCommand +from bump_pydantic.codemods.class_def_visitor import ClassDefVisitor + + +class TestAddAnnotationsCommand(UnitTest): + def add_annotations(self, file_path: str, code: str) -> cst.Module: + mod = MetadataWrapper( + parse_module(CodemodTest.make_fixture_data(code)), + cache={ + FullyQualifiedNameProvider: FullyQualifiedNameProvider.gen_cache(Path(""), [file_path], None).get( + file_path, "" + ) + }, + ) + mod.resolve_many(AddAnnotationsCommand.METADATA_DEPENDENCIES) + context = CodemodContext(wrapper=mod) + instance = ClassDefVisitor(context=context) + mod.visit(instance) + + instance = AddAnnotationsCommand(context=context) # type: ignore[assignment] + return mod.visit(instance) + + def test_not_a_model(self) -> None: + source = textwrap.dedent( + """ + class Potato: + a = True + """ + ).lstrip() + module = self.add_annotations("some/test/module.py", source) + assert module.code == source + + def test_has_annotation(self) -> None: + source = textwrap.dedent( + """ + from pydantic import BaseModel + + class Potato(BaseModel): + a: bool = True + """ + ).lstrip() + module = self.add_annotations( + "some/test/module.py", + source, + ) + + assert module.code == source + + def test_add_annotations(self) -> None: + source = textwrap.dedent( + """ + from pydantic import BaseModel, Field + + class Potato(BaseModel): + name: str + is_sale = True + tags = ["tag1", "tag2"] + price = 10.5 + description = "Some item" + active = Field(default=True) + ready = Field(True) + age = Field(10, title="Age") + + def do_stuff(self): + something = [1, 2, 3] + return something + """ + ).lstrip() + module = self.add_annotations( + "some/test/module.py", + source, + ) + expected = textwrap.dedent( + """ + from pydantic import BaseModel, Field + + class Potato(BaseModel): + name: str + is_sale: bool = True + # TODO[pydantic]: add type annotation + tags = ["tag1", "tag2"] + price: float = 10.5 + description: str = "Some item" + active: bool = Field(default=True) + ready: bool = Field(True) + age: int = Field(10, title="Age") + + def do_stuff(self): + something = [1, 2, 3] + return something + """ + ).lstrip() + assert module.code == expected + + def test_with_multiple_classes(self) -> None: + source = textwrap.dedent( + """ + from pydantic import BaseModel, Field + + class Foo(BaseModel): + name: str + is_sale = True + tags = ["tag1", "tag2"] + price = 10.5 + description = "Some item" + active = Field(default=True) + ready = Field(True) + age = Field(10, title="Age") + + class Bar(Foo): + sub_name: str + sub_is_sale = True + sub_tags = ["tag1", "tag2"] + sub_price = 10.5 + sub_description = "Some item" + sub_active = Field(default=True) + sub_ready = Field(True) + sub_age = Field(10, title="Age") + + def do_stuff(self): + something = [1, 2, 3] + return something + """ + ).lstrip() + module = self.add_annotations( + "some/test/module.py", + source, + ) + expected = textwrap.dedent( + """ + from pydantic import BaseModel, Field + + class Foo(BaseModel): + name: str + is_sale: bool = True + # TODO[pydantic]: add type annotation + tags = ["tag1", "tag2"] + price: float = 10.5 + description: str = "Some item" + active: bool = Field(default=True) + ready: bool = Field(True) + age: int = Field(10, title="Age") + + class Bar(Foo): + sub_name: str + sub_is_sale: bool = True + # TODO[pydantic]: add type annotation + sub_tags = ["tag1", "tag2"] + sub_price: float = 10.5 + sub_description: str = "Some item" + sub_active: bool = Field(default=True) + sub_ready: bool = Field(True) + sub_age: int = Field(10, title="Age") + + def do_stuff(self): + something = [1, 2, 3] + return something + """ + ).lstrip() + assert module.code == expected From 27154dec634d754aa679812fa8857253cf5f9e8b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Thu, 18 Apr 2024 00:46:16 -0500 Subject: [PATCH 4/7] =?UTF-8?q?=F0=9F=93=9D=20Add=20docs=20for=20new=20rul?= =?UTF-8?q?e=20BP010?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 38 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/README.md b/README.md index b990c80..a4b08b7 100644 --- a/README.md +++ b/README.md @@ -340,6 +340,44 @@ class SomeThing: field_schema['example'] = "Weird example" ``` +### BP010: Add type annotations or TODO comments to fields without them + +- ✅ Add type annotations based on the default value for a few types that can be inferred, like `bool`, `str`, `int`, `float`. +- ✅ Add `# TODO[pydantic]: add type annotation` comments to fields that can't be inferred. + +The following code will be transformed: + +```py +from pydantic import BaseModel, Field + +class Potato(BaseModel): + name: str + is_sale = True + tags = ["tag1", "tag2"] + price = 10.5 + description = "Some item" + active = Field(default=True) + ready = Field(True) + age = Field(10, title="Age") +``` + +Into: + +```py +from pydantic import BaseModel, Field + +class Potato(BaseModel): + name: str + is_sale: bool = True + # TODO[pydantic]: add type annotation + tags = ["tag1", "tag2"] + price: float = 10.5 + description: str = "Some item" + active: bool = Field(default=True) + ready: bool = Field(True) + age: int = Field(10, title="Age") +``` +