Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ Add codemod and rule to add type annotations or TODO comment for fields with a default and no type annotations #163

Merged
merged 7 commits into from
Apr 18, 2024
38 changes: 38 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,44 @@ class SomeThing:
field_schema['example'] = "Weird example"
```

### BP010: Add type annotations or TODO comments to fields without them

- ✅ Add type annotations based on the default value for a few types that can be inferred, like `bool`, `str`, `int`, `float`.
- ✅ Add `# TODO[pydantic]: add type annotation` comments to fields that can't be inferred.

The following code will be transformed:

```py
from pydantic import BaseModel, Field

class Potato(BaseModel):
name: str
is_sale = True
tags = ["tag1", "tag2"]
price = 10.5
description = "Some item"
active = Field(default=True)
ready = Field(True)
age = Field(10, title="Age")
```

Into:

```py
from pydantic import BaseModel, Field

class Potato(BaseModel):
name: str
is_sale: bool = True
# TODO[pydantic]: add type annotation
tags = ["tag1", "tag2"]
price: float = 10.5
description: str = "Some item"
active: bool = Field(default=True)
ready: bool = Field(True)
age: int = Field(10, title="Age")
```

<!-- ### BP010: Replace `pydantic.parse_obj_as` by `pydantic.TypeAdapter`

- ✅ Replace `pydantic.parse_obj_as(T, obj)` to `pydantic.TypeAdapter(T).validate_python(obj)`.
Expand Down
6 changes: 6 additions & 0 deletions bump_pydantic/codemods/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from libcst.codemod import ContextAwareTransformer
from libcst.codemod.visitors import AddImportsVisitor, RemoveImportsVisitor

from bump_pydantic.codemods.add_annotations import AddAnnotationsCommand
from bump_pydantic.codemods.add_default_none import AddDefaultNoneCommand
from bump_pydantic.codemods.con_func import ConFuncCallCommand
from bump_pydantic.codemods.custom_types import CustomTypeCodemod
Expand Down Expand Up @@ -34,6 +35,8 @@ class Rule(str, Enum):
"""Replace `con*` functions by `Annotated` versions."""
BP009 = "BP009"
"""Mark Pydantic "protocol" functions in custom types with proper TODOs."""
BP010 = "BP010"
"""Add type annotations or TODOs to fields without them."""


def gather_codemods(disabled: List[Rule]) -> List[Type[ContextAwareTransformer]]:
Expand Down Expand Up @@ -67,6 +70,9 @@ def gather_codemods(disabled: List[Rule]) -> List[Type[ContextAwareTransformer]]
if Rule.BP009 not in disabled:
codemods.append(CustomTypeCodemod)

if Rule.BP010 not in disabled:
codemods.append(AddAnnotationsCommand)

# Those codemods need to be the last ones.
codemods.extend([RemoveImportsVisitor, AddImportsVisitor])
return codemods
158 changes: 158 additions & 0 deletions bump_pydantic/codemods/add_annotations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
from __future__ import annotations

import libcst as cst
import libcst.matchers as m
from libcst.codemod import CodemodContext, VisitorBasedCodemodCommand
from libcst.metadata import FullyQualifiedNameProvider, QualifiedName

from bump_pydantic.codemods.class_def_visitor import ClassDefVisitor

COMMENT = "# TODO[pydantic]: add type annotation"


class AddAnnotationsCommand(VisitorBasedCodemodCommand):
"""This codemod adds a type annotation or TODO comment to pydantic fields without
a type annotation.

Example::
# Before
```py
from pydantic import BaseModel, Field

class Foo(BaseModel):
name: str
is_sale = True
tags = ["tag1", "tag2"]
price = 10.5
description = "Some item"
active = Field(default=True)
ready = Field(True)
age = Field(10, title="Age")
```

# After
```py
from pydantic import BaseModel, Field

class Foo(BaseModel):
name: str
is_sale: bool = True
# TODO[pydantic]: add type annotation
tags = ["tag1", "tag2"]
price: float = 10.5
description: str = "Some item"
active: bool = Field(default=True)
ready: bool = Field(True)
age: int = Field(10, title="Age")
```
"""

METADATA_DEPENDENCIES = (FullyQualifiedNameProvider,)

def __init__(self, context: CodemodContext) -> None:
super().__init__(context)

self.inside_base_model = False
self.base_model_fields: set[cst.Assign | cst.AnnAssign | cst.SimpleStatementLine] = set()
self.statement: cst.SimpleStatementLine | None = None
self.needs_comment = False
self.has_comment = False
self.in_field = False

def visit_ClassDef(self, node: cst.ClassDef) -> None:
fqn_set = self.get_metadata(FullyQualifiedNameProvider, node)

if not fqn_set:
return None

fqn: QualifiedName = next(iter(fqn_set)) # type: ignore
if fqn.name in self.context.scratch[ClassDefVisitor.BASE_MODEL_CONTEXT_KEY]:
self.inside_base_model = True
self.base_model_fields = {
child for child in node.body.children if isinstance(child, cst.SimpleStatementLine)
}
return

def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef:
self.base_model_fields = set()
return updated_node

def visit_SimpleStatementLine(self, node: cst.SimpleStatementLine) -> None:
if node not in self.base_model_fields:
return
if not self.inside_base_model:
return
self.statement = node
self.in_field = True
for line in node.leading_lines:
if m.matches(line, m.EmptyLine(comment=m.Comment(value=COMMENT))):
self.has_comment = True

def leave_SimpleStatementLine(
self, original_node: cst.SimpleStatementLine, updated_node: cst.SimpleStatementLine
) -> cst.SimpleStatementLine:
if original_node not in self.base_model_fields:
return updated_node
if self.needs_comment and not self.has_comment:
updated_node = updated_node.with_changes(
leading_lines=[
*updated_node.leading_lines,
cst.EmptyLine(comment=cst.Comment(value=(COMMENT))),
],
body=[
*updated_node.body,
],
)
self.statement = None
self.needs_comment = False
self.has_comment = False
self.in_field = False
return updated_node

def leave_Assign(self, original_node: cst.Assign, updated_node: cst.Assign) -> cst.Assign | cst.AnnAssign:
if not self.in_field:
return updated_node
if self.inside_base_model:
if m.matches(updated_node, m.Assign(targets=[m.AssignTarget(target=m.Name("model_config"))])):
return updated_node
Undefined = object()
value: cst.BaseExpression | object = Undefined
if m.matches(updated_node.value, m.Call(func=m.Name("Field"))):
assert isinstance(updated_node.value, cst.Call)
args = updated_node.value.args
if args:
default_keywords = [arg.value for arg in args if arg.keyword and arg.keyword.value == "default"]
# NOTE: It has a "default" value as positional argument.
if args[0].keyword is None:
value = args[0].value
# NOTE: It has a "default" keyword argument.
elif default_keywords:
value = default_keywords[0]
else:
value = updated_node.value
if value is Undefined:
self.needs_comment = True
return updated_node

# Infer simple type annotations
ann_type = None
assert isinstance(value, cst.BaseExpression)
if m.matches(value, m.Name("True") | m.Name("False")):
ann_type = "bool"
elif m.matches(value, m.SimpleString()):
ann_type = "str"
elif m.matches(value, m.Integer()):
ann_type = "int"
elif m.matches(value, m.Float()):
ann_type = "float"

# If there's a simple inferred type annotation, return that
if ann_type:
return cst.AnnAssign(
target=updated_node.targets[0].target,
annotation=cst.Annotation(cst.Name(ann_type)),
value=updated_node.value,
)
else:
self.needs_comment = True
return updated_node
8 changes: 4 additions & 4 deletions tests/integration/cases/replace_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,12 @@
"from pydantic import BaseModel, Field",
"",
"",
"class A(str, Enum):",
"class E(str, Enum):",
" a = 'a'",
" b = 'b'",
"",
"class A(BaseModel):",
" a: A = Field(A.a, const=True)",
" a: E = Field(E.a, const=True)",
],
),
expected=File(
Expand All @@ -70,12 +70,12 @@
"from typing import Literal",
"",
"",
"class A(str, Enum):",
"class E(str, Enum):",
" a = 'a'",
" b = 'b'",
"",
"class A(BaseModel):",
" a: Literal[A.a] = A.a",
" a: Literal[E.a] = E.a",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These tests were failing because both classes had the same name, one would override the other.

And then the new "add type annotations" codemod would modify the enum, just because the name is the same. But as it's not really valid to have the same name for two classes in the same file, I just updated it. 🤓

],
),
),
Expand Down
Loading