Skip to content

Commit

Permalink
Update model.py
Browse files Browse the repository at this point in the history
  • Loading branch information
viveksilimkhan1 authored Oct 24, 2023
1 parent 77b3c2c commit cd342cc
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions spacy_llm/models/bedrock/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __init__(
self._max_retries = max_retries

# @property
def get_session(self) -> Dict[str, str]:
def get_session_kwargs(self) -> Dict[str, str]:

# Fetch and check the credentials
profile = os.getenv("AWS_PROFILE") if not None else ""
Expand Down Expand Up @@ -69,7 +69,7 @@ def __call__(self, prompts: Iterable[str]) -> Iterable[str]:
api_responses: List[str] = []
prompts = list(prompts)

def _request(json_data: Dict[str, Any]) -> Dict[str, Any]:
def _request(json_data: str) -> Dict[str, Any]:
try:
import boto3
import botocore
Expand All @@ -79,6 +79,7 @@ def _request(json_data: Dict[str, Any]) -> Dict[str, Any]:
from botocore.config import Config
session_kwargs = self.get_session_kwargs()
session = boto3.Session(**session_kwargs)
api_config = Config(retries = dict(max_attempts = self._max_retries))
print("Session:", session)
bedrock = session.client(service_name="bedrock-runtime", config=api_config)
accept = "application/json"
Expand Down

0 comments on commit cd342cc

Please sign in to comment.