""" Statements that match messages based on an expression and have a list of actions to take on them. """ import re from abc import ABCMeta, abstractmethod from datetime import datetime, timezone from typing import Any, Literal, Union from discord import Message from discord import utils as discordutils from discord.ext.commands import Context from rocketbot.utils import ( is_user_id, str_from_quoted_str, timedelta_from_str, user_id_from_mention, ) PatternField = Literal['content.markdown', 'content', 'content.plain', 'author', 'author.id', 'author.joinage', 'author.name', 'lastmatched'] PatternComparisonOperator = Literal['==', '!=', '<', '>', '<=', '>=', 'contains', '!contains', 'matches', '!matches', 'containsword', '!containsword'] PatternBooleanOperator = Literal['!', 'and', 'or'] PatternActionType = Literal['ban', 'delete', 'kick', 'modinfo', 'modwarn', 'reply'] class PatternError(RuntimeError): """ Error thrown when parsing a pattern statement. """ class PatternDeprecationError(PatternError): """ Error raised by PatternStatement.check_deprecated_syntax. """ class PatternAction: """ Describes one action to take on a matched message or its author. """ TYPE_BAN: PatternActionType = 'ban' TYPE_DELETE: PatternActionType = 'delete' TYPE_KICK: PatternActionType = 'kick' TYPE_INFORM_MODS: PatternActionType = 'modinfo' TYPE_WARN_MODS: PatternActionType = 'modwarn' TYPE_REPLY: PatternActionType = 'reply' def __init__(self, action: str, args: list[Any]): self.action = action self.arguments = list(args) def __str__(self) -> str: arg_str = ', '.join(self.arguments) return f'{self.action}({arg_str})' class PatternExpression(metaclass=ABCMeta): """ Abstract message matching expression. """ def __init__(self): pass @abstractmethod def matches(self, message: Message, other_fields: dict[str, Any]) -> bool: """ Whether a message matches this expression. other_fields are additional fields that can be queried not contained in the message itself. """ return False class PatternSimpleExpression(PatternExpression): """ Message matching expression with a simple " " structure. """ FIELD_CONTENT_MARKDOWN: PatternField = 'content.markdown' FIELD_CONTENT_PLAIN: PatternField = 'content.plain' FIELD_AUTHOR_ID: PatternField = 'author.id' FIELD_AUTHOR_JOINAGE: PatternField = 'author.joinage' FIELD_AUTHOR_NAME: PatternField = 'author.name' FIELD_LAST_MATCHED: PatternField = 'lastmatched' # Less preferred but recognized field aliases ALIAS_FIELD_CONTENT_MARKDOWN: PatternField = 'content' ALIAS_FIELD_AUTHOR_ID: PatternField = 'author' OP_EQUALS: PatternComparisonOperator = '==' OP_NOT_EQUALS: PatternComparisonOperator = '!=' OP_LESS_THAN: PatternComparisonOperator = '<' OP_GREATER_THAN: PatternComparisonOperator = '>' OP_LESS_THAN_OR_EQUALS: PatternComparisonOperator = '<=' OP_GREATER_THAN_OR_EQUALS: PatternComparisonOperator = '>=' OP_CONTAINS: PatternComparisonOperator = 'contains' OP_NOT_CONTAINS: PatternComparisonOperator = '!contains' OP_MATCHES: PatternComparisonOperator = 'matches' OP_NOT_MATCHES: PatternComparisonOperator = '!matches' OP_CONTAINS_WORD: PatternComparisonOperator = 'containsword' OP_NOT_CONTAINS_WORD: PatternComparisonOperator = '!containsword' def __init__(self, field: PatternField, operator: PatternComparisonOperator, value: Any): super().__init__() self.field: PatternField = field self.operator: PatternComparisonOperator = operator self.value: Any = value def __field_value(self, message: Message, other_fields: dict[str, Any]) -> Any: cls = PatternSimpleExpression if self.field in (cls.FIELD_CONTENT_MARKDOWN, cls.ALIAS_FIELD_CONTENT_MARKDOWN): return message.content if self.field == cls.FIELD_CONTENT_PLAIN: return discordutils.remove_markdown(message.clean_content) if self.field in (cls.FIELD_AUTHOR_ID, cls.ALIAS_FIELD_AUTHOR_ID): return str(message.author.id) if self.field == cls.FIELD_AUTHOR_JOINAGE: return message.created_at - message.author.joined_at if self.field == cls.FIELD_AUTHOR_NAME: return message.author.name if self.field == cls.FIELD_LAST_MATCHED: long_ago = datetime(year=1900, month=1, day=1, hour=0, minute=0, second=0, tzinfo=timezone.utc) last_matched = other_fields.get('last_matched') or long_ago return message.created_at - last_matched raise ValueError(f'Bad field name "{self.field}"') def matches(self, message: Message, other_fields: dict[str, Any]) -> bool: cls = PatternSimpleExpression field_value = self.__field_value(message, other_fields) if self.operator == cls.OP_EQUALS: if isinstance(field_value, str) and isinstance(self.value, str): return field_value.lower() == self.value.lower() return field_value == self.value if self.operator == cls.OP_NOT_EQUALS: if isinstance(field_value, str) and isinstance(self.value, str): return field_value.lower() != self.value.lower() return field_value != self.value if self.operator == cls.OP_LESS_THAN: return field_value < self.value if self.operator == cls.OP_GREATER_THAN: return field_value > self.value if self.operator == cls.OP_LESS_THAN_OR_EQUALS: return field_value <= self.value if self.operator == cls.OP_GREATER_THAN_OR_EQUALS: return field_value >= self.value if self.operator == cls.OP_CONTAINS: return self.value.lower() in field_value.lower() if self.operator == cls.OP_NOT_CONTAINS: return self.value.lower() not in field_value.lower() if self.operator in (cls.OP_MATCHES, cls.OP_CONTAINS_WORD): return self.value.search(field_value.lower()) is not None if self.operator in (cls.OP_NOT_MATCHES, cls.OP_NOT_CONTAINS_WORD): return self.value.search(field_value.lower()) is None raise ValueError(f'Bad operator {self.operator}') def __str__(self) -> str: return f'({self.field} {self.operator} {self.value})' class PatternCompoundExpression(PatternExpression): """ Message matching expression that combines several child expressions with a boolean operator. """ OP_NOT = '!' OP_AND = 'and' OP_OR = 'or' def __init__(self, operator: PatternBooleanOperator, operands: list[PatternExpression]): super().__init__() self.operator: PatternBooleanOperator = operator self.operands = list(operands) def matches(self, message: Message, other_fields: dict[str, Any]) -> bool: if self.operator == PatternCompoundExpression.OP_NOT: return not self.operands[0].matches(message, other_fields) if self.operator == PatternCompoundExpression.OP_AND: for op in self.operands: if not op.matches(message, other_fields): return False return True if self.operator == PatternCompoundExpression.OP_OR: for op in self.operands: if op.matches(message, other_fields): return True return False raise ValueError(f'Bad operator "{self.operator}"') def __str__(self) -> str: if self.operator == PatternCompoundExpression.OP_NOT: return f'(!( {self.operands[0]} ))' strs = map(str, self.operands) joined = f' {self.operator} '.join(strs) return f'( {joined} )' class PatternStatement: """ A full message match statement. If a message matches the given expression, the given actions should be performed. """ DEFAULT_PRIORITY: int = 100 def __init__(self, name: str, actions: list[PatternAction], expression: PatternExpression, original: str, priority: int = DEFAULT_PRIORITY): self.name: str = name self.actions: list[PatternAction] = list(actions) # PatternAction[] self.expression: PatternExpression = expression self.original: str = original self.priority: int = priority def check_deprecations(self) -> None: """ Tests whether this statement uses any deprecated syntax. Will raise a PatternDeprecationError if one is found. """ self.__check_deprecations(self.expression) @classmethod def __check_deprecations(cls, expression: PatternExpression) -> None: if isinstance(expression, PatternSimpleExpression): s: PatternSimpleExpression = expression if s.field in PatternCompiler.DEPRECATED_FIELDS: raise PatternDeprecationError(f'"{s.field}" field is deprecated') elif isinstance(expression, PatternCompoundExpression): c: PatternCompoundExpression = expression for oper in c.operands: cls.__check_deprecations(oper) def to_json(self) -> dict[str, Any]: """ Returns a JSON representation of this statement. """ return { 'name': self.name, 'priority': self.priority, 'statement': self.original, } @classmethod def from_json(cls, json: dict[str, Any]): """ Gets a PatternStatement from its JSON representation. """ ps = PatternCompiler.parse_statement(json['name'], json['statement']) ps.priority = json.get('priority', cls.DEFAULT_PRIORITY) return ps class PatternCompiler: """ Parses a user-provided message filter statement into a PatternStatement. """ DATATYPE_FLOAT: str = 'float' DATATYPE_ID: str = 'id' DATATYPE_INT: str = 'int' DATATYPE_MEMBER: str = 'Member' DATATYPE_REGEX: str = 'regex' DATATYPE_TEXT: str = 'text' DATATYPE_TIMESPAN: str = 'timespan' FIELD_TO_DATATYPE: dict[PatternField, str] = { PatternSimpleExpression.ALIAS_FIELD_AUTHOR_ID: DATATYPE_MEMBER, PatternSimpleExpression.FIELD_AUTHOR_ID: DATATYPE_ID, PatternSimpleExpression.FIELD_AUTHOR_JOINAGE: DATATYPE_TIMESPAN, PatternSimpleExpression.FIELD_AUTHOR_NAME: DATATYPE_TEXT, PatternSimpleExpression.ALIAS_FIELD_CONTENT_MARKDOWN: DATATYPE_TEXT, # deprecated, use content.markdown or content.plain PatternSimpleExpression.FIELD_CONTENT_MARKDOWN: DATATYPE_TEXT, PatternSimpleExpression.FIELD_CONTENT_PLAIN: DATATYPE_TEXT, PatternSimpleExpression.FIELD_LAST_MATCHED: DATATYPE_TIMESPAN, } DEPRECATED_FIELDS: set[PatternField] = { 'content' } ACTION_TO_ARGS: dict[PatternActionType, list[str]] = { PatternAction.TYPE_BAN: [], PatternAction.TYPE_DELETE: [], PatternAction.TYPE_KICK: [], PatternAction.TYPE_INFORM_MODS: [], PatternAction.TYPE_WARN_MODS: [], PatternAction.TYPE_REPLY: [ DATATYPE_TEXT ], } OPERATORS_IDENTITY: set[PatternComparisonOperator] = { PatternSimpleExpression.OP_EQUALS, PatternSimpleExpression.OP_NOT_EQUALS, } OPERATORS_COMPARISON: set[PatternComparisonOperator] = { PatternSimpleExpression.OP_LESS_THAN, PatternSimpleExpression.OP_GREATER_THAN, PatternSimpleExpression.OP_LESS_THAN_OR_EQUALS, PatternSimpleExpression.OP_GREATER_THAN_OR_EQUALS, } OPERATORS_NUMERIC: set[PatternComparisonOperator] = OPERATORS_IDENTITY | OPERATORS_COMPARISON OPERATORS_TEXT: set[PatternComparisonOperator] = OPERATORS_IDENTITY | { PatternSimpleExpression.OP_CONTAINS, PatternSimpleExpression.OP_NOT_CONTAINS, PatternSimpleExpression.OP_CONTAINS_WORD, PatternSimpleExpression.OP_NOT_CONTAINS_WORD, PatternSimpleExpression.OP_MATCHES, PatternSimpleExpression.OP_NOT_MATCHES, } OPERATORS_ALL: set[str] = OPERATORS_IDENTITY | OPERATORS_COMPARISON | OPERATORS_TEXT DATATYPE_TO_OPERATORS: dict[str, set[PatternComparisonOperator]] = { DATATYPE_ID: OPERATORS_IDENTITY, DATATYPE_MEMBER: OPERATORS_IDENTITY, DATATYPE_TEXT: OPERATORS_TEXT, DATATYPE_INT: OPERATORS_NUMERIC, DATATYPE_FLOAT: OPERATORS_NUMERIC, DATATYPE_TIMESPAN: OPERATORS_NUMERIC, } WHITESPACE_CHARS: str = ' \t\n\r' STRING_QUOTE_CHARS: str = '\'"' SYMBOL_CHARS: str = 'abcdefghijklmnopqrstuvwxyz.' VALUE_CHARS: str = '0123456789dhms<@!>' OP_CHARS: str = '<=>!(),' MAX_EXPRESSION_NESTING: int = 8 @classmethod def expression_str_from_context(cls, context: Context, name: str) -> str: """ Extracts the statement string from an "add" command context. """ pattern_str: str = context.message.content command_chain = [ name ] cmd = context.command while cmd: command_chain.insert(0, cmd.name) cmd = cmd.parent command_chain[0] = f'{context.prefix}{command_chain[0]}' for cmd in command_chain: if pattern_str.startswith(cmd): pattern_str = pattern_str[len(cmd):].lstrip() elif pattern_str.startswith(f'"{cmd}"'): pattern_str = pattern_str[len(cmd) + 2:].lstrip() return pattern_str @classmethod def parse_statement(cls, name: str, statement: str) -> PatternStatement: """ Parses a user-provided message filter statement into a PatternStatement. Raises PatternError on failure. """ tokens: list[str] = cls.__tokenize(statement) token_index: int = 0 actions, token_index = cls.__read_actions(tokens, token_index) expression, token_index = cls.__read_expression(tokens, token_index) return PatternStatement(name, actions, expression, statement) @classmethod def __tokenize(cls, statement: str) -> list[str]: """ Converts a message filter statement into a list of tokens. """ tokens: list[str] = [] in_quote: Union[bool, str] = False in_escape: bool = False all_token_types: set[str] = { 'sym', 'op', 'val' } possible_token_types: set[str] = set(all_token_types) current_token: str = '' for ch in statement: if in_quote: if in_escape: if ch == 'n': current_token += '\n' elif ch == 't': current_token += '\t' else: current_token += ch in_escape = False elif ch == '\\': in_escape = True elif ch == in_quote: current_token += ch tokens.append(current_token) current_token = '' possible_token_types |= all_token_types in_quote = False else: current_token += ch else: if ch in cls.STRING_QUOTE_CHARS: if len(current_token) > 0: tokens.append(current_token) possible_token_types |= all_token_types in_quote = ch current_token = ch elif ch == '\\': raise PatternError("Unexpected \\ outside quoted string") elif ch in cls.WHITESPACE_CHARS: if len(current_token) > 0: tokens.append(current_token) current_token = '' possible_token_types |= all_token_types else: possible_ch_types = set() if ch in cls.SYMBOL_CHARS: possible_ch_types.add('sym') if ch in cls.VALUE_CHARS: possible_ch_types.add('val') if ch in cls.OP_CHARS: possible_ch_types.add('op') if len(current_token) > 0 and \ possible_ch_types.isdisjoint(possible_token_types): if len(current_token) > 0: tokens.append(current_token) current_token = '' possible_token_types |= all_token_types possible_token_types &= possible_ch_types current_token += ch if len(current_token) > 0: tokens.append(current_token) # Some symbols might be glommed onto other tokens. Split 'em up. prefixes_to_split = [ '!', '(', ',' ] suffixes_to_split = [ ')', ',' ] i = 0 while i < len(tokens): token = tokens[i] mutated = False for prefix in prefixes_to_split: if token.startswith(prefix) and len(token) > len(prefix): tokens.insert(i, prefix) tokens[i + 1] = token[len(prefix):] i += 1 mutated = True break if mutated: continue for suffix in suffixes_to_split: if token.endswith(suffix) and len(token) > len(suffix): tokens[i] = token[0:-len(suffix)] tokens.insert(i + 1, suffix) mutated = True break if mutated: continue i += 1 return tokens @classmethod def __read_actions(cls, tokens: list[str], token_index: int) -> tuple[list[PatternAction], int]: """ Reads the actions from a list of statement tokens. Returns a tuple containing a list of PatternActions and the token index this method left off at (the token after the "if"). """ actions: list[PatternAction] = [] current_action_tokens = [] while token_index < len(tokens): token = tokens[token_index] if token == 'if': if len(current_action_tokens) > 0: a = PatternAction(current_action_tokens[0], current_action_tokens[1:]) cls.__validate_action(a) actions.append(a) token_index += 1 return actions, token_index elif token == ',': if len(current_action_tokens) < 1: raise PatternError('Unexpected ,') a = PatternAction(current_action_tokens[0], current_action_tokens[1:]) cls.__validate_action(a) actions.append(a) current_action_tokens = [] else: current_action_tokens.append(token) token_index += 1 raise PatternError('Unexpected end of line in action list') @classmethod def __validate_action(cls, action: PatternAction) -> None: args: list[str] = cls.ACTION_TO_ARGS.get(action.action) if args is None: raise PatternError(f'Unknown action "{action.action}"') if len(action.arguments) != len(args): if len(args) == 0: raise PatternError(f'Action "{action.action}" expects no ' + \ f'arguments, got {len(action.arguments)}.') raise PatternError(f'Action "{action.action}" expects ' + \ f'{len(args)} arguments, got {len(action.arguments)}.') for i, datatype in enumerate(args): action.arguments[i] = cls.__parse_value(action.arguments[i], datatype) @classmethod def __read_expression(cls, tokens: list[str], token_index: int, depth: int = 0, one_subexpression: bool = False) -> tuple[PatternExpression, int]: """ Reads an expression from a list of statement tokens. Returns a tuple containing the PatternExpression and the token index it left off at. If one_subexpression is True then it will return after reading a single expression instead of joining multiples (for reading the subject of a NOT expression). """ subexpressions = [] last_compound_operator = None while token_index < len(tokens): if one_subexpression: if len(subexpressions) == 1: return subexpressions[0], token_index if len(subexpressions) > 1: raise PatternError('Too many subexpressions') if tokens[token_index] == ')': if len(subexpressions) == 0: raise PatternError('No subexpressions') if len(subexpressions) == 1: return subexpressions[0], token_index return (PatternCompoundExpression(last_compound_operator, subexpressions), token_index) if tokens[token_index] in { PatternCompoundExpression.OP_AND, PatternCompoundExpression.OP_OR }: compound_operator = tokens[token_index] if last_compound_operator and \ compound_operator != last_compound_operator: subexpressions = [ PatternCompoundExpression(last_compound_operator, subexpressions), ] last_compound_operator = compound_operator token_index += 1 if tokens[token_index] == PatternCompoundExpression.OP_NOT: (exp, next_index) = cls.__read_expression(tokens, token_index + 1, depth + 1, one_subexpression=True) subexpressions.append(PatternCompoundExpression('!', [exp])) token_index = next_index elif tokens[token_index] == '(': (exp, next_index) = cls.__read_expression(tokens, token_index + 1, depth + 1) if tokens[next_index] != ')': raise PatternError('Expected )') subexpressions.append(exp) token_index = next_index + 1 else: (simple, next_index) = cls.__read_simple_expression(tokens, token_index, depth) subexpressions.append(simple) token_index = next_index if len(subexpressions) == 0: raise PatternError('No subexpressions') elif len(subexpressions) == 1: return subexpressions[0], token_index else: return PatternCompoundExpression(last_compound_operator, subexpressions), token_index @classmethod def __read_simple_expression(cls, tokens: list[str], token_index: int, depth: int = 0) -> tuple[PatternExpression, int]: """ Reads a simple expression consisting of a field name, operator, and comparison value. Returns a tuple of the PatternSimpleExpression and the token index it left off at. """ if depth > cls.MAX_EXPRESSION_NESTING: raise PatternError('Expression nests too deeply') if token_index >= len(tokens): raise PatternError('Expected field name, found EOL') field: PatternField = tokens[token_index] token_index += 1 datatype = cls.FIELD_TO_DATATYPE.get(field, None) if datatype is None: raise PatternError(f'No such field "{field}"') if token_index >= len(tokens): raise PatternError('Expected operator, found EOL') op = tokens[token_index] token_index += 1 if op == PatternCompoundExpression.OP_NOT: if token_index >= len(tokens): raise PatternError('Expected operator, found EOL') op = '!' + tokens[token_index] token_index += 1 allowed_ops = cls.DATATYPE_TO_OPERATORS[datatype] if op not in allowed_ops: if op in cls.OPERATORS_ALL: raise PatternError(f'Operator {op} cannot be used with ' + \ f'field "{field}"') raise PatternError(f'Unrecognized operator "{op}" - allowed: ' + \ f'{sorted(list(allowed_ops))}') if token_index >= len(tokens): raise PatternError('Expected value, found EOL') value_str = tokens[token_index] try: value = cls.__parse_value(value_str, datatype, op) except ValueError as cause: raise PatternError(f'Bad value {value_str}') from cause token_index += 1 exp = PatternSimpleExpression(field, op, value) return exp, token_index @classmethod def __parse_value(cls, value: str, datatype: str, op: str = None) -> Any: """ Converts a value token to its Python value. Raises ValueError on failure. """ if datatype == cls.DATATYPE_ID: if not is_user_id(value): raise ValueError(f'Illegal user id value: {value}') return value if datatype == cls.DATATYPE_MEMBER: return user_id_from_mention(value) if datatype == cls.DATATYPE_TEXT: s = str_from_quoted_str(value) if op in ('matches', '!matches'): try: return re.compile(s.lower()) except re.error as e: raise ValueError(f'Invalid regex: {e}') from e if op in ('containsword', '!containsword'): try: return re.compile(f'\\b{re.escape(s.lower())}\\b') except re.error as e: raise ValueError(f'Invalid regex: {e}') from e return s if datatype == cls.DATATYPE_INT: return int(value) if datatype == cls.DATATYPE_FLOAT: return float(value) if datatype == cls.DATATYPE_TIMESPAN: return timedelta_from_str(value) raise ValueError(f'Unhandled datatype {datatype}')