From a8b276dd2714277d032287563b41a111b8fbcb50 Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Mon, 15 Jul 2024 17:04:37 +0200 Subject: [PATCH] Modify transport init parameters Adding connect_args parameter to be able to provide any argument to the ws_connect method Removing the following parameters (they can now be provided in the connect_args dict): - autoclose - autoping - compress - max_msg_size - verify_ssl - method Renaming protocols to subprotocols to be more similar to the websockets transport --- gql/transport/aiohttp_websockets.py | 51 +++++++++++++-------------- tests/conftest.py | 2 +- tests/test_aiohttp_websocket_query.py | 6 ++-- 3 files changed, 29 insertions(+), 30 deletions(-) diff --git a/gql/transport/aiohttp_websockets.py b/gql/transport/aiohttp_websockets.py index 9c28f233..6186610f 100644 --- a/gql/transport/aiohttp_websockets.py +++ b/gql/transport/aiohttp_websockets.py @@ -17,7 +17,7 @@ ) import aiohttp -from aiohttp import BasicAuth, Fingerprint, WSMsgType, hdrs +from aiohttp import BasicAuth, Fingerprint, WSMsgType from aiohttp.typedefs import LooseHeaders, StrOrURL from graphql import DocumentNode, ExecutionResult, print_ast from multidict import CIMultiDictProxy @@ -110,10 +110,7 @@ def __init__( self, url: StrOrURL, *, - method: str = hdrs.METH_GET, - protocols: Collection[str] = (), - autoclose: bool = True, - autoping: bool = True, + subprotocols: Optional[Collection[str]] = None, heartbeat: Optional[float] = None, auth: Optional[BasicAuth] = None, origin: Optional[str] = None, @@ -121,12 +118,9 @@ def __init__( headers: Optional[LooseHeaders] = None, proxy: Optional[StrOrURL] = None, proxy_auth: Optional[BasicAuth] = None, + proxy_headers: Optional[LooseHeaders] = None, ssl: Optional[Union[SSLContext, Literal[False], Fingerprint]] = None, ssl_context: Optional[SSLContext] = None, - verify_ssl: Optional[bool] = True, - proxy_headers: Optional[LooseHeaders] = None, - compress: int = 0, - max_msg_size: int = 4 * 1024 * 1024, websocket_close_timeout: float = 10.0, receive_timeout: Optional[float] = None, ssl_close_timeout: Optional[Union[int, float]] = 10, @@ -139,32 +133,31 @@ def __init__( pong_timeout: Optional[Union[int, float]] = None, answer_pings: bool = True, client_session_args: Optional[Dict[str, Any]] = None, + connect_args: Dict[str, Any] = {}, ) -> None: self.url: StrOrURL = url - self.headers: Optional[LooseHeaders] = headers - self.auth: Optional[BasicAuth] = auth - self.autoclose: bool = autoclose - self.autoping: bool = autoping - self.compress: int = compress self.heartbeat: Optional[float] = heartbeat - self.max_msg_size: int = max_msg_size - self.method: str = method + self.auth: Optional[BasicAuth] = auth self.origin: Optional[str] = origin self.params: Optional[Mapping[str, str]] = params - self.protocols: Collection[str] = protocols + self.headers: Optional[LooseHeaders] = headers + self.proxy: Optional[StrOrURL] = proxy self.proxy_auth: Optional[BasicAuth] = proxy_auth self.proxy_headers: Optional[LooseHeaders] = proxy_headers - self.ssl_close_timeout: Optional[Union[int, float]] = ssl_close_timeout + self.ssl: Optional[Union[SSLContext, Literal[False], Fingerprint]] = ssl self.ssl_context: Optional[SSLContext] = ssl_context + self.websocket_close_timeout: float = websocket_close_timeout self.receive_timeout: Optional[float] = receive_timeout + + self.ssl_close_timeout: Optional[Union[int, float]] = ssl_close_timeout self.connect_timeout: Optional[Union[int, float]] = connect_timeout self.close_timeout: Optional[Union[int, float]] = close_timeout self.ack_timeout: Optional[Union[int, float]] = ack_timeout self.keep_alive_timeout: Optional[Union[int, float]] = keep_alive_timeout - self.verify_ssl: Optional[bool] = verify_ssl + self.init_payload: Dict[str, Any] = init_payload # We need to set an event loop here if there is none @@ -221,12 +214,15 @@ def __init__( """pong_received is an asyncio Event which will fire each time a pong is received with the graphql-ws protocol""" - self.supported_subprotocols: Collection[str] = protocols or ( + self.supported_subprotocols: Collection[str] = subprotocols or ( self.APOLLO_SUBPROTOCOL, self.GRAPHQLWS_SUBPROTOCOL, ) + self.close_exception: Optional[Exception] = None + self.client_session_args = client_session_args + self.connect_args = connect_args def _parse_answer_graphqlws( self, answer: Dict[str, Any] @@ -782,28 +778,29 @@ async def connect(self) -> None: if self.websocket is None and not self._connecting: self._connecting = True + connect_args: Dict[str, Any] = {} + + # Adding custom parameters passed from init + if self.connect_args: + connect_args.update(self.connect_args) + try: self.websocket = await self.session.ws_connect( - method=self.method, url=self.url, headers=self.headers, auth=self.auth, - autoclose=self.autoclose, - autoping=self.autoping, - compress=self.compress, heartbeat=self.heartbeat, - max_msg_size=self.max_msg_size, origin=self.origin, params=self.params, protocols=self.supported_subprotocols, proxy=self.proxy, proxy_auth=self.proxy_auth, proxy_headers=self.proxy_headers, + timeout=self.websocket_close_timeout, receive_timeout=self.receive_timeout, ssl=self.ssl, ssl_context=None, - timeout=self.websocket_close_timeout, - verify_ssl=self.verify_ssl, + **connect_args, ) finally: self._connecting = False diff --git a/tests/conftest.py b/tests/conftest.py index bd68982b..ee288eea 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -516,7 +516,7 @@ async def client_and_aiohttp_websocket_graphql_server(graphqlws_server): url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}{path}" sample_transport = AIOHTTPWebsocketsTransport( url=url, - protocols=[AIOHTTPWebsocketsTransport.GRAPHQLWS_SUBPROTOCOL], + subprotocols=[AIOHTTPWebsocketsTransport.GRAPHQLWS_SUBPROTOCOL], ) async with Client(transport=sample_transport) as session: diff --git a/tests/test_aiohttp_websocket_query.py b/tests/test_aiohttp_websocket_query.py index 8d6fbab9..6fb8eafa 100644 --- a/tests/test_aiohttp_websocket_query.py +++ b/tests/test_aiohttp_websocket_query.py @@ -499,10 +499,12 @@ async def test_aiohttp_websocket_add_extra_parameters_to_connect(event_loop, ser url = f"ws://{server.hostname}:{server.port}/graphql" - # Increase max payload size to avoid websockets.exceptions.PayloadTooBig exceptions + # Increase max payload size transport = AIOHTTPWebsocketsTransport( url=url, - max_msg_size=(2**21), + connect_args={ + "max_msg_size": 2**21, + }, ) query = gql(query1_str)