Skip to content

Commit

Permalink
fix some tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mpegman-scwx committed Jun 10, 2024
1 parent 6584f54 commit fc73ba1
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 10 deletions.
37 changes: 28 additions & 9 deletions gql/transport/aiohttp_websockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@
)

import aiohttp
from aiohttp.client_reqrep import Fingerprint
from aiohttp.helpers import BasicAuth, hdrs
from aiohttp import hdrs, BasicAuth, Fingerprint, WSMsgType
from aiohttp.typedefs import LooseHeaders, StrOrURL
from graphql import DocumentNode, ExecutionResult, print_ast
from multidict import CIMultiDict, CIMultiDictProxy
Expand All @@ -32,6 +31,11 @@
)
from gql.transport.websockets_base import ListenerQueue

try:
from json.decoder import JSONDecodeError
except ImportError:
from simplejson import JSONDecodeError

log = logging.getLogger("gql.transport.aiohttp_websockets")


Expand Down Expand Up @@ -149,7 +153,7 @@ def __init__(
self.close_exception: Optional[Exception] = None

def _parse_answer_graphqlws(
self, json_answer: Dict[str, Any]
self, answer: Dict[str, Any]
) -> Tuple[str, Optional[int], Optional[ExecutionResult]]:
"""Parse the answer received from the server if the server supports the
graphql-ws protocol.
Expand All @@ -175,14 +179,14 @@ def _parse_answer_graphqlws(
execution_result: Optional[ExecutionResult] = None

try:
answer_type = str(json_answer.get("type"))
answer_type = str(answer.get("type"))

if answer_type in ["next", "error", "complete"]:
answer_id = int(str(json_answer.get("id")))
answer_id = int(str(answer.get("id")))

if answer_type == "next" or answer_type == "error":

payload = json_answer.get("payload")
payload = answer.get("payload")

if answer_type == "next":

Expand Down Expand Up @@ -213,7 +217,7 @@ def _parse_answer_graphqlws(
)

elif answer_type in ["ping", "pong", "connection_ack"]:
self.payloads[answer_type] = json_answer.get("payload", None)
self.payloads[answer_type] = answer.get("payload", None)

else:
raise ValueError
Expand All @@ -223,7 +227,7 @@ def _parse_answer_graphqlws(

except ValueError as e:
raise TransportProtocolError(
f"Server did not return a GraphQL result: {json_answer}"
f"Server did not return a GraphQL result: {answer}"
) from e

return answer_type, answer_id, execution_result
Expand Down Expand Up @@ -471,14 +475,27 @@ async def _send(self, message: Dict[str, Any]) -> None:
raise e

async def _receive(self) -> Dict[str, Any]:
log.debug("Entering _receive()")

if self.websocket is None:
raise TransportClosed("WebSocket connection is closed")

answer = await self.websocket.receive_json()
try:
answer = await self.websocket.receive_json()
except TypeError as e:
answer = await self.websocket.receive()
if answer.type in (WSMsgType.CLOSE, WSMsgType.CLOSED, WSMsgType.CLOSING):
self._fail(e, clean_close=True)
raise ConnectionResetError
else:
self._fail(e, clean_close=False)
except JSONDecodeError as e:
self._fail(e)

log.info("<<< %s", answer)

log.debug("Exiting _receive()")

return answer

def _remove_listener(self, query_id) -> None:
Expand Down Expand Up @@ -546,6 +563,8 @@ async def _handle_answer(
async def _receive_data_loop(self) -> None:
"""Main asyncio task which will listen to the incoming messages and will
call the parse_answer and handle_answer methods of the subclass."""
log.debug("Entering _receive_data_loop()")

try:
while True:

Expand Down
2 changes: 1 addition & 1 deletion tests/test_aiohttp_websocket_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ async def test_aiohttp_websocket_transport_protocol_errors(

query = gql("query { hello }")

with pytest.raises(TransportProtocolError):
with pytest.raises((TransportProtocolError, TransportQueryError)):
await session.execute(query)


Expand Down

0 comments on commit fc73ba1

Please sign in to comment.