Non-whitelisted commands in channel_hooks now by default are allowed in DM's
This commit is contained in:
parent
6aabcc72e1
commit
1e2be3f80d
|
|
@ -401,3 +401,23 @@ refresh_hook = slack_util.ChannelHook(refresh_callback,
|
|||
"update points"
|
||||
],
|
||||
channel_whitelist=["#command-center"])
|
||||
|
||||
block_action = """
|
||||
[
|
||||
{
|
||||
"type": "actions",
|
||||
"block_id": "test_block_id",
|
||||
"elements": [
|
||||
{
|
||||
"type": "button",
|
||||
"action_id": "test_action_id",
|
||||
"text": {
|
||||
"type": "plain_text",
|
||||
"text": "Send payload",
|
||||
"emoji": false
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -68,6 +68,7 @@ class Event:
|
|||
user: Optional[UserContext] = None
|
||||
message: Optional[MessageContext] = None
|
||||
thread: Optional[ThreadContext] = None
|
||||
interaction: Optional[InteractiveContext] = None
|
||||
|
||||
|
||||
# If this was posted in a specific channel or conversation
|
||||
|
|
@ -100,6 +101,15 @@ class ThreadContext:
|
|||
thread_ts: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class InteractiveContext:
|
||||
response_url: str # Used to confirm/respond to requests
|
||||
trigger_id: str # Used to open popups
|
||||
block_id: str # Identifies the block of the interacted component
|
||||
action_id: str # Identifies the interacted component
|
||||
action_value: str # Identifies the selected value in the component
|
||||
|
||||
|
||||
# If a file was additionally shared
|
||||
@dataclass
|
||||
class File:
|
||||
|
|
@ -129,7 +139,7 @@ def message_stream(slack: SlackClient) -> Generator[Event, None, None]:
|
|||
# Handle each
|
||||
for update in update_list:
|
||||
print("Message received: {}".format(update))
|
||||
yield dict_to_event(update)
|
||||
yield message_dict_to_event(update)
|
||||
|
||||
except (SlackNotConnected, OSError) as e:
|
||||
print("Error while reading messages:")
|
||||
|
|
@ -142,7 +152,7 @@ def message_stream(slack: SlackClient) -> Generator[Event, None, None]:
|
|||
print("Connection failed - retrying")
|
||||
|
||||
|
||||
def dict_to_event(update: dict) -> Event:
|
||||
def message_dict_to_event(update: dict) -> Event:
|
||||
"""
|
||||
Converts a dict update to an actual event.
|
||||
"""
|
||||
|
|
@ -219,18 +229,10 @@ class ClientWrapper(object):
|
|||
queue = asyncio.Queue()
|
||||
|
||||
# Create a task to put rtm events to the queue
|
||||
async def put_rtm():
|
||||
async for t1 in self.rtm_event_feed():
|
||||
await queue.put(t1)
|
||||
|
||||
rtm_task = asyncio.Task(put_rtm())
|
||||
rtm_task = asyncio.create_task(self.rtm_event_feed(queue))
|
||||
|
||||
# Create a task to put http events to the queue
|
||||
async def put_http():
|
||||
async for t2 in self.http_event_feed():
|
||||
await queue.put(t2)
|
||||
|
||||
http_task = asyncio.Task(put_http())
|
||||
http_task = asyncio.create_task(self.http_event_feed(queue))
|
||||
|
||||
# Create a task to handle all other tasks
|
||||
async def handle_task_loop():
|
||||
|
|
@ -239,11 +241,10 @@ class ClientWrapper(object):
|
|||
if DEBUG_MODE:
|
||||
await t3
|
||||
|
||||
# Create a task to read and process events from the queue
|
||||
handler_task = handle_task_loop()
|
||||
await asyncio.gather(rtm_task, http_task, handler_task)
|
||||
# Handle them all
|
||||
await asyncio.gather(rtm_task, http_task, handle_task_loop())
|
||||
|
||||
async def rtm_event_feed(self) -> AsyncGenerator[Event, None]:
|
||||
async def rtm_event_feed(self, msg_queue: asyncio.Queue) -> None:
|
||||
"""
|
||||
Async wrapper around the message feed.
|
||||
Yields messages awaitably forever.
|
||||
|
|
@ -257,26 +258,58 @@ class ClientWrapper(object):
|
|||
|
||||
# Continuously yield async threaded tasks that poll the feed
|
||||
while True:
|
||||
yield await asyncio.get_running_loop().run_in_executor(None, get_one)
|
||||
next_event = await asyncio.get_running_loop().run_in_executor(None, get_one)
|
||||
await msg_queue.put(next_event)
|
||||
|
||||
async def http_event_feed(self) -> AsyncGenerator[Event, None]:
|
||||
async def http_event_feed(self, event_queue: asyncio.Queue) -> None:
|
||||
# Create a callback to convert requests to events
|
||||
async def interr(request: web.Request):
|
||||
return web.Response()
|
||||
if request.can_read_body:
|
||||
# Get the payload
|
||||
body_dict = await request.json()
|
||||
payload = body_dict["payload"]
|
||||
|
||||
# Handle each action separately
|
||||
if "actions" in payload:
|
||||
for action in payload["actions"]:
|
||||
|
||||
# Start building the event
|
||||
ev = Event()
|
||||
|
||||
# Get the user who clicked the button
|
||||
ev.user = UserContext(payload["user"]["id"])
|
||||
|
||||
# Get the channel it was clicked in
|
||||
ev.conversation = ConversationContext(payload["channel"]["id"])
|
||||
|
||||
# Get the message this button/action was attached to
|
||||
ev.interaction = InteractiveContext(payload["response_url"],
|
||||
payload["trigger_id"],
|
||||
action["block_id"],
|
||||
action["action_id"],
|
||||
action["value"])
|
||||
|
||||
# Put it in the queue
|
||||
await event_queue.put(ev)
|
||||
|
||||
# Respond that everything is fine
|
||||
return web.Response(status=200)
|
||||
else:
|
||||
# If we can't read it, get mad
|
||||
return web.Response(status=400)
|
||||
|
||||
# Create the server
|
||||
app = web.Application()
|
||||
app.add_routes([web.get('/bothttpcallback/', interr)])
|
||||
app.add_routes([web.get('/bothttpcallback', interr)])
|
||||
|
||||
# Asynchronously serve that boy up
|
||||
runner = web.AppRunner(app)
|
||||
await runner.setup()
|
||||
site = web.TCPSite(runner, 'localhost', 8080)
|
||||
site = web.TCPSite(runner, 'localhost', 31019)
|
||||
await site.start()
|
||||
|
||||
while True:
|
||||
await asyncio.sleep(30)
|
||||
yield Event()
|
||||
# print("Server up")
|
||||
# while True:
|
||||
# await asyncio.sleep(30)
|
||||
|
||||
async def spool_tasks(self, event_queue: asyncio.Queue) -> AsyncGenerator[asyncio.Task, Any]:
|
||||
"""
|
||||
|
|
@ -502,7 +535,8 @@ class ChannelHook(AbsHook):
|
|||
patterns: Union[str, List[str]],
|
||||
channel_whitelist: Optional[List[str]] = None,
|
||||
channel_blacklist: Optional[List[str]] = None,
|
||||
consumer: bool = True):
|
||||
consumer: bool = True,
|
||||
allow_dms: bool = True):
|
||||
super(ChannelHook, self).__init__(consumer)
|
||||
|
||||
# Save all
|
||||
|
|
@ -513,6 +547,7 @@ class ChannelHook(AbsHook):
|
|||
self.channel_whitelist = channel_whitelist
|
||||
self.channel_blacklist = channel_blacklist
|
||||
self.callback = callback
|
||||
self.allows_dms = allow_dms
|
||||
|
||||
# Remedy some sensible defaults
|
||||
if self.channel_blacklist is None:
|
||||
|
|
@ -527,7 +562,7 @@ class ChannelHook(AbsHook):
|
|||
Returns whether a message should be handled by this dict, returning a Match if so, or None
|
||||
"""
|
||||
# Ensure that this is an event in a specific channel, with a text component
|
||||
if not (event.conversation and event.message and isinstance(event.conversation.get_conversation(), Channel)):
|
||||
if not (event.conversation and event.message):
|
||||
return None
|
||||
|
||||
# Fail if pattern invalid
|
||||
|
|
@ -541,7 +576,12 @@ class ChannelHook(AbsHook):
|
|||
return None
|
||||
|
||||
# Get the channel name
|
||||
if isinstance(event.conversation.get_conversation(), Channel):
|
||||
channel_name = event.conversation.get_conversation().name
|
||||
elif self.allows_dms:
|
||||
channel_name = "DIRECT_MSG"
|
||||
else:
|
||||
return None
|
||||
|
||||
# Fail if whitelist defined, and we aren't there
|
||||
if self.channel_whitelist is not None and channel_name not in self.channel_whitelist:
|
||||
|
|
|
|||
Loading…
Reference in New Issue