Skip to content

Commit

Permalink
(fix): support generating correct code snippets when extending base c…
Browse files Browse the repository at this point in the history
…lient in python (#3097)

* fix

* (fix): support generating accurate code snippets when writing to base_client.py

* specify changelog

* fix check
  • Loading branch information
dsinghvi authored Mar 3, 2024
1 parent 279101a commit 6fc9abc
Show file tree
Hide file tree
Showing 151 changed files with 3,013 additions and 27 deletions.
36 changes: 36 additions & 0 deletions generators/python/sdk/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,42 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [0.11.8-rc0] - 2024-02-27

- Beta: Introduce a `client` custom config that allows you to specify class_name and
filename for the client. This configuration can be used in several ways:

1. Rename your client class:

```yml
config:
client:
class_name: Imdb
```
2. Add custom functions to your generated SDK:
```yml
config:
client:
class_name: BaseImdb
filename: base_client.py
exported_class_name: Imdb
exported_filename: client.py
```
Often times you may want to add additional methods or utilites to the
generated client. The easiest way to do this is to configure Fern to write
the autogenerated client in another file and extend it on your own.
With the configuration above, Fern's Python SDK generator will create a
class called `BaseImdb` and `AsyncBaseImdb` and put them in a file called
`base_client.py`. As a user, you can extend both these classes with
custom utilities.

To make sure the code snippets in the generated SDK are accurate you can
specify `exported_class_name` and `exported_filename`.

## [0.11.7] - 2024-02-27

- Improvement: Introduces a flag `use_str_enums` to swap from using proper Enum classes to using Literals to represent enums. This change allows for forward compatibility of enums, since the user will receive the string back.
Expand Down
2 changes: 1 addition & 1 deletion generators/python/sdk/VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.11.7
0.11.8-rc0
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,13 @@ def __init__(
self._snippet_writer = snippet_writer

def generate(self, source_file: SourceFile) -> GeneratedRootClient:
exported_client_class_name = self._context.get_class_name_for_exported_root_client()
builder = RootClientGenerator.GeneratedRootClientBuilder(
module_path=self._context.get_module_path_in_project(
self._context.get_filepath_for_root_client().to_module().path
self._context.get_filepath_for_exported_root_client().to_module().path
),
class_name=self._class_name,
async_class_name=self._async_class_name,
class_name=self._context.get_class_name_for_exported_root_client(),
async_class_name="Async" + exported_client_class_name,
constructor_parameters=self._client_wrapper_constructor_params,
)
generated_root_client = builder.build()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,19 @@ def get_class_name_of_async_subpackage_service(self, subpackage_id: ir_types.Sub
...

@abstractmethod
def get_filepath_for_root_client(self) -> Filepath:
def get_filepath_for_generated_root_client(self) -> Filepath:
...

@abstractmethod
def get_class_name_for_root_client(self) -> str:
def get_class_name_for_generated_root_client(self) -> str:
...

@abstractmethod
def get_filepath_for_exported_root_client(self) -> Filepath:
...

@abstractmethod
def get_class_name_for_exported_root_client(self) -> str:
...

@abstractmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,24 +34,33 @@ def __init__(
custom_config=custom_config,
project_module_path=project_module_path,
)
client_class_name = custom_config.client_class_name or (
pascal_case(generator_config.organization) + pascal_case(generator_config.workspace_name)
client_class_name = (
custom_config.client_class_name
or custom_config.client.class_name
or (pascal_case(generator_config.organization) + pascal_case(generator_config.workspace_name))
)
exported_client_class_name = custom_config.client.exported_class_name or client_class_name
client_filename = custom_config.client_filename or custom_config.client.filename
self._error_declaration_referencer = ErrorDeclarationReferencer(
skip_resources_module=custom_config.improved_imports
)
self._environments_enum_declaration_referencer = EnvironmentsEnumDeclarationReferencer(
client_class_name=client_class_name, skip_resources_module=custom_config.improved_imports
client_class_name=exported_client_class_name, skip_resources_module=custom_config.improved_imports
)
self._subpackage_client_declaration_referencer = SubpackageClientDeclarationReferencer(
skip_resources_module=custom_config.improved_imports
)
self._subpackage_async_client_declaration_referencer = SubpackageAsyncClientDeclarationReferencer(
skip_resources_module=custom_config.improved_imports
)
self._root_client_declaration_referencer = RootClientDeclarationReferencer(
root_class_name=client_class_name,
root_client_filename=custom_config.client_filename,
self._root_generated_client_declaration_referencer = RootClientDeclarationReferencer(
client_class_name=client_class_name,
client_filename=client_filename,
skip_resources_module=custom_config.improved_imports,
)
self._root_exported_client_declaration_referencer = RootClientDeclarationReferencer(
client_class_name=exported_client_class_name,
client_filename=custom_config.client.exported_filename or client_filename,
skip_resources_module=custom_config.improved_imports,
)
self._custom_config = custom_config
Expand Down Expand Up @@ -90,11 +99,17 @@ def get_reference_to_subpackage_service(self, subpackage_id: ir_types.Subpackage
subpackage = self.ir.subpackages[subpackage_id]
return self._subpackage_client_declaration_referencer.get_class_reference(name=subpackage)

def get_filepath_for_root_client(self) -> Filepath:
return self._root_client_declaration_referencer.get_filepath(name=None)
def get_filepath_for_generated_root_client(self) -> Filepath:
return self._root_generated_client_declaration_referencer.get_filepath(name=None)

def get_class_name_for_generated_root_client(self) -> str:
return self._root_generated_client_declaration_referencer.get_class_name(name=None)

def get_filepath_for_exported_root_client(self) -> Filepath:
return self._root_exported_client_declaration_referencer.get_filepath(name=None)

def get_class_name_for_root_client(self) -> str:
return self._root_client_declaration_referencer.get_class_name(name=None)
def get_class_name_for_exported_root_client(self) -> str:
return self._root_exported_client_declaration_referencer.get_class_name(name=None)

def get_filepath_for_async_subpackage_service(self, subpackage_id: ir_types.SubpackageId) -> Filepath:
subpackage = self.ir.subpackages[subpackage_id]
Expand Down
22 changes: 20 additions & 2 deletions generators/python/src/fern_python/generators/sdk/custom_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,24 @@ class SdkPydanticModelCustomConfig(PydanticModelCustomConfig):
require_optional_fields: bool = False


class ClientConfiguration(pydantic.BaseModel):
# The filename where the auto-generated client
# lives
filename: str = "client.py"
class_name: Optional[str] = None
# The filename of the exported client which
# will be used in code snippets
exported_filename: str = "client.py"
exported_class_name: Optional[str] = None

class Config:
extra = pydantic.Extra.forbid


class SDKCustomConfig(pydantic.BaseModel):
extra_dependencies: Dict[str, str] = {}
skip_formatting: bool = False
client_class_name: Optional[str] = None
client_filename: str = "client.py"
client: ClientConfiguration = ClientConfiguration()
include_union_utils: bool = False
use_api_name_in_package: bool = False
package_name: Optional[str] = None
Expand All @@ -30,5 +43,10 @@ class SDKCustomConfig(pydantic.BaseModel):
# Python SDK by removing nested `resources` directoy
improved_imports: bool = False

# deprecated, use client config instead
client_class_name: Optional[str] = None
# deprecated, use client config instead
client_filename: Optional[str] = None

class Config:
extra = pydantic.Extra.forbid
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,17 @@


class RootClientDeclarationReferencer(SdkDeclarationReferencer[None]):
def __init__(self, root_class_name: str, root_client_filename: str, skip_resources_module: bool):
def __init__(self, client_class_name: str, client_filename: str, skip_resources_module: bool):
super().__init__(skip_resources_module=skip_resources_module)
self._root_class_name = root_class_name
self._root_client_filename = root_client_filename
self._client_class_name = client_class_name
self._client_filename = client_filename

def get_filepath(self, *, name: None) -> Filepath:
return Filepath(
directories=(),
file=Filepath.FilepathPart(module_name=self._root_client_filename[:-3]),
# the [:-3] removes the .py extension
file=Filepath.FilepathPart(module_name=self._client_filename[:-3]),
)

def get_class_name(self, *, name: None) -> str:
return self._root_class_name
return self._client_class_name
14 changes: 10 additions & 4 deletions generators/python/src/fern_python/generators/sdk/sdk_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,15 @@ def run(
) -> None:
custom_config = SDKCustomConfig.parse_obj(generator_config.custom_config or {})

if not custom_config.client_filename.endswith(".py"):
if custom_config.client_filename is not None and not custom_config.client_filename.endswith(".py"):
raise RuntimeError("client_filename must end in .py")

if not custom_config.client.filename.endswith(".py"):
raise RuntimeError("client_location.filename must end in .py")

if not custom_config.client.exported_filename.endswith(".py"):
raise RuntimeError("client_location.exported_filename must end in .py")

for dep, version in custom_config.extra_dependencies.items():
project.add_dependency(dependency=AST.Dependency(name=dep, version=version))

Expand Down Expand Up @@ -225,16 +231,16 @@ def _generate_root_client(
snippet_registry: SnippetRegistry,
snippet_writer: SnippetWriter,
) -> GeneratedRootClient:
filepath = context.get_filepath_for_root_client()
filepath = context.get_filepath_for_generated_root_client()
source_file = SourceFileFactory.create(
project=project, filepath=filepath, generator_exec_wrapper=generator_exec_wrapper
)
generated_root_client = RootClientGenerator(
context=context,
package=ir.root_package,
generated_environment=generated_environment,
class_name=context.get_class_name_for_root_client(),
async_class_name="Async" + context.get_class_name_for_root_client(),
class_name=context.get_class_name_for_generated_root_client(),
async_class_name="Async" + context.get_class_name_for_generated_root_client(),
snippet_registry=snippet_registry,
snippet_writer=snippet_writer,
).generate(source_file=source_file)
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
Loading

0 comments on commit 6fc9abc

Please sign in to comment.