Experimental Discord bot written in Python
Ви не можете вибрати більше 25 тем Теми мають розпочинатися з літери або цифри, можуть містити дефіси (-) і не повинні перевищувати 35 символів.

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581
  1. """
  2. Statements that match messages based on an expression and have a list of actions
  3. to take on them.
  4. """
  5. import re
  6. from abc import ABCMeta, abstractmethod
  7. from datetime import datetime, timezone
  8. from typing import Any
  9. from discord import Message, utils as discordutils
  10. from discord.ext.commands import Context
  11. from rocketbot.utils import is_user_id, str_from_quoted_str, timedelta_from_str, \
  12. user_id_from_mention
  13. class PatternError(RuntimeError):
  14. """
  15. Error thrown when parsing a pattern statement.
  16. """
  17. class PatternDeprecationError(PatternError):
  18. """
  19. Error raised by PatternStatement.check_deprecated_syntax.
  20. """
  21. class PatternAction:
  22. """
  23. Describes one action to take on a matched message or its author.
  24. """
  25. def __init__(self, action: str, args: list[Any]):
  26. self.action = action
  27. self.arguments = list(args)
  28. def __str__(self) -> str:
  29. arg_str = ', '.join(self.arguments)
  30. return f'{self.action}({arg_str})'
  31. class PatternExpression(metaclass=ABCMeta):
  32. """
  33. Abstract message matching expression.
  34. """
  35. def __init__(self):
  36. pass
  37. @abstractmethod
  38. def matches(self, message: Message, other_fields: dict[str, Any]) -> bool:
  39. """
  40. Whether a message matches this expression. other_fields are additional
  41. fields that can be queried not contained in the message itself.
  42. """
  43. return False
  44. class PatternSimpleExpression(PatternExpression):
  45. """
  46. Message matching expression with a simple "<field> <operator> <value>"
  47. structure.
  48. """
  49. def __init__(self, field: str, operator: str, value: Any):
  50. super().__init__()
  51. self.field: str = field
  52. self.operator: str = operator
  53. self.value: Any = value
  54. def __field_value(self, message: Message, other_fields: dict[str: Any]) -> Any:
  55. if self.field in ('content.markdown', 'content'):
  56. return message.content
  57. if self.field == 'content.plain':
  58. return discordutils.remove_markdown(message.clean_content)
  59. if self.field == 'author':
  60. return str(message.author.id)
  61. if self.field == 'author.id':
  62. return str(message.author.id)
  63. if self.field == 'author.joinage':
  64. return message.created_at - message.author.joined_at
  65. if self.field == 'author.name':
  66. return message.author.name
  67. if self.field == 'lastmatched':
  68. long_ago = datetime(year=1900, month=1, day=1, hour=0, minute=0, second=0, tzinfo=timezone.utc)
  69. last_matched = other_fields.get('last_matched') or long_ago
  70. return message.created_at - last_matched
  71. raise ValueError(f'Bad field name {self.field}')
  72. def matches(self, message: Message, other_fields: dict[str, Any]) -> bool:
  73. field_value = self.__field_value(message, other_fields)
  74. if self.operator == '==':
  75. if isinstance(field_value, str) and isinstance(self.value, str):
  76. return field_value.lower() == self.value.lower()
  77. return field_value == self.value
  78. if self.operator == '!=':
  79. if isinstance(field_value, str) and isinstance(self.value, str):
  80. return field_value.lower() != self.value.lower()
  81. return field_value != self.value
  82. if self.operator == '<':
  83. return field_value < self.value
  84. if self.operator == '>':
  85. return field_value > self.value
  86. if self.operator == '<=':
  87. return field_value <= self.value
  88. if self.operator == '>=':
  89. return field_value >= self.value
  90. if self.operator == 'contains':
  91. return self.value.lower() in field_value.lower()
  92. if self.operator == '!contains':
  93. return self.value.lower() not in field_value.lower()
  94. if self.operator in ('matches', 'containsword'):
  95. return self.value.search(field_value.lower()) is not None
  96. if self.operator in ('!matches', '!containsword'):
  97. return self.value.search(field_value.lower()) is None
  98. raise ValueError(f'Bad operator {self.operator}')
  99. def __str__(self) -> str:
  100. return f'({self.field} {self.operator} {self.value})'
  101. class PatternCompoundExpression(PatternExpression):
  102. """
  103. Message matching expression that combines several child expressions with
  104. a boolean operator.
  105. """
  106. def __init__(self, operator: str, operands: list[PatternExpression]):
  107. super().__init__()
  108. self.operator = operator
  109. self.operands = list(operands)
  110. def matches(self, message: Message, other_fields: dict[str, Any]) -> bool:
  111. if self.operator == '!':
  112. return not self.operands[0].matches(message, other_fields)
  113. if self.operator == 'and':
  114. for op in self.operands:
  115. if not op.matches(message, other_fields):
  116. return False
  117. return True
  118. if self.operator == 'or':
  119. for op in self.operands:
  120. if op.matches(message, other_fields):
  121. return True
  122. return False
  123. raise ValueError(f'Bad operator "{self.operator}"')
  124. def __str__(self) -> str:
  125. if self.operator == '!':
  126. return f'(!( {self.operands[0]} ))'
  127. strs = map(str, self.operands)
  128. joined = f' {self.operator} '.join(strs)
  129. return f'( {joined} )'
  130. class PatternStatement:
  131. """
  132. A full message match statement. If a message matches the given expression,
  133. the given actions should be performed.
  134. """
  135. DEFAULT_PRIORITY: int = 100
  136. def __init__(self,
  137. name: str,
  138. actions: list[PatternAction],
  139. expression: PatternExpression,
  140. original: str,
  141. priority: int = DEFAULT_PRIORITY):
  142. self.name: str = name
  143. self.actions: list[PatternAction] = list(actions) # PatternAction[]
  144. self.expression: PatternExpression = expression
  145. self.original: str = original
  146. self.priority: int = priority
  147. def check_deprecations(self) -> None:
  148. """
  149. Tests whether this statement uses any deprecated syntax. Will raise a
  150. PatternDeprecationError if one is found.
  151. """
  152. self.__check_deprecations(self.expression)
  153. @classmethod
  154. def __check_deprecations(cls, expression: PatternExpression) -> None:
  155. if isinstance(expression, PatternSimpleExpression):
  156. s: PatternSimpleExpression = expression
  157. if s.field in PatternCompiler.DEPRECATED_FIELDS:
  158. raise PatternDeprecationError(f'"{s.field}" field is deprecated')
  159. elif isinstance(expression, PatternCompoundExpression):
  160. c: PatternCompoundExpression = expression
  161. for oper in c.operands:
  162. cls.__check_deprecations(oper)
  163. def to_json(self) -> dict[str, Any]:
  164. """
  165. Returns a JSON representation of this statement.
  166. """
  167. return {
  168. 'name': self.name,
  169. 'priority': self.priority,
  170. 'statement': self.original,
  171. }
  172. @classmethod
  173. def from_json(cls, json: dict[str, Any]):
  174. """
  175. Gets a PatternStatement from its JSON representation.
  176. """
  177. ps = PatternCompiler.parse_statement(json['name'], json['statement'])
  178. ps.priority = json.get('priority', cls.DEFAULT_PRIORITY)
  179. return ps
  180. class PatternCompiler:
  181. """
  182. Parses a user-provided message filter statement into a PatternStatement.
  183. """
  184. TYPE_FLOAT: str = 'float'
  185. TYPE_ID: str = 'id'
  186. TYPE_INT: str = 'int'
  187. TYPE_MEMBER: str = 'Member'
  188. TYPE_REGEX: str = 'regex'
  189. TYPE_TEXT: str = 'text'
  190. TYPE_TIMESPAN: str = 'timespan'
  191. FIELD_TO_TYPE: dict[str, str] = {
  192. 'author': TYPE_MEMBER,
  193. 'author.id': TYPE_ID,
  194. 'author.joinage': TYPE_TIMESPAN,
  195. 'author.name': TYPE_TEXT,
  196. 'content': TYPE_TEXT, # deprecated, use content.markdown or content.plain
  197. 'content.markdown': TYPE_TEXT,
  198. 'content.plain': TYPE_TEXT,
  199. 'lastmatched': TYPE_TIMESPAN,
  200. }
  201. DEPRECATED_FIELDS: set[str] = set([ 'content' ])
  202. ACTION_TO_ARGS: dict[str, list[str]] = {
  203. 'ban': [],
  204. 'delete': [],
  205. 'kick': [],
  206. 'modinfo': [],
  207. 'modwarn': [],
  208. 'reply': [ TYPE_TEXT ],
  209. }
  210. OPERATORS_IDENTITY: set[str] = set([ '==', '!=' ])
  211. OPERATORS_COMPARISON: set[str] = set([ '<', '>', '<=', '>=' ])
  212. OPERATORS_NUMERIC: set[str] = OPERATORS_IDENTITY | OPERATORS_COMPARISON
  213. OPERATORS_TEXT: set[str] = OPERATORS_IDENTITY | set([
  214. 'contains', '!contains',
  215. 'containsword', '!containsword',
  216. 'matches', '!matches',
  217. ])
  218. OPERATORS_ALL: set[str] = OPERATORS_IDENTITY | OPERATORS_COMPARISON | OPERATORS_TEXT
  219. TYPE_TO_OPERATORS: dict[str, set[str]] = {
  220. TYPE_ID: OPERATORS_IDENTITY,
  221. TYPE_MEMBER: OPERATORS_IDENTITY,
  222. TYPE_TEXT: OPERATORS_TEXT,
  223. TYPE_INT: OPERATORS_NUMERIC,
  224. TYPE_FLOAT: OPERATORS_NUMERIC,
  225. TYPE_TIMESPAN: OPERATORS_NUMERIC,
  226. }
  227. WHITESPACE_CHARS: str = ' \t\n\r'
  228. STRING_QUOTE_CHARS: str = '\'"'
  229. SYMBOL_CHARS: str = 'abcdefghijklmnopqrstuvwxyz.'
  230. VALUE_CHARS: str = '0123456789dhms<@!>'
  231. OP_CHARS: str = '<=>!(),'
  232. MAX_EXPRESSION_NESTING: int = 8
  233. @classmethod
  234. def expression_str_from_context(cls, context: Context, name: str) -> str:
  235. """
  236. Extracts the statement string from an "add" command context.
  237. """
  238. pattern_str: str = context.message.content
  239. command_chain = [ name ]
  240. cmd = context.command
  241. while cmd:
  242. command_chain.insert(0, cmd.name)
  243. cmd = cmd.parent
  244. command_chain[0] = f'{context.prefix}{command_chain[0]}'
  245. for cmd in command_chain:
  246. if pattern_str.startswith(cmd):
  247. pattern_str = pattern_str[len(cmd):].lstrip()
  248. elif pattern_str.startswith(f'"{cmd}"'):
  249. pattern_str = pattern_str[len(cmd) + 2:].lstrip()
  250. return pattern_str
  251. @classmethod
  252. def parse_statement(cls, name: str, statement: str) -> PatternStatement:
  253. """
  254. Parses a user-provided message filter statement into a PatternStatement.
  255. Raises PatternError on failure.
  256. """
  257. tokens: list[str] = cls.__tokenize(statement)
  258. token_index: int = 0
  259. actions, token_index = cls.__read_actions(tokens, token_index)
  260. expression, token_index = cls.__read_expression(tokens, token_index)
  261. return PatternStatement(name, actions, expression, statement)
  262. @classmethod
  263. def __tokenize(cls, statement: str) -> list[str]:
  264. """
  265. Converts a message filter statement into a list of tokens.
  266. """
  267. tokens: list[str] = []
  268. in_quote: bool = False
  269. in_escape: bool = False
  270. all_token_types: set[str] = set([ 'sym', 'op', 'val' ])
  271. possible_token_types: set[str] = set(all_token_types)
  272. current_token: str = ''
  273. for ch in statement:
  274. if in_quote:
  275. if in_escape:
  276. if ch == 'n':
  277. current_token += '\n'
  278. elif ch == 't':
  279. current_token += '\t'
  280. else:
  281. current_token += ch
  282. in_escape = False
  283. elif ch == '\\':
  284. in_escape = True
  285. elif ch == in_quote:
  286. current_token += ch
  287. tokens.append(current_token)
  288. current_token = ''
  289. possible_token_types |= all_token_types
  290. in_quote = False
  291. else:
  292. current_token += ch
  293. else:
  294. if ch in cls.STRING_QUOTE_CHARS:
  295. if len(current_token) > 0:
  296. tokens.append(current_token)
  297. current_token = ''
  298. possible_token_types |= all_token_types
  299. in_quote = ch
  300. current_token = ch
  301. elif ch == '\\':
  302. raise PatternError("Unexpected \\ outside quoted string")
  303. elif ch in cls.WHITESPACE_CHARS:
  304. if len(current_token) > 0:
  305. tokens.append(current_token)
  306. current_token = ''
  307. possible_token_types |= all_token_types
  308. else:
  309. possible_ch_types = set()
  310. if ch in cls.SYMBOL_CHARS:
  311. possible_ch_types.add('sym')
  312. if ch in cls.VALUE_CHARS:
  313. possible_ch_types.add('val')
  314. if ch in cls.OP_CHARS:
  315. possible_ch_types.add('op')
  316. if len(current_token) > 0 and \
  317. possible_ch_types.isdisjoint(possible_token_types):
  318. if len(current_token) > 0:
  319. tokens.append(current_token)
  320. current_token = ''
  321. possible_token_types |= all_token_types
  322. possible_token_types &= possible_ch_types
  323. current_token += ch
  324. if len(current_token) > 0:
  325. tokens.append(current_token)
  326. # Some symbols might be glommed onto other tokens. Split 'em up.
  327. prefixes_to_split = [ '!', '(', ',' ]
  328. suffixes_to_split = [ ')', ',' ]
  329. i = 0
  330. while i < len(tokens):
  331. token = tokens[i]
  332. mutated = False
  333. for prefix in prefixes_to_split:
  334. if token.startswith(prefix) and len(token) > len(prefix):
  335. tokens.insert(i, prefix)
  336. tokens[i + 1] = token[len(prefix):]
  337. i += 1
  338. mutated = True
  339. break
  340. if mutated:
  341. continue
  342. for suffix in suffixes_to_split:
  343. if token.endswith(suffix) and len(token) > len(suffix):
  344. tokens[i] = token[0:-len(suffix)]
  345. tokens.insert(i + 1, suffix)
  346. mutated = True
  347. break
  348. if mutated:
  349. continue
  350. i += 1
  351. return tokens
  352. @classmethod
  353. def __read_actions(cls,
  354. tokens: list[str],
  355. token_index: int) -> tuple[list[PatternAction], int]:
  356. """
  357. Reads the actions from a list of statement tokens. Returns a tuple
  358. containing a list of PatternActions and the token index this method
  359. left off at (the token after the "if").
  360. """
  361. actions: list[PatternAction] = []
  362. current_action_tokens = []
  363. while token_index < len(tokens):
  364. token = tokens[token_index]
  365. if token == 'if':
  366. if len(current_action_tokens) > 0:
  367. a = PatternAction(current_action_tokens[0], \
  368. current_action_tokens[1:])
  369. cls.__validate_action(a)
  370. actions.append(a)
  371. token_index += 1
  372. return (actions, token_index)
  373. elif token == ',':
  374. if len(current_action_tokens) < 1:
  375. raise PatternError('Unexpected ,')
  376. a = PatternAction(current_action_tokens[0], \
  377. current_action_tokens[1:])
  378. cls.__validate_action(a)
  379. actions.append(a)
  380. current_action_tokens = []
  381. else:
  382. current_action_tokens.append(token)
  383. token_index += 1
  384. raise PatternError('Unexpected end of line in action list')
  385. @classmethod
  386. def __validate_action(cls, action: PatternAction) -> None:
  387. args: list[str] = cls.ACTION_TO_ARGS.get(action.action)
  388. if args is None:
  389. raise PatternError(f'Unknown action "{action.action}"')
  390. if len(action.arguments) != len(args):
  391. if len(args) == 0:
  392. raise PatternError(f'Action "{action.action}" expects no ' + \
  393. f'arguments, got {len(action.arguments)}.')
  394. raise PatternError(f'Action "{action.action}" expects ' + \
  395. f'{len(args)} arguments, got {len(action.arguments)}.')
  396. for i, datatype in enumerate(args):
  397. action.arguments[i] = cls.__parse_value(action.arguments[i], datatype)
  398. @classmethod
  399. def __read_expression(cls,
  400. tokens: list[str],
  401. token_index: int,
  402. depth: int = 0,
  403. one_subexpression: bool = False) -> tuple[PatternExpression, int]:
  404. """
  405. Reads an expression from a list of statement tokens. Returns a tuple
  406. containing the PatternExpression and the token index it left off at.
  407. If one_subexpression is True then it will return after reading a
  408. single expression instead of joining multiples (for reading the
  409. subject of a NOT expression).
  410. """
  411. subexpressions = []
  412. last_compound_operator = None
  413. while token_index < len(tokens):
  414. if one_subexpression:
  415. if len(subexpressions) == 1:
  416. return (subexpressions[0], token_index)
  417. if len(subexpressions) > 1:
  418. raise PatternError('Too many subexpressions')
  419. compound_operator = None
  420. if tokens[token_index] == ')':
  421. if len(subexpressions) == 0:
  422. raise PatternError('No subexpressions')
  423. if len(subexpressions) == 1:
  424. return (subexpressions[0], token_index)
  425. return (PatternCompoundExpression(last_compound_operator,
  426. subexpressions), token_index)
  427. if tokens[token_index] in set(["and", "or"]):
  428. compound_operator = tokens[token_index]
  429. if last_compound_operator and \
  430. compound_operator != last_compound_operator:
  431. subexpressions = [
  432. PatternCompoundExpression(last_compound_operator,
  433. subexpressions),
  434. ]
  435. last_compound_operator = compound_operator
  436. token_index += 1
  437. if tokens[token_index] == '!':
  438. (exp, next_index) = cls.__read_expression(tokens, \
  439. token_index + 1, depth + 1, one_subexpression=True)
  440. subexpressions.append(PatternCompoundExpression('!', [exp]))
  441. token_index = next_index
  442. elif tokens[token_index] == '(':
  443. (exp, next_index) = cls.__read_expression(tokens,
  444. token_index + 1, depth + 1)
  445. if tokens[next_index] != ')':
  446. raise PatternError('Expected )')
  447. subexpressions.append(exp)
  448. token_index = next_index + 1
  449. else:
  450. (simple, next_index) = cls.__read_simple_expression(tokens,
  451. token_index, depth)
  452. subexpressions.append(simple)
  453. token_index = next_index
  454. if len(subexpressions) == 0:
  455. raise PatternError('No subexpressions')
  456. elif len(subexpressions) == 1:
  457. return (subexpressions[0], token_index)
  458. else:
  459. return (PatternCompoundExpression(last_compound_operator,
  460. subexpressions), token_index)
  461. @classmethod
  462. def __read_simple_expression(cls,
  463. tokens: list[str],
  464. token_index: int,
  465. depth: int = 0) -> tuple[PatternExpression, int]:
  466. """
  467. Reads a simple expression consisting of a field name, operator, and
  468. comparison value. Returns a tuple of the PatternSimpleExpression and
  469. the token index it left off at.
  470. """
  471. if depth > cls.MAX_EXPRESSION_NESTING:
  472. raise PatternError('Expression nests too deeply')
  473. if token_index >= len(tokens):
  474. raise PatternError('Expected field name, found EOL')
  475. field = tokens[token_index]
  476. token_index += 1
  477. datatype = cls.FIELD_TO_TYPE.get(field)
  478. if datatype is None:
  479. raise PatternError(f'No such field "{field}"')
  480. if token_index >= len(tokens):
  481. raise PatternError('Expected operator, found EOL')
  482. op = tokens[token_index]
  483. token_index += 1
  484. if op == '!':
  485. if token_index >= len(tokens):
  486. raise PatternError('Expected operator, found EOL')
  487. op = '!' + tokens[token_index]
  488. token_index += 1
  489. allowed_ops = cls.TYPE_TO_OPERATORS[datatype]
  490. if op not in allowed_ops:
  491. if op in cls.OPERATORS_ALL:
  492. raise PatternError(f'Operator {op} cannot be used with ' + \
  493. f'field "{field}"')
  494. raise PatternError(f'Unrecognized operator "{op}" - allowed: ' + \
  495. f'{sorted(list(allowed_ops))}')
  496. if token_index >= len(tokens):
  497. raise PatternError('Expected value, found EOL')
  498. value_str = tokens[token_index]
  499. try:
  500. value = cls.__parse_value(value_str, datatype, op)
  501. except ValueError as cause:
  502. raise PatternError(f'Bad value {value_str}') from cause
  503. token_index += 1
  504. exp = PatternSimpleExpression(field, op, value)
  505. return (exp, token_index)
  506. @classmethod
  507. def __parse_value(cls, value: str, datatype: str, op: str = None) -> Any:
  508. """
  509. Converts a value token to its Python value. Raises ValueError on failure.
  510. """
  511. if datatype == cls.TYPE_ID:
  512. if not is_user_id(value):
  513. raise ValueError(f'Illegal user id value: {value}')
  514. return value
  515. if datatype == cls.TYPE_MEMBER:
  516. return user_id_from_mention(value)
  517. if datatype == cls.TYPE_TEXT:
  518. s = str_from_quoted_str(value)
  519. if op in ('matches', '!matches'):
  520. try:
  521. return re.compile(s.lower())
  522. except re.error as e:
  523. raise ValueError(f'Invalid regex: {e}') from e
  524. if op in ('containsword', '!containsword'):
  525. try:
  526. return re.compile(f'\\b{re.escape(s.lower())}\\b')
  527. except re.error as e:
  528. raise ValueError(f'Invalid regex: {e}') from e
  529. return s
  530. if datatype == cls.TYPE_INT:
  531. return int(value)
  532. if datatype == cls.TYPE_FLOAT:
  533. return float(value)
  534. if datatype == cls.TYPE_TIMESPAN:
  535. return timedelta_from_str(value)
  536. raise ValueError(f'Unhandled datatype {datatype}')