| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653 |
- from abc import ABC, abstractmethod
- from discord import Guild, Member, Message
- from discord.ext import commands
- from datetime import timedelta
- import re
-
- from cogs.basecog import BaseCog, BotMessage, BotMessageReaction
- from config import CONFIG
- from storage import Storage
-
- class PatternAction:
- def __init__(self, type: str, args: list):
- self.type = type
- self.arguments = list(args)
-
- def __str__(self) -> str:
- arg_str = ', '.join(self.arguments)
- return f'{self.type}({arg_str})'
-
- class PatternExpression(ABC):
- def __init__(self):
- pass
-
- @abstractmethod
- def matches(self, message: Message) -> bool:
- return False
-
- class PatternSimpleExpression(PatternExpression):
- def __init__(self, field: str, operator: str, value):
- self.field = field
- self.operator = operator
- self.value = value
-
- def matches(self, message: Message) -> bool:
- field_value = None
- if self.field == 'content':
- field_value = message.content
- elif self.field == 'author':
- field_value = str(message.author.id)
- elif self.field == 'author.id':
- field_value = str(message.author.id)
- elif self.field == 'author.joinage':
- field_value = message.created_at - message.author.joined_at
- elif self.field == 'author.name':
- field_value = message.author.name
- else:
- raise ValueError(f'Bad field name {self.field}')
- if self.operator == '==':
- if isinstance(field_value, str) and isinstance(self.value, str):
- return field_value.lower() == self.value.lower()
- return field_value == self.value
- if self.operator == '!=':
- if isinstance(field_value, str) and isinstance(self.value, str):
- return field_value.lower() != self.value.lower()
- return field_value != self.value
- if self.operator == '<':
- return field_value < self.value
- if self.operator == '>':
- return field_value > self.value
- if self.operator == '<=':
- return field_value <= self.value
- if self.operator == '>=':
- return field_value >= self.value
- if self.operator == 'contains':
- return self.value.lower() in field_value.lower()
- if self.operator == '!contains':
- return self.value.lower() not in field_value.lower()
- if self.operator == 'matches':
- p = re.compile(self.value.lower())
- return p.match(field_value.lower()) is not None
- if self.operator == '!matches':
- p = re.compile(self.value.lower())
- return p.match(field_value.lower()) is None
- raise ValueError(f'Bad operator {self.operator}')
-
- def __str__(self) -> str:
- return f'({self.field} {self.operator} {self.value})'
-
- class PatternCompoundExpression(PatternExpression):
- def __init__(self, operator: str, operands: list):
- self.operator = operator
- self.operands = list(operands)
-
- def matches(self, message: Message) -> bool:
- if self.operator == '!':
- return not self.operands[0].matches(message)
- elif self.operator == 'and':
- for op in self.operands:
- if not op.matches(message):
- return False
- return True
- elif self.operator == 'or':
- for op in self.operands:
- if op.matches(message):
- return True
- return False
- else:
- raise RuntimeError(f'Bad operator "{self.operator}"')
-
- def __str__(self) -> str:
- if self.operator == '!':
- return f'(!( {self.operands[0]} ))'
- else:
- strs = map(str, self.operands)
- joined = f' {self.operator} '.join(strs)
- return f'( {joined} )'
-
- class PatternStatement:
- def __init__(self, name: str, actions: list, expression: PatternExpression, original: str):
- self.name = name
- self.actions = list(actions) # PatternAction[]
- self.expression = expression
- self.original = original
-
- class PatternContext:
- def __init__(self, message: Message, statement: PatternStatement):
- self.message = message
- self.statement = statement
- self.is_deleted = False
- self.is_kicked = False
- self.is_banned = False
-
- class PatternCog(BaseCog):
- def __init__(self, bot):
- super().__init__(bot)
-
- # def __patterns(self, guild: Guild) -> list:
- # patterns = Storage.get_state_value(guild, 'pattern_patterns')
- # if patterns is None:
- # patterns_encoded = Storage.get_config_value(guild, 'pattern_patterns')
- # if patterns_encoded:
- # patterns = []
- # for pe in patterns_encoded:
- # patterns.append(Pattern.decode(pe))
- # Storage.set_state_value(guild, 'pattern_patterns', patterns)
- # return patterns
-
- def __get_patterns(self, guild: Guild) -> dict:
- patterns = Storage.get_state_value(guild, 'PatternsCog.patterns')
- if patterns is None:
- patterns = {}
- patterns_encoded = Storage.get_config_value(guild, 'PatternsCog.patterns')
- if patterns_encoded:
- for pe in patterns_encoded:
- name = pe.get('name')
- statement = pe.get('statement')
- try:
- ps = PatternCompiler.parse_statement(name, statement)
- patterns[name] = ps
- except RuntimeError as e:
- self.log(guild, f'Error parsing saved statement "{name}". Skipping: {statement}')
- Storage.set_state_value(guild, 'PatternsCog.patterns', patterns)
- return patterns
-
- def __save_patterns(self, guild: Guild, patterns: dict) -> None:
- to_save = []
- for name, statement in patterns.items():
- to_save.append({
- 'name': name,
- 'statement': statement.original,
- })
- Storage.set_config_value(guild, 'PatternsCog.patterns', to_save)
-
- @commands.Cog.listener()
- async def on_message(self, message: Message) -> None:
- if message.author is None or \
- message.author.bot or \
- message.channel is None or \
- message.guild is None or \
- message.content is None or \
- message.content == '':
- return
- if message.author.permissions_in(message.channel).ban_members:
- # Ignore mods
- return
-
- patterns = self.__get_patterns(message.guild)
- for name, statement in patterns.items():
- if statement.expression.matches(message):
- await self.__trigger_actions(message, statement)
- break
-
-
- async def __trigger_actions(self, message: Message, statement: PatternStatement) -> None:
- context = PatternContext(message, statement)
- should_alert_mods = False
- action_descriptions = []
- self.log(message.guild, f'Message from {message.author.name} matched pattern "{statement.name}"')
- for action in statement.actions:
- if action.type == 'ban':
- await message.author.ban(
- reason=f'Rocketbot: Message matched custom pattern named "{statement.name}"',
- delete_message_days=0)
- context.is_banned = True
- context.is_kicked = True
- action_descriptions.append('Author banned')
- self.log(message.guild, f'{message.author.name} banned')
- elif action.type == 'delete':
- await message.delete()
- context.is_deleted = True
- action_descriptions.append('Message deleted')
- self.log(message.guild, f'{message.author.name}\'s message deleted')
- elif action.type == 'kick':
- await message.author.kick(
- reason=f'Rocketbot: Message matched custom pattern named "{statement.name}"')
- context.is_kicked = True
- action_descriptions.append('Author kicked')
- self.log(message.guild, f'{message.author.name} kicked')
- elif action.type == 'modwarn':
- should_alert_mods = True
- action_descriptions.append('Mods alerted')
- elif action.type == 'reply':
- await message.reply(
- f'{action.arguments[0]}',
- mention_author=False)
- action_descriptions.append('Autoreplied')
- self.log(message.guild, f'{message.author.name} autoreplied to')
- bm = BotMessage(
- message.guild,
- f'User {message.author.name} tripped custom pattern ' + \
- f'`{statement.name}`.\n\nAutomatic actions taken:\n• ' + \
- ('\n• '.join(action_descriptions)),
- type=BotMessage.TYPE_MOD_WARNING if should_alert_mods else BotMessage.TYPE_INFO,
- context=context)
- bm.quote = message.content
- await bm.set_reactions(BotMessageReaction.standard_set(
- did_delete=context.is_deleted,
- did_kick=context.is_kicked,
- did_ban=context.is_banned))
- await self.post_message(bm)
-
- async def on_mod_react(self,
- bot_message: BotMessage,
- reaction: BotMessageReaction,
- reacted_by: Member) -> None:
- context: PatternContext = bot_message.context
- if reaction.emoji == CONFIG['trash_emoji']:
- await context.message.delete()
- context.is_deleted = True
- elif reaction.emoji == CONFIG['kick_emoji']:
- await context.message.author.kick(
- reason=f'Rocketbot: Message matched custom pattern named ' + \
- '"{statement.name}". Kicked by {reacted_by.name}.')
- context.is_kicked = True
- elif reaction.emoji == CONFIG['ban_emoji']:
- await context.message.author.ban(
- reason=f'Rocketbot: Message matched custom pattern named ' + \
- '"{statement.name}". Banned by {reacted_by.name}.',
- delete_message_days=1)
- context.is_banned = True
- await bot_message.set_reactions(BotMessageReaction.standard_set(
- did_delete=context.is_deleted,
- did_kick=context.is_kicked,
- did_ban=context.is_banned))
-
- @commands.group(
- brief='Manages message pattern matching',
- )
- @commands.has_permissions(ban_members=True)
- @commands.guild_only()
- async def patterns(self, context: commands.Context):
- 'Message pattern matching'
- if context.invoked_subcommand is None:
- await context.send_help()
-
- @patterns.command(
- brief='Adds a custom pattern',
- description='Adds a custom pattern. Patterns use a simplified expression language. Full documentation found here: https://git.rixafrix.com/ialbert/python-app-rocketbot/src/branch/master/patterns.md',
- usage='<pattern_name> <expression...>',
- ignore_extra=True
- )
- async def add(self, context: commands.Context, name: str):
- pattern_str = PatternCompiler.expression_str_from_context(context, name)
- try:
- statement = PatternCompiler.parse_statement(name, pattern_str)
- patterns = self.__get_patterns(context.guild)
- patterns[name] = statement
- self.__save_patterns(context.guild, patterns)
- await context.message.reply(
- f'{CONFIG["success_emoji"]} Pattern `{name}` added.',
- mention_author=False)
- except Exception as e:
- await context.message.reply(
- f'{CONFIG["failure_emoji"]} Error parsing statement. {e}',
- mention_author=False)
-
- @patterns.command(
- brief='Removes a custom pattern',
- usage='<pattern_name>'
- )
- async def remove(self, context: commands.Context, name: str):
- patterns = self.__get_patterns(context.guild)
- if patterns.get(name) is not None:
- del patterns[name]
- self.__save_patterns(context.guild, patterns)
- await context.message.reply(
- f'{CONFIG["success_emoji"]} Pattern `{name}` deleted.',
- mention_author=False)
- else:
- await context.message.reply(
- f'{CONFIG["failure_emoji"]} No pattern named `{name}`.',
- mention_author=False)
-
- @patterns.command(
- brief='Lists all patterns'
- )
- async def list(self, context: commands.Context) -> None:
- patterns = self.__get_patterns(context.guild)
- if len(patterns) == 0:
- await context.message.reply('No patterns defined.', mention_author=False)
- return
- msg = ''
- for name, statement in sorted(patterns.items()):
- msg += f'Pattern `{name}`:\n```\n{statement.original}\n```\n'
- await context.message.reply(msg, mention_author=False)
-
- class PatternCompiler:
- TYPE_ID = 'id'
- TYPE_MEMBER = 'Member'
- TYPE_TEXT = 'text'
- TYPE_INT = 'int'
- TYPE_FLOAT = 'float'
- TYPE_TIMESPAN = 'timespan'
-
- FIELD_TO_TYPE = {
- 'content': TYPE_TEXT,
- 'author': TYPE_MEMBER,
- 'author.id': TYPE_ID,
- 'author.name': TYPE_TEXT,
- 'author.joinage': TYPE_TIMESPAN,
- }
-
- ACTION_TO_ARGS = {
- 'ban': [],
- 'delete': [],
- 'kick': [],
- 'modwarn': [],
- 'reply': [ TYPE_TEXT ],
- }
-
- OPERATORS_IDENTITY = set([ '==', '!=' ])
- OPERATORS_COMPARISON = set([ '<', '>', '<=', '>=' ])
- OPERATORS_NUMERIC = OPERATORS_IDENTITY | OPERATORS_COMPARISON
- OPERATORS_TEXT = OPERATORS_IDENTITY | set([ 'contains', '!contains', 'matches', '!matches' ])
- OPERATORS_ALL = OPERATORS_IDENTITY | OPERATORS_COMPARISON | OPERATORS_TEXT
-
- TYPE_TO_OPERATORS = {
- TYPE_ID: OPERATORS_IDENTITY,
- TYPE_MEMBER: OPERATORS_IDENTITY,
- TYPE_TEXT: OPERATORS_TEXT,
- TYPE_INT: OPERATORS_NUMERIC,
- TYPE_FLOAT: OPERATORS_NUMERIC,
- TYPE_TIMESPAN: OPERATORS_NUMERIC,
- }
-
- WHITESPACE_CHARS = ' \t\n\r'
- STRING_QUOTE_CHARS = '\'"'
- SYMBOL_CHARS = 'abcdefghijklmnopqrstuvwxyz.'
- VALUE_CHARS = '0123456789dhms<@!>'
- OP_CHARS = '<=>!(),'
-
- @classmethod
- def expression_str_from_context(cls, context: commands.Context, name: str) -> str:
- pattern_str = context.message.content
- command_chain = [ name ]
- cmd = context.command
- while cmd:
- command_chain.insert(0, cmd.name)
- cmd = cmd.parent
- command_chain[0] = f'{context.prefix}{command_chain[0]}'
- for cmd in command_chain:
- if pattern_str.startswith(cmd):
- pattern_str = pattern_str[len(cmd):].lstrip()
- return pattern_str
-
- @classmethod
- def parse_statement(cls, name: str, statement: str) -> PatternStatement:
- tokens = cls.tokenize(statement)
- token_index = 0
- actions, token_index = cls.read_actions(tokens, token_index)
- expression, token_index = cls.read_expression(tokens, token_index)
- return PatternStatement(name, actions, expression, statement)
-
- @classmethod
- def tokenize(cls, statement: str) -> list:
- tokens = []
- in_quote = False
- in_escape = False
- all_token_types = set([ 'sym', 'op', 'val' ])
- possible_token_types = set(all_token_types)
- current_token = ''
- for ch in statement:
- if in_quote:
- if in_escape:
- if ch == 'n':
- current_token += '\n'
- elif ch == 't':
- current_token += '\t'
- else:
- current_token += ch
- in_escape = False
- elif ch == '\\':
- in_escape = True
- elif ch == in_quote:
- current_token += ch
- tokens.append(current_token)
- current_token = ''
- possible_token_types |= all_token_types
- in_quote = False
- else:
- current_token += ch
- else:
- if ch in cls.STRING_QUOTE_CHARS:
- if len(current_token) > 0:
- tokens.append(current_token)
- current_token = ''
- possible_token_types |= all_token_types
- in_quote = ch
- current_token = ch
- elif ch == '\\':
- raise RuntimeError("Unexpected \\")
- elif ch in cls.WHITESPACE_CHARS:
- if len(current_token) > 0:
- tokens.append(current_token)
- current_token = ''
- possible_token_types |= all_token_types
- else:
- possible_ch_types = set()
- if ch in cls.SYMBOL_CHARS:
- possible_ch_types.add('sym')
- if ch in cls.VALUE_CHARS:
- possible_ch_types.add('val')
- if ch in cls.OP_CHARS:
- possible_ch_types.add('op')
- if len(current_token) > 0 and possible_ch_types.isdisjoint(possible_token_types):
- if len(current_token) > 0:
- tokens.append(current_token)
- current_token = ''
- possible_token_types |= all_token_types
- possible_token_types &= possible_ch_types
- current_token += ch
- if len(current_token) > 0:
- tokens.append(current_token)
-
- # Some symbols might be glommed onto other tokens. Split 'em up.
- prefixes_to_split = [ '!', '(', ',' ]
- suffixes_to_split = [ ')', ',' ]
- i = 0
- while i < len(tokens):
- token = tokens[i]
- mutated = False
- for prefix in prefixes_to_split:
- if token.startswith(prefix) and len(token) > len(prefix):
- tokens.insert(i, prefix)
- tokens[i + 1] = token[len(prefix):]
- i += 1
- mutated = True
- break
- if mutated:
- continue
- for suffix in suffixes_to_split:
- if token.endswith(suffix) and len(token) > len(suffix):
- tokens[i] = token[0:-len(suffix)]
- tokens.insert(i + 1, suffix)
- mutated = True
- break
- if mutated:
- continue
- i += 1
- return tokens
-
- @classmethod
- def read_actions(cls, tokens: list, token_index: int) -> tuple:
- actions = []
- current_action_tokens = []
- while token_index < len(tokens):
- token = tokens[token_index]
- if token == 'if':
- if len(current_action_tokens) > 0:
- a = PatternAction(current_action_tokens[0], current_action_tokens[1:])
- cls.__validate_action(a)
- actions.append(a)
- token_index += 1
- return (actions, token_index)
- elif token == ',':
- if len(current_action_tokens) < 1:
- raise RuntimeError('Unexpected ,')
- a = PatternAction(current_action_tokens[0], current_action_tokens[1:])
- cls.__validate_action(a)
- actions.append(a)
- current_action_tokens = []
- else:
- current_action_tokens.append(token)
- token_index += 1
- raise RuntimeError('Unexpected end of line')
-
- @classmethod
- def __validate_action(cls, action: PatternAction) -> None:
- args = cls.ACTION_TO_ARGS.get(action.type)
- if args is None:
- raise RuntimeError(f'Unknown action "{action.type}"')
- if len(action.arguments) != len(args):
- arg_list = ', '.join(args)
- if len(args) == 0:
- raise RuntimeError(f'Action "{action.type}" expects no arguments, got {len(action.arguments)}.')
- else:
- raise RuntimeError(f'Action "{action.type}" expects {len(args)} arguments, got {len(action.arguments)}.')
- for i in range(len(args)):
- datatype = args[i]
- action.arguments[i] = cls.parse_value(action.arguments[i], datatype)
-
- @classmethod
- def read_expression(cls, tokens: list, token_index: int, depth: int = 0, one_subexpression: bool = False) -> tuple:
- # field op value
- # (field op value)
- # !(field op value)
- # field op value and field op value
- # (field op value and field op value) or field op value
- indent = '\t' * depth
- subexpressions = []
- last_compound_operator = None
- while token_index < len(tokens):
- if one_subexpression:
- if len(subexpressions) == 1:
- return (subexpressions[0], token_index)
- elif len(subexpressions) > 1:
- raise RuntimeError('Too many subexpressions')
- compound_operator = None
- if tokens[token_index] == ')':
- if len(subexpressions) == 0:
- raise RuntimeError('No subexpressions')
- elif len(subexpressions) == 1:
- return (subexpressions[0], token_index)
- else:
- return (PatternCompoundExpression(last_compound_operator, subexpressions), token_index)
- if tokens[token_index] in set(["and", "or"]):
- compound_operator = tokens[token_index]
- if last_compound_operator and compound_operator != last_compound_operator:
- subexpressions = [ PatternCompoundExpression(last_compound_operator, subexpressions) ]
- last_compound_operator = compound_operator
- else:
- last_compound_operator = compound_operator
- token_index += 1
- if tokens[token_index] == '!':
- (exp, next_index) = cls.read_expression(tokens, token_index + 1, depth + 1, one_subexpression=True)
- subexpressions.append(PatternCompoundExpression('!', [exp]))
- token_index = next_index
- elif tokens[token_index] == '(':
- (exp, next_index) = cls.read_expression(tokens, token_index + 1, depth + 1)
- if tokens[next_index] != ')':
- raise RuntimeError('Expected )')
- subexpressions.append(exp)
- token_index = next_index + 1
- else:
- (simple, next_index) = cls.read_simple_expression(tokens, token_index, depth)
- subexpressions.append(simple)
- token_index = next_index
- if len(subexpressions) == 0:
- raise RuntimeError('No subexpressions')
- elif len(subexpressions) == 1:
- return (subexpressions[0], token_index)
- else:
- return (PatternCompoundExpression(last_compound_operator, subexpressions), token_index)
-
- @classmethod
- def read_simple_expression(cls, tokens: list, token_index: int, depth: int = 0) -> tuple:
- indent = '\t' * depth
-
- if token_index >= len(tokens):
- raise RuntimeError('Expected field name, found EOL')
- field = tokens[token_index]
- token_index += 1
-
- datatype = cls.FIELD_TO_TYPE.get(field)
- if datatype is None:
- raise RuntimeError(f'No such field "{field}"')
-
- if token_index >= len(tokens):
- raise RuntimeError('Expected operator, found EOL')
- op = tokens[token_index]
- token_index += 1
-
- if op == '!':
- if token_index >= len(tokens):
- raise RuntimeError('Expected operator, found EOL')
- op = '!' + tokens[token_index]
- token_index += 1
-
- allowed_ops = cls.TYPE_TO_OPERATORS[datatype]
- if op not in allowed_ops:
- if op in cls.OPERATORS_ALL:
- raise RuntimeError(f'Operator {op} cannot be used with field "{field}"')
- else:
- raise RuntimeError(f'Unrecognized operator "{op}" - allowed: {list(allowed_ops)}')
-
- if token_index >= len(tokens):
- raise RuntimeError('Expected value, found EOL')
- value = tokens[token_index]
-
- value = cls.parse_value(value, datatype)
-
- token_index += 1
- exp = PatternSimpleExpression(field, op, value)
- return (exp, token_index)
-
- @classmethod
- def parse_value(cls, value: str, type: str):
- if type == cls.TYPE_ID:
- p = re.compile('^[0-9]+$')
- if p.match(value) is None:
- raise ValueError(f'Illegal id value "{value}"')
- # Store it as a str so it can be larger than an int
- return value
- if type == cls.TYPE_MEMBER:
- p = re.compile('^<@!?([0-9]+)>$')
- m = p.match(value)
- if m is None:
- raise ValueError(f'Illegal member value. Must be an @ mention.')
- return m.group(1)
- if type == cls.TYPE_TEXT:
- # Must be quoted.
- if len(value) < 2 or \
- value[0:1] not in cls.STRING_QUOTE_CHARS or \
- value[-1:] not in cls.STRING_QUOTE_CHARS or \
- value[0:1] != value[-1:]:
- raise ValueError(f'Not a quoted string value: {value}')
- return value[1:-1]
- if type == cls.TYPE_INT:
- return int(value)
- if type == cls.TYPE_FLOAT:
- return float(value)
- if type == cls.TYPE_TIMESPAN:
- p = re.compile('^(?:[0-9]+[dhms])+$')
- if p.match(value) is None:
- raise RuntimeError("Illegal timespan value \"{value}\". Must be like \"100d\", \"5m30s\", etc.")
- p = re.compile('([0-9]+)([dhms])')
- days = 0
- hours = 0
- minutes = 0
- seconds = 0
- for m in p.finditer(value):
- scalar = int(m.group(1))
- unit = m.group(2)
- if unit == 'd':
- days = scalar
- elif unit == 'h':
- hours = scalar
- elif unit == 'm':
- minutes = scalar
- elif unit == 's':
- seconds = scalar
- return timedelta(days=days, hours=hours, minutes=minutes, seconds=seconds)
- raise ValueError(f'Unhandled datatype {datatype}')
|