""" Cog for matching messages against guild-configurable criteria and taking automated actions on them. """ import re from abc import ABCMeta, abstractmethod from discord import Guild, Member, Message from discord.ext import commands from config import CONFIG from rocketbot.cogs.basecog import BaseCog, BotMessage, BotMessageReaction from rocketbot.storage import Storage from rocketbot.utils import parse_timedelta class PatternAction: """ Describes one action to take on a matched message or its author. """ def __init__(self, action: str, args: list): self.action = action self.arguments = list(args) def __str__(self) -> str: arg_str = ', '.join(self.arguments) return f'{self.action}({arg_str})' class PatternExpression(metaclass=ABCMeta): """ Abstract message matching expression. """ def __init__(self): pass @abstractmethod def matches(self, message: Message) -> bool: """ Whether a message matches this expression. """ return False class PatternSimpleExpression(PatternExpression): """ Message matching expression with a simple " " structure. """ def __init__(self, field: str, operator: str, value): super().__init__() self.field = field self.operator = operator self.value = value def __field_value(self, message: Message): if self.field == 'content': return message.content if self.field == 'author': return str(message.author.id) if self.field == 'author.id': return str(message.author.id) if self.field == 'author.joinage': return message.created_at - message.author.joined_at if self.field == 'author.name': return message.author.name else: raise ValueError(f'Bad field name {self.field}') def matches(self, message: Message) -> bool: field_value = self.__field_value(message) if self.operator == '==': if isinstance(field_value, str) and isinstance(self.value, str): return field_value.lower() == self.value.lower() return field_value == self.value if self.operator == '!=': if isinstance(field_value, str) and isinstance(self.value, str): return field_value.lower() != self.value.lower() return field_value != self.value if self.operator == '<': return field_value < self.value if self.operator == '>': return field_value > self.value if self.operator == '<=': return field_value <= self.value if self.operator == '>=': return field_value >= self.value if self.operator == 'contains': return self.value.lower() in field_value.lower() if self.operator == '!contains': return self.value.lower() not in field_value.lower() if self.operator == 'matches': p = re.compile(self.value.lower()) return p.match(field_value.lower()) is not None if self.operator == '!matches': p = re.compile(self.value.lower()) return p.match(field_value.lower()) is None raise ValueError(f'Bad operator {self.operator}') def __str__(self) -> str: return f'({self.field} {self.operator} {self.value})' class PatternCompoundExpression(PatternExpression): """ Message matching expression that combines several child expressions with a boolean operator. """ def __init__(self, operator: str, operands: list): super().__init__() self.operator = operator self.operands = list(operands) def matches(self, message: Message) -> bool: if self.operator == '!': return not self.operands[0].matches(message) if self.operator == 'and': for op in self.operands: if not op.matches(message): return False return True if self.operator == 'or': for op in self.operands: if op.matches(message): return True return False raise RuntimeError(f'Bad operator "{self.operator}"') def __str__(self) -> str: if self.operator == '!': return f'(!( {self.operands[0]} ))' strs = map(str, self.operands) joined = f' {self.operator} '.join(strs) return f'( {joined} )' class PatternStatement: """ A full message match statement. If a message matches the given expression, the given actions should be performed. """ def __init__(self, name: str, actions: list, expression: PatternExpression, original: str): self.name = name self.actions = list(actions) # PatternAction[] self.expression = expression self.original = original class PatternContext: """ Data about a message that has matched a configured statement and what actions have been carried out. """ def __init__(self, message: Message, statement: PatternStatement): self.message = message self.statement = statement self.is_deleted = False self.is_kicked = False self.is_banned = False class PatternCog(BaseCog, name='Pattern Matching'): """ Highly flexible cog for performing various actions on messages that match various critera. Patterns can be defined by mods for each guild. """ def __get_patterns(self, guild: Guild) -> dict: patterns = Storage.get_state_value(guild, 'PatternCog.patterns') if patterns is None: patterns = {} patterns_encoded = Storage.get_config_value(guild, 'PatternCog.patterns') if patterns_encoded: for pe in patterns_encoded: name = pe.get('name') statement = pe.get('statement') try: ps = PatternCompiler.parse_statement(name, statement) patterns[name] = ps except PatternError as e: self.log(guild, 'Error parsing saved statement ' + \ f'"{name}": "{e}" Statement: {statement}') Storage.set_state_value(guild, 'PatternCog.patterns', patterns) return patterns @classmethod def __save_patterns(cls, guild: Guild, patterns: dict) -> None: to_save = [] for name, statement in patterns.items(): to_save.append({ 'name': name, 'statement': statement.original, }) Storage.set_config_value(guild, 'PatternCog.patterns', to_save) @commands.Cog.listener() async def on_message(self, message: Message) -> None: 'Event listener' if message.author is None or \ message.author.bot or \ message.channel is None or \ message.guild is None or \ message.content is None or \ message.content == '': return if message.author.permissions_in(message.channel).ban_members: # Ignore mods return patterns = self.__get_patterns(message.guild) for _, statement in patterns.items(): if statement.expression.matches(message): await self.__trigger_actions(message, statement) break async def __trigger_actions(self, message: Message, statement: PatternStatement) -> None: context = PatternContext(message, statement) should_alert_mods = False action_descriptions = [] self.log(message.guild, f'Message from {message.author.name} matched pattern "{statement.name}"') for action in statement.actions: if action.action == 'ban': await message.author.ban( reason=f'Rocketbot: Message matched custom pattern named "{statement.name}"', delete_message_days=0) context.is_banned = True context.is_kicked = True action_descriptions.append('Author banned') self.log(message.guild, f'{message.author.name} banned') elif action.action == 'delete': await message.delete() context.is_deleted = True action_descriptions.append('Message deleted') self.log(message.guild, f'{message.author.name}\'s message deleted') elif action.action == 'kick': await message.author.kick( reason=f'Rocketbot: Message matched custom pattern named "{statement.name}"') context.is_kicked = True action_descriptions.append('Author kicked') self.log(message.guild, f'{message.author.name} kicked') elif action.action == 'modwarn': should_alert_mods = True action_descriptions.append('Mods alerted') elif action.action == 'reply': await message.reply( f'{action.arguments[0]}', mention_author=False) action_descriptions.append('Autoreplied') self.log(message.guild, f'{message.author.name} autoreplied to') bm = BotMessage( message.guild, f'User {message.author.name} tripped custom pattern ' + \ f'`{statement.name}`.\n\nAutomatic actions taken:\n• ' + \ ('\n• '.join(action_descriptions)), type=BotMessage.TYPE_MOD_WARNING if should_alert_mods else BotMessage.TYPE_INFO, context=context) bm.quote = message.content await bm.set_reactions(BotMessageReaction.standard_set( did_delete=context.is_deleted, did_kick=context.is_kicked, did_ban=context.is_banned)) await self.post_message(bm) async def on_mod_react(self, bot_message: BotMessage, reaction: BotMessageReaction, reacted_by: Member) -> None: context: PatternContext = bot_message.context if reaction.emoji == CONFIG['trash_emoji']: await context.message.delete() context.is_deleted = True elif reaction.emoji == CONFIG['kick_emoji']: await context.message.author.kick( reason='Rocketbot: Message matched custom pattern named ' + \ f'"{context.statement.name}". Kicked by {reacted_by.name}.') context.is_kicked = True elif reaction.emoji == CONFIG['ban_emoji']: await context.message.author.ban( reason='Rocketbot: Message matched custom pattern named ' + \ f'"{context.statement.name}". Banned by {reacted_by.name}.', delete_message_days=1) context.is_banned = True await bot_message.set_reactions(BotMessageReaction.standard_set( did_delete=context.is_deleted, did_kick=context.is_kicked, did_ban=context.is_banned)) @commands.group( brief='Manages message pattern matching', ) @commands.has_permissions(ban_members=True) @commands.guild_only() async def pattern(self, context: commands.Context): 'Message pattern matching command group' if context.invoked_subcommand is None: await context.send_help() @pattern.command( brief='Adds a custom pattern', description='Adds a custom pattern. Patterns use a simplified ' + \ 'expression language. Full documentation found here: ' + \ 'https://git.rixafrix.com/ialbert/python-app-rocketbot/src/branch/master/patterns.md', usage=' ', ignore_extra=True ) async def add(self, context: commands.Context, name: str): 'Command handler' pattern_str = PatternCompiler.expression_str_from_context(context, name) try: statement = PatternCompiler.parse_statement(name, pattern_str) patterns = self.__get_patterns(context.guild) patterns[name] = statement self.__save_patterns(context.guild, patterns) await context.message.reply( f'{CONFIG["success_emoji"]} Pattern `{name}` added.', mention_author=False) except PatternError as e: await context.message.reply( f'{CONFIG["failure_emoji"]} Error parsing statement. {e}', mention_author=False) @pattern.command( brief='Removes a custom pattern', usage='' ) async def remove(self, context: commands.Context, name: str): 'Command handler' patterns = self.__get_patterns(context.guild) if patterns.get(name) is not None: del patterns[name] self.__save_patterns(context.guild, patterns) await context.message.reply( f'{CONFIG["success_emoji"]} Pattern `{name}` deleted.', mention_author=False) else: await context.message.reply( f'{CONFIG["failure_emoji"]} No pattern named `{name}`.', mention_author=False) @pattern.command( brief='Lists all patterns' ) async def list(self, context: commands.Context) -> None: 'Command handler' patterns = self.__get_patterns(context.guild) if len(patterns) == 0: await context.message.reply('No patterns defined.', mention_author=False) return msg = '' for name, statement in sorted(patterns.items()): msg += f'Pattern `{name}`:\n```\n{statement.original}\n```\n' await context.message.reply(msg, mention_author=False) class PatternError(RuntimeError): """ Error thrown when parsing a pattern statement. """ class PatternCompiler: """ Parses a user-provided message filter statement into a PatternStatement. """ TYPE_ID = 'id' TYPE_MEMBER = 'Member' TYPE_TEXT = 'text' TYPE_INT = 'int' TYPE_FLOAT = 'float' TYPE_TIMESPAN = 'timespan' FIELD_TO_TYPE = { 'content': TYPE_TEXT, 'author': TYPE_MEMBER, 'author.id': TYPE_ID, 'author.name': TYPE_TEXT, 'author.joinage': TYPE_TIMESPAN, } ACTION_TO_ARGS = { 'ban': [], 'delete': [], 'kick': [], 'modwarn': [], 'reply': [ TYPE_TEXT ], } OPERATORS_IDENTITY = set([ '==', '!=' ]) OPERATORS_COMPARISON = set([ '<', '>', '<=', '>=' ]) OPERATORS_NUMERIC = OPERATORS_IDENTITY | OPERATORS_COMPARISON OPERATORS_TEXT = OPERATORS_IDENTITY | set([ 'contains', '!contains', 'matches', '!matches' ]) OPERATORS_ALL = OPERATORS_IDENTITY | OPERATORS_COMPARISON | OPERATORS_TEXT TYPE_TO_OPERATORS = { TYPE_ID: OPERATORS_IDENTITY, TYPE_MEMBER: OPERATORS_IDENTITY, TYPE_TEXT: OPERATORS_TEXT, TYPE_INT: OPERATORS_NUMERIC, TYPE_FLOAT: OPERATORS_NUMERIC, TYPE_TIMESPAN: OPERATORS_NUMERIC, } WHITESPACE_CHARS = ' \t\n\r' STRING_QUOTE_CHARS = '\'"' SYMBOL_CHARS = 'abcdefghijklmnopqrstuvwxyz.' VALUE_CHARS = '0123456789dhms<@!>' OP_CHARS = '<=>!(),' @classmethod def expression_str_from_context(cls, context: commands.Context, name: str) -> str: """ Extracts the statement string from an "add" command context. """ pattern_str = context.message.content command_chain = [ name ] cmd = context.command while cmd: command_chain.insert(0, cmd.name) cmd = cmd.parent command_chain[0] = f'{context.prefix}{command_chain[0]}' for cmd in command_chain: if pattern_str.startswith(cmd): pattern_str = pattern_str[len(cmd):].lstrip() elif pattern_str.startswith(f'"{cmd}"'): pattern_str = pattern_str[len(cmd) + 2:].lstrip() return pattern_str @classmethod def parse_statement(cls, name: str, statement: str) -> PatternStatement: """ Parses a user-provided message filter statement into a PatternStatement. """ tokens = cls.tokenize(statement) token_index = 0 actions, token_index = cls.read_actions(tokens, token_index) expression, token_index = cls.read_expression(tokens, token_index) return PatternStatement(name, actions, expression, statement) @classmethod def tokenize(cls, statement: str) -> list: """ Converts a message filter statement into a list of tokens. """ tokens = [] in_quote = False in_escape = False all_token_types = set([ 'sym', 'op', 'val' ]) possible_token_types = set(all_token_types) current_token = '' for ch in statement: if in_quote: if in_escape: if ch == 'n': current_token += '\n' elif ch == 't': current_token += '\t' else: current_token += ch in_escape = False elif ch == '\\': in_escape = True elif ch == in_quote: current_token += ch tokens.append(current_token) current_token = '' possible_token_types |= all_token_types in_quote = False else: current_token += ch else: if ch in cls.STRING_QUOTE_CHARS: if len(current_token) > 0: tokens.append(current_token) current_token = '' possible_token_types |= all_token_types in_quote = ch current_token = ch elif ch == '\\': raise PatternError("Unexpected \\ outside quoted string") elif ch in cls.WHITESPACE_CHARS: if len(current_token) > 0: tokens.append(current_token) current_token = '' possible_token_types |= all_token_types else: possible_ch_types = set() if ch in cls.SYMBOL_CHARS: possible_ch_types.add('sym') if ch in cls.VALUE_CHARS: possible_ch_types.add('val') if ch in cls.OP_CHARS: possible_ch_types.add('op') if len(current_token) > 0 and possible_ch_types.isdisjoint(possible_token_types): if len(current_token) > 0: tokens.append(current_token) current_token = '' possible_token_types |= all_token_types possible_token_types &= possible_ch_types current_token += ch if len(current_token) > 0: tokens.append(current_token) # Some symbols might be glommed onto other tokens. Split 'em up. prefixes_to_split = [ '!', '(', ',' ] suffixes_to_split = [ ')', ',' ] i = 0 while i < len(tokens): token = tokens[i] mutated = False for prefix in prefixes_to_split: if token.startswith(prefix) and len(token) > len(prefix): tokens.insert(i, prefix) tokens[i + 1] = token[len(prefix):] i += 1 mutated = True break if mutated: continue for suffix in suffixes_to_split: if token.endswith(suffix) and len(token) > len(suffix): tokens[i] = token[0:-len(suffix)] tokens.insert(i + 1, suffix) mutated = True break if mutated: continue i += 1 return tokens @classmethod def read_actions(cls, tokens: list, token_index: int) -> tuple: """ Reads the actions from a list of statement tokens. Returns a tuple containing a list of PatternActions and the token index this method left off at (the token after the "if"). """ actions = [] current_action_tokens = [] while token_index < len(tokens): token = tokens[token_index] if token == 'if': if len(current_action_tokens) > 0: a = PatternAction(current_action_tokens[0], current_action_tokens[1:]) cls.__validate_action(a) actions.append(a) token_index += 1 return (actions, token_index) elif token == ',': if len(current_action_tokens) < 1: raise PatternError('Unexpected ,') a = PatternAction(current_action_tokens[0], current_action_tokens[1:]) cls.__validate_action(a) actions.append(a) current_action_tokens = [] else: current_action_tokens.append(token) token_index += 1 raise PatternError('Unexpected end of line in action list') @classmethod def __validate_action(cls, action: PatternAction) -> None: args = cls.ACTION_TO_ARGS.get(action.action) if args is None: raise PatternError(f'Unknown action "{action.action}"') if len(action.arguments) != len(args): if len(args) == 0: raise PatternError(f'Action "{action.action}" expects no arguments, ' + \ f'got {len(action.arguments)}.') else: raise PatternError(f'Action "{action.action}" expects {len(args)} ' + \ f'arguments, got {len(action.arguments)}.') for i, datatype in enumerate(args): action.arguments[i] = cls.parse_value(action.arguments[i], datatype) @classmethod def read_expression(cls, tokens: list, token_index: int, depth: int = 0, one_subexpression: bool = False) -> tuple: """ Reads an expression from a list of statement tokens. Returns a tuple containing the PatternExpression and the token index it left off at. If one_subexpression is True then it will return after reading a single expression instead of joining multiples (for readong the subject of a NOT expression). """ subexpressions = [] last_compound_operator = None while token_index < len(tokens): if one_subexpression: if len(subexpressions) == 1: return (subexpressions[0], token_index) if len(subexpressions) > 1: raise PatternError('Too many subexpressions') compound_operator = None if tokens[token_index] == ')': if len(subexpressions) == 0: raise PatternError('No subexpressions') if len(subexpressions) == 1: return (subexpressions[0], token_index) return (PatternCompoundExpression(last_compound_operator, subexpressions), token_index) if tokens[token_index] in set(["and", "or"]): compound_operator = tokens[token_index] if last_compound_operator and compound_operator != last_compound_operator: subexpressions = [ PatternCompoundExpression(last_compound_operator, subexpressions) ] last_compound_operator = compound_operator else: last_compound_operator = compound_operator token_index += 1 if tokens[token_index] == '!': (exp, next_index) = cls.read_expression(tokens, token_index + 1, \ depth + 1, one_subexpression=True) subexpressions.append(PatternCompoundExpression('!', [exp])) token_index = next_index elif tokens[token_index] == '(': (exp, next_index) = cls.read_expression(tokens, token_index + 1, depth + 1) if tokens[next_index] != ')': raise PatternError('Expected )') subexpressions.append(exp) token_index = next_index + 1 else: (simple, next_index) = cls.read_simple_expression(tokens, token_index, depth) subexpressions.append(simple) token_index = next_index if len(subexpressions) == 0: raise PatternError('No subexpressions') elif len(subexpressions) == 1: return (subexpressions[0], token_index) else: return (PatternCompoundExpression(last_compound_operator, subexpressions), token_index) @classmethod def read_simple_expression(cls, tokens: list, token_index: int, depth: int = 0) -> tuple: """ Reads a simple expression consisting of a field name, operator, and comparison value. Returns a tuple of the PatternSimpleExpression and the token index it left off at. """ if depth > 8: raise PatternError('Expression nests too deeply') if token_index >= len(tokens): raise PatternError('Expected field name, found EOL') field = tokens[token_index] token_index += 1 datatype = cls.FIELD_TO_TYPE.get(field) if datatype is None: raise PatternError(f'No such field "{field}"') if token_index >= len(tokens): raise PatternError('Expected operator, found EOL') op = tokens[token_index] token_index += 1 if op == '!': if token_index >= len(tokens): raise PatternError('Expected operator, found EOL') op = '!' + tokens[token_index] token_index += 1 allowed_ops = cls.TYPE_TO_OPERATORS[datatype] if op not in allowed_ops: if op in cls.OPERATORS_ALL: raise PatternError(f'Operator {op} cannot be used with field "{field}"') raise PatternError(f'Unrecognized operator "{op}" - allowed: {list(allowed_ops)}') if token_index >= len(tokens): raise PatternError('Expected value, found EOL') value = tokens[token_index] try: value = cls.parse_value(value, datatype) except ValueError as cause: raise PatternError(f'Bad value {value}') from cause token_index += 1 exp = PatternSimpleExpression(field, op, value) return (exp, token_index) @classmethod def parse_value(cls, value: str, datatype: str): """ Converts a value token to its Python value. Raises ValueError on failure. """ if datatype == cls.TYPE_ID: p = re.compile('^[0-9]+$') if p.match(value) is None: raise ValueError(f'Illegal id value "{value}"') # Store it as a str so it can be larger than an int return value if datatype == cls.TYPE_MEMBER: p = re.compile('^<@!?([0-9]+)>$') m = p.match(value) if m is None: raise ValueError('Illegal member value. Must be an @ mention.') return m.group(1) if datatype == cls.TYPE_TEXT: # Must be quoted. if len(value) < 2 or \ value[0:1] not in cls.STRING_QUOTE_CHARS or \ value[-1:] not in cls.STRING_QUOTE_CHARS or \ value[0:1] != value[-1:]: raise ValueError(f'Not a quoted string value: {value}') return value[1:-1] if datatype == cls.TYPE_INT: return int(value) if datatype == cls.TYPE_FLOAT: return float(value) if datatype == cls.TYPE_TIMESPAN: return parse_timedelta(value) raise ValueError(f'Unhandled datatype {datatype}')