Semi-stable. Need to check signoffs

This commit is contained in:
Jacob Henry 2019-02-22 21:15:54 -05:00
parent 5058ae0085
commit 77aa8755ed
4 changed files with 145 additions and 59 deletions

View File

@ -350,7 +350,7 @@ async def nag_jobs(day_of_week: str) -> bool:
response += "(scroll missing. Please register for @ pings!)" response += "(scroll missing. Please register for @ pings!)"
response += "\n" response += "\n"
general_id = slack_util.get_slack().get_channel_by_name("#general").id general_id = slack_util.get_slack().get_conversation_by_name("#general").id
slack_util.get_slack().send_message(response, general_id) slack_util.get_slack().send_message(response, general_id)
return True return True

View File

@ -53,9 +53,9 @@ def main() -> None:
event_loop = asyncio.get_event_loop() event_loop = asyncio.get_event_loop()
event_loop.set_debug(slack_util.DEBUG_MODE) event_loop.set_debug(slack_util.DEBUG_MODE)
message_handling = wrap.respond_messages() event_handling = wrap.handle_events()
passive_handling = wrap.run_passives() passive_handling = wrap.run_passives()
both = asyncio.gather(message_handling, passive_handling) both = asyncio.gather(event_handling, passive_handling)
event_loop.run_until_complete(both) event_loop.run_until_complete(both)

View File

@ -25,7 +25,7 @@ class ItsTenPM(slack_util.Passive):
await asyncio.sleep(delay) await asyncio.sleep(delay)
# Crow like a rooster # Crow like a rooster
slack_util.get_slack().send_message("IT'S 10 PM!", slack_util.get_slack().get_channel_by_name("#random").id) slack_util.get_slack().send_message("IT'S 10 PM!", slack_util.get_slack().get_conversation_by_name("#random").id)
# Wait a while before trying it again, to prevent duplicates # Wait a while before trying it again, to prevent duplicates
await asyncio.sleep(60) await asyncio.sleep(60)

View File

@ -6,7 +6,7 @@ import sys
import traceback import traceback
from dataclasses import dataclass from dataclasses import dataclass
from time import sleep, time from time import sleep, time
from typing import List, Any, AsyncGenerator, Coroutine, TypeVar from typing import List, Any, AsyncGenerator, Coroutine, TypeVar, Dict
from typing import Optional, Generator, Match, Callable, Union, Awaitable from typing import Optional, Generator, Match, Callable, Union, Awaitable
from slackclient import SlackClient from slackclient import SlackClient
@ -43,6 +43,20 @@ class Channel:
name: str name: str
@dataclass
class DirectMessage:
id: str
user_id: str
def get_user(self) -> Optional[User]:
"""
Lookup the user to which this DM corresponds.
"""
return get_slack().get_user(self.user_id)
Conversation = Union[Channel, DirectMessage]
""" """
Objects to represent attributes an event may contain Objects to represent attributes an event may contain
""" """
@ -50,7 +64,7 @@ Objects to represent attributes an event may contain
@dataclass @dataclass
class Event: class Event:
channel: Optional[ChannelContext] = None conversation: Optional[ConversationContext] = None
user: Optional[UserContext] = None user: Optional[UserContext] = None
message: Optional[MessageContext] = None message: Optional[MessageContext] = None
thread: Optional[ThreadContext] = None thread: Optional[ThreadContext] = None
@ -58,11 +72,11 @@ class Event:
# If this was posted in a specific channel or conversation # If this was posted in a specific channel or conversation
@dataclass @dataclass
class ChannelContext: class ConversationContext:
channel_id: str conversation_id: str
def get_channel(self) -> Channel: def get_conversation(self) -> Optional[Conversation]:
raise NotImplementedError() return get_slack().get_conversation(self.conversation_id)
# If there is a specific user associated with this event # If there is a specific user associated with this event
@ -70,8 +84,8 @@ class ChannelContext:
class UserContext: class UserContext:
user_id: str user_id: str
def as_user(self) -> User: def as_user(self) -> Optional[User]:
raise NotImplementedError() return get_slack().get_user(self.user_id)
# Whether or not this is a threadable text message # Whether or not this is a threadable text message
@ -84,7 +98,6 @@ class MessageContext:
@dataclass @dataclass
class ThreadContext: class ThreadContext:
thread_ts: str thread_ts: str
parent_ts: str
# If a file was additionally shared # If a file was additionally shared
@ -94,7 +107,7 @@ class File:
""" """
Objects for interfacing easily with rtm steams Objects for interfacing easily with rtm steams, and handling async events
""" """
@ -120,9 +133,16 @@ def message_stream(slack: SlackClient) -> Generator[Event, None, None]:
# Big logic folks # Big logic folks
if update["type"] == "message": if update["type"] == "message":
event.message = MessageContext(update["ts"], update["text"]) # For now we only handle these basic types of messages involving text
event.channel = ChannelContext(update["channel"]) # TODO: Handle "unwrappeable" messages
event.user = UserContext(update["user"]) if "text" in update and "ts" in update:
event.message = MessageContext(update["ts"], update["text"])
if "channel" in update:
event.conversation = ConversationContext(update["channel"])
if "user" in update:
event.user = UserContext(update["user"])
if "thread_ts" in update:
event.thread = ThreadContext(update["thread_ts"])
# TODO: Handle more types # TODO: Handle more types
# We need to # We need to
@ -167,10 +187,10 @@ class ClientWrapper(object):
self.passives: List[Passive] = [] self.passives: List[Passive] = []
# Cache users and channels # Cache users and channels
self.users: dict = {} self.users: Dict[str, User] = {}
self.channels: dict = {} self.conversations: Dict[str, Conversation] = {}
# Scheduled events handling # Scheduled/passive events handling
def add_passive(self, per: Passive) -> None: def add_passive(self, per: Passive) -> None:
self.passives.append(per) self.passives.append(per)
@ -181,21 +201,72 @@ class ClientWrapper(object):
awaitables = [p.run() for p in self.passives] awaitables = [p.run() for p in self.passives]
await asyncio.gather(*awaitables) await asyncio.gather(*awaitables)
# Message handling # Incoming slack hook handling
def add_hook(self, hook: AbsHook) -> None: def add_hook(self, hook: AbsHook) -> None:
self.hooks.append(hook) self.hooks.append(hook)
async def respond_messages(self) -> None: async def handle_events(self) -> None:
""" """
Asynchronous tasks that eternally reads and responds to messages. Asynchronous tasks that eternally reads and responds to messages.
""" """
async for t in self.spool_tasks(): # Create a queue
sys.stdout.flush() queue = asyncio.Queue()
if DEBUG_MODE:
await t
async def spool_tasks(self) -> AsyncGenerator[asyncio.Task, Any]: # Create a task to put rtm events to the queue
async for event in self.async_event_feed(): async def put_rtm():
async for t1 in self.rtm_event_feed():
await queue.put(t1)
rtm_task = asyncio.Task(put_rtm())
# Create a task to put http events to the queue
async def put_http():
async for t2 in self.http_event_feed():
await queue.put(t2)
http_task = asyncio.Task(put_http())
# Create a task to handle all other tasks
async def handle_task_loop():
async for t3 in self.spool_tasks(queue):
sys.stdout.flush()
if DEBUG_MODE:
await t3
# Create a task to read and process events from the queue
handler_task = handle_task_loop()
await asyncio.gather(rtm_task, http_task, handler_task)
async def rtm_event_feed(self) -> AsyncGenerator[Event, None]:
"""
Async wrapper around the message feed.
Yields messages awaitably forever.
"""
# Create the msg feed
feed = message_stream(self.slack)
# Create a simple callable that gets one message from the feed
def get_one():
return next(feed)
# Continuously yield async threaded tasks that poll the feed
while True:
yield await asyncio.get_running_loop().run_in_executor(None, get_one)
async def http_event_feed(self) -> AsyncGenerator[Event, None]:
# Create the server
pass
while True:
await asyncio.sleep(30)
yield Event()
async def spool_tasks(self, event_queue: asyncio.Queue) -> AsyncGenerator[asyncio.Task, Any]:
"""
Read in from async event feed, and spool them out as async tasks
"""
while True:
event: Event = await event_queue.get()
# Find which hook, if any, satisfies # Find which hook, if any, satisfies
for hook in list(self.hooks): # Note that we do list(self.hooks) to avoid edit-while-iterating issues for hook in list(self.hooks): # Note that we do list(self.hooks) to avoid edit-while-iterating issues
# Try invoking each # Try invoking each
@ -215,38 +286,46 @@ class ClientWrapper(object):
self.hooks.remove(hook) self.hooks.remove(hook)
print("Done spawning tasks. Now {} running total.".format(len(asyncio.all_tasks()))) print("Done spawning tasks. Now {} running total.".format(len(asyncio.all_tasks())))
async def async_event_feed(self) -> AsyncGenerator[Event, None]: # Data getting/sending
"""
Async wrapper around the message feed.
Yields messages awaitably forever.
"""
# Create the msg feed
feed = message_stream(self.slack)
# Create a simple callable that gets one message from the feed def get_conversation(self, conversation_id: str) -> Optional[Conversation]:
def get_one(): return self.conversations.get(conversation_id)
return next(feed)
# Continuously yield async threaded tasks that poll the feed def get_conversation_by_name(self, conversation_identifier: str) -> Optional[Conversation]:
while True: # If looking for a direct message, first lookup user, then fetch
yield await asyncio.get_running_loop().run_in_executor(None, get_one) if conversation_identifier[0] == "@":
user_name = conversation_identifier[1:]
def get_channel(self, channel_id: str) -> Optional[Channel]: # Find the user by their name
return self.channels.get(channel_id) raise NotImplementedError("There wasn't a clear use case for this yet, so we've opted to just not use it")
def get_channel_by_name(self, channel_name: str) -> Optional[Channel]: # If looking for a channel, just lookup normally
# Find the channel in the dict elif conversation_identifier[0] == "#":
for v in self.channels.values(): channel_name = conversation_identifier[1:]
if v.name == channel_name:
return v # Find the channel in the dict
for channel in self.conversations.values():
if channel.name == channel_name:
return channel
# If it doesn't fit the above, we don't know how to process
else:
raise ValueError("Please give either an #channel-name or @user-name")
# If we haven't returned already, give up and return None
return None return None
def get_user(self, user_id: str) -> Optional[Channel]: def get_user(self, user_id: str) -> Optional[User]:
return self.users.get(user_id) return self.users.get(user_id)
def get_user_by_name(self, user_name: str) -> Optional[User]:
raise NotImplementedError()
def api_call(self, api_method, **kwargs): def api_call(self, api_method, **kwargs):
return self.slack.api_call(api_method, **kwargs) return self.slack.api_call(api_method, **kwargs)
# Simpler wrappers around message sending/replying
def reply(self, event: Event, text: str, in_thread: bool = True) -> dict: def reply(self, event: Event, text: str, in_thread: bool = True) -> dict:
""" """
Replies to a message. Replies to a message.
@ -254,7 +333,7 @@ class ClientWrapper(object):
Returns the JSON response. Returns the JSON response.
""" """
# Ensure we're actually replying to a valid message # Ensure we're actually replying to a valid message
assert (event.channel and event.message) is not None assert (event.conversation and event.message) is not None
# Send in a thread by default # Send in a thread by default
if in_thread: if in_thread:
@ -262,9 +341,9 @@ class ClientWrapper(object):
thread = event.message.ts thread = event.message.ts
if event.thread: if event.thread:
thread = event.thread.thread_ts thread = event.thread.thread_ts
return self.send_message(text, event.channel.channel_id, thread=thread) return self.send_message(text, event.conversation.conversation_id, thread=thread)
else: else:
return self.send_message(text, event.channel.channel_id) return self.send_message(text, event.conversation.conversation_id)
def send_message(self, text: str, channel_id: str, thread: str = None, broadcast: bool = False) -> dict: def send_message(self, text: str, channel_id: str, thread: str = None, broadcast: bool = False) -> dict:
""" """
@ -279,6 +358,8 @@ class ClientWrapper(object):
return self.api_call("chat.postMessage", **kwargs) return self.api_call("chat.postMessage", **kwargs)
# Update slack data
def update_channels(self): def update_channels(self):
""" """
Queries the slack API for all current channels Queries the slack API for all current channels
@ -292,7 +373,7 @@ class ClientWrapper(object):
# Iterate over results # Iterate over results
while True: while True:
# Set args depending on if a cursor exists # Set args depending on if a cursor exists
args = {"limit": 1000, "type": "public_channel,private_channel,mpim,im"} args = {"limit": 1000, "types": "public_channel,private_channel,mpim,im"}
if cursor: if cursor:
args["cursor"] = cursor args["cursor"] = cursor
@ -301,8 +382,12 @@ class ClientWrapper(object):
# If the response is good, put its results to the dict # If the response is good, put its results to the dict
if channel_dicts["ok"]: if channel_dicts["ok"]:
for channel_dict in channel_dicts["channels"]: for channel_dict in channel_dicts["channels"]:
new_channel = Channel(id=channel_dict["id"], if channel_dict["is_im"]:
name=channel_dict["name"]) new_channel = DirectMessage(id=channel_dict["id"],
user_id=channel_dict["user"])
else:
new_channel = Channel(id=channel_dict["id"],
name=channel_dict["name"])
new_dict[new_channel.id] = new_channel new_dict[new_channel.id] = new_channel
# Fetch the cursor # Fetch the cursor
@ -312,11 +397,10 @@ class ClientWrapper(object):
if cursor == "": if cursor == "":
break break
else: else:
print("Warning: failed to retrieve channels. Message: {}".format(channel_dicts)) print("Warning: failed to retrieve channels. Message: {}".format(channel_dicts))
break break
self.channels = new_dict self.conversations = new_dict
def update_users(self): def update_users(self):
""" """
@ -427,9 +511,11 @@ class ChannelHook(AbsHook):
Returns whether a message should be handled by this dict, returning a Match if so, or None Returns whether a message should be handled by this dict, returning a Match if so, or None
""" """
# Ensure that this is an event in a specific channel, with a text component # Ensure that this is an event in a specific channel, with a text component
if not (event.channel and event.message): if not (event.conversation and event.message and isinstance(event.conversation.get_conversation(), Channel)):
return None return None
# Fail if pattern invalid # Fail if pattern invalid
match = None match = None
for p in self.patterns: for p in self.patterns:
@ -441,7 +527,7 @@ class ChannelHook(AbsHook):
return None return None
# Get the channel name # Get the channel name
channel_name = event.channel.get_channel().name channel_name = event.conversation.get_conversation().name
# Fail if whitelist defined, and we aren't there # Fail if whitelist defined, and we aren't there
if self.channel_whitelist is not None and channel_name not in self.channel_whitelist: if self.channel_whitelist is not None and channel_name not in self.channel_whitelist: