diff --git a/twitchio/websocket.py b/twitchio/websocket.py index da61a6a1..d4073e7a 100644 --- a/twitchio/websocket.py +++ b/twitchio/websocket.py @@ -31,7 +31,7 @@ import time import traceback from functools import partial -from typing import Union, Optional, List, TYPE_CHECKING +from typing import Union, Callable, Optional, List, TYPE_CHECKING import aiohttp @@ -57,7 +57,7 @@ def __init__( client: "Client", token: str = None, modes: tuple = None, - initial_channels: List[str] = None, + initial_channels: Union[list, tuple, Callable] = None, retain_cache: Optional[bool] = True, ): self._loop = loop @@ -115,7 +115,8 @@ def __init__( async def _task_cleanup(self): while True: # keep all undone tasks - self._background_tasks = list(filter(lambda task: not task.done(), self._background_tasks)) + self._background_tasks = list( + filter(lambda task: not task.done(), self._background_tasks)) # cleanup tasks every 30 seconds await asyncio.sleep(30) @@ -134,13 +135,15 @@ async def _connect(self): if self._keeper: self._keeper.cancel() # Stop our current keep alive. if self.is_alive: - await self._websocket.close() # If for some reason we are in a weird state, close it before retrying. + # If for some reason we are in a weird state, close it before retrying. + await self._websocket.close() if not self._client._http.nick: try: data = await self._client._http.validate(token=self._token) except AuthenticationError: await self._client._http.session.close() - self._client._closing.set() # clean up and error out (this is called to avoid calling Client.close in start() + # clean up and error out (this is called to avoid calling Client.close in start() + self._client._closing.set() raise self.nick = data["login"] self.user_id = int(data["user_id"]) @@ -152,7 +155,8 @@ async def _connect(self): self._websocket = await session.ws_connect(url=HOST, heartbeat=self._heartbeat) except Exception as e: retry = self._backoff.delay() - log.error(f"Websocket connection failure: {e}:: Attempting reconnect in {retry} seconds.") + log.error( + f"Websocket connection failure: {e}:: Attempting reconnect in {retry} seconds.") await asyncio.sleep(retry) return await self._connect() @@ -160,10 +164,12 @@ async def _connect(self): await self.authenticate(self._initial_channels) self._reconnect_requested = False - self._keeper = asyncio.create_task(self._keep_alive()) # Create our keep alive. + # Create our keep alive. + self._keeper = asyncio.create_task(self._keep_alive()) if not self._task_cleaner or self._task_cleaner.done(): - self._task_cleaner = asyncio.create_task(self._task_cleanup()) # Create our task cleaner. + # Create our task cleaner. + self._task_cleaner = asyncio.create_task(self._task_cleanup()) self._ws_ready_event.set() @@ -182,14 +188,16 @@ async def _keep_alive(self): data = msg.data if data: log.debug(f" < {data}") - self.dispatch("raw_data", data) # Dispatch our event_raw_data event... + # Dispatch our event_raw_data event... + self.dispatch("raw_data", data) events = data.split("\r\n") for event in events: if not event: continue task = asyncio.create_task(self._process_data(event)) - task.add_done_callback(partial(self._task_callback, event)) # Process our raw data + # Process our raw data + task.add_done_callback(partial(self._task_callback, event)) self._background_tasks.append(task) self._background_tasks.append(asyncio.create_task(self._connect())) @@ -197,13 +205,16 @@ async def _keep_alive(self): def _task_callback(self, data, task): exc = task.exception() - if isinstance(exc, AuthenticationError): # Check if we failed to log in... - log.error("Authentication error. Please check your credentials and try again.") + # Check if we failed to log in... + if isinstance(exc, AuthenticationError): + log.error( + "Authentication error. Please check your credentials and try again.") self._close() elif exc: # event_error task need to be shielded to avoid cancelling in self._close() function # we need ensure, that the event will print its traceback - shielded_task = asyncio.shield(asyncio.create_task(self.event_error(exc, data))) + shielded_task = asyncio.shield( + asyncio.create_task(self.event_error(exc, data))) self._background_tasks.append(shielded_task) async def send(self, message: str): @@ -218,7 +229,8 @@ async def send(self, message: str): dummy = f"> :{self.nick}!{self.nick}@{self.nick}.tmi.twitch.tv PRIVMSG(ECHO) #{channel} {content}\r\n" task = asyncio.create_task(self._process_data(dummy)) - task.add_done_callback(partial(self._task_callback, dummy)) # Process our raw data + # Process our raw data + task.add_done_callback(partial(self._task_callback, dummy)) self._background_tasks.append(task) await self._websocket.send_str(message + "\r\n") @@ -233,7 +245,8 @@ async def reply(self, msg_id: str, message: str): dummy = f"> @reply-parent-msg-id={msg_id} :{self.nick}!{self.nick}@{self.nick}.tmi.twitch.tv PRIVMSG(ECHO) #{channel} {content}\r\n" task = asyncio.create_task(self._process_data(dummy)) - task.add_done_callback(partial(self._task_callback, dummy)) # Process our raw data + # Process our raw data + task.add_done_callback(partial(self._task_callback, dummy)) self._background_tasks.append(task) await self._websocket.send_str(f"@reply-parent-msg-id={msg_id} {message} \r\n") @@ -258,7 +271,8 @@ async def authenticate(self, channels: Union[list, tuple]): await self.send(f"NICK {self.nick}\r\n") for cap in self.modes: - await self.send(f"CAP REQ :twitch.tv/{cap}") # Ideally no one should overwrite defaults... + # Ideally no one should overwrite defaults... + await self.send(f"CAP REQ :twitch.tv/{cap}") if not channels and not self._initial_channels: return channels = channels or self._initial_channels @@ -305,10 +319,12 @@ async def join_channels(self, *channels: str): channel_count = len(channels) if channel_count > 20: timeout = self._assign_timeout(channel_count) - chunks = [channels[i : i + 20] for i in range(0, len(channels), 20)] + chunks = [channels[i: i + 20] + for i in range(0, len(channels), 20)] for chunk in chunks: for channel in chunk: - task = asyncio.create_task(self._join_channel(channel, timeout)) + task = asyncio.create_task( + self._join_channel(channel, timeout)) self._background_tasks.append(task) await asyncio.sleep(11) @@ -322,7 +338,8 @@ async def _join_channel(self, entry: str, timeout: int): await self.send(f"JOIN #{channel}\r\n") self._join_pending[channel] = fut = self._loop.create_future() - self._background_tasks.append(asyncio.create_task(self._join_future_handle(fut, channel, timeout))) + self._background_tasks.append(asyncio.create_task( + self._join_future_handle(fut, channel, timeout))) async def _join_future_handle(self, fut: asyncio.Future, channel: str, timeout: int): try: @@ -498,7 +515,8 @@ async def _usernotice(self, parsed): channel = Channel(name=parsed["channel"], websocket=self) rawData = parsed["groups"][0] tags = dict(x.split("=", 1) for x in rawData.split(";")) - tags["user-type"] = tags["user-type"].split(":tmi.twitch.tv")[0].strip() + tags["user-type"] = tags["user-type"].split(":tmi.twitch.tv")[ + 0].strip() self.dispatch("raw_usernotice", channel, tags) @@ -552,7 +570,8 @@ def _cache_add(self, parsed: dict): if parsed["batches"]: for u in parsed["batches"]: - user = PartialChatter(name=u, bot=self._client, websocket=self, channel=channel_) + user = PartialChatter( + name=u, bot=self._client, websocket=self, channel=channel_) self._cache[channel].add(user) else: name = parsed["user"] or parsed["nick"] @@ -570,7 +589,8 @@ async def _mode(self, parsed): # TODO pass async def _reconnect(self, parsed): - log.debug("ACTION: RECONNECT:: Twitch has gracefully closed the connection and will reconnect.") + log.debug( + "ACTION: RECONNECT:: Twitch has gracefully closed the connection and will reconnect.") self._reconnect_requested = True self._keeper.cancel() self._loop.create_task(self._connect()) @@ -582,7 +602,8 @@ def dispatch(self, event: str, *args, **kwargs): self._client.run_event(event, *args, **kwargs) async def event_error(self, error: Exception, data: str = None): - traceback.print_exception(type(error), error, error.__traceback__, file=sys.stderr) + traceback.print_exception( + type(error), error, error.__traceback__, file=sys.stderr) def _fetch_futures(self): return [