Experimental Discord bot written in Python
Du kannst nicht mehr als 25 Themen auswählen Themen müssen mit entweder einem Buchstaben oder einer Ziffer beginnen. Sie können Bindestriche („-“) enthalten und bis zu 35 Zeichen lang sein.

pattern.py 18KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579
  1. """
  2. Statements that match messages based on an expression and have a list of actions
  3. to take on them.
  4. """
  5. import re
  6. from abc import ABCMeta, abstractmethod
  7. from datetime import datetime, timezone
  8. from typing import Any, Union
  9. from discord import Message, utils as discordutils
  10. from discord.ext.commands import Context
  11. from rocketbot.utils import is_user_id, str_from_quoted_str, timedelta_from_str, \
  12. user_id_from_mention
  13. class PatternError(RuntimeError):
  14. """
  15. Error thrown when parsing a pattern statement.
  16. """
  17. class PatternDeprecationError(PatternError):
  18. """
  19. Error raised by PatternStatement.check_deprecated_syntax.
  20. """
  21. class PatternAction:
  22. """
  23. Describes one action to take on a matched message or its author.
  24. """
  25. def __init__(self, action: str, args: list[Any]):
  26. self.action = action
  27. self.arguments = list(args)
  28. def __str__(self) -> str:
  29. arg_str = ', '.join(self.arguments)
  30. return f'{self.action}({arg_str})'
  31. class PatternExpression(metaclass=ABCMeta):
  32. """
  33. Abstract message matching expression.
  34. """
  35. def __init__(self):
  36. pass
  37. @abstractmethod
  38. def matches(self, message: Message, other_fields: dict[str, Any]) -> bool:
  39. """
  40. Whether a message matches this expression. other_fields are additional
  41. fields that can be queried not contained in the message itself.
  42. """
  43. return False
  44. class PatternSimpleExpression(PatternExpression):
  45. """
  46. Message matching expression with a simple "<field> <operator> <value>"
  47. structure.
  48. """
  49. def __init__(self, field: str, operator: str, value: Any):
  50. super().__init__()
  51. self.field: str = field
  52. self.operator: str = operator
  53. self.value: Any = value
  54. def __field_value(self, message: Message, other_fields: dict[str, Any]) -> Any:
  55. if self.field in ('content.markdown', 'content'):
  56. return message.content
  57. if self.field == 'content.plain':
  58. return discordutils.remove_markdown(message.clean_content)
  59. if self.field == 'author':
  60. return str(message.author.id)
  61. if self.field == 'author.id':
  62. return str(message.author.id)
  63. if self.field == 'author.joinage':
  64. return message.created_at - message.author.joined_at
  65. if self.field == 'author.name':
  66. return message.author.name
  67. if self.field == 'lastmatched':
  68. long_ago = datetime(year=1900, month=1, day=1, hour=0, minute=0, second=0, tzinfo=timezone.utc)
  69. last_matched = other_fields.get('last_matched') or long_ago
  70. return message.created_at - last_matched
  71. raise ValueError(f'Bad field name {self.field}')
  72. def matches(self, message: Message, other_fields: dict[str, Any]) -> bool:
  73. field_value = self.__field_value(message, other_fields)
  74. if self.operator == '==':
  75. if isinstance(field_value, str) and isinstance(self.value, str):
  76. return field_value.lower() == self.value.lower()
  77. return field_value == self.value
  78. if self.operator == '!=':
  79. if isinstance(field_value, str) and isinstance(self.value, str):
  80. return field_value.lower() != self.value.lower()
  81. return field_value != self.value
  82. if self.operator == '<':
  83. return field_value < self.value
  84. if self.operator == '>':
  85. return field_value > self.value
  86. if self.operator == '<=':
  87. return field_value <= self.value
  88. if self.operator == '>=':
  89. return field_value >= self.value
  90. if self.operator == 'contains':
  91. return self.value.lower() in field_value.lower()
  92. if self.operator == '!contains':
  93. return self.value.lower() not in field_value.lower()
  94. if self.operator in ('matches', 'containsword'):
  95. return self.value.search(field_value.lower()) is not None
  96. if self.operator in ('!matches', '!containsword'):
  97. return self.value.search(field_value.lower()) is None
  98. raise ValueError(f'Bad operator {self.operator}')
  99. def __str__(self) -> str:
  100. return f'({self.field} {self.operator} {self.value})'
  101. class PatternCompoundExpression(PatternExpression):
  102. """
  103. Message matching expression that combines several child expressions with
  104. a boolean operator.
  105. """
  106. def __init__(self, operator: str, operands: list[PatternExpression]):
  107. super().__init__()
  108. self.operator = operator
  109. self.operands = list(operands)
  110. def matches(self, message: Message, other_fields: dict[str, Any]) -> bool:
  111. if self.operator == '!':
  112. return not self.operands[0].matches(message, other_fields)
  113. if self.operator == 'and':
  114. for op in self.operands:
  115. if not op.matches(message, other_fields):
  116. return False
  117. return True
  118. if self.operator == 'or':
  119. for op in self.operands:
  120. if op.matches(message, other_fields):
  121. return True
  122. return False
  123. raise ValueError(f'Bad operator "{self.operator}"')
  124. def __str__(self) -> str:
  125. if self.operator == '!':
  126. return f'(!( {self.operands[0]} ))'
  127. strs = map(str, self.operands)
  128. joined = f' {self.operator} '.join(strs)
  129. return f'( {joined} )'
  130. class PatternStatement:
  131. """
  132. A full message match statement. If a message matches the given expression,
  133. the given actions should be performed.
  134. """
  135. DEFAULT_PRIORITY: int = 100
  136. def __init__(self,
  137. name: str,
  138. actions: list[PatternAction],
  139. expression: PatternExpression,
  140. original: str,
  141. priority: int = DEFAULT_PRIORITY):
  142. self.name: str = name
  143. self.actions: list[PatternAction] = list(actions) # PatternAction[]
  144. self.expression: PatternExpression = expression
  145. self.original: str = original
  146. self.priority: int = priority
  147. def check_deprecations(self) -> None:
  148. """
  149. Tests whether this statement uses any deprecated syntax. Will raise a
  150. PatternDeprecationError if one is found.
  151. """
  152. self.__check_deprecations(self.expression)
  153. @classmethod
  154. def __check_deprecations(cls, expression: PatternExpression) -> None:
  155. if isinstance(expression, PatternSimpleExpression):
  156. s: PatternSimpleExpression = expression
  157. if s.field in PatternCompiler.DEPRECATED_FIELDS:
  158. raise PatternDeprecationError(f'"{s.field}" field is deprecated')
  159. elif isinstance(expression, PatternCompoundExpression):
  160. c: PatternCompoundExpression = expression
  161. for oper in c.operands:
  162. cls.__check_deprecations(oper)
  163. def to_json(self) -> dict[str, Any]:
  164. """
  165. Returns a JSON representation of this statement.
  166. """
  167. return {
  168. 'name': self.name,
  169. 'priority': self.priority,
  170. 'statement': self.original,
  171. }
  172. @classmethod
  173. def from_json(cls, json: dict[str, Any]):
  174. """
  175. Gets a PatternStatement from its JSON representation.
  176. """
  177. ps = PatternCompiler.parse_statement(json['name'], json['statement'])
  178. ps.priority = json.get('priority', cls.DEFAULT_PRIORITY)
  179. return ps
  180. class PatternCompiler:
  181. """
  182. Parses a user-provided message filter statement into a PatternStatement.
  183. """
  184. TYPE_FLOAT: str = 'float'
  185. TYPE_ID: str = 'id'
  186. TYPE_INT: str = 'int'
  187. TYPE_MEMBER: str = 'Member'
  188. TYPE_REGEX: str = 'regex'
  189. TYPE_TEXT: str = 'text'
  190. TYPE_TIMESPAN: str = 'timespan'
  191. FIELD_TO_TYPE: dict[str, str] = {
  192. 'author': TYPE_MEMBER,
  193. 'author.id': TYPE_ID,
  194. 'author.joinage': TYPE_TIMESPAN,
  195. 'author.name': TYPE_TEXT,
  196. 'content': TYPE_TEXT, # deprecated, use content.markdown or content.plain
  197. 'content.markdown': TYPE_TEXT,
  198. 'content.plain': TYPE_TEXT,
  199. 'lastmatched': TYPE_TIMESPAN,
  200. }
  201. DEPRECATED_FIELDS: set[str] = { 'content' }
  202. ACTION_TO_ARGS: dict[str, list[str]] = {
  203. 'ban': [],
  204. 'delete': [],
  205. 'kick': [],
  206. 'modinfo': [],
  207. 'modwarn': [],
  208. 'reply': [ TYPE_TEXT ],
  209. }
  210. OPERATORS_IDENTITY: set[str] = { '==', '!=' }
  211. OPERATORS_COMPARISON: set[str] = { '<', '>', '<=', '>=' }
  212. OPERATORS_NUMERIC: set[str] = OPERATORS_IDENTITY | OPERATORS_COMPARISON
  213. OPERATORS_TEXT: set[str] = OPERATORS_IDENTITY | {
  214. 'contains', '!contains',
  215. 'containsword', '!containsword',
  216. 'matches', '!matches',
  217. }
  218. OPERATORS_ALL: set[str] = OPERATORS_IDENTITY | OPERATORS_COMPARISON | OPERATORS_TEXT
  219. TYPE_TO_OPERATORS: dict[str, set[str]] = {
  220. TYPE_ID: OPERATORS_IDENTITY,
  221. TYPE_MEMBER: OPERATORS_IDENTITY,
  222. TYPE_TEXT: OPERATORS_TEXT,
  223. TYPE_INT: OPERATORS_NUMERIC,
  224. TYPE_FLOAT: OPERATORS_NUMERIC,
  225. TYPE_TIMESPAN: OPERATORS_NUMERIC,
  226. }
  227. WHITESPACE_CHARS: str = ' \t\n\r'
  228. STRING_QUOTE_CHARS: str = '\'"'
  229. SYMBOL_CHARS: str = 'abcdefghijklmnopqrstuvwxyz.'
  230. VALUE_CHARS: str = '0123456789dhms<@!>'
  231. OP_CHARS: str = '<=>!(),'
  232. MAX_EXPRESSION_NESTING: int = 8
  233. @classmethod
  234. def expression_str_from_context(cls, context: Context, name: str) -> str:
  235. """
  236. Extracts the statement string from an "add" command context.
  237. """
  238. pattern_str: str = context.message.content
  239. command_chain = [ name ]
  240. cmd = context.command
  241. while cmd:
  242. command_chain.insert(0, cmd.name)
  243. cmd = cmd.parent
  244. command_chain[0] = f'{context.prefix}{command_chain[0]}'
  245. for cmd in command_chain:
  246. if pattern_str.startswith(cmd):
  247. pattern_str = pattern_str[len(cmd):].lstrip()
  248. elif pattern_str.startswith(f'"{cmd}"'):
  249. pattern_str = pattern_str[len(cmd) + 2:].lstrip()
  250. return pattern_str
  251. @classmethod
  252. def parse_statement(cls, name: str, statement: str) -> PatternStatement:
  253. """
  254. Parses a user-provided message filter statement into a PatternStatement.
  255. Raises PatternError on failure.
  256. """
  257. tokens: list[str] = cls.__tokenize(statement)
  258. token_index: int = 0
  259. actions, token_index = cls.__read_actions(tokens, token_index)
  260. expression, token_index = cls.__read_expression(tokens, token_index)
  261. return PatternStatement(name, actions, expression, statement)
  262. @classmethod
  263. def __tokenize(cls, statement: str) -> list[str]:
  264. """
  265. Converts a message filter statement into a list of tokens.
  266. """
  267. tokens: list[str] = []
  268. in_quote: Union[bool, str] = False
  269. in_escape: bool = False
  270. all_token_types: set[str] = { 'sym', 'op', 'val' }
  271. possible_token_types: set[str] = set(all_token_types)
  272. current_token: str = ''
  273. for ch in statement:
  274. if in_quote:
  275. if in_escape:
  276. if ch == 'n':
  277. current_token += '\n'
  278. elif ch == 't':
  279. current_token += '\t'
  280. else:
  281. current_token += ch
  282. in_escape = False
  283. elif ch == '\\':
  284. in_escape = True
  285. elif ch == in_quote:
  286. current_token += ch
  287. tokens.append(current_token)
  288. current_token = ''
  289. possible_token_types |= all_token_types
  290. in_quote = False
  291. else:
  292. current_token += ch
  293. else:
  294. if ch in cls.STRING_QUOTE_CHARS:
  295. if len(current_token) > 0:
  296. tokens.append(current_token)
  297. possible_token_types |= all_token_types
  298. in_quote = ch
  299. current_token = ch
  300. elif ch == '\\':
  301. raise PatternError("Unexpected \\ outside quoted string")
  302. elif ch in cls.WHITESPACE_CHARS:
  303. if len(current_token) > 0:
  304. tokens.append(current_token)
  305. current_token = ''
  306. possible_token_types |= all_token_types
  307. else:
  308. possible_ch_types = set()
  309. if ch in cls.SYMBOL_CHARS:
  310. possible_ch_types.add('sym')
  311. if ch in cls.VALUE_CHARS:
  312. possible_ch_types.add('val')
  313. if ch in cls.OP_CHARS:
  314. possible_ch_types.add('op')
  315. if len(current_token) > 0 and \
  316. possible_ch_types.isdisjoint(possible_token_types):
  317. if len(current_token) > 0:
  318. tokens.append(current_token)
  319. current_token = ''
  320. possible_token_types |= all_token_types
  321. possible_token_types &= possible_ch_types
  322. current_token += ch
  323. if len(current_token) > 0:
  324. tokens.append(current_token)
  325. # Some symbols might be glommed onto other tokens. Split 'em up.
  326. prefixes_to_split = [ '!', '(', ',' ]
  327. suffixes_to_split = [ ')', ',' ]
  328. i = 0
  329. while i < len(tokens):
  330. token = tokens[i]
  331. mutated = False
  332. for prefix in prefixes_to_split:
  333. if token.startswith(prefix) and len(token) > len(prefix):
  334. tokens.insert(i, prefix)
  335. tokens[i + 1] = token[len(prefix):]
  336. i += 1
  337. mutated = True
  338. break
  339. if mutated:
  340. continue
  341. for suffix in suffixes_to_split:
  342. if token.endswith(suffix) and len(token) > len(suffix):
  343. tokens[i] = token[0:-len(suffix)]
  344. tokens.insert(i + 1, suffix)
  345. mutated = True
  346. break
  347. if mutated:
  348. continue
  349. i += 1
  350. return tokens
  351. @classmethod
  352. def __read_actions(cls,
  353. tokens: list[str],
  354. token_index: int) -> tuple[list[PatternAction], int]:
  355. """
  356. Reads the actions from a list of statement tokens. Returns a tuple
  357. containing a list of PatternActions and the token index this method
  358. left off at (the token after the "if").
  359. """
  360. actions: list[PatternAction] = []
  361. current_action_tokens = []
  362. while token_index < len(tokens):
  363. token = tokens[token_index]
  364. if token == 'if':
  365. if len(current_action_tokens) > 0:
  366. a = PatternAction(current_action_tokens[0],
  367. current_action_tokens[1:])
  368. cls.__validate_action(a)
  369. actions.append(a)
  370. token_index += 1
  371. return actions, token_index
  372. elif token == ',':
  373. if len(current_action_tokens) < 1:
  374. raise PatternError('Unexpected ,')
  375. a = PatternAction(current_action_tokens[0],
  376. current_action_tokens[1:])
  377. cls.__validate_action(a)
  378. actions.append(a)
  379. current_action_tokens = []
  380. else:
  381. current_action_tokens.append(token)
  382. token_index += 1
  383. raise PatternError('Unexpected end of line in action list')
  384. @classmethod
  385. def __validate_action(cls, action: PatternAction) -> None:
  386. args: list[str] = cls.ACTION_TO_ARGS.get(action.action)
  387. if args is None:
  388. raise PatternError(f'Unknown action "{action.action}"')
  389. if len(action.arguments) != len(args):
  390. if len(args) == 0:
  391. raise PatternError(f'Action "{action.action}" expects no ' + \
  392. f'arguments, got {len(action.arguments)}.')
  393. raise PatternError(f'Action "{action.action}" expects ' + \
  394. f'{len(args)} arguments, got {len(action.arguments)}.')
  395. for i, datatype in enumerate(args):
  396. action.arguments[i] = cls.__parse_value(action.arguments[i], datatype)
  397. @classmethod
  398. def __read_expression(cls,
  399. tokens: list[str],
  400. token_index: int,
  401. depth: int = 0,
  402. one_subexpression: bool = False) -> tuple[PatternExpression, int]:
  403. """
  404. Reads an expression from a list of statement tokens. Returns a tuple
  405. containing the PatternExpression and the token index it left off at.
  406. If one_subexpression is True then it will return after reading a
  407. single expression instead of joining multiples (for reading the
  408. subject of a NOT expression).
  409. """
  410. subexpressions = []
  411. last_compound_operator = None
  412. while token_index < len(tokens):
  413. if one_subexpression:
  414. if len(subexpressions) == 1:
  415. return subexpressions[0], token_index
  416. if len(subexpressions) > 1:
  417. raise PatternError('Too many subexpressions')
  418. if tokens[token_index] == ')':
  419. if len(subexpressions) == 0:
  420. raise PatternError('No subexpressions')
  421. if len(subexpressions) == 1:
  422. return subexpressions[0], token_index
  423. return (PatternCompoundExpression(last_compound_operator,
  424. subexpressions), token_index)
  425. if tokens[token_index] in { "and", "or" }:
  426. compound_operator = tokens[token_index]
  427. if last_compound_operator and \
  428. compound_operator != last_compound_operator:
  429. subexpressions = [
  430. PatternCompoundExpression(last_compound_operator,
  431. subexpressions),
  432. ]
  433. last_compound_operator = compound_operator
  434. token_index += 1
  435. if tokens[token_index] == '!':
  436. (exp, next_index) = cls.__read_expression(tokens,
  437. token_index + 1, depth + 1, one_subexpression=True)
  438. subexpressions.append(PatternCompoundExpression('!', [exp]))
  439. token_index = next_index
  440. elif tokens[token_index] == '(':
  441. (exp, next_index) = cls.__read_expression(tokens,
  442. token_index + 1, depth + 1)
  443. if tokens[next_index] != ')':
  444. raise PatternError('Expected )')
  445. subexpressions.append(exp)
  446. token_index = next_index + 1
  447. else:
  448. (simple, next_index) = cls.__read_simple_expression(tokens,
  449. token_index, depth)
  450. subexpressions.append(simple)
  451. token_index = next_index
  452. if len(subexpressions) == 0:
  453. raise PatternError('No subexpressions')
  454. elif len(subexpressions) == 1:
  455. return subexpressions[0], token_index
  456. else:
  457. return PatternCompoundExpression(last_compound_operator,
  458. subexpressions), token_index
  459. @classmethod
  460. def __read_simple_expression(cls,
  461. tokens: list[str],
  462. token_index: int,
  463. depth: int = 0) -> tuple[PatternExpression, int]:
  464. """
  465. Reads a simple expression consisting of a field name, operator, and
  466. comparison value. Returns a tuple of the PatternSimpleExpression and
  467. the token index it left off at.
  468. """
  469. if depth > cls.MAX_EXPRESSION_NESTING:
  470. raise PatternError('Expression nests too deeply')
  471. if token_index >= len(tokens):
  472. raise PatternError('Expected field name, found EOL')
  473. field = tokens[token_index]
  474. token_index += 1
  475. datatype = cls.FIELD_TO_TYPE.get(field)
  476. if datatype is None:
  477. raise PatternError(f'No such field "{field}"')
  478. if token_index >= len(tokens):
  479. raise PatternError('Expected operator, found EOL')
  480. op = tokens[token_index]
  481. token_index += 1
  482. if op == '!':
  483. if token_index >= len(tokens):
  484. raise PatternError('Expected operator, found EOL')
  485. op = '!' + tokens[token_index]
  486. token_index += 1
  487. allowed_ops = cls.TYPE_TO_OPERATORS[datatype]
  488. if op not in allowed_ops:
  489. if op in cls.OPERATORS_ALL:
  490. raise PatternError(f'Operator {op} cannot be used with ' + \
  491. f'field "{field}"')
  492. raise PatternError(f'Unrecognized operator "{op}" - allowed: ' + \
  493. f'{sorted(list(allowed_ops))}')
  494. if token_index >= len(tokens):
  495. raise PatternError('Expected value, found EOL')
  496. value_str = tokens[token_index]
  497. try:
  498. value = cls.__parse_value(value_str, datatype, op)
  499. except ValueError as cause:
  500. raise PatternError(f'Bad value {value_str}') from cause
  501. token_index += 1
  502. exp = PatternSimpleExpression(field, op, value)
  503. return exp, token_index
  504. @classmethod
  505. def __parse_value(cls, value: str, datatype: str, op: str = None) -> Any:
  506. """
  507. Converts a value token to its Python value. Raises ValueError on failure.
  508. """
  509. if datatype == cls.TYPE_ID:
  510. if not is_user_id(value):
  511. raise ValueError(f'Illegal user id value: {value}')
  512. return value
  513. if datatype == cls.TYPE_MEMBER:
  514. return user_id_from_mention(value)
  515. if datatype == cls.TYPE_TEXT:
  516. s = str_from_quoted_str(value)
  517. if op in ('matches', '!matches'):
  518. try:
  519. return re.compile(s.lower())
  520. except re.error as e:
  521. raise ValueError(f'Invalid regex: {e}') from e
  522. if op in ('containsword', '!containsword'):
  523. try:
  524. return re.compile(f'\\b{re.escape(s.lower())}\\b')
  525. except re.error as e:
  526. raise ValueError(f'Invalid regex: {e}') from e
  527. return s
  528. if datatype == cls.TYPE_INT:
  529. return int(value)
  530. if datatype == cls.TYPE_FLOAT:
  531. return float(value)
  532. if datatype == cls.TYPE_TIMESPAN:
  533. return timedelta_from_str(value)
  534. raise ValueError(f'Unhandled datatype {datatype}')