Skip to content

Commit

Permalink
Format and remove unnecessary imports
Browse files Browse the repository at this point in the history
  • Loading branch information
viveksilimkhan1 committed Oct 27, 2023
1 parent a133124 commit ee23841
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 64 deletions.
3 changes: 3 additions & 0 deletions spacy_llm/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .bedrock import titan_express, titan_lite
from .hf import dolly_hf, openllama_hf, stablelm_hf
from .langchain import query_langchain
from .rest import anthropic, cohere, noop, openai, palm
Expand All @@ -12,4 +13,6 @@
"openllama_hf",
"palm",
"query_langchain",
"titan_lite",
"titan_express",
]
100 changes: 54 additions & 46 deletions spacy_llm/models/bedrock/model.py
Original file line number Diff line number Diff line change
@@ -1,100 +1,108 @@
import os
import json
import os
import warnings
from enum import Enum
from requests import HTTPError
from typing import Any, Dict, Iterable, Optional, Type, List, Sized, Tuple

from confection import SimpleFrozenDict
from typing import Any, Dict, Iterable, List, Optional

from ...registry import registry

try:
import boto3
import botocore
from botocore.config import Config
except ImportError as err:
print("To use Bedrock, you need to install boto3. Use `pip install boto3` ")
raise err

class Models(str, Enum):
# Completion models
TITAN_EXPRESS = "amazon.titan-text-express-v1"
TITAN_LITE = "amazon.titan-text-lite-v1"

class Bedrock():

class Bedrock:
def __init__(
self,
model_id: str,
region: str,
config: Dict[Any, Any],
max_retries: int = 5
self, model_id: str, region: str, config: Dict[Any, Any], max_retries: int = 5
):

self._region = region
self._model_id = model_id
self._config = config
self._max_retries = max_retries

# @property
def get_session(self) -> Dict[str, str]:

def get_session_kwargs(self) -> Dict[str, Optional[str]]:

# Fetch and check the credentials
profile = os.getenv("AWS_PROFILE") if not None else ""
profile = os.getenv("AWS_PROFILE") if not None else ""
secret_key_id = os.getenv("AWS_ACCESS_KEY_ID")
secret_access_key = os.getenv("AWS_SECRET_ACCESS_KEY")
session_token = os.getenv("AWS_SESSION_TOKEN")

if profile is None:
warnings.warn(
"Could not find the AWS_PROFILE to access the Amazon Bedrock . Ensure you have an AWS_PROFILE "
"set up by making it available as an environment variable 'AWS_PROFILE'."
)
"set up by making it available as an environment variable AWS_PROFILE."
)

if secret_key_id is None:
warnings.warn(
"Could not find the AWS_ACCESS_KEY_ID to access the Amazon Bedrock . Ensure you have an AWS_ACCESS_KEY_ID "
"set up by making it available as an environment variable 'AWS_ACCESS_KEY_ID'."
"set up by making it available as an environment variable AWS_ACCESS_KEY_ID."
)

if secret_access_key is None:
warnings.warn(
"Could not find the AWS_SECRET_ACCESS_KEY to access the Amazon Bedrock . Ensure you have an AWS_SECRET_ACCESS_KEY "
"set up by making it available as an environment variable 'AWS_SECRET_ACCESS_KEY'."
"set up by making it available as an environment variable AWS_SECRET_ACCESS_KEY."
)

if session_token is None:
warnings.warn(
"Could not find the AWS_SESSION_TOKEN to access the Amazon Bedrock . Ensure you have an AWS_SESSION_TOKEN "
"set up by making it available as an environment variable 'AWS_SESSION_TOKEN'."
"set up by making it available as an environment variable AWS_SESSION_TOKEN."
)

assert secret_key_id is not None
assert secret_access_key is not None
assert session_token is not None

session_kwargs = {"profile_name":profile, "region_name":self._region, "aws_access_key_id":secret_key_id, "aws_secret_access_key":secret_access_key, "aws_session_token":session_token}
bedrock = boto3.Session(**session_kwargs)
return bedrock

def __call__(self, prompts: Iterable[str]) -> Iterable[str]:
session_kwargs = {
"profile_name": profile,
"region_name": self._region,
"aws_access_key_id": secret_key_id,
"aws_secret_access_key": secret_access_key,
"aws_session_token": session_token,
}
return session_kwargs

def __call__(self, prompts: Iterable[str]) -> Iterable[str]:
api_responses: List[str] = []
prompts = list(prompts)
api_config = Config(retries = dict(max_attempts = self._max_retries))

def _request(json_data: Dict[str, Any]) -> Dict[str, Any]:
session = self.get_session()
print("Session:", session)
def _request(json_data: str) -> str:
try:
import boto3
except ImportError as err:
warnings.warn(
"To use Bedrock, you need to install boto3. Use pip install boto3 "
)
raise err
from botocore.config import Config

session_kwargs = self.get_session_kwargs()
session = boto3.Session(**session_kwargs)
api_config = Config(retries=dict(max_attempts=self._max_retries))
bedrock = session.client(service_name="bedrock-runtime", config=api_config)
accept = 'application/json'
contentType = 'application/json'
r = bedrock.invoke_model(body=json_data, modelId=self._model_id, accept=accept, contentType=contentType)
responses = json.loads(r['body'].read().decode())['results'][0]['outputText']
accept = "application/json"
contentType = "application/json"
r = bedrock.invoke_model(
body=json_data,
modelId=self._model_id,
accept=accept,
contentType=contentType,
)
responses = json.loads(r["body"].read().decode())["results"][0][
"outputText"
]
return responses

for prompt in prompts:
if self._model_id in [Models.TITAN_LITE, Models.TITAN_EXPRESS]:
responses = _request(json.dumps({"inputText": prompt, "textGenerationConfig":self._config}))
if "error" in responses:
return responses["error"]
responses = _request(
json.dumps(
{"inputText": prompt, "textGenerationConfig": self._config}
)
)

api_responses.append(responses)

Expand Down
42 changes: 25 additions & 17 deletions spacy_llm/models/bedrock/registry.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,50 @@
from typing import Any, Callable, Dict, Iterable
from typing import Any, Callable, Dict, Iterable, List

from confection import SimpleFrozenDict

from ...registry import registry
from .model import Bedrock, Models

_DEFAULT_RETRIES = 5
_DEFAULT_TEMPERATURE = 0.0
_DEFAULT_MAX_TOKEN_COUNT = 512
_DEFAULT_TOP_P = 1
_DEFAULT_STOP_SEQUENCES = []
_DEFAULT_RETRIES: int = 5
_DEFAULT_TEMPERATURE: float = 0.0
_DEFAULT_MAX_TOKEN_COUNT: int = 512
_DEFAULT_TOP_P: int = 1
_DEFAULT_STOP_SEQUENCES: List[str] = []


@registry.llm_models("spacy.Bedrock.Titan.Express.v1")
def titan_express(
region: str,
model_id: Models = Models.TITAN_EXPRESS,
config: Dict[Any, Any] = SimpleFrozenDict(temperature=_DEFAULT_TEMPERATURE, maxTokenCount=_DEFAULT_MAX_TOKEN_COUNT, stopSequences=_DEFAULT_STOP_SEQUENCES, topP =_DEFAULT_TOP_P),
max_retries: int = _DEFAULT_RETRIES
config: Dict[Any, Any] = SimpleFrozenDict(
temperature=_DEFAULT_TEMPERATURE,
maxTokenCount=_DEFAULT_MAX_TOKEN_COUNT,
stopSequences=_DEFAULT_STOP_SEQUENCES,
topP=_DEFAULT_TOP_P,
),
max_retries: int = _DEFAULT_RETRIES,
) -> Callable[[Iterable[str]], Iterable[str]]:
"""Returns Bedrock instance for 'amazon-titan-express' model using boto3 to prompt API.
model_id (ModelId): ID of the deployed model (titan-express)
region (str): Specify the AWS region for the service
config (Dict[Any, Any]): LLM config passed on to the model's initialization.
"""
return Bedrock(
model_id = model_id,
region = region,
config=config,
max_retries=max_retries
model_id=model_id, region=region, config=config, max_retries=max_retries
)


@registry.llm_models("spacy.Bedrock.Titan.Lite.v1")
def titan_lite(
region: str,
model_id: Models = Models.TITAN_LITE,
config: Dict[Any, Any] = SimpleFrozenDict(temperature=_DEFAULT_TEMPERATURE, maxTokenCount=_DEFAULT_MAX_TOKEN_COUNT, stopSequences=_DEFAULT_STOP_SEQUENCES, topP =_DEFAULT_TOP_P),
max_retries: int = _DEFAULT_RETRIES
config: Dict[Any, Any] = SimpleFrozenDict(
temperature=_DEFAULT_TEMPERATURE,
maxTokenCount=_DEFAULT_MAX_TOKEN_COUNT,
stopSequences=_DEFAULT_STOP_SEQUENCES,
topP=_DEFAULT_TOP_P,
),
max_retries: int = _DEFAULT_RETRIES,
) -> Callable[[Iterable[str]], Iterable[str]]:
"""Returns Bedrock instance for 'amazon-titan-lite' model using boto3 to prompt API.
region (str): Specify the AWS region for the service
Expand All @@ -44,9 +53,8 @@ def titan_lite(
config (Dict[Any, Any]): LLM config passed on to the model's initialization.
"""
return Bedrock(
model_id = model_id,
region = region,
model_id=model_id,
region=region,
config=config,
max_retries=max_retries,
)

2 changes: 1 addition & 1 deletion usage_examples/ner_v3_titan/fewshot.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,4 @@ path = "${paths.examples}"

[components.llm.model]
@llm_models = "spacy.Bedrock.Titan.Express.v1"
region = us-east-1
region = <aws-region>

0 comments on commit ee23841

Please sign in to comment.