| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581 |
- """
- Statements that match messages based on an expression and have a list of actions
- to take on them.
- """
- import re
- from abc import ABCMeta, abstractmethod
- from datetime import datetime
- from typing import Any
-
- from discord import Message, utils as discordutils
- from discord.ext.commands import Context
-
- from rocketbot.utils import is_user_id, str_from_quoted_str, timedelta_from_str, \
- user_id_from_mention
-
- class PatternError(RuntimeError):
- """
- Error thrown when parsing a pattern statement.
- """
-
- class PatternDeprecationError(PatternError):
- """
- Error raised by PatternStatement.check_deprecated_syntax.
- """
-
- class PatternAction:
- """
- Describes one action to take on a matched message or its author.
- """
- def __init__(self, action: str, args: list[Any]):
- self.action = action
- self.arguments = list(args)
-
- def __str__(self) -> str:
- arg_str = ', '.join(self.arguments)
- return f'{self.action}({arg_str})'
-
- class PatternExpression(metaclass=ABCMeta):
- """
- Abstract message matching expression.
- """
- def __init__(self):
- pass
-
- @abstractmethod
- def matches(self, message: Message, other_fields: dict[str, Any]) -> bool:
- """
- Whether a message matches this expression. other_fields are additional
- fields that can be queried not contained in the message itself.
- """
- return False
-
- class PatternSimpleExpression(PatternExpression):
- """
- Message matching expression with a simple "<field> <operator> <value>"
- structure.
- """
- def __init__(self, field: str, operator: str, value: Any):
- super().__init__()
- self.field: str = field
- self.operator: str = operator
- self.value: Any = value
-
- def __field_value(self, message: Message, other_fields: dict[str: Any]) -> Any:
- if self.field in ('content.markdown', 'content'):
- return message.content
- if self.field == 'content.plain':
- return discordutils.remove_markdown(message.clean_content)
- if self.field == 'author':
- return str(message.author.id)
- if self.field == 'author.id':
- return str(message.author.id)
- if self.field == 'author.joinage':
- return message.created_at - message.author.joined_at
- if self.field == 'author.name':
- return message.author.name
- if self.field == 'lastmatched':
- long_ago = datetime(year=1900, month=1, day=1, hour=0, minute=0, second=0)
- last_matched = other_fields.get('last_matched') or long_ago
- return message.created_at - last_matched
- raise ValueError(f'Bad field name {self.field}')
-
- def matches(self, message: Message, other_fields: dict[str, Any]) -> bool:
- field_value = self.__field_value(message, other_fields)
- if self.operator == '==':
- if isinstance(field_value, str) and isinstance(self.value, str):
- return field_value.lower() == self.value.lower()
- return field_value == self.value
- if self.operator == '!=':
- if isinstance(field_value, str) and isinstance(self.value, str):
- return field_value.lower() != self.value.lower()
- return field_value != self.value
- if self.operator == '<':
- return field_value < self.value
- if self.operator == '>':
- return field_value > self.value
- if self.operator == '<=':
- return field_value <= self.value
- if self.operator == '>=':
- return field_value >= self.value
- if self.operator == 'contains':
- return self.value.lower() in field_value.lower()
- if self.operator == '!contains':
- return self.value.lower() not in field_value.lower()
- if self.operator in ('matches', 'containsword'):
- return self.value.search(field_value.lower()) is not None
- if self.operator in ('!matches', '!containsword'):
- return self.value.search(field_value.lower()) is None
- raise ValueError(f'Bad operator {self.operator}')
-
- def __str__(self) -> str:
- return f'({self.field} {self.operator} {self.value})'
-
- class PatternCompoundExpression(PatternExpression):
- """
- Message matching expression that combines several child expressions with
- a boolean operator.
- """
- def __init__(self, operator: str, operands: list[PatternExpression]):
- super().__init__()
- self.operator = operator
- self.operands = list(operands)
-
- def matches(self, message: Message, other_fields: dict[str, Any]) -> bool:
- if self.operator == '!':
- return not self.operands[0].matches(message, other_fields)
- if self.operator == 'and':
- for op in self.operands:
- if not op.matches(message, other_fields):
- return False
- return True
- if self.operator == 'or':
- for op in self.operands:
- if op.matches(message, other_fields):
- return True
- return False
- raise ValueError(f'Bad operator "{self.operator}"')
-
- def __str__(self) -> str:
- if self.operator == '!':
- return f'(!( {self.operands[0]} ))'
- strs = map(str, self.operands)
- joined = f' {self.operator} '.join(strs)
- return f'( {joined} )'
-
- class PatternStatement:
- """
- A full message match statement. If a message matches the given expression,
- the given actions should be performed.
- """
-
- DEFAULT_PRIORITY: int = 100
-
- def __init__(self,
- name: str,
- actions: list[PatternAction],
- expression: PatternExpression,
- original: str,
- priority: int = DEFAULT_PRIORITY):
- self.name: str = name
- self.actions: list[PatternAction] = list(actions) # PatternAction[]
- self.expression: PatternExpression = expression
- self.original: str = original
- self.priority: int = priority
-
- def check_deprecations(self) -> None:
- """
- Tests whether this statement uses any deprecated syntax. Will raise a
- PatternDeprecationError if one is found.
- """
- self.__check_deprecations(self.expression)
-
- @classmethod
- def __check_deprecations(cls, expression: PatternExpression) -> None:
- if isinstance(expression, PatternSimpleExpression):
- s: PatternSimpleExpression = expression
- if s.field in PatternCompiler.DEPRECATED_FIELDS:
- raise PatternDeprecationError(f'"{s.field}" field is deprecated')
- elif isinstance(expression, PatternCompoundExpression):
- c: PatternCompoundExpression = expression
- for oper in c.operands:
- cls.__check_deprecations(oper)
-
- def to_json(self) -> dict[str, Any]:
- """
- Returns a JSON representation of this statement.
- """
- return {
- 'name': self.name,
- 'priority': self.priority,
- 'statement': self.original,
- }
-
- @classmethod
- def from_json(cls, json: dict[str, Any]):
- """
- Gets a PatternStatement from its JSON representation.
- """
- ps = PatternCompiler.parse_statement(json['name'], json['statement'])
- ps.priority = json.get('priority', cls.DEFAULT_PRIORITY)
- return ps
-
- class PatternCompiler:
- """
- Parses a user-provided message filter statement into a PatternStatement.
- """
- TYPE_FLOAT: str = 'float'
- TYPE_ID: str = 'id'
- TYPE_INT: str = 'int'
- TYPE_MEMBER: str = 'Member'
- TYPE_REGEX: str = 'regex'
- TYPE_TEXT: str = 'text'
- TYPE_TIMESPAN: str = 'timespan'
-
- FIELD_TO_TYPE: dict[str, str] = {
- 'author': TYPE_MEMBER,
- 'author.id': TYPE_ID,
- 'author.joinage': TYPE_TIMESPAN,
- 'author.name': TYPE_TEXT,
- 'content': TYPE_TEXT, # deprecated, use content.markdown or content.plain
- 'content.markdown': TYPE_TEXT,
- 'content.plain': TYPE_TEXT,
- 'lastmatched': TYPE_TIMESPAN,
- }
- DEPRECATED_FIELDS: set[str] = set([ 'content' ])
-
- ACTION_TO_ARGS: dict[str, list[str]] = {
- 'ban': [],
- 'delete': [],
- 'kick': [],
- 'modinfo': [],
- 'modwarn': [],
- 'reply': [ TYPE_TEXT ],
- }
-
- OPERATORS_IDENTITY: set[str] = set([ '==', '!=' ])
- OPERATORS_COMPARISON: set[str] = set([ '<', '>', '<=', '>=' ])
- OPERATORS_NUMERIC: set[str] = OPERATORS_IDENTITY | OPERATORS_COMPARISON
- OPERATORS_TEXT: set[str] = OPERATORS_IDENTITY | set([
- 'contains', '!contains',
- 'containsword', '!containsword',
- 'matches', '!matches',
- ])
- OPERATORS_ALL: set[str] = OPERATORS_IDENTITY | OPERATORS_COMPARISON | OPERATORS_TEXT
-
- TYPE_TO_OPERATORS: dict[str, set[str]] = {
- TYPE_ID: OPERATORS_IDENTITY,
- TYPE_MEMBER: OPERATORS_IDENTITY,
- TYPE_TEXT: OPERATORS_TEXT,
- TYPE_INT: OPERATORS_NUMERIC,
- TYPE_FLOAT: OPERATORS_NUMERIC,
- TYPE_TIMESPAN: OPERATORS_NUMERIC,
- }
-
- WHITESPACE_CHARS: str = ' \t\n\r'
- STRING_QUOTE_CHARS: str = '\'"'
- SYMBOL_CHARS: str = 'abcdefghijklmnopqrstuvwxyz.'
- VALUE_CHARS: str = '0123456789dhms<@!>'
- OP_CHARS: str = '<=>!(),'
-
- MAX_EXPRESSION_NESTING: int = 8
-
- @classmethod
- def expression_str_from_context(cls, context: Context, name: str) -> str:
- """
- Extracts the statement string from an "add" command context.
- """
- pattern_str: str = context.message.content
- command_chain = [ name ]
- cmd = context.command
- while cmd:
- command_chain.insert(0, cmd.name)
- cmd = cmd.parent
- command_chain[0] = f'{context.prefix}{command_chain[0]}'
- for cmd in command_chain:
- if pattern_str.startswith(cmd):
- pattern_str = pattern_str[len(cmd):].lstrip()
- elif pattern_str.startswith(f'"{cmd}"'):
- pattern_str = pattern_str[len(cmd) + 2:].lstrip()
- return pattern_str
-
- @classmethod
- def parse_statement(cls, name: str, statement: str) -> PatternStatement:
- """
- Parses a user-provided message filter statement into a PatternStatement.
- Raises PatternError on failure.
- """
- tokens: list[str] = cls.__tokenize(statement)
- token_index: int = 0
- actions, token_index = cls.__read_actions(tokens, token_index)
- expression, token_index = cls.__read_expression(tokens, token_index)
- return PatternStatement(name, actions, expression, statement)
-
- @classmethod
- def __tokenize(cls, statement: str) -> list[str]:
- """
- Converts a message filter statement into a list of tokens.
- """
- tokens: list[str] = []
- in_quote: bool = False
- in_escape: bool = False
- all_token_types: set[str] = set([ 'sym', 'op', 'val' ])
- possible_token_types: set[str] = set(all_token_types)
- current_token: str = ''
- for ch in statement:
- if in_quote:
- if in_escape:
- if ch == 'n':
- current_token += '\n'
- elif ch == 't':
- current_token += '\t'
- else:
- current_token += ch
- in_escape = False
- elif ch == '\\':
- in_escape = True
- elif ch == in_quote:
- current_token += ch
- tokens.append(current_token)
- current_token = ''
- possible_token_types |= all_token_types
- in_quote = False
- else:
- current_token += ch
- else:
- if ch in cls.STRING_QUOTE_CHARS:
- if len(current_token) > 0:
- tokens.append(current_token)
- current_token = ''
- possible_token_types |= all_token_types
- in_quote = ch
- current_token = ch
- elif ch == '\\':
- raise PatternError("Unexpected \\ outside quoted string")
- elif ch in cls.WHITESPACE_CHARS:
- if len(current_token) > 0:
- tokens.append(current_token)
- current_token = ''
- possible_token_types |= all_token_types
- else:
- possible_ch_types = set()
- if ch in cls.SYMBOL_CHARS:
- possible_ch_types.add('sym')
- if ch in cls.VALUE_CHARS:
- possible_ch_types.add('val')
- if ch in cls.OP_CHARS:
- possible_ch_types.add('op')
- if len(current_token) > 0 and \
- possible_ch_types.isdisjoint(possible_token_types):
- if len(current_token) > 0:
- tokens.append(current_token)
- current_token = ''
- possible_token_types |= all_token_types
- possible_token_types &= possible_ch_types
- current_token += ch
- if len(current_token) > 0:
- tokens.append(current_token)
-
- # Some symbols might be glommed onto other tokens. Split 'em up.
- prefixes_to_split = [ '!', '(', ',' ]
- suffixes_to_split = [ ')', ',' ]
- i = 0
- while i < len(tokens):
- token = tokens[i]
- mutated = False
- for prefix in prefixes_to_split:
- if token.startswith(prefix) and len(token) > len(prefix):
- tokens.insert(i, prefix)
- tokens[i + 1] = token[len(prefix):]
- i += 1
- mutated = True
- break
- if mutated:
- continue
- for suffix in suffixes_to_split:
- if token.endswith(suffix) and len(token) > len(suffix):
- tokens[i] = token[0:-len(suffix)]
- tokens.insert(i + 1, suffix)
- mutated = True
- break
- if mutated:
- continue
- i += 1
- return tokens
-
- @classmethod
- def __read_actions(cls,
- tokens: list[str],
- token_index: int) -> tuple[list[PatternAction], int]:
- """
- Reads the actions from a list of statement tokens. Returns a tuple
- containing a list of PatternActions and the token index this method
- left off at (the token after the "if").
- """
- actions: list[PatternAction] = []
- current_action_tokens = []
- while token_index < len(tokens):
- token = tokens[token_index]
- if token == 'if':
- if len(current_action_tokens) > 0:
- a = PatternAction(current_action_tokens[0], \
- current_action_tokens[1:])
- cls.__validate_action(a)
- actions.append(a)
- token_index += 1
- return (actions, token_index)
- elif token == ',':
- if len(current_action_tokens) < 1:
- raise PatternError('Unexpected ,')
- a = PatternAction(current_action_tokens[0], \
- current_action_tokens[1:])
- cls.__validate_action(a)
- actions.append(a)
- current_action_tokens = []
- else:
- current_action_tokens.append(token)
- token_index += 1
- raise PatternError('Unexpected end of line in action list')
-
- @classmethod
- def __validate_action(cls, action: PatternAction) -> None:
- args: list[str] = cls.ACTION_TO_ARGS.get(action.action)
- if args is None:
- raise PatternError(f'Unknown action "{action.action}"')
- if len(action.arguments) != len(args):
- if len(args) == 0:
- raise PatternError(f'Action "{action.action}" expects no ' + \
- f'arguments, got {len(action.arguments)}.')
- raise PatternError(f'Action "{action.action}" expects ' + \
- f'{len(args)} arguments, got {len(action.arguments)}.')
- for i, datatype in enumerate(args):
- action.arguments[i] = cls.__parse_value(action.arguments[i], datatype)
-
- @classmethod
- def __read_expression(cls,
- tokens: list[str],
- token_index: int,
- depth: int = 0,
- one_subexpression: bool = False) -> tuple[PatternExpression, int]:
- """
- Reads an expression from a list of statement tokens. Returns a tuple
- containing the PatternExpression and the token index it left off at.
- If one_subexpression is True then it will return after reading a
- single expression instead of joining multiples (for reading the
- subject of a NOT expression).
- """
- subexpressions = []
- last_compound_operator = None
- while token_index < len(tokens):
- if one_subexpression:
- if len(subexpressions) == 1:
- return (subexpressions[0], token_index)
- if len(subexpressions) > 1:
- raise PatternError('Too many subexpressions')
- compound_operator = None
- if tokens[token_index] == ')':
- if len(subexpressions) == 0:
- raise PatternError('No subexpressions')
- if len(subexpressions) == 1:
- return (subexpressions[0], token_index)
- return (PatternCompoundExpression(last_compound_operator,
- subexpressions), token_index)
- if tokens[token_index] in set(["and", "or"]):
- compound_operator = tokens[token_index]
- if last_compound_operator and \
- compound_operator != last_compound_operator:
- subexpressions = [
- PatternCompoundExpression(last_compound_operator,
- subexpressions),
- ]
- last_compound_operator = compound_operator
- token_index += 1
- if tokens[token_index] == '!':
- (exp, next_index) = cls.__read_expression(tokens, \
- token_index + 1, depth + 1, one_subexpression=True)
- subexpressions.append(PatternCompoundExpression('!', [exp]))
- token_index = next_index
- elif tokens[token_index] == '(':
- (exp, next_index) = cls.__read_expression(tokens,
- token_index + 1, depth + 1)
- if tokens[next_index] != ')':
- raise PatternError('Expected )')
- subexpressions.append(exp)
- token_index = next_index + 1
- else:
- (simple, next_index) = cls.__read_simple_expression(tokens,
- token_index, depth)
- subexpressions.append(simple)
- token_index = next_index
- if len(subexpressions) == 0:
- raise PatternError('No subexpressions')
- elif len(subexpressions) == 1:
- return (subexpressions[0], token_index)
- else:
- return (PatternCompoundExpression(last_compound_operator,
- subexpressions), token_index)
-
- @classmethod
- def __read_simple_expression(cls,
- tokens: list[str],
- token_index: int,
- depth: int = 0) -> tuple[PatternExpression, int]:
- """
- Reads a simple expression consisting of a field name, operator, and
- comparison value. Returns a tuple of the PatternSimpleExpression and
- the token index it left off at.
- """
- if depth > cls.MAX_EXPRESSION_NESTING:
- raise PatternError('Expression nests too deeply')
- if token_index >= len(tokens):
- raise PatternError('Expected field name, found EOL')
- field = tokens[token_index]
- token_index += 1
-
- datatype = cls.FIELD_TO_TYPE.get(field)
- if datatype is None:
- raise PatternError(f'No such field "{field}"')
-
- if token_index >= len(tokens):
- raise PatternError('Expected operator, found EOL')
- op = tokens[token_index]
- token_index += 1
-
- if op == '!':
- if token_index >= len(tokens):
- raise PatternError('Expected operator, found EOL')
- op = '!' + tokens[token_index]
- token_index += 1
-
- allowed_ops = cls.TYPE_TO_OPERATORS[datatype]
- if op not in allowed_ops:
- if op in cls.OPERATORS_ALL:
- raise PatternError(f'Operator {op} cannot be used with ' + \
- f'field "{field}"')
- raise PatternError(f'Unrecognized operator "{op}" - allowed: ' + \
- f'{sorted(list(allowed_ops))}')
-
- if token_index >= len(tokens):
- raise PatternError('Expected value, found EOL')
- value_str = tokens[token_index]
-
- try:
- value = cls.__parse_value(value_str, datatype, op)
- except ValueError as cause:
- raise PatternError(f'Bad value {value_str}') from cause
-
- token_index += 1
- exp = PatternSimpleExpression(field, op, value)
- return (exp, token_index)
-
- @classmethod
- def __parse_value(cls, value: str, datatype: str, op: str = None) -> Any:
- """
- Converts a value token to its Python value. Raises ValueError on failure.
- """
- if datatype == cls.TYPE_ID:
- if not is_user_id(value):
- raise ValueError(f'Illegal user id value: {value}')
- return value
- if datatype == cls.TYPE_MEMBER:
- return user_id_from_mention(value)
- if datatype == cls.TYPE_TEXT:
- s = str_from_quoted_str(value)
- if op in ('matches', '!matches'):
- try:
- return re.compile(s.lower())
- except re.error as e:
- raise ValueError(f'Invalid regex: {e}') from e
- if op in ('containsword', '!containsword'):
- try:
- return re.compile(f'\\b{re.escape(s.lower())}\\b')
- except re.error as e:
- raise ValueError(f'Invalid regex: {e}') from e
- return s
- if datatype == cls.TYPE_INT:
- return int(value)
- if datatype == cls.TYPE_FLOAT:
- return float(value)
- if datatype == cls.TYPE_TIMESPAN:
- return timedelta_from_str(value)
- raise ValueError(f'Unhandled datatype {datatype}')
|