Non-whitelisted commands in channel_hooks now by default are allowed in DM's

This commit is contained in:
Jacob Henry 2019-03-01 16:02:25 -05:00
parent 6aabcc72e1
commit 1e2be3f80d
2 changed files with 88 additions and 28 deletions

View File

@ -401,3 +401,23 @@ refresh_hook = slack_util.ChannelHook(refresh_callback,
"update points"
],
channel_whitelist=["#command-center"])
block_action = """
[
{
"type": "actions",
"block_id": "test_block_id",
"elements": [
{
"type": "button",
"action_id": "test_action_id",
"text": {
"type": "plain_text",
"text": "Send payload",
"emoji": false
}
}
]
}
]
"""

View File

@ -68,6 +68,7 @@ class Event:
user: Optional[UserContext] = None
message: Optional[MessageContext] = None
thread: Optional[ThreadContext] = None
interaction: Optional[InteractiveContext] = None
# If this was posted in a specific channel or conversation
@ -100,6 +101,15 @@ class ThreadContext:
thread_ts: str
@dataclass
class InteractiveContext:
response_url: str # Used to confirm/respond to requests
trigger_id: str # Used to open popups
block_id: str # Identifies the block of the interacted component
action_id: str # Identifies the interacted component
action_value: str # Identifies the selected value in the component
# If a file was additionally shared
@dataclass
class File:
@ -129,7 +139,7 @@ def message_stream(slack: SlackClient) -> Generator[Event, None, None]:
# Handle each
for update in update_list:
print("Message received: {}".format(update))
yield dict_to_event(update)
yield message_dict_to_event(update)
except (SlackNotConnected, OSError) as e:
print("Error while reading messages:")
@ -142,7 +152,7 @@ def message_stream(slack: SlackClient) -> Generator[Event, None, None]:
print("Connection failed - retrying")
def dict_to_event(update: dict) -> Event:
def message_dict_to_event(update: dict) -> Event:
"""
Converts a dict update to an actual event.
"""
@ -219,18 +229,10 @@ class ClientWrapper(object):
queue = asyncio.Queue()
# Create a task to put rtm events to the queue
async def put_rtm():
async for t1 in self.rtm_event_feed():
await queue.put(t1)
rtm_task = asyncio.Task(put_rtm())
rtm_task = asyncio.create_task(self.rtm_event_feed(queue))
# Create a task to put http events to the queue
async def put_http():
async for t2 in self.http_event_feed():
await queue.put(t2)
http_task = asyncio.Task(put_http())
http_task = asyncio.create_task(self.http_event_feed(queue))
# Create a task to handle all other tasks
async def handle_task_loop():
@ -239,11 +241,10 @@ class ClientWrapper(object):
if DEBUG_MODE:
await t3
# Create a task to read and process events from the queue
handler_task = handle_task_loop()
await asyncio.gather(rtm_task, http_task, handler_task)
# Handle them all
await asyncio.gather(rtm_task, http_task, handle_task_loop())
async def rtm_event_feed(self) -> AsyncGenerator[Event, None]:
async def rtm_event_feed(self, msg_queue: asyncio.Queue) -> None:
"""
Async wrapper around the message feed.
Yields messages awaitably forever.
@ -257,26 +258,58 @@ class ClientWrapper(object):
# Continuously yield async threaded tasks that poll the feed
while True:
yield await asyncio.get_running_loop().run_in_executor(None, get_one)
next_event = await asyncio.get_running_loop().run_in_executor(None, get_one)
await msg_queue.put(next_event)
async def http_event_feed(self) -> AsyncGenerator[Event, None]:
async def http_event_feed(self, event_queue: asyncio.Queue) -> None:
# Create a callback to convert requests to events
async def interr(request: web.Request):
return web.Response()
if request.can_read_body:
# Get the payload
body_dict = await request.json()
payload = body_dict["payload"]
# Handle each action separately
if "actions" in payload:
for action in payload["actions"]:
# Start building the event
ev = Event()
# Get the user who clicked the button
ev.user = UserContext(payload["user"]["id"])
# Get the channel it was clicked in
ev.conversation = ConversationContext(payload["channel"]["id"])
# Get the message this button/action was attached to
ev.interaction = InteractiveContext(payload["response_url"],
payload["trigger_id"],
action["block_id"],
action["action_id"],
action["value"])
# Put it in the queue
await event_queue.put(ev)
# Respond that everything is fine
return web.Response(status=200)
else:
# If we can't read it, get mad
return web.Response(status=400)
# Create the server
app = web.Application()
app.add_routes([web.get('/bothttpcallback/', interr)])
app.add_routes([web.get('/bothttpcallback', interr)])
# Asynchronously serve that boy up
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(runner, 'localhost', 8080)
site = web.TCPSite(runner, 'localhost', 31019)
await site.start()
while True:
await asyncio.sleep(30)
yield Event()
# print("Server up")
# while True:
# await asyncio.sleep(30)
async def spool_tasks(self, event_queue: asyncio.Queue) -> AsyncGenerator[asyncio.Task, Any]:
"""
@ -502,7 +535,8 @@ class ChannelHook(AbsHook):
patterns: Union[str, List[str]],
channel_whitelist: Optional[List[str]] = None,
channel_blacklist: Optional[List[str]] = None,
consumer: bool = True):
consumer: bool = True,
allow_dms: bool = True):
super(ChannelHook, self).__init__(consumer)
# Save all
@ -513,6 +547,7 @@ class ChannelHook(AbsHook):
self.channel_whitelist = channel_whitelist
self.channel_blacklist = channel_blacklist
self.callback = callback
self.allows_dms = allow_dms
# Remedy some sensible defaults
if self.channel_blacklist is None:
@ -527,7 +562,7 @@ class ChannelHook(AbsHook):
Returns whether a message should be handled by this dict, returning a Match if so, or None
"""
# Ensure that this is an event in a specific channel, with a text component
if not (event.conversation and event.message and isinstance(event.conversation.get_conversation(), Channel)):
if not (event.conversation and event.message):
return None
# Fail if pattern invalid
@ -541,7 +576,12 @@ class ChannelHook(AbsHook):
return None
# Get the channel name
channel_name = event.conversation.get_conversation().name
if isinstance(event.conversation.get_conversation(), Channel):
channel_name = event.conversation.get_conversation().name
elif self.allows_dms:
channel_name = "DIRECT_MSG"
else:
return None
# Fail if whitelist defined, and we aren't there
if self.channel_whitelist is not None and channel_name not in self.channel_whitelist: