Semi-stable. Need to check signoffs

This commit is contained in:
Jacob Henry 2019-02-22 21:15:54 -05:00
parent 5058ae0085
commit 77aa8755ed
4 changed files with 145 additions and 59 deletions

View File

@ -350,7 +350,7 @@ async def nag_jobs(day_of_week: str) -> bool:
response += "(scroll missing. Please register for @ pings!)"
response += "\n"
general_id = slack_util.get_slack().get_channel_by_name("#general").id
general_id = slack_util.get_slack().get_conversation_by_name("#general").id
slack_util.get_slack().send_message(response, general_id)
return True

View File

@ -53,9 +53,9 @@ def main() -> None:
event_loop = asyncio.get_event_loop()
event_loop.set_debug(slack_util.DEBUG_MODE)
message_handling = wrap.respond_messages()
event_handling = wrap.handle_events()
passive_handling = wrap.run_passives()
both = asyncio.gather(message_handling, passive_handling)
both = asyncio.gather(event_handling, passive_handling)
event_loop.run_until_complete(both)

View File

@ -25,7 +25,7 @@ class ItsTenPM(slack_util.Passive):
await asyncio.sleep(delay)
# Crow like a rooster
slack_util.get_slack().send_message("IT'S 10 PM!", slack_util.get_slack().get_channel_by_name("#random").id)
slack_util.get_slack().send_message("IT'S 10 PM!", slack_util.get_slack().get_conversation_by_name("#random").id)
# Wait a while before trying it again, to prevent duplicates
await asyncio.sleep(60)

View File

@ -6,7 +6,7 @@ import sys
import traceback
from dataclasses import dataclass
from time import sleep, time
from typing import List, Any, AsyncGenerator, Coroutine, TypeVar
from typing import List, Any, AsyncGenerator, Coroutine, TypeVar, Dict
from typing import Optional, Generator, Match, Callable, Union, Awaitable
from slackclient import SlackClient
@ -43,6 +43,20 @@ class Channel:
name: str
@dataclass
class DirectMessage:
id: str
user_id: str
def get_user(self) -> Optional[User]:
"""
Lookup the user to which this DM corresponds.
"""
return get_slack().get_user(self.user_id)
Conversation = Union[Channel, DirectMessage]
"""
Objects to represent attributes an event may contain
"""
@ -50,7 +64,7 @@ Objects to represent attributes an event may contain
@dataclass
class Event:
channel: Optional[ChannelContext] = None
conversation: Optional[ConversationContext] = None
user: Optional[UserContext] = None
message: Optional[MessageContext] = None
thread: Optional[ThreadContext] = None
@ -58,11 +72,11 @@ class Event:
# If this was posted in a specific channel or conversation
@dataclass
class ChannelContext:
channel_id: str
class ConversationContext:
conversation_id: str
def get_channel(self) -> Channel:
raise NotImplementedError()
def get_conversation(self) -> Optional[Conversation]:
return get_slack().get_conversation(self.conversation_id)
# If there is a specific user associated with this event
@ -70,8 +84,8 @@ class ChannelContext:
class UserContext:
user_id: str
def as_user(self) -> User:
raise NotImplementedError()
def as_user(self) -> Optional[User]:
return get_slack().get_user(self.user_id)
# Whether or not this is a threadable text message
@ -84,7 +98,6 @@ class MessageContext:
@dataclass
class ThreadContext:
thread_ts: str
parent_ts: str
# If a file was additionally shared
@ -94,7 +107,7 @@ class File:
"""
Objects for interfacing easily with rtm steams
Objects for interfacing easily with rtm steams, and handling async events
"""
@ -120,9 +133,16 @@ def message_stream(slack: SlackClient) -> Generator[Event, None, None]:
# Big logic folks
if update["type"] == "message":
event.message = MessageContext(update["ts"], update["text"])
event.channel = ChannelContext(update["channel"])
event.user = UserContext(update["user"])
# For now we only handle these basic types of messages involving text
# TODO: Handle "unwrappeable" messages
if "text" in update and "ts" in update:
event.message = MessageContext(update["ts"], update["text"])
if "channel" in update:
event.conversation = ConversationContext(update["channel"])
if "user" in update:
event.user = UserContext(update["user"])
if "thread_ts" in update:
event.thread = ThreadContext(update["thread_ts"])
# TODO: Handle more types
# We need to
@ -167,10 +187,10 @@ class ClientWrapper(object):
self.passives: List[Passive] = []
# Cache users and channels
self.users: dict = {}
self.channels: dict = {}
self.users: Dict[str, User] = {}
self.conversations: Dict[str, Conversation] = {}
# Scheduled events handling
# Scheduled/passive events handling
def add_passive(self, per: Passive) -> None:
self.passives.append(per)
@ -181,21 +201,72 @@ class ClientWrapper(object):
awaitables = [p.run() for p in self.passives]
await asyncio.gather(*awaitables)
# Message handling
# Incoming slack hook handling
def add_hook(self, hook: AbsHook) -> None:
self.hooks.append(hook)
async def respond_messages(self) -> None:
async def handle_events(self) -> None:
"""
Asynchronous tasks that eternally reads and responds to messages.
"""
async for t in self.spool_tasks():
sys.stdout.flush()
if DEBUG_MODE:
await t
# Create a queue
queue = asyncio.Queue()
async def spool_tasks(self) -> AsyncGenerator[asyncio.Task, Any]:
async for event in self.async_event_feed():
# Create a task to put rtm events to the queue
async def put_rtm():
async for t1 in self.rtm_event_feed():
await queue.put(t1)
rtm_task = asyncio.Task(put_rtm())
# Create a task to put http events to the queue
async def put_http():
async for t2 in self.http_event_feed():
await queue.put(t2)
http_task = asyncio.Task(put_http())
# Create a task to handle all other tasks
async def handle_task_loop():
async for t3 in self.spool_tasks(queue):
sys.stdout.flush()
if DEBUG_MODE:
await t3
# Create a task to read and process events from the queue
handler_task = handle_task_loop()
await asyncio.gather(rtm_task, http_task, handler_task)
async def rtm_event_feed(self) -> AsyncGenerator[Event, None]:
"""
Async wrapper around the message feed.
Yields messages awaitably forever.
"""
# Create the msg feed
feed = message_stream(self.slack)
# Create a simple callable that gets one message from the feed
def get_one():
return next(feed)
# Continuously yield async threaded tasks that poll the feed
while True:
yield await asyncio.get_running_loop().run_in_executor(None, get_one)
async def http_event_feed(self) -> AsyncGenerator[Event, None]:
# Create the server
pass
while True:
await asyncio.sleep(30)
yield Event()
async def spool_tasks(self, event_queue: asyncio.Queue) -> AsyncGenerator[asyncio.Task, Any]:
"""
Read in from async event feed, and spool them out as async tasks
"""
while True:
event: Event = await event_queue.get()
# Find which hook, if any, satisfies
for hook in list(self.hooks): # Note that we do list(self.hooks) to avoid edit-while-iterating issues
# Try invoking each
@ -215,38 +286,46 @@ class ClientWrapper(object):
self.hooks.remove(hook)
print("Done spawning tasks. Now {} running total.".format(len(asyncio.all_tasks())))
async def async_event_feed(self) -> AsyncGenerator[Event, None]:
"""
Async wrapper around the message feed.
Yields messages awaitably forever.
"""
# Create the msg feed
feed = message_stream(self.slack)
# Data getting/sending
# Create a simple callable that gets one message from the feed
def get_one():
return next(feed)
def get_conversation(self, conversation_id: str) -> Optional[Conversation]:
return self.conversations.get(conversation_id)
# Continuously yield async threaded tasks that poll the feed
while True:
yield await asyncio.get_running_loop().run_in_executor(None, get_one)
def get_conversation_by_name(self, conversation_identifier: str) -> Optional[Conversation]:
# If looking for a direct message, first lookup user, then fetch
if conversation_identifier[0] == "@":
user_name = conversation_identifier[1:]
def get_channel(self, channel_id: str) -> Optional[Channel]:
return self.channels.get(channel_id)
# Find the user by their name
raise NotImplementedError("There wasn't a clear use case for this yet, so we've opted to just not use it")
def get_channel_by_name(self, channel_name: str) -> Optional[Channel]:
# Find the channel in the dict
for v in self.channels.values():
if v.name == channel_name:
return v
# If looking for a channel, just lookup normally
elif conversation_identifier[0] == "#":
channel_name = conversation_identifier[1:]
# Find the channel in the dict
for channel in self.conversations.values():
if channel.name == channel_name:
return channel
# If it doesn't fit the above, we don't know how to process
else:
raise ValueError("Please give either an #channel-name or @user-name")
# If we haven't returned already, give up and return None
return None
def get_user(self, user_id: str) -> Optional[Channel]:
def get_user(self, user_id: str) -> Optional[User]:
return self.users.get(user_id)
def get_user_by_name(self, user_name: str) -> Optional[User]:
raise NotImplementedError()
def api_call(self, api_method, **kwargs):
return self.slack.api_call(api_method, **kwargs)
# Simpler wrappers around message sending/replying
def reply(self, event: Event, text: str, in_thread: bool = True) -> dict:
"""
Replies to a message.
@ -254,7 +333,7 @@ class ClientWrapper(object):
Returns the JSON response.
"""
# Ensure we're actually replying to a valid message
assert (event.channel and event.message) is not None
assert (event.conversation and event.message) is not None
# Send in a thread by default
if in_thread:
@ -262,9 +341,9 @@ class ClientWrapper(object):
thread = event.message.ts
if event.thread:
thread = event.thread.thread_ts
return self.send_message(text, event.channel.channel_id, thread=thread)
return self.send_message(text, event.conversation.conversation_id, thread=thread)
else:
return self.send_message(text, event.channel.channel_id)
return self.send_message(text, event.conversation.conversation_id)
def send_message(self, text: str, channel_id: str, thread: str = None, broadcast: bool = False) -> dict:
"""
@ -279,6 +358,8 @@ class ClientWrapper(object):
return self.api_call("chat.postMessage", **kwargs)
# Update slack data
def update_channels(self):
"""
Queries the slack API for all current channels
@ -292,7 +373,7 @@ class ClientWrapper(object):
# Iterate over results
while True:
# Set args depending on if a cursor exists
args = {"limit": 1000, "type": "public_channel,private_channel,mpim,im"}
args = {"limit": 1000, "types": "public_channel,private_channel,mpim,im"}
if cursor:
args["cursor"] = cursor
@ -301,8 +382,12 @@ class ClientWrapper(object):
# If the response is good, put its results to the dict
if channel_dicts["ok"]:
for channel_dict in channel_dicts["channels"]:
new_channel = Channel(id=channel_dict["id"],
name=channel_dict["name"])
if channel_dict["is_im"]:
new_channel = DirectMessage(id=channel_dict["id"],
user_id=channel_dict["user"])
else:
new_channel = Channel(id=channel_dict["id"],
name=channel_dict["name"])
new_dict[new_channel.id] = new_channel
# Fetch the cursor
@ -312,11 +397,10 @@ class ClientWrapper(object):
if cursor == "":
break
else:
print("Warning: failed to retrieve channels. Message: {}".format(channel_dicts))
break
self.channels = new_dict
self.conversations = new_dict
def update_users(self):
"""
@ -427,9 +511,11 @@ class ChannelHook(AbsHook):
Returns whether a message should be handled by this dict, returning a Match if so, or None
"""
# Ensure that this is an event in a specific channel, with a text component
if not (event.channel and event.message):
if not (event.conversation and event.message and isinstance(event.conversation.get_conversation(), Channel)):
return None
# Fail if pattern invalid
match = None
for p in self.patterns:
@ -441,7 +527,7 @@ class ChannelHook(AbsHook):
return None
# Get the channel name
channel_name = event.channel.get_channel().name
channel_name = event.conversation.get_conversation().name
# Fail if whitelist defined, and we aren't there
if self.channel_whitelist is not None and channel_name not in self.channel_whitelist: