import re from abc import ABC, abstractmethod from discord import Guild, Member, Message from discord.ext import commands from cogs.basecog import BaseCog, BotMessage, BotMessageReaction from config import CONFIG from rbutils import parse_timedelta from storage import Storage class PatternAction: """ Describes one action to take on a matched message or its author. """ def __init__(self, action: str, args: list): self.action = action self.arguments = list(args) def __str__(self) -> str: arg_str = ', '.join(self.arguments) return f'{self.action}({arg_str})' class PatternExpression(ABC): """ Abstract message matching expression. """ def __init__(self): pass @abstractmethod def matches(self, message: Message) -> bool: """ Whether a message matches this expression. """ return False class PatternSimpleExpression(PatternExpression): """ Message matching expression with a simple " " structure. """ def __init__(self, field: str, operator: str, value): super().__init__() self.field = field self.operator = operator self.value = value def __field_value(self, message: Message): if self.field == 'content': return message.content if self.field == 'author': return str(message.author.id) if self.field == 'author.id': return str(message.author.id) if self.field == 'author.joinage': return message.created_at - message.author.joined_at if self.field == 'author.name': return message.author.name else: raise ValueError(f'Bad field name {self.field}') def matches(self, message: Message) -> bool: field_value = self.__field_value(message) if self.operator == '==': if isinstance(field_value, str) and isinstance(self.value, str): return field_value.lower() == self.value.lower() return field_value == self.value if self.operator == '!=': if isinstance(field_value, str) and isinstance(self.value, str): return field_value.lower() != self.value.lower() return field_value != self.value if self.operator == '<': return field_value < self.value if self.operator == '>': return field_value > self.value if self.operator == '<=': return field_value <= self.value if self.operator == '>=': return field_value >= self.value if self.operator == 'contains': return self.value.lower() in field_value.lower() if self.operator == '!contains': return self.value.lower() not in field_value.lower() if self.operator == 'matches': p = re.compile(self.value.lower()) return p.match(field_value.lower()) is not None if self.operator == '!matches': p = re.compile(self.value.lower()) return p.match(field_value.lower()) is None raise ValueError(f'Bad operator {self.operator}') def __str__(self) -> str: return f'({self.field} {self.operator} {self.value})' class PatternCompoundExpression(PatternExpression): """ Message matching expression that combines several child expressions with a boolean operator. """ def __init__(self, operator: str, operands: list): super().__init__() self.operator = operator self.operands = list(operands) def matches(self, message: Message) -> bool: if self.operator == '!': return not self.operands[0].matches(message) if self.operator == 'and': for op in self.operands: if not op.matches(message): return False return True if self.operator == 'or': for op in self.operands: if op.matches(message): return True return False raise RuntimeError(f'Bad operator "{self.operator}"') def __str__(self) -> str: if self.operator == '!': return f'(!( {self.operands[0]} ))' strs = map(str, self.operands) joined = f' {self.operator} '.join(strs) return f'( {joined} )' class PatternStatement: """ A full message match statement. If a message matches the given expression, the given actions should be performed. """ def __init__(self, name: str, actions: list, expression: PatternExpression, original: str): self.name = name self.actions = list(actions) # PatternAction[] self.expression = expression self.original = original class PatternContext: """ Data about a message that has matched a configured statement and what actions have been carried out. """ def __init__(self, message: Message, statement: PatternStatement): self.message = message self.statement = statement self.is_deleted = False self.is_kicked = False self.is_banned = False class PatternCog(BaseCog, name='Pattern Matching'): """ Highly flexible cog for performing various actions on messages that match various critera. Patterns can be defined by mods for each guild. """ def __get_patterns(self, guild: Guild) -> dict: patterns = Storage.get_state_value(guild, 'PatternCog.patterns') if patterns is None: patterns = {} patterns_encoded = Storage.get_config_value(guild, 'PatternCog.patterns') if patterns_encoded: for pe in patterns_encoded: name = pe.get('name') statement = pe.get('statement') try: ps = PatternCompiler.parse_statement(name, statement) patterns[name] = ps except Exception as e: self.log(guild, f'Error parsing saved statement "{name}". Skipping: {statement}. Error: {e}') Storage.set_state_value(guild, 'PatternCog.patterns', patterns) return patterns @classmethod def __save_patterns(cls, guild: Guild, patterns: dict) -> None: to_save = [] for name, statement in patterns.items(): to_save.append({ 'name': name, 'statement': statement.original, }) Storage.set_config_value(guild, 'PatternCog.patterns', to_save) @commands.Cog.listener() async def on_message(self, message: Message) -> None: 'Event listener' if message.author is None or \ message.author.bot or \ message.channel is None or \ message.guild is None or \ message.content is None or \ message.content == '': return if message.author.permissions_in(message.channel).ban_members: # Ignore mods return patterns = self.__get_patterns(message.guild) for _, statement in patterns.items(): if statement.expression.matches(message): await self.__trigger_actions(message, statement) break async def __trigger_actions(self, message: Message, statement: PatternStatement) -> None: context = PatternContext(message, statement) should_alert_mods = False action_descriptions = [] self.log(message.guild, f'Message from {message.author.name} matched pattern "{statement.name}"') for action in statement.actions: if action.action == 'ban': await message.author.ban( reason=f'Rocketbot: Message matched custom pattern named "{statement.name}"', delete_message_days=0) context.is_banned = True context.is_kicked = True action_descriptions.append('Author banned') self.log(message.guild, f'{message.author.name} banned') elif action.action == 'delete': await message.delete() context.is_deleted = True action_descriptions.append('Message deleted') self.log(message.guild, f'{message.author.name}\'s message deleted') elif action.action == 'kick': await message.author.kick( reason=f'Rocketbot: Message matched custom pattern named "{statement.name}"') context.is_kicked = True action_descriptions.append('Author kicked') self.log(message.guild, f'{message.author.name} kicked') elif action.action == 'modwarn': should_alert_mods = True action_descriptions.append('Mods alerted') elif action.action == 'reply': await message.reply( f'{action.arguments[0]}', mention_author=False) action_descriptions.append('Autoreplied') self.log(message.guild, f'{message.author.name} autoreplied to') bm = BotMessage( message.guild, f'User {message.author.name} tripped custom pattern ' + \ f'`{statement.name}`.\n\nAutomatic actions taken:\n• ' + \ ('\n• '.join(action_descriptions)), type=BotMessage.TYPE_MOD_WARNING if should_alert_mods else BotMessage.TYPE_INFO, context=context) bm.quote = message.content await bm.set_reactions(BotMessageReaction.standard_set( did_delete=context.is_deleted, did_kick=context.is_kicked, did_ban=context.is_banned)) await self.post_message(bm) async def on_mod_react(self, bot_message: BotMessage, reaction: BotMessageReaction, reacted_by: Member) -> None: context: PatternContext = bot_message.context if reaction.emoji == CONFIG['trash_emoji']: await context.message.delete() context.is_deleted = True elif reaction.emoji == CONFIG['kick_emoji']: await context.message.author.kick( reason='Rocketbot: Message matched custom pattern named ' + \ f'"{context.statement.name}". Kicked by {reacted_by.name}.') context.is_kicked = True elif reaction.emoji == CONFIG['ban_emoji']: await context.message.author.ban( reason='Rocketbot: Message matched custom pattern named ' + \ f'"{context.statement.name}". Banned by {reacted_by.name}.', delete_message_days=1) context.is_banned = True await bot_message.set_reactions(BotMessageReaction.standard_set( did_delete=context.is_deleted, did_kick=context.is_kicked, did_ban=context.is_banned)) @commands.group( brief='Manages message pattern matching', ) @commands.has_permissions(ban_members=True) @commands.guild_only() async def pattern(self, context: commands.Context): 'Message pattern matching command group' if context.invoked_subcommand is None: await context.send_help() @pattern.command( brief='Adds a custom pattern', description='Adds a custom pattern. Patterns use a simplified ' + \ 'expression language. Full documentation found here: ' + \ 'https://git.rixafrix.com/ialbert/python-app-rocketbot/src/branch/master/patterns.md', usage=' ', ignore_extra=True ) async def add(self, context: commands.Context, name: str): 'Command handler' pattern_str = PatternCompiler.expression_str_from_context(context, name) try: statement = PatternCompiler.parse_statement(name, pattern_str) patterns = self.__get_patterns(context.guild) patterns[name] = statement self.__save_patterns(context.guild, patterns) await context.message.reply( f'{CONFIG["success_emoji"]} Pattern `{name}` added.', mention_author=False) except Exception as e: await context.message.reply( f'{CONFIG["failure_emoji"]} Error parsing statement. {e}', mention_author=False) @pattern.command( brief='Removes a custom pattern', usage='' ) async def remove(self, context: commands.Context, name: str): 'Command handler' patterns = self.__get_patterns(context.guild) if patterns.get(name) is not None: del patterns[name] self.__save_patterns(context.guild, patterns) await context.message.reply( f'{CONFIG["success_emoji"]} Pattern `{name}` deleted.', mention_author=False) else: await context.message.reply( f'{CONFIG["failure_emoji"]} No pattern named `{name}`.', mention_author=False) @pattern.command( brief='Lists all patterns' ) async def list(self, context: commands.Context) -> None: 'Command handler' patterns = self.__get_patterns(context.guild) if len(patterns) == 0: await context.message.reply('No patterns defined.', mention_author=False) return msg = '' for name, statement in sorted(patterns.items()): msg += f'Pattern `{name}`:\n```\n{statement.original}\n```\n' await context.message.reply(msg, mention_author=False) class PatternCompiler: """ Parses a user-provided message filter statement into a PatternStatement. """ TYPE_ID = 'id' TYPE_MEMBER = 'Member' TYPE_TEXT = 'text' TYPE_INT = 'int' TYPE_FLOAT = 'float' TYPE_TIMESPAN = 'timespan' FIELD_TO_TYPE = { 'content': TYPE_TEXT, 'author': TYPE_MEMBER, 'author.id': TYPE_ID, 'author.name': TYPE_TEXT, 'author.joinage': TYPE_TIMESPAN, } ACTION_TO_ARGS = { 'ban': [], 'delete': [], 'kick': [], 'modwarn': [], 'reply': [ TYPE_TEXT ], } OPERATORS_IDENTITY = set([ '==', '!=' ]) OPERATORS_COMPARISON = set([ '<', '>', '<=', '>=' ]) OPERATORS_NUMERIC = OPERATORS_IDENTITY | OPERATORS_COMPARISON OPERATORS_TEXT = OPERATORS_IDENTITY | set([ 'contains', '!contains', 'matches', '!matches' ]) OPERATORS_ALL = OPERATORS_IDENTITY | OPERATORS_COMPARISON | OPERATORS_TEXT TYPE_TO_OPERATORS = { TYPE_ID: OPERATORS_IDENTITY, TYPE_MEMBER: OPERATORS_IDENTITY, TYPE_TEXT: OPERATORS_TEXT, TYPE_INT: OPERATORS_NUMERIC, TYPE_FLOAT: OPERATORS_NUMERIC, TYPE_TIMESPAN: OPERATORS_NUMERIC, } WHITESPACE_CHARS = ' \t\n\r' STRING_QUOTE_CHARS = '\'"' SYMBOL_CHARS = 'abcdefghijklmnopqrstuvwxyz.' VALUE_CHARS = '0123456789dhms<@!>' OP_CHARS = '<=>!(),' @classmethod def expression_str_from_context(cls, context: commands.Context, name: str) -> str: """ Extracts the statement string from an "add" command context. """ pattern_str = context.message.content command_chain = [ name ] cmd = context.command while cmd: command_chain.insert(0, cmd.name) cmd = cmd.parent command_chain[0] = f'{context.prefix}{command_chain[0]}' for cmd in command_chain: if pattern_str.startswith(cmd): pattern_str = pattern_str[len(cmd):].lstrip() elif pattern_str.startswith(f'"{cmd}"'): pattern_str = pattern_str[len(cmd) + 2:].lstrip() return pattern_str @classmethod def parse_statement(cls, name: str, statement: str) -> PatternStatement: """ Parses a user-provided message filter statement into a PatternStatement. """ tokens = cls.tokenize(statement) token_index = 0 actions, token_index = cls.read_actions(tokens, token_index) expression, token_index = cls.read_expression(tokens, token_index) return PatternStatement(name, actions, expression, statement) @classmethod def tokenize(cls, statement: str) -> list: """ Converts a message filter statement into a list of tokens. """ tokens = [] in_quote = False in_escape = False all_token_types = set([ 'sym', 'op', 'val' ]) possible_token_types = set(all_token_types) current_token = '' for ch in statement: if in_quote: if in_escape: if ch == 'n': current_token += '\n' elif ch == 't': current_token += '\t' else: current_token += ch in_escape = False elif ch == '\\': in_escape = True elif ch == in_quote: current_token += ch tokens.append(current_token) current_token = '' possible_token_types |= all_token_types in_quote = False else: current_token += ch else: if ch in cls.STRING_QUOTE_CHARS: if len(current_token) > 0: tokens.append(current_token) current_token = '' possible_token_types |= all_token_types in_quote = ch current_token = ch elif ch == '\\': raise RuntimeError("Unexpected \\") elif ch in cls.WHITESPACE_CHARS: if len(current_token) > 0: tokens.append(current_token) current_token = '' possible_token_types |= all_token_types else: possible_ch_types = set() if ch in cls.SYMBOL_CHARS: possible_ch_types.add('sym') if ch in cls.VALUE_CHARS: possible_ch_types.add('val') if ch in cls.OP_CHARS: possible_ch_types.add('op') if len(current_token) > 0 and possible_ch_types.isdisjoint(possible_token_types): if len(current_token) > 0: tokens.append(current_token) current_token = '' possible_token_types |= all_token_types possible_token_types &= possible_ch_types current_token += ch if len(current_token) > 0: tokens.append(current_token) # Some symbols might be glommed onto other tokens. Split 'em up. prefixes_to_split = [ '!', '(', ',' ] suffixes_to_split = [ ')', ',' ] i = 0 while i < len(tokens): token = tokens[i] mutated = False for prefix in prefixes_to_split: if token.startswith(prefix) and len(token) > len(prefix): tokens.insert(i, prefix) tokens[i + 1] = token[len(prefix):] i += 1 mutated = True break if mutated: continue for suffix in suffixes_to_split: if token.endswith(suffix) and len(token) > len(suffix): tokens[i] = token[0:-len(suffix)] tokens.insert(i + 1, suffix) mutated = True break if mutated: continue i += 1 return tokens @classmethod def read_actions(cls, tokens: list, token_index: int) -> tuple: """ Reads the actions from a list of statement tokens. Returns a tuple containing a list of PatternActions and the token index this method left off at (the token after the "if"). """ actions = [] current_action_tokens = [] while token_index < len(tokens): token = tokens[token_index] if token == 'if': if len(current_action_tokens) > 0: a = PatternAction(current_action_tokens[0], current_action_tokens[1:]) cls.__validate_action(a) actions.append(a) token_index += 1 return (actions, token_index) elif token == ',': if len(current_action_tokens) < 1: raise RuntimeError('Unexpected ,') a = PatternAction(current_action_tokens[0], current_action_tokens[1:]) cls.__validate_action(a) actions.append(a) current_action_tokens = [] else: current_action_tokens.append(token) token_index += 1 raise RuntimeError('Unexpected end of line') @classmethod def __validate_action(cls, action: PatternAction) -> None: args = cls.ACTION_TO_ARGS.get(action.action) if args is None: raise RuntimeError(f'Unknown action "{action.action}"') if len(action.arguments) != len(args): if len(args) == 0: raise RuntimeError(f'Action "{action.action}" expects no arguments, ' + \ f'got {len(action.arguments)}.') else: raise RuntimeError(f'Action "{action.action}" expects {len(args)} ' + \ f'arguments, got {len(action.arguments)}.') for i, datatype in enumerate(args): action.arguments[i] = cls.parse_value(action.arguments[i], datatype) @classmethod def read_expression(cls, tokens: list, token_index: int, depth: int = 0, one_subexpression: bool = False) -> tuple: """ Reads an expression from a list of statement tokens. Returns a tuple containing the PatternExpression and the token index it left off at. If one_subexpression is True then it will return after reading a single expression instead of joining multiples (for readong the subject of a NOT expression). """ subexpressions = [] last_compound_operator = None while token_index < len(tokens): if one_subexpression: if len(subexpressions) == 1: return (subexpressions[0], token_index) if len(subexpressions) > 1: raise RuntimeError('Too many subexpressions') compound_operator = None if tokens[token_index] == ')': if len(subexpressions) == 0: raise RuntimeError('No subexpressions') if len(subexpressions) == 1: return (subexpressions[0], token_index) return (PatternCompoundExpression(last_compound_operator, subexpressions), token_index) if tokens[token_index] in set(["and", "or"]): compound_operator = tokens[token_index] if last_compound_operator and compound_operator != last_compound_operator: subexpressions = [ PatternCompoundExpression(last_compound_operator, subexpressions) ] last_compound_operator = compound_operator else: last_compound_operator = compound_operator token_index += 1 if tokens[token_index] == '!': (exp, next_index) = cls.read_expression(tokens, token_index + 1, \ depth + 1, one_subexpression=True) subexpressions.append(PatternCompoundExpression('!', [exp])) token_index = next_index elif tokens[token_index] == '(': (exp, next_index) = cls.read_expression(tokens, token_index + 1, depth + 1) if tokens[next_index] != ')': raise RuntimeError('Expected )') subexpressions.append(exp) token_index = next_index + 1 else: (simple, next_index) = cls.read_simple_expression(tokens, token_index, depth) subexpressions.append(simple) token_index = next_index if len(subexpressions) == 0: raise RuntimeError('No subexpressions') elif len(subexpressions) == 1: return (subexpressions[0], token_index) else: return (PatternCompoundExpression(last_compound_operator, subexpressions), token_index) @classmethod def read_simple_expression(cls, tokens: list, token_index: int, depth: int = 0) -> tuple: """ Reads a simple expression consisting of a field name, operator, and comparison value. Returns a tuple of the PatternSimpleExpression and the token index it left off at. """ if depth > 8: raise RuntimeError('Expression nests too deeply') if token_index >= len(tokens): raise RuntimeError('Expected field name, found EOL') field = tokens[token_index] token_index += 1 datatype = cls.FIELD_TO_TYPE.get(field) if datatype is None: raise RuntimeError(f'No such field "{field}"') if token_index >= len(tokens): raise RuntimeError('Expected operator, found EOL') op = tokens[token_index] token_index += 1 if op == '!': if token_index >= len(tokens): raise RuntimeError('Expected operator, found EOL') op = '!' + tokens[token_index] token_index += 1 allowed_ops = cls.TYPE_TO_OPERATORS[datatype] if op not in allowed_ops: if op in cls.OPERATORS_ALL: raise RuntimeError(f'Operator {op} cannot be used with field "{field}"') else: raise RuntimeError(f'Unrecognized operator "{op}" - allowed: {list(allowed_ops)}') if token_index >= len(tokens): raise RuntimeError('Expected value, found EOL') value = tokens[token_index] value = cls.parse_value(value, datatype) token_index += 1 exp = PatternSimpleExpression(field, op, value) return (exp, token_index) @classmethod def parse_value(cls, value: str, datatype: str): """ Converts a value token to its Python value. """ if datatype == cls.TYPE_ID: p = re.compile('^[0-9]+$') if p.match(value) is None: raise ValueError(f'Illegal id value "{value}"') # Store it as a str so it can be larger than an int return value if datatype == cls.TYPE_MEMBER: p = re.compile('^<@!?([0-9]+)>$') m = p.match(value) if m is None: raise ValueError('Illegal member value. Must be an @ mention.') return m.group(1) if datatype == cls.TYPE_TEXT: # Must be quoted. if len(value) < 2 or \ value[0:1] not in cls.STRING_QUOTE_CHARS or \ value[-1:] not in cls.STRING_QUOTE_CHARS or \ value[0:1] != value[-1:]: raise ValueError(f'Not a quoted string value: {value}') return value[1:-1] if datatype == cls.TYPE_INT: return int(value) if datatype == cls.TYPE_FLOAT: return float(value) if datatype == cls.TYPE_TIMESPAN: return parse_timedelta(value) raise ValueError(f'Unhandled datatype {datatype}')