From fc73ba12dfdb97ef5560d88e42fc08227b35fdea Mon Sep 17 00:00:00 2001 From: Micah Pegman Date: Mon, 10 Jun 2024 18:45:37 +0000 Subject: [PATCH] fix some tests --- gql/transport/aiohttp_websockets.py | 37 ++++++++++++++++------ tests/test_aiohttp_websocket_exceptions.py | 2 +- 2 files changed, 29 insertions(+), 10 deletions(-) diff --git a/gql/transport/aiohttp_websockets.py b/gql/transport/aiohttp_websockets.py index aa0000de..5cf88bbb 100644 --- a/gql/transport/aiohttp_websockets.py +++ b/gql/transport/aiohttp_websockets.py @@ -16,8 +16,7 @@ ) import aiohttp -from aiohttp.client_reqrep import Fingerprint -from aiohttp.helpers import BasicAuth, hdrs +from aiohttp import hdrs, BasicAuth, Fingerprint, WSMsgType from aiohttp.typedefs import LooseHeaders, StrOrURL from graphql import DocumentNode, ExecutionResult, print_ast from multidict import CIMultiDict, CIMultiDictProxy @@ -32,6 +31,11 @@ ) from gql.transport.websockets_base import ListenerQueue +try: + from json.decoder import JSONDecodeError +except ImportError: + from simplejson import JSONDecodeError + log = logging.getLogger("gql.transport.aiohttp_websockets") @@ -149,7 +153,7 @@ def __init__( self.close_exception: Optional[Exception] = None def _parse_answer_graphqlws( - self, json_answer: Dict[str, Any] + self, answer: Dict[str, Any] ) -> Tuple[str, Optional[int], Optional[ExecutionResult]]: """Parse the answer received from the server if the server supports the graphql-ws protocol. @@ -175,14 +179,14 @@ def _parse_answer_graphqlws( execution_result: Optional[ExecutionResult] = None try: - answer_type = str(json_answer.get("type")) + answer_type = str(answer.get("type")) if answer_type in ["next", "error", "complete"]: - answer_id = int(str(json_answer.get("id"))) + answer_id = int(str(answer.get("id"))) if answer_type == "next" or answer_type == "error": - payload = json_answer.get("payload") + payload = answer.get("payload") if answer_type == "next": @@ -213,7 +217,7 @@ def _parse_answer_graphqlws( ) elif answer_type in ["ping", "pong", "connection_ack"]: - self.payloads[answer_type] = json_answer.get("payload", None) + self.payloads[answer_type] = answer.get("payload", None) else: raise ValueError @@ -223,7 +227,7 @@ def _parse_answer_graphqlws( except ValueError as e: raise TransportProtocolError( - f"Server did not return a GraphQL result: {json_answer}" + f"Server did not return a GraphQL result: {answer}" ) from e return answer_type, answer_id, execution_result @@ -471,14 +475,27 @@ async def _send(self, message: Dict[str, Any]) -> None: raise e async def _receive(self) -> Dict[str, Any]: + log.debug("Entering _receive()") if self.websocket is None: raise TransportClosed("WebSocket connection is closed") - answer = await self.websocket.receive_json() + try: + answer = await self.websocket.receive_json() + except TypeError as e: + answer = await self.websocket.receive() + if answer.type in (WSMsgType.CLOSE, WSMsgType.CLOSED, WSMsgType.CLOSING): + self._fail(e, clean_close=True) + raise ConnectionResetError + else: + self._fail(e, clean_close=False) + except JSONDecodeError as e: + self._fail(e) log.info("<<< %s", answer) + log.debug("Exiting _receive()") + return answer def _remove_listener(self, query_id) -> None: @@ -546,6 +563,8 @@ async def _handle_answer( async def _receive_data_loop(self) -> None: """Main asyncio task which will listen to the incoming messages and will call the parse_answer and handle_answer methods of the subclass.""" + log.debug("Entering _receive_data_loop()") + try: while True: diff --git a/tests/test_aiohttp_websocket_exceptions.py b/tests/test_aiohttp_websocket_exceptions.py index b2e53188..d50ac887 100644 --- a/tests/test_aiohttp_websocket_exceptions.py +++ b/tests/test_aiohttp_websocket_exceptions.py @@ -250,7 +250,7 @@ async def test_aiohttp_websocket_transport_protocol_errors( query = gql("query { hello }") - with pytest.raises(TransportProtocolError): + with pytest.raises((TransportProtocolError, TransportQueryError)): await session.execute(query)