Skip to content

Commit

Permalink
Formatted registry.py
Browse files Browse the repository at this point in the history
  • Loading branch information
viveksilimkhan1 authored Oct 25, 2023
1 parent 621d2d1 commit c35c88e
Showing 1 changed file with 19 additions and 11 deletions.
30 changes: 19 additions & 11 deletions spacy_llm/models/bedrock/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,31 +11,40 @@
_DEFAULT_TOP_P: int = 1
_DEFAULT_STOP_SEQUENCES: List[str] = []


@registry.llm_models("spacy.Bedrock.Titan.Express.v1")
def titan_express(
region: str,
model_id: Models = Models.TITAN_EXPRESS,
config: Dict[Any, Any] = SimpleFrozenDict(temperature=_DEFAULT_TEMPERATURE, maxTokenCount=_DEFAULT_MAX_TOKEN_COUNT, stopSequences=_DEFAULT_STOP_SEQUENCES, topP =_DEFAULT_TOP_P),
max_retries: int = _DEFAULT_RETRIES
config: Dict[Any, Any] = SimpleFrozenDict(
temperature=_DEFAULT_TEMPERATURE,
maxTokenCount=_DEFAULT_MAX_TOKEN_COUNT,
stopSequences=_DEFAULT_STOP_SEQUENCES,
topP=_DEFAULT_TOP_P,
),
max_retries: int = _DEFAULT_RETRIES,
) -> Callable[[Iterable[str]], Iterable[str]]:
"""Returns Bedrock instance for 'amazon-titan-express' model using boto3 to prompt API.
model_id (ModelId): ID of the deployed model (titan-express)
region (str): Specify the AWS region for the service
config (Dict[Any, Any]): LLM config passed on to the model's initialization.
"""
return Bedrock(
model_id = model_id,
region = region,
config=config,
max_retries=max_retries
model_id=model_id, region=region, config=config, max_retries=max_retries
)


@registry.llm_models("spacy.Bedrock.Titan.Lite.v1")
def titan_lite(
region: str,
model_id: Models = Models.TITAN_LITE,
config: Dict[Any, Any] = SimpleFrozenDict(temperature=_DEFAULT_TEMPERATURE, maxTokenCount=_DEFAULT_MAX_TOKEN_COUNT, stopSequences=_DEFAULT_STOP_SEQUENCES, topP =_DEFAULT_TOP_P),
max_retries: int = _DEFAULT_RETRIES
config: Dict[Any, Any] = SimpleFrozenDict(
temperature=_DEFAULT_TEMPERATURE,
maxTokenCount=_DEFAULT_MAX_TOKEN_COUNT,
stopSequences=_DEFAULT_STOP_SEQUENCES,
topP=_DEFAULT_TOP_P,
),
max_retries: int = _DEFAULT_RETRIES,
) -> Callable[[Iterable[str]], Iterable[str]]:
"""Returns Bedrock instance for 'amazon-titan-lite' model using boto3 to prompt API.
region (str): Specify the AWS region for the service
Expand All @@ -44,9 +53,8 @@ def titan_lite(
config (Dict[Any, Any]): LLM config passed on to the model's initialization.
"""
return Bedrock(
model_id = model_id,
region = region,
model_id=model_id,
region=region,
config=config,
max_retries=max_retries,
)

0 comments on commit c35c88e

Please sign in to comment.