Semi-stable. Need to check signoffs
This commit is contained in:
parent
5058ae0085
commit
77aa8755ed
|
|
@ -350,7 +350,7 @@ async def nag_jobs(day_of_week: str) -> bool:
|
|||
response += "(scroll missing. Please register for @ pings!)"
|
||||
response += "\n"
|
||||
|
||||
general_id = slack_util.get_slack().get_channel_by_name("#general").id
|
||||
general_id = slack_util.get_slack().get_conversation_by_name("#general").id
|
||||
slack_util.get_slack().send_message(response, general_id)
|
||||
return True
|
||||
|
||||
|
|
|
|||
4
main.py
4
main.py
|
|
@ -53,9 +53,9 @@ def main() -> None:
|
|||
|
||||
event_loop = asyncio.get_event_loop()
|
||||
event_loop.set_debug(slack_util.DEBUG_MODE)
|
||||
message_handling = wrap.respond_messages()
|
||||
event_handling = wrap.handle_events()
|
||||
passive_handling = wrap.run_passives()
|
||||
both = asyncio.gather(message_handling, passive_handling)
|
||||
both = asyncio.gather(event_handling, passive_handling)
|
||||
event_loop.run_until_complete(both)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ class ItsTenPM(slack_util.Passive):
|
|||
await asyncio.sleep(delay)
|
||||
|
||||
# Crow like a rooster
|
||||
slack_util.get_slack().send_message("IT'S 10 PM!", slack_util.get_slack().get_channel_by_name("#random").id)
|
||||
slack_util.get_slack().send_message("IT'S 10 PM!", slack_util.get_slack().get_conversation_by_name("#random").id)
|
||||
|
||||
# Wait a while before trying it again, to prevent duplicates
|
||||
await asyncio.sleep(60)
|
||||
|
|
|
|||
182
slack_util.py
182
slack_util.py
|
|
@ -6,7 +6,7 @@ import sys
|
|||
import traceback
|
||||
from dataclasses import dataclass
|
||||
from time import sleep, time
|
||||
from typing import List, Any, AsyncGenerator, Coroutine, TypeVar
|
||||
from typing import List, Any, AsyncGenerator, Coroutine, TypeVar, Dict
|
||||
from typing import Optional, Generator, Match, Callable, Union, Awaitable
|
||||
|
||||
from slackclient import SlackClient
|
||||
|
|
@ -43,6 +43,20 @@ class Channel:
|
|||
name: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class DirectMessage:
|
||||
id: str
|
||||
user_id: str
|
||||
|
||||
def get_user(self) -> Optional[User]:
|
||||
"""
|
||||
Lookup the user to which this DM corresponds.
|
||||
"""
|
||||
return get_slack().get_user(self.user_id)
|
||||
|
||||
|
||||
Conversation = Union[Channel, DirectMessage]
|
||||
|
||||
"""
|
||||
Objects to represent attributes an event may contain
|
||||
"""
|
||||
|
|
@ -50,7 +64,7 @@ Objects to represent attributes an event may contain
|
|||
|
||||
@dataclass
|
||||
class Event:
|
||||
channel: Optional[ChannelContext] = None
|
||||
conversation: Optional[ConversationContext] = None
|
||||
user: Optional[UserContext] = None
|
||||
message: Optional[MessageContext] = None
|
||||
thread: Optional[ThreadContext] = None
|
||||
|
|
@ -58,11 +72,11 @@ class Event:
|
|||
|
||||
# If this was posted in a specific channel or conversation
|
||||
@dataclass
|
||||
class ChannelContext:
|
||||
channel_id: str
|
||||
class ConversationContext:
|
||||
conversation_id: str
|
||||
|
||||
def get_channel(self) -> Channel:
|
||||
raise NotImplementedError()
|
||||
def get_conversation(self) -> Optional[Conversation]:
|
||||
return get_slack().get_conversation(self.conversation_id)
|
||||
|
||||
|
||||
# If there is a specific user associated with this event
|
||||
|
|
@ -70,8 +84,8 @@ class ChannelContext:
|
|||
class UserContext:
|
||||
user_id: str
|
||||
|
||||
def as_user(self) -> User:
|
||||
raise NotImplementedError()
|
||||
def as_user(self) -> Optional[User]:
|
||||
return get_slack().get_user(self.user_id)
|
||||
|
||||
|
||||
# Whether or not this is a threadable text message
|
||||
|
|
@ -84,7 +98,6 @@ class MessageContext:
|
|||
@dataclass
|
||||
class ThreadContext:
|
||||
thread_ts: str
|
||||
parent_ts: str
|
||||
|
||||
|
||||
# If a file was additionally shared
|
||||
|
|
@ -94,7 +107,7 @@ class File:
|
|||
|
||||
|
||||
"""
|
||||
Objects for interfacing easily with rtm steams
|
||||
Objects for interfacing easily with rtm steams, and handling async events
|
||||
"""
|
||||
|
||||
|
||||
|
|
@ -120,9 +133,16 @@ def message_stream(slack: SlackClient) -> Generator[Event, None, None]:
|
|||
|
||||
# Big logic folks
|
||||
if update["type"] == "message":
|
||||
# For now we only handle these basic types of messages involving text
|
||||
# TODO: Handle "unwrappeable" messages
|
||||
if "text" in update and "ts" in update:
|
||||
event.message = MessageContext(update["ts"], update["text"])
|
||||
event.channel = ChannelContext(update["channel"])
|
||||
if "channel" in update:
|
||||
event.conversation = ConversationContext(update["channel"])
|
||||
if "user" in update:
|
||||
event.user = UserContext(update["user"])
|
||||
if "thread_ts" in update:
|
||||
event.thread = ThreadContext(update["thread_ts"])
|
||||
|
||||
# TODO: Handle more types
|
||||
# We need to
|
||||
|
|
@ -167,10 +187,10 @@ class ClientWrapper(object):
|
|||
self.passives: List[Passive] = []
|
||||
|
||||
# Cache users and channels
|
||||
self.users: dict = {}
|
||||
self.channels: dict = {}
|
||||
self.users: Dict[str, User] = {}
|
||||
self.conversations: Dict[str, Conversation] = {}
|
||||
|
||||
# Scheduled events handling
|
||||
# Scheduled/passive events handling
|
||||
def add_passive(self, per: Passive) -> None:
|
||||
self.passives.append(per)
|
||||
|
||||
|
|
@ -181,21 +201,72 @@ class ClientWrapper(object):
|
|||
awaitables = [p.run() for p in self.passives]
|
||||
await asyncio.gather(*awaitables)
|
||||
|
||||
# Message handling
|
||||
# Incoming slack hook handling
|
||||
def add_hook(self, hook: AbsHook) -> None:
|
||||
self.hooks.append(hook)
|
||||
|
||||
async def respond_messages(self) -> None:
|
||||
async def handle_events(self) -> None:
|
||||
"""
|
||||
Asynchronous tasks that eternally reads and responds to messages.
|
||||
"""
|
||||
async for t in self.spool_tasks():
|
||||
# Create a queue
|
||||
queue = asyncio.Queue()
|
||||
|
||||
# Create a task to put rtm events to the queue
|
||||
async def put_rtm():
|
||||
async for t1 in self.rtm_event_feed():
|
||||
await queue.put(t1)
|
||||
|
||||
rtm_task = asyncio.Task(put_rtm())
|
||||
|
||||
# Create a task to put http events to the queue
|
||||
async def put_http():
|
||||
async for t2 in self.http_event_feed():
|
||||
await queue.put(t2)
|
||||
|
||||
http_task = asyncio.Task(put_http())
|
||||
|
||||
# Create a task to handle all other tasks
|
||||
async def handle_task_loop():
|
||||
async for t3 in self.spool_tasks(queue):
|
||||
sys.stdout.flush()
|
||||
if DEBUG_MODE:
|
||||
await t
|
||||
await t3
|
||||
|
||||
async def spool_tasks(self) -> AsyncGenerator[asyncio.Task, Any]:
|
||||
async for event in self.async_event_feed():
|
||||
# Create a task to read and process events from the queue
|
||||
handler_task = handle_task_loop()
|
||||
await asyncio.gather(rtm_task, http_task, handler_task)
|
||||
|
||||
async def rtm_event_feed(self) -> AsyncGenerator[Event, None]:
|
||||
"""
|
||||
Async wrapper around the message feed.
|
||||
Yields messages awaitably forever.
|
||||
"""
|
||||
# Create the msg feed
|
||||
feed = message_stream(self.slack)
|
||||
|
||||
# Create a simple callable that gets one message from the feed
|
||||
def get_one():
|
||||
return next(feed)
|
||||
|
||||
# Continuously yield async threaded tasks that poll the feed
|
||||
while True:
|
||||
yield await asyncio.get_running_loop().run_in_executor(None, get_one)
|
||||
|
||||
async def http_event_feed(self) -> AsyncGenerator[Event, None]:
|
||||
# Create the server
|
||||
pass
|
||||
|
||||
while True:
|
||||
await asyncio.sleep(30)
|
||||
yield Event()
|
||||
|
||||
async def spool_tasks(self, event_queue: asyncio.Queue) -> AsyncGenerator[asyncio.Task, Any]:
|
||||
"""
|
||||
Read in from async event feed, and spool them out as async tasks
|
||||
"""
|
||||
while True:
|
||||
event: Event = await event_queue.get()
|
||||
# Find which hook, if any, satisfies
|
||||
for hook in list(self.hooks): # Note that we do list(self.hooks) to avoid edit-while-iterating issues
|
||||
# Try invoking each
|
||||
|
|
@ -215,38 +286,46 @@ class ClientWrapper(object):
|
|||
self.hooks.remove(hook)
|
||||
print("Done spawning tasks. Now {} running total.".format(len(asyncio.all_tasks())))
|
||||
|
||||
async def async_event_feed(self) -> AsyncGenerator[Event, None]:
|
||||
"""
|
||||
Async wrapper around the message feed.
|
||||
Yields messages awaitably forever.
|
||||
"""
|
||||
# Create the msg feed
|
||||
feed = message_stream(self.slack)
|
||||
# Data getting/sending
|
||||
|
||||
# Create a simple callable that gets one message from the feed
|
||||
def get_one():
|
||||
return next(feed)
|
||||
def get_conversation(self, conversation_id: str) -> Optional[Conversation]:
|
||||
return self.conversations.get(conversation_id)
|
||||
|
||||
# Continuously yield async threaded tasks that poll the feed
|
||||
while True:
|
||||
yield await asyncio.get_running_loop().run_in_executor(None, get_one)
|
||||
def get_conversation_by_name(self, conversation_identifier: str) -> Optional[Conversation]:
|
||||
# If looking for a direct message, first lookup user, then fetch
|
||||
if conversation_identifier[0] == "@":
|
||||
user_name = conversation_identifier[1:]
|
||||
|
||||
def get_channel(self, channel_id: str) -> Optional[Channel]:
|
||||
return self.channels.get(channel_id)
|
||||
# Find the user by their name
|
||||
raise NotImplementedError("There wasn't a clear use case for this yet, so we've opted to just not use it")
|
||||
|
||||
# If looking for a channel, just lookup normally
|
||||
elif conversation_identifier[0] == "#":
|
||||
channel_name = conversation_identifier[1:]
|
||||
|
||||
def get_channel_by_name(self, channel_name: str) -> Optional[Channel]:
|
||||
# Find the channel in the dict
|
||||
for v in self.channels.values():
|
||||
if v.name == channel_name:
|
||||
return v
|
||||
for channel in self.conversations.values():
|
||||
if channel.name == channel_name:
|
||||
return channel
|
||||
|
||||
# If it doesn't fit the above, we don't know how to process
|
||||
else:
|
||||
raise ValueError("Please give either an #channel-name or @user-name")
|
||||
|
||||
# If we haven't returned already, give up and return None
|
||||
return None
|
||||
|
||||
def get_user(self, user_id: str) -> Optional[Channel]:
|
||||
def get_user(self, user_id: str) -> Optional[User]:
|
||||
return self.users.get(user_id)
|
||||
|
||||
def get_user_by_name(self, user_name: str) -> Optional[User]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def api_call(self, api_method, **kwargs):
|
||||
return self.slack.api_call(api_method, **kwargs)
|
||||
|
||||
# Simpler wrappers around message sending/replying
|
||||
|
||||
def reply(self, event: Event, text: str, in_thread: bool = True) -> dict:
|
||||
"""
|
||||
Replies to a message.
|
||||
|
|
@ -254,7 +333,7 @@ class ClientWrapper(object):
|
|||
Returns the JSON response.
|
||||
"""
|
||||
# Ensure we're actually replying to a valid message
|
||||
assert (event.channel and event.message) is not None
|
||||
assert (event.conversation and event.message) is not None
|
||||
|
||||
# Send in a thread by default
|
||||
if in_thread:
|
||||
|
|
@ -262,9 +341,9 @@ class ClientWrapper(object):
|
|||
thread = event.message.ts
|
||||
if event.thread:
|
||||
thread = event.thread.thread_ts
|
||||
return self.send_message(text, event.channel.channel_id, thread=thread)
|
||||
return self.send_message(text, event.conversation.conversation_id, thread=thread)
|
||||
else:
|
||||
return self.send_message(text, event.channel.channel_id)
|
||||
return self.send_message(text, event.conversation.conversation_id)
|
||||
|
||||
def send_message(self, text: str, channel_id: str, thread: str = None, broadcast: bool = False) -> dict:
|
||||
"""
|
||||
|
|
@ -279,6 +358,8 @@ class ClientWrapper(object):
|
|||
|
||||
return self.api_call("chat.postMessage", **kwargs)
|
||||
|
||||
# Update slack data
|
||||
|
||||
def update_channels(self):
|
||||
"""
|
||||
Queries the slack API for all current channels
|
||||
|
|
@ -292,7 +373,7 @@ class ClientWrapper(object):
|
|||
# Iterate over results
|
||||
while True:
|
||||
# Set args depending on if a cursor exists
|
||||
args = {"limit": 1000, "type": "public_channel,private_channel,mpim,im"}
|
||||
args = {"limit": 1000, "types": "public_channel,private_channel,mpim,im"}
|
||||
if cursor:
|
||||
args["cursor"] = cursor
|
||||
|
||||
|
|
@ -301,6 +382,10 @@ class ClientWrapper(object):
|
|||
# If the response is good, put its results to the dict
|
||||
if channel_dicts["ok"]:
|
||||
for channel_dict in channel_dicts["channels"]:
|
||||
if channel_dict["is_im"]:
|
||||
new_channel = DirectMessage(id=channel_dict["id"],
|
||||
user_id=channel_dict["user"])
|
||||
else:
|
||||
new_channel = Channel(id=channel_dict["id"],
|
||||
name=channel_dict["name"])
|
||||
new_dict[new_channel.id] = new_channel
|
||||
|
|
@ -312,11 +397,10 @@ class ClientWrapper(object):
|
|||
if cursor == "":
|
||||
break
|
||||
|
||||
|
||||
else:
|
||||
print("Warning: failed to retrieve channels. Message: {}".format(channel_dicts))
|
||||
break
|
||||
self.channels = new_dict
|
||||
self.conversations = new_dict
|
||||
|
||||
def update_users(self):
|
||||
"""
|
||||
|
|
@ -427,9 +511,11 @@ class ChannelHook(AbsHook):
|
|||
Returns whether a message should be handled by this dict, returning a Match if so, or None
|
||||
"""
|
||||
# Ensure that this is an event in a specific channel, with a text component
|
||||
if not (event.channel and event.message):
|
||||
if not (event.conversation and event.message and isinstance(event.conversation.get_conversation(), Channel)):
|
||||
return None
|
||||
|
||||
|
||||
|
||||
# Fail if pattern invalid
|
||||
match = None
|
||||
for p in self.patterns:
|
||||
|
|
@ -441,7 +527,7 @@ class ChannelHook(AbsHook):
|
|||
return None
|
||||
|
||||
# Get the channel name
|
||||
channel_name = event.channel.get_channel().name
|
||||
channel_name = event.conversation.get_conversation().name
|
||||
|
||||
# Fail if whitelist defined, and we aren't there
|
||||
if self.channel_whitelist is not None and channel_name not in self.channel_whitelist:
|
||||
|
|
|
|||
Loading…
Reference in New Issue