Skip to content

Commit

Permalink
Fix invalid type hinting of initial_channels in WS Connection (Python…
Browse files Browse the repository at this point in the history
  • Loading branch information
junah201 committed Jul 2, 2024
1 parent 8e91cec commit bd20ae7
Showing 1 changed file with 44 additions and 23 deletions.
67 changes: 44 additions & 23 deletions twitchio/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
import time
import traceback
from functools import partial
from typing import Union, Optional, List, TYPE_CHECKING
from typing import Union, Callable, Optional, List, TYPE_CHECKING

import aiohttp

Expand All @@ -57,7 +57,7 @@ def __init__(
client: "Client",
token: str = None,
modes: tuple = None,
initial_channels: List[str] = None,
initial_channels: Union[list, tuple, Callable] = None,
retain_cache: Optional[bool] = True,
):
self._loop = loop
Expand Down Expand Up @@ -115,7 +115,8 @@ def __init__(
async def _task_cleanup(self):
while True:
# keep all undone tasks
self._background_tasks = list(filter(lambda task: not task.done(), self._background_tasks))
self._background_tasks = list(
filter(lambda task: not task.done(), self._background_tasks))

# cleanup tasks every 30 seconds
await asyncio.sleep(30)
Expand All @@ -134,13 +135,15 @@ async def _connect(self):
if self._keeper:
self._keeper.cancel() # Stop our current keep alive.
if self.is_alive:
await self._websocket.close() # If for some reason we are in a weird state, close it before retrying.
# If for some reason we are in a weird state, close it before retrying.
await self._websocket.close()
if not self._client._http.nick:
try:
data = await self._client._http.validate(token=self._token)
except AuthenticationError:
await self._client._http.session.close()
self._client._closing.set() # clean up and error out (this is called to avoid calling Client.close in start()
# clean up and error out (this is called to avoid calling Client.close in start()
self._client._closing.set()
raise
self.nick = data["login"]
self.user_id = int(data["user_id"])
Expand All @@ -152,18 +155,21 @@ async def _connect(self):
self._websocket = await session.ws_connect(url=HOST, heartbeat=self._heartbeat)
except Exception as e:
retry = self._backoff.delay()
log.error(f"Websocket connection failure: {e}:: Attempting reconnect in {retry} seconds.")
log.error(
f"Websocket connection failure: {e}:: Attempting reconnect in {retry} seconds.")

await asyncio.sleep(retry)
return await self._connect()

await self.authenticate(self._initial_channels)

self._reconnect_requested = False
self._keeper = asyncio.create_task(self._keep_alive()) # Create our keep alive.
# Create our keep alive.
self._keeper = asyncio.create_task(self._keep_alive())

if not self._task_cleaner or self._task_cleaner.done():
self._task_cleaner = asyncio.create_task(self._task_cleanup()) # Create our task cleaner.
# Create our task cleaner.
self._task_cleaner = asyncio.create_task(self._task_cleanup())

self._ws_ready_event.set()

Expand All @@ -182,28 +188,33 @@ async def _keep_alive(self):
data = msg.data
if data:
log.debug(f" < {data}")
self.dispatch("raw_data", data) # Dispatch our event_raw_data event...
# Dispatch our event_raw_data event...
self.dispatch("raw_data", data)

events = data.split("\r\n")
for event in events:
if not event:
continue
task = asyncio.create_task(self._process_data(event))
task.add_done_callback(partial(self._task_callback, event)) # Process our raw data
# Process our raw data
task.add_done_callback(partial(self._task_callback, event))
self._background_tasks.append(task)

self._background_tasks.append(asyncio.create_task(self._connect()))

def _task_callback(self, data, task):
exc = task.exception()

if isinstance(exc, AuthenticationError): # Check if we failed to log in...
log.error("Authentication error. Please check your credentials and try again.")
# Check if we failed to log in...
if isinstance(exc, AuthenticationError):
log.error(
"Authentication error. Please check your credentials and try again.")
self._close()
elif exc:
# event_error task need to be shielded to avoid cancelling in self._close() function
# we need ensure, that the event will print its traceback
shielded_task = asyncio.shield(asyncio.create_task(self.event_error(exc, data)))
shielded_task = asyncio.shield(
asyncio.create_task(self.event_error(exc, data)))
self._background_tasks.append(shielded_task)

async def send(self, message: str):
Expand All @@ -218,7 +229,8 @@ async def send(self, message: str):
dummy = f"> :{self.nick}!{self.nick}@{self.nick}.tmi.twitch.tv PRIVMSG(ECHO) #{channel} {content}\r\n"

task = asyncio.create_task(self._process_data(dummy))
task.add_done_callback(partial(self._task_callback, dummy)) # Process our raw data
# Process our raw data
task.add_done_callback(partial(self._task_callback, dummy))
self._background_tasks.append(task)
await self._websocket.send_str(message + "\r\n")

Expand All @@ -233,7 +245,8 @@ async def reply(self, msg_id: str, message: str):

dummy = f"> @reply-parent-msg-id={msg_id} :{self.nick}!{self.nick}@{self.nick}.tmi.twitch.tv PRIVMSG(ECHO) #{channel} {content}\r\n"
task = asyncio.create_task(self._process_data(dummy))
task.add_done_callback(partial(self._task_callback, dummy)) # Process our raw data
# Process our raw data
task.add_done_callback(partial(self._task_callback, dummy))
self._background_tasks.append(task)
await self._websocket.send_str(f"@reply-parent-msg-id={msg_id} {message} \r\n")

Expand All @@ -258,7 +271,8 @@ async def authenticate(self, channels: Union[list, tuple]):
await self.send(f"NICK {self.nick}\r\n")

for cap in self.modes:
await self.send(f"CAP REQ :twitch.tv/{cap}") # Ideally no one should overwrite defaults...
# Ideally no one should overwrite defaults...
await self.send(f"CAP REQ :twitch.tv/{cap}")
if not channels and not self._initial_channels:
return
channels = channels or self._initial_channels
Expand Down Expand Up @@ -305,10 +319,12 @@ async def join_channels(self, *channels: str):
channel_count = len(channels)
if channel_count > 20:
timeout = self._assign_timeout(channel_count)
chunks = [channels[i : i + 20] for i in range(0, len(channels), 20)]
chunks = [channels[i: i + 20]
for i in range(0, len(channels), 20)]
for chunk in chunks:
for channel in chunk:
task = asyncio.create_task(self._join_channel(channel, timeout))
task = asyncio.create_task(
self._join_channel(channel, timeout))
self._background_tasks.append(task)

await asyncio.sleep(11)
Expand All @@ -322,7 +338,8 @@ async def _join_channel(self, entry: str, timeout: int):
await self.send(f"JOIN #{channel}\r\n")

self._join_pending[channel] = fut = self._loop.create_future()
self._background_tasks.append(asyncio.create_task(self._join_future_handle(fut, channel, timeout)))
self._background_tasks.append(asyncio.create_task(
self._join_future_handle(fut, channel, timeout)))

async def _join_future_handle(self, fut: asyncio.Future, channel: str, timeout: int):
try:
Expand Down Expand Up @@ -498,7 +515,8 @@ async def _usernotice(self, parsed):
channel = Channel(name=parsed["channel"], websocket=self)
rawData = parsed["groups"][0]
tags = dict(x.split("=", 1) for x in rawData.split(";"))
tags["user-type"] = tags["user-type"].split(":tmi.twitch.tv")[0].strip()
tags["user-type"] = tags["user-type"].split(":tmi.twitch.tv")[
0].strip()

self.dispatch("raw_usernotice", channel, tags)

Expand Down Expand Up @@ -552,7 +570,8 @@ def _cache_add(self, parsed: dict):

if parsed["batches"]:
for u in parsed["batches"]:
user = PartialChatter(name=u, bot=self._client, websocket=self, channel=channel_)
user = PartialChatter(
name=u, bot=self._client, websocket=self, channel=channel_)
self._cache[channel].add(user)
else:
name = parsed["user"] or parsed["nick"]
Expand All @@ -570,7 +589,8 @@ async def _mode(self, parsed): # TODO
pass

async def _reconnect(self, parsed):
log.debug("ACTION: RECONNECT:: Twitch has gracefully closed the connection and will reconnect.")
log.debug(
"ACTION: RECONNECT:: Twitch has gracefully closed the connection and will reconnect.")
self._reconnect_requested = True
self._keeper.cancel()
self._loop.create_task(self._connect())
Expand All @@ -582,7 +602,8 @@ def dispatch(self, event: str, *args, **kwargs):
self._client.run_event(event, *args, **kwargs)

async def event_error(self, error: Exception, data: str = None):
traceback.print_exception(type(error), error, error.__traceback__, file=sys.stderr)
traceback.print_exception(
type(error), error, error.__traceback__, file=sys.stderr)

def _fetch_futures(self):
return [
Expand Down

0 comments on commit bd20ae7

Please sign in to comment.