from abc import ABC, abstractmethod from discord import Guild, Member, Message from discord.ext import commands from datetime import timedelta import re from cogs.basecog import BaseCog, BotMessage, BotMessageReaction from config import CONFIG from storage import Storage class PatternAction: def __init__(self, type: str, args: list): self.type = type self.arguments = list(args) def __str__(self) -> str: arg_str = ', '.join(self.arguments) return f'{self.type}({arg_str})' class PatternExpression(ABC): def __init__(self): pass @abstractmethod def matches(self, message: Message) -> bool: return False class PatternSimpleExpression(PatternExpression): def __init__(self, field: str, operator: str, value): self.field = field self.operator = operator self.value = value def matches(self, message: Message) -> bool: field_value = None if self.field == 'content': field_value = message.content elif self.field == 'author': field_value = str(message.author.id) elif self.field == 'author.id': field_value = str(message.author.id) elif self.field == 'author.joinage': field_value = message.created_at - message.author.joined_at elif self.field == 'author.name': field_value = message.author.name else: raise ValueError(f'Bad field name {self.field}') if self.operator == '==': if isinstance(field_value, str) and isinstance(self.value, str): return field_value.lower() == self.value.lower() return field_value == self.value if self.operator == '!=': if isinstance(field_value, str) and isinstance(self.value, str): return field_value.lower() != self.value.lower() return field_value != self.value if self.operator == '<': return field_value < self.value if self.operator == '>': return field_value > self.value if self.operator == '<=': return field_value <= self.value if self.operator == '>=': return field_value >= self.value if self.operator == 'contains': return self.value.lower() in field_value.lower() if self.operator == '!contains': return self.value.lower() not in field_value.lower() if self.operator == 'matches': p = re.compile(self.value.lower()) return p.match(field_value.lower()) is not None if self.operator == '!matches': p = re.compile(self.value.lower()) return p.match(field_value.lower()) is None raise ValueError(f'Bad operator {self.operator}') def __str__(self) -> str: return f'({self.field} {self.operator} {self.value})' class PatternCompoundExpression(PatternExpression): def __init__(self, operator: str, operands: list): self.operator = operator self.operands = list(operands) def matches(self, message: Message) -> bool: if self.operator == '!': return not self.operands[0].matches(message) elif self.operator == 'and': for op in self.operands: if not op.matches(message): return False return True elif self.operator == 'or': for op in self.operands: if op.matches(message): return True return False else: raise RuntimeError(f'Bad operator "{self.operator}"') def __str__(self) -> str: if self.operator == '!': return f'(!( {self.operands[0]} ))' else: strs = map(str, self.operands) joined = f' {self.operator} '.join(strs) return f'( {joined} )' class PatternStatement: def __init__(self, name: str, actions: list, expression: PatternExpression, original: str): self.name = name self.actions = list(actions) # PatternAction[] self.expression = expression self.original = original class PatternContext: def __init__(self, message: Message, statement: PatternStatement): self.message = message self.statement = statement self.is_deleted = False self.is_kicked = False self.is_banned = False class PatternCog(BaseCog): def __init__(self, bot): super().__init__(bot) # def __patterns(self, guild: Guild) -> list: # patterns = Storage.get_state_value(guild, 'pattern_patterns') # if patterns is None: # patterns_encoded = Storage.get_config_value(guild, 'pattern_patterns') # if patterns_encoded: # patterns = [] # for pe in patterns_encoded: # patterns.append(Pattern.decode(pe)) # Storage.set_state_value(guild, 'pattern_patterns', patterns) # return patterns def __get_patterns(self, guild: Guild) -> dict: patterns = Storage.get_state_value(guild, 'PatternCog.patterns') if patterns is None: patterns = {} patterns_encoded = Storage.get_config_value(guild, 'PatternCog.patterns') if patterns_encoded: for pe in patterns_encoded: name = pe.get('name') statement = pe.get('statement') try: ps = PatternCompiler.parse_statement(name, statement) patterns[name] = ps except RuntimeError as e: self.log(guild, f'Error parsing saved statement "{name}". Skipping: {statement}') Storage.set_state_value(guild, 'PatternCog.patterns', patterns) return patterns def __save_patterns(self, guild: Guild, patterns: dict) -> None: to_save = [] for name, statement in patterns.items(): to_save.append({ 'name': name, 'statement': statement.original, }) Storage.set_config_value(guild, 'PatternCog.patterns', to_save) @commands.Cog.listener() async def on_message(self, message: Message) -> None: if message.author is None or \ message.author.bot or \ message.channel is None or \ message.guild is None or \ message.content is None or \ message.content == '': return if message.author.permissions_in(message.channel).ban_members: # Ignore mods return patterns = self.__get_patterns(message.guild) for name, statement in patterns.items(): if statement.expression.matches(message): await self.__trigger_actions(message, statement) break async def __trigger_actions(self, message: Message, statement: PatternStatement) -> None: context = PatternContext(message, statement) should_alert_mods = False action_descriptions = [] self.log(message.guild, f'Message from {message.author.name} matched pattern "{statement.name}"') for action in statement.actions: if action.type == 'ban': await message.author.ban( reason=f'Rocketbot: Message matched custom pattern named "{statement.name}"', delete_message_days=0) context.is_banned = True context.is_kicked = True action_descriptions.append('Author banned') self.log(message.guild, f'{message.author.name} banned') elif action.type == 'delete': await message.delete() context.is_deleted = True action_descriptions.append('Message deleted') self.log(message.guild, f'{message.author.name}\'s message deleted') elif action.type == 'kick': await message.author.kick( reason=f'Rocketbot: Message matched custom pattern named "{statement.name}"') context.is_kicked = True action_descriptions.append('Author kicked') self.log(message.guild, f'{message.author.name} kicked') elif action.type == 'modwarn': should_alert_mods = True action_descriptions.append('Mods alerted') elif action.type == 'reply': await message.reply( f'{action.arguments[0]}', mention_author=False) action_descriptions.append('Autoreplied') self.log(message.guild, f'{message.author.name} autoreplied to') bm = BotMessage( message.guild, f'User {message.author.name} tripped custom pattern ' + \ f'`{statement.name}`.\n\nAutomatic actions taken:\n• ' + \ ('\n• '.join(action_descriptions)), type=BotMessage.TYPE_MOD_WARNING if should_alert_mods else BotMessage.TYPE_INFO, context=context) bm.quote = message.content await bm.set_reactions(BotMessageReaction.standard_set( did_delete=context.is_deleted, did_kick=context.is_kicked, did_ban=context.is_banned)) await self.post_message(bm) async def on_mod_react(self, bot_message: BotMessage, reaction: BotMessageReaction, reacted_by: Member) -> None: context: PatternContext = bot_message.context if reaction.emoji == CONFIG['trash_emoji']: await context.message.delete() context.is_deleted = True elif reaction.emoji == CONFIG['kick_emoji']: await context.message.author.kick( reason=f'Rocketbot: Message matched custom pattern named ' + \ '"{statement.name}". Kicked by {reacted_by.name}.') context.is_kicked = True elif reaction.emoji == CONFIG['ban_emoji']: await context.message.author.ban( reason=f'Rocketbot: Message matched custom pattern named ' + \ '"{statement.name}". Banned by {reacted_by.name}.', delete_message_days=1) context.is_banned = True await bot_message.set_reactions(BotMessageReaction.standard_set( did_delete=context.is_deleted, did_kick=context.is_kicked, did_ban=context.is_banned)) @commands.group( brief='Manages message pattern matching', ) @commands.has_permissions(ban_members=True) @commands.guild_only() async def pattern(self, context: commands.Context): 'Message pattern matching' if context.invoked_subcommand is None: await context.send_help() @pattern.command( brief='Adds a custom pattern', description='Adds a custom pattern. Patterns use a simplified ' + \ 'expression language. Full documentation found here: ' + \ 'https://git.rixafrix.com/ialbert/python-app-rocketbot/src/branch/master/patterns.md', usage=' ', ignore_extra=True ) async def add(self, context: commands.Context, name: str): pattern_str = PatternCompiler.expression_str_from_context(context, name) try: statement = PatternCompiler.parse_statement(name, pattern_str) patterns = self.__get_patterns(context.guild) patterns[name] = statement self.__save_patterns(context.guild, patterns) await context.message.reply( f'{CONFIG["success_emoji"]} Pattern `{name}` added.', mention_author=False) except Exception as e: await context.message.reply( f'{CONFIG["failure_emoji"]} Error parsing statement. {e}', mention_author=False) @pattern.command( brief='Removes a custom pattern', usage='' ) async def remove(self, context: commands.Context, name: str): patterns = self.__get_patterns(context.guild) if patterns.get(name) is not None: del patterns[name] self.__save_patterns(context.guild, patterns) await context.message.reply( f'{CONFIG["success_emoji"]} Pattern `{name}` deleted.', mention_author=False) else: await context.message.reply( f'{CONFIG["failure_emoji"]} No pattern named `{name}`.', mention_author=False) @pattern.command( brief='Lists all patterns' ) async def list(self, context: commands.Context) -> None: patterns = self.__get_patterns(context.guild) if len(patterns) == 0: await context.message.reply('No patterns defined.', mention_author=False) return msg = '' for name, statement in sorted(patterns.items()): msg += f'Pattern `{name}`:\n```\n{statement.original}\n```\n' await context.message.reply(msg, mention_author=False) class PatternCompiler: TYPE_ID = 'id' TYPE_MEMBER = 'Member' TYPE_TEXT = 'text' TYPE_INT = 'int' TYPE_FLOAT = 'float' TYPE_TIMESPAN = 'timespan' FIELD_TO_TYPE = { 'content': TYPE_TEXT, 'author': TYPE_MEMBER, 'author.id': TYPE_ID, 'author.name': TYPE_TEXT, 'author.joinage': TYPE_TIMESPAN, } ACTION_TO_ARGS = { 'ban': [], 'delete': [], 'kick': [], 'modwarn': [], 'reply': [ TYPE_TEXT ], } OPERATORS_IDENTITY = set([ '==', '!=' ]) OPERATORS_COMPARISON = set([ '<', '>', '<=', '>=' ]) OPERATORS_NUMERIC = OPERATORS_IDENTITY | OPERATORS_COMPARISON OPERATORS_TEXT = OPERATORS_IDENTITY | set([ 'contains', '!contains', 'matches', '!matches' ]) OPERATORS_ALL = OPERATORS_IDENTITY | OPERATORS_COMPARISON | OPERATORS_TEXT TYPE_TO_OPERATORS = { TYPE_ID: OPERATORS_IDENTITY, TYPE_MEMBER: OPERATORS_IDENTITY, TYPE_TEXT: OPERATORS_TEXT, TYPE_INT: OPERATORS_NUMERIC, TYPE_FLOAT: OPERATORS_NUMERIC, TYPE_TIMESPAN: OPERATORS_NUMERIC, } WHITESPACE_CHARS = ' \t\n\r' STRING_QUOTE_CHARS = '\'"' SYMBOL_CHARS = 'abcdefghijklmnopqrstuvwxyz.' VALUE_CHARS = '0123456789dhms<@!>' OP_CHARS = '<=>!(),' @classmethod def expression_str_from_context(cls, context: commands.Context, name: str) -> str: pattern_str = context.message.content command_chain = [ name ] cmd = context.command while cmd: command_chain.insert(0, cmd.name) cmd = cmd.parent command_chain[0] = f'{context.prefix}{command_chain[0]}' for cmd in command_chain: if pattern_str.startswith(cmd): pattern_str = pattern_str[len(cmd):].lstrip() return pattern_str @classmethod def parse_statement(cls, name: str, statement: str) -> PatternStatement: tokens = cls.tokenize(statement) token_index = 0 actions, token_index = cls.read_actions(tokens, token_index) expression, token_index = cls.read_expression(tokens, token_index) return PatternStatement(name, actions, expression, statement) @classmethod def tokenize(cls, statement: str) -> list: tokens = [] in_quote = False in_escape = False all_token_types = set([ 'sym', 'op', 'val' ]) possible_token_types = set(all_token_types) current_token = '' for ch in statement: if in_quote: if in_escape: if ch == 'n': current_token += '\n' elif ch == 't': current_token += '\t' else: current_token += ch in_escape = False elif ch == '\\': in_escape = True elif ch == in_quote: current_token += ch tokens.append(current_token) current_token = '' possible_token_types |= all_token_types in_quote = False else: current_token += ch else: if ch in cls.STRING_QUOTE_CHARS: if len(current_token) > 0: tokens.append(current_token) current_token = '' possible_token_types |= all_token_types in_quote = ch current_token = ch elif ch == '\\': raise RuntimeError("Unexpected \\") elif ch in cls.WHITESPACE_CHARS: if len(current_token) > 0: tokens.append(current_token) current_token = '' possible_token_types |= all_token_types else: possible_ch_types = set() if ch in cls.SYMBOL_CHARS: possible_ch_types.add('sym') if ch in cls.VALUE_CHARS: possible_ch_types.add('val') if ch in cls.OP_CHARS: possible_ch_types.add('op') if len(current_token) > 0 and possible_ch_types.isdisjoint(possible_token_types): if len(current_token) > 0: tokens.append(current_token) current_token = '' possible_token_types |= all_token_types possible_token_types &= possible_ch_types current_token += ch if len(current_token) > 0: tokens.append(current_token) # Some symbols might be glommed onto other tokens. Split 'em up. prefixes_to_split = [ '!', '(', ',' ] suffixes_to_split = [ ')', ',' ] i = 0 while i < len(tokens): token = tokens[i] mutated = False for prefix in prefixes_to_split: if token.startswith(prefix) and len(token) > len(prefix): tokens.insert(i, prefix) tokens[i + 1] = token[len(prefix):] i += 1 mutated = True break if mutated: continue for suffix in suffixes_to_split: if token.endswith(suffix) and len(token) > len(suffix): tokens[i] = token[0:-len(suffix)] tokens.insert(i + 1, suffix) mutated = True break if mutated: continue i += 1 return tokens @classmethod def read_actions(cls, tokens: list, token_index: int) -> tuple: actions = [] current_action_tokens = [] while token_index < len(tokens): token = tokens[token_index] if token == 'if': if len(current_action_tokens) > 0: a = PatternAction(current_action_tokens[0], current_action_tokens[1:]) cls.__validate_action(a) actions.append(a) token_index += 1 return (actions, token_index) elif token == ',': if len(current_action_tokens) < 1: raise RuntimeError('Unexpected ,') a = PatternAction(current_action_tokens[0], current_action_tokens[1:]) cls.__validate_action(a) actions.append(a) current_action_tokens = [] else: current_action_tokens.append(token) token_index += 1 raise RuntimeError('Unexpected end of line') @classmethod def __validate_action(cls, action: PatternAction) -> None: args = cls.ACTION_TO_ARGS.get(action.type) if args is None: raise RuntimeError(f'Unknown action "{action.type}"') if len(action.arguments) != len(args): arg_list = ', '.join(args) if len(args) == 0: raise RuntimeError(f'Action "{action.type}" expects no arguments, got {len(action.arguments)}.') else: raise RuntimeError(f'Action "{action.type}" expects {len(args)} arguments, got {len(action.arguments)}.') for i in range(len(args)): datatype = args[i] action.arguments[i] = cls.parse_value(action.arguments[i], datatype) @classmethod def read_expression(cls, tokens: list, token_index: int, depth: int = 0, one_subexpression: bool = False) -> tuple: # field op value # (field op value) # !(field op value) # field op value and field op value # (field op value and field op value) or field op value indent = '\t' * depth subexpressions = [] last_compound_operator = None while token_index < len(tokens): if one_subexpression: if len(subexpressions) == 1: return (subexpressions[0], token_index) elif len(subexpressions) > 1: raise RuntimeError('Too many subexpressions') compound_operator = None if tokens[token_index] == ')': if len(subexpressions) == 0: raise RuntimeError('No subexpressions') elif len(subexpressions) == 1: return (subexpressions[0], token_index) else: return (PatternCompoundExpression(last_compound_operator, subexpressions), token_index) if tokens[token_index] in set(["and", "or"]): compound_operator = tokens[token_index] if last_compound_operator and compound_operator != last_compound_operator: subexpressions = [ PatternCompoundExpression(last_compound_operator, subexpressions) ] last_compound_operator = compound_operator else: last_compound_operator = compound_operator token_index += 1 if tokens[token_index] == '!': (exp, next_index) = cls.read_expression(tokens, token_index + 1, depth + 1, one_subexpression=True) subexpressions.append(PatternCompoundExpression('!', [exp])) token_index = next_index elif tokens[token_index] == '(': (exp, next_index) = cls.read_expression(tokens, token_index + 1, depth + 1) if tokens[next_index] != ')': raise RuntimeError('Expected )') subexpressions.append(exp) token_index = next_index + 1 else: (simple, next_index) = cls.read_simple_expression(tokens, token_index, depth) subexpressions.append(simple) token_index = next_index if len(subexpressions) == 0: raise RuntimeError('No subexpressions') elif len(subexpressions) == 1: return (subexpressions[0], token_index) else: return (PatternCompoundExpression(last_compound_operator, subexpressions), token_index) @classmethod def read_simple_expression(cls, tokens: list, token_index: int, depth: int = 0) -> tuple: indent = '\t' * depth if token_index >= len(tokens): raise RuntimeError('Expected field name, found EOL') field = tokens[token_index] token_index += 1 datatype = cls.FIELD_TO_TYPE.get(field) if datatype is None: raise RuntimeError(f'No such field "{field}"') if token_index >= len(tokens): raise RuntimeError('Expected operator, found EOL') op = tokens[token_index] token_index += 1 if op == '!': if token_index >= len(tokens): raise RuntimeError('Expected operator, found EOL') op = '!' + tokens[token_index] token_index += 1 allowed_ops = cls.TYPE_TO_OPERATORS[datatype] if op not in allowed_ops: if op in cls.OPERATORS_ALL: raise RuntimeError(f'Operator {op} cannot be used with field "{field}"') else: raise RuntimeError(f'Unrecognized operator "{op}" - allowed: {list(allowed_ops)}') if token_index >= len(tokens): raise RuntimeError('Expected value, found EOL') value = tokens[token_index] value = cls.parse_value(value, datatype) token_index += 1 exp = PatternSimpleExpression(field, op, value) return (exp, token_index) @classmethod def parse_value(cls, value: str, type: str): if type == cls.TYPE_ID: p = re.compile('^[0-9]+$') if p.match(value) is None: raise ValueError(f'Illegal id value "{value}"') # Store it as a str so it can be larger than an int return value if type == cls.TYPE_MEMBER: p = re.compile('^<@!?([0-9]+)>$') m = p.match(value) if m is None: raise ValueError(f'Illegal member value. Must be an @ mention.') return m.group(1) if type == cls.TYPE_TEXT: # Must be quoted. if len(value) < 2 or \ value[0:1] not in cls.STRING_QUOTE_CHARS or \ value[-1:] not in cls.STRING_QUOTE_CHARS or \ value[0:1] != value[-1:]: raise ValueError(f'Not a quoted string value: {value}') return value[1:-1] if type == cls.TYPE_INT: return int(value) if type == cls.TYPE_FLOAT: return float(value) if type == cls.TYPE_TIMESPAN: p = re.compile('^(?:[0-9]+[dhms])+$') if p.match(value) is None: raise RuntimeError("Illegal timespan value \"{value}\". Must be like \"100d\", \"5m30s\", etc.") p = re.compile('([0-9]+)([dhms])') days = 0 hours = 0 minutes = 0 seconds = 0 for m in p.finditer(value): scalar = int(m.group(1)) unit = m.group(2) if unit == 'd': days = scalar elif unit == 'h': hours = scalar elif unit == 'm': minutes = scalar elif unit == 's': seconds = scalar return timedelta(days=days, hours=hours, minutes=minutes, seconds=seconds) raise ValueError(f'Unhandled datatype {datatype}')