From 77aa8755ede8230453d3f4dcb070a586509f6bce Mon Sep 17 00:00:00 2001 From: Jacob Henry Date: Fri, 22 Feb 2019 21:15:54 -0500 Subject: [PATCH] Semi-stable. Need to check signoffs --- job_commands.py | 2 +- main.py | 4 +- periodicals.py | 2 +- slack_util.py | 196 ++++++++++++++++++++++++++++++++++-------------- 4 files changed, 145 insertions(+), 59 deletions(-) diff --git a/job_commands.py b/job_commands.py index eb8a0af..5dd83f1 100644 --- a/job_commands.py +++ b/job_commands.py @@ -350,7 +350,7 @@ async def nag_jobs(day_of_week: str) -> bool: response += "(scroll missing. Please register for @ pings!)" response += "\n" - general_id = slack_util.get_slack().get_channel_by_name("#general").id + general_id = slack_util.get_slack().get_conversation_by_name("#general").id slack_util.get_slack().send_message(response, general_id) return True diff --git a/main.py b/main.py index 1aaf65d..a25b861 100644 --- a/main.py +++ b/main.py @@ -53,9 +53,9 @@ def main() -> None: event_loop = asyncio.get_event_loop() event_loop.set_debug(slack_util.DEBUG_MODE) - message_handling = wrap.respond_messages() + event_handling = wrap.handle_events() passive_handling = wrap.run_passives() - both = asyncio.gather(message_handling, passive_handling) + both = asyncio.gather(event_handling, passive_handling) event_loop.run_until_complete(both) diff --git a/periodicals.py b/periodicals.py index 24d9145..b9af2c6 100644 --- a/periodicals.py +++ b/periodicals.py @@ -25,7 +25,7 @@ class ItsTenPM(slack_util.Passive): await asyncio.sleep(delay) # Crow like a rooster - slack_util.get_slack().send_message("IT'S 10 PM!", slack_util.get_slack().get_channel_by_name("#random").id) + slack_util.get_slack().send_message("IT'S 10 PM!", slack_util.get_slack().get_conversation_by_name("#random").id) # Wait a while before trying it again, to prevent duplicates await asyncio.sleep(60) diff --git a/slack_util.py b/slack_util.py index f480146..79b994e 100644 --- a/slack_util.py +++ b/slack_util.py @@ -6,7 +6,7 @@ import sys import traceback from dataclasses import dataclass from time import sleep, time -from typing import List, Any, AsyncGenerator, Coroutine, TypeVar +from typing import List, Any, AsyncGenerator, Coroutine, TypeVar, Dict from typing import Optional, Generator, Match, Callable, Union, Awaitable from slackclient import SlackClient @@ -43,6 +43,20 @@ class Channel: name: str +@dataclass +class DirectMessage: + id: str + user_id: str + + def get_user(self) -> Optional[User]: + """ + Lookup the user to which this DM corresponds. + """ + return get_slack().get_user(self.user_id) + + +Conversation = Union[Channel, DirectMessage] + """ Objects to represent attributes an event may contain """ @@ -50,7 +64,7 @@ Objects to represent attributes an event may contain @dataclass class Event: - channel: Optional[ChannelContext] = None + conversation: Optional[ConversationContext] = None user: Optional[UserContext] = None message: Optional[MessageContext] = None thread: Optional[ThreadContext] = None @@ -58,11 +72,11 @@ class Event: # If this was posted in a specific channel or conversation @dataclass -class ChannelContext: - channel_id: str +class ConversationContext: + conversation_id: str - def get_channel(self) -> Channel: - raise NotImplementedError() + def get_conversation(self) -> Optional[Conversation]: + return get_slack().get_conversation(self.conversation_id) # If there is a specific user associated with this event @@ -70,8 +84,8 @@ class ChannelContext: class UserContext: user_id: str - def as_user(self) -> User: - raise NotImplementedError() + def as_user(self) -> Optional[User]: + return get_slack().get_user(self.user_id) # Whether or not this is a threadable text message @@ -84,7 +98,6 @@ class MessageContext: @dataclass class ThreadContext: thread_ts: str - parent_ts: str # If a file was additionally shared @@ -94,7 +107,7 @@ class File: """ -Objects for interfacing easily with rtm steams +Objects for interfacing easily with rtm steams, and handling async events """ @@ -120,9 +133,16 @@ def message_stream(slack: SlackClient) -> Generator[Event, None, None]: # Big logic folks if update["type"] == "message": - event.message = MessageContext(update["ts"], update["text"]) - event.channel = ChannelContext(update["channel"]) - event.user = UserContext(update["user"]) + # For now we only handle these basic types of messages involving text + # TODO: Handle "unwrappeable" messages + if "text" in update and "ts" in update: + event.message = MessageContext(update["ts"], update["text"]) + if "channel" in update: + event.conversation = ConversationContext(update["channel"]) + if "user" in update: + event.user = UserContext(update["user"]) + if "thread_ts" in update: + event.thread = ThreadContext(update["thread_ts"]) # TODO: Handle more types # We need to @@ -167,10 +187,10 @@ class ClientWrapper(object): self.passives: List[Passive] = [] # Cache users and channels - self.users: dict = {} - self.channels: dict = {} + self.users: Dict[str, User] = {} + self.conversations: Dict[str, Conversation] = {} - # Scheduled events handling + # Scheduled/passive events handling def add_passive(self, per: Passive) -> None: self.passives.append(per) @@ -181,21 +201,72 @@ class ClientWrapper(object): awaitables = [p.run() for p in self.passives] await asyncio.gather(*awaitables) - # Message handling + # Incoming slack hook handling def add_hook(self, hook: AbsHook) -> None: self.hooks.append(hook) - async def respond_messages(self) -> None: + async def handle_events(self) -> None: """ Asynchronous tasks that eternally reads and responds to messages. """ - async for t in self.spool_tasks(): - sys.stdout.flush() - if DEBUG_MODE: - await t + # Create a queue + queue = asyncio.Queue() - async def spool_tasks(self) -> AsyncGenerator[asyncio.Task, Any]: - async for event in self.async_event_feed(): + # Create a task to put rtm events to the queue + async def put_rtm(): + async for t1 in self.rtm_event_feed(): + await queue.put(t1) + + rtm_task = asyncio.Task(put_rtm()) + + # Create a task to put http events to the queue + async def put_http(): + async for t2 in self.http_event_feed(): + await queue.put(t2) + + http_task = asyncio.Task(put_http()) + + # Create a task to handle all other tasks + async def handle_task_loop(): + async for t3 in self.spool_tasks(queue): + sys.stdout.flush() + if DEBUG_MODE: + await t3 + + # Create a task to read and process events from the queue + handler_task = handle_task_loop() + await asyncio.gather(rtm_task, http_task, handler_task) + + async def rtm_event_feed(self) -> AsyncGenerator[Event, None]: + """ + Async wrapper around the message feed. + Yields messages awaitably forever. + """ + # Create the msg feed + feed = message_stream(self.slack) + + # Create a simple callable that gets one message from the feed + def get_one(): + return next(feed) + + # Continuously yield async threaded tasks that poll the feed + while True: + yield await asyncio.get_running_loop().run_in_executor(None, get_one) + + async def http_event_feed(self) -> AsyncGenerator[Event, None]: + # Create the server + pass + + while True: + await asyncio.sleep(30) + yield Event() + + async def spool_tasks(self, event_queue: asyncio.Queue) -> AsyncGenerator[asyncio.Task, Any]: + """ + Read in from async event feed, and spool them out as async tasks + """ + while True: + event: Event = await event_queue.get() # Find which hook, if any, satisfies for hook in list(self.hooks): # Note that we do list(self.hooks) to avoid edit-while-iterating issues # Try invoking each @@ -215,38 +286,46 @@ class ClientWrapper(object): self.hooks.remove(hook) print("Done spawning tasks. Now {} running total.".format(len(asyncio.all_tasks()))) - async def async_event_feed(self) -> AsyncGenerator[Event, None]: - """ - Async wrapper around the message feed. - Yields messages awaitably forever. - """ - # Create the msg feed - feed = message_stream(self.slack) + # Data getting/sending - # Create a simple callable that gets one message from the feed - def get_one(): - return next(feed) + def get_conversation(self, conversation_id: str) -> Optional[Conversation]: + return self.conversations.get(conversation_id) - # Continuously yield async threaded tasks that poll the feed - while True: - yield await asyncio.get_running_loop().run_in_executor(None, get_one) + def get_conversation_by_name(self, conversation_identifier: str) -> Optional[Conversation]: + # If looking for a direct message, first lookup user, then fetch + if conversation_identifier[0] == "@": + user_name = conversation_identifier[1:] - def get_channel(self, channel_id: str) -> Optional[Channel]: - return self.channels.get(channel_id) + # Find the user by their name + raise NotImplementedError("There wasn't a clear use case for this yet, so we've opted to just not use it") - def get_channel_by_name(self, channel_name: str) -> Optional[Channel]: - # Find the channel in the dict - for v in self.channels.values(): - if v.name == channel_name: - return v + # If looking for a channel, just lookup normally + elif conversation_identifier[0] == "#": + channel_name = conversation_identifier[1:] + + # Find the channel in the dict + for channel in self.conversations.values(): + if channel.name == channel_name: + return channel + + # If it doesn't fit the above, we don't know how to process + else: + raise ValueError("Please give either an #channel-name or @user-name") + + # If we haven't returned already, give up and return None return None - def get_user(self, user_id: str) -> Optional[Channel]: + def get_user(self, user_id: str) -> Optional[User]: return self.users.get(user_id) + def get_user_by_name(self, user_name: str) -> Optional[User]: + raise NotImplementedError() + def api_call(self, api_method, **kwargs): return self.slack.api_call(api_method, **kwargs) + # Simpler wrappers around message sending/replying + def reply(self, event: Event, text: str, in_thread: bool = True) -> dict: """ Replies to a message. @@ -254,7 +333,7 @@ class ClientWrapper(object): Returns the JSON response. """ # Ensure we're actually replying to a valid message - assert (event.channel and event.message) is not None + assert (event.conversation and event.message) is not None # Send in a thread by default if in_thread: @@ -262,9 +341,9 @@ class ClientWrapper(object): thread = event.message.ts if event.thread: thread = event.thread.thread_ts - return self.send_message(text, event.channel.channel_id, thread=thread) + return self.send_message(text, event.conversation.conversation_id, thread=thread) else: - return self.send_message(text, event.channel.channel_id) + return self.send_message(text, event.conversation.conversation_id) def send_message(self, text: str, channel_id: str, thread: str = None, broadcast: bool = False) -> dict: """ @@ -279,6 +358,8 @@ class ClientWrapper(object): return self.api_call("chat.postMessage", **kwargs) + # Update slack data + def update_channels(self): """ Queries the slack API for all current channels @@ -292,7 +373,7 @@ class ClientWrapper(object): # Iterate over results while True: # Set args depending on if a cursor exists - args = {"limit": 1000, "type": "public_channel,private_channel,mpim,im"} + args = {"limit": 1000, "types": "public_channel,private_channel,mpim,im"} if cursor: args["cursor"] = cursor @@ -301,8 +382,12 @@ class ClientWrapper(object): # If the response is good, put its results to the dict if channel_dicts["ok"]: for channel_dict in channel_dicts["channels"]: - new_channel = Channel(id=channel_dict["id"], - name=channel_dict["name"]) + if channel_dict["is_im"]: + new_channel = DirectMessage(id=channel_dict["id"], + user_id=channel_dict["user"]) + else: + new_channel = Channel(id=channel_dict["id"], + name=channel_dict["name"]) new_dict[new_channel.id] = new_channel # Fetch the cursor @@ -312,11 +397,10 @@ class ClientWrapper(object): if cursor == "": break - else: print("Warning: failed to retrieve channels. Message: {}".format(channel_dicts)) break - self.channels = new_dict + self.conversations = new_dict def update_users(self): """ @@ -427,9 +511,11 @@ class ChannelHook(AbsHook): Returns whether a message should be handled by this dict, returning a Match if so, or None """ # Ensure that this is an event in a specific channel, with a text component - if not (event.channel and event.message): + if not (event.conversation and event.message and isinstance(event.conversation.get_conversation(), Channel)): return None + + # Fail if pattern invalid match = None for p in self.patterns: @@ -441,7 +527,7 @@ class ChannelHook(AbsHook): return None # Get the channel name - channel_name = event.channel.get_channel().name + channel_name = event.conversation.get_conversation().name # Fail if whitelist defined, and we aren't there if self.channel_whitelist is not None and channel_name not in self.channel_whitelist: