Experimental Discord bot written in Python
Вы не можете выбрать более 25 тем Темы должны начинаться с буквы или цифры, могут содержать дефисы(-) и должны содержать не более 35 символов.

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568
  1. """
  2. Statements that match messages based on an expression and have a list of actions
  3. to take on them.
  4. """
  5. import re
  6. from abc import ABCMeta, abstractmethod
  7. from typing import Any
  8. from discord import Message, utils as discordutils
  9. from discord.ext.commands import Context
  10. from rocketbot.utils import is_user_id, str_from_quoted_str, timedelta_from_str, \
  11. user_id_from_mention
  12. class PatternError(RuntimeError):
  13. """
  14. Error thrown when parsing a pattern statement.
  15. """
  16. class PatternDeprecationError(PatternError):
  17. """
  18. Error raised by PatternStatement.check_deprecated_syntax.
  19. """
  20. class PatternAction:
  21. """
  22. Describes one action to take on a matched message or its author.
  23. """
  24. def __init__(self, action: str, args: list[Any]):
  25. self.action = action
  26. self.arguments = list(args)
  27. def __str__(self) -> str:
  28. arg_str = ', '.join(self.arguments)
  29. return f'{self.action}({arg_str})'
  30. class PatternExpression(metaclass=ABCMeta):
  31. """
  32. Abstract message matching expression.
  33. """
  34. def __init__(self):
  35. pass
  36. @abstractmethod
  37. def matches(self, message: Message) -> bool:
  38. """
  39. Whether a message matches this expression.
  40. """
  41. return False
  42. class PatternSimpleExpression(PatternExpression):
  43. """
  44. Message matching expression with a simple "<field> <operator> <value>"
  45. structure.
  46. """
  47. def __init__(self, field: str, operator: str, value: Any):
  48. super().__init__()
  49. self.field = field
  50. self.operator = operator
  51. self.value = value
  52. def __field_value(self, message: Message) -> Any:
  53. if self.field in ('content.markdown', 'content'):
  54. return message.content
  55. if self.field == 'content.plain':
  56. return discordutils.remove_markdown(message.clean_content)
  57. if self.field == 'author':
  58. return str(message.author.id)
  59. if self.field == 'author.id':
  60. return str(message.author.id)
  61. if self.field == 'author.joinage':
  62. return message.created_at - message.author.joined_at
  63. if self.field == 'author.name':
  64. return message.author.name
  65. else:
  66. raise ValueError(f'Bad field name {self.field}')
  67. def matches(self, message: Message) -> bool:
  68. field_value = self.__field_value(message)
  69. if self.operator == '==':
  70. if isinstance(field_value, str) and isinstance(self.value, str):
  71. return field_value.lower() == self.value.lower()
  72. return field_value == self.value
  73. if self.operator == '!=':
  74. if isinstance(field_value, str) and isinstance(self.value, str):
  75. return field_value.lower() != self.value.lower()
  76. return field_value != self.value
  77. if self.operator == '<':
  78. return field_value < self.value
  79. if self.operator == '>':
  80. return field_value > self.value
  81. if self.operator == '<=':
  82. return field_value <= self.value
  83. if self.operator == '>=':
  84. return field_value >= self.value
  85. if self.operator == 'contains':
  86. return self.value.lower() in field_value.lower()
  87. if self.operator == '!contains':
  88. return self.value.lower() not in field_value.lower()
  89. if self.operator in ('matches', 'containsword'):
  90. return self.value.search(field_value.lower()) is not None
  91. if self.operator in ('!matches', '!containsword'):
  92. return self.value.search(field_value.lower()) is None
  93. raise ValueError(f'Bad operator {self.operator}')
  94. def __str__(self) -> str:
  95. return f'({self.field} {self.operator} {self.value})'
  96. class PatternCompoundExpression(PatternExpression):
  97. """
  98. Message matching expression that combines several child expressions with
  99. a boolean operator.
  100. """
  101. def __init__(self, operator: str, operands: list[PatternExpression]):
  102. super().__init__()
  103. self.operator = operator
  104. self.operands = list(operands)
  105. def matches(self, message: Message) -> bool:
  106. if self.operator == '!':
  107. return not self.operands[0].matches(message)
  108. if self.operator == 'and':
  109. for op in self.operands:
  110. if not op.matches(message):
  111. return False
  112. return True
  113. if self.operator == 'or':
  114. for op in self.operands:
  115. if op.matches(message):
  116. return True
  117. return False
  118. raise ValueError(f'Bad operator "{self.operator}"')
  119. def __str__(self) -> str:
  120. if self.operator == '!':
  121. return f'(!( {self.operands[0]} ))'
  122. strs = map(str, self.operands)
  123. joined = f' {self.operator} '.join(strs)
  124. return f'( {joined} )'
  125. class PatternStatement:
  126. """
  127. A full message match statement. If a message matches the given expression,
  128. the given actions should be performed.
  129. """
  130. def __init__(self,
  131. name: str,
  132. actions: list[PatternAction],
  133. expression: PatternExpression,
  134. original: str):
  135. self.name = name
  136. self.actions = list(actions) # PatternAction[]
  137. self.expression = expression
  138. self.original = original
  139. def check_deprecations(self) -> None:
  140. """
  141. Tests whether this statement uses any deprecated syntax. Will raise a
  142. PatternDeprecationError if one is found.
  143. """
  144. self.__check_deprecations(self.expression)
  145. @classmethod
  146. def __check_deprecations(cls, expression: PatternExpression) -> None:
  147. if isinstance(expression, PatternSimpleExpression):
  148. s: PatternSimpleExpression = expression
  149. if s.field in PatternCompiler.DEPRECATED_FIELDS:
  150. raise PatternDeprecationError(f'"{s.field}" field is deprecated')
  151. elif isinstance(expression, PatternCompoundExpression):
  152. c: PatternCompoundExpression = expression
  153. for oper in c.operands:
  154. cls.__check_deprecations(oper)
  155. def to_json(self) -> dict:
  156. """
  157. Returns a JSON representation of this statement.
  158. """
  159. return {
  160. 'name': self.name,
  161. 'statement': self.original,
  162. }
  163. @classmethod
  164. def from_json(cls, json: dict):
  165. """
  166. Gets a PatternStatement from its JSON representation.
  167. """
  168. return PatternCompiler.parse_statement(json['name'], json['statement'])
  169. class PatternCompiler:
  170. """
  171. Parses a user-provided message filter statement into a PatternStatement.
  172. """
  173. TYPE_FLOAT = 'float'
  174. TYPE_ID = 'id'
  175. TYPE_INT = 'int'
  176. TYPE_MEMBER = 'Member'
  177. TYPE_REGEX = 'regex'
  178. TYPE_TEXT = 'text'
  179. TYPE_TIMESPAN = 'timespan'
  180. FIELD_TO_TYPE: dict[str, str] = {
  181. 'content.plain': TYPE_TEXT,
  182. 'content.markdown': TYPE_TEXT,
  183. 'author': TYPE_MEMBER,
  184. 'author.id': TYPE_ID,
  185. 'author.name': TYPE_TEXT,
  186. 'author.joinage': TYPE_TIMESPAN,
  187. 'content': TYPE_TEXT, # deprecated, use content.markdown or content.plain
  188. }
  189. DEPRECATED_FIELDS: set[str] = set([ 'content' ])
  190. ACTION_TO_ARGS: dict[str, list[str]] = {
  191. 'ban': [],
  192. 'delete': [],
  193. 'kick': [],
  194. 'modinfo': [],
  195. 'modwarn': [],
  196. 'reply': [ TYPE_TEXT ],
  197. }
  198. OPERATORS_IDENTITY: set[str] = set([ '==', '!=' ])
  199. OPERATORS_COMPARISON: set[str] = set([ '<', '>', '<=', '>=' ])
  200. OPERATORS_NUMERIC = OPERATORS_IDENTITY | OPERATORS_COMPARISON
  201. OPERATORS_TEXT = OPERATORS_IDENTITY | set([
  202. 'contains', '!contains',
  203. 'containsword', '!containsword',
  204. 'matches', '!matches',
  205. ])
  206. OPERATORS_ALL = OPERATORS_IDENTITY | OPERATORS_COMPARISON | OPERATORS_TEXT
  207. TYPE_TO_OPERATORS: dict[str, set[str]] = {
  208. TYPE_ID: OPERATORS_IDENTITY,
  209. TYPE_MEMBER: OPERATORS_IDENTITY,
  210. TYPE_TEXT: OPERATORS_TEXT,
  211. TYPE_INT: OPERATORS_NUMERIC,
  212. TYPE_FLOAT: OPERATORS_NUMERIC,
  213. TYPE_TIMESPAN: OPERATORS_NUMERIC,
  214. }
  215. WHITESPACE_CHARS = ' \t\n\r'
  216. STRING_QUOTE_CHARS = '\'"'
  217. SYMBOL_CHARS = 'abcdefghijklmnopqrstuvwxyz.'
  218. VALUE_CHARS = '0123456789dhms<@!>'
  219. OP_CHARS = '<=>!(),'
  220. MAX_EXPRESSION_NESTING = 8
  221. @classmethod
  222. def expression_str_from_context(cls, context: Context, name: str) -> str:
  223. """
  224. Extracts the statement string from an "add" command context.
  225. """
  226. pattern_str = context.message.content
  227. command_chain = [ name ]
  228. cmd = context.command
  229. while cmd:
  230. command_chain.insert(0, cmd.name)
  231. cmd = cmd.parent
  232. command_chain[0] = f'{context.prefix}{command_chain[0]}'
  233. for cmd in command_chain:
  234. if pattern_str.startswith(cmd):
  235. pattern_str = pattern_str[len(cmd):].lstrip()
  236. elif pattern_str.startswith(f'"{cmd}"'):
  237. pattern_str = pattern_str[len(cmd) + 2:].lstrip()
  238. return pattern_str
  239. @classmethod
  240. def parse_statement(cls, name: str, statement: str) -> PatternStatement:
  241. """
  242. Parses a user-provided message filter statement into a PatternStatement.
  243. Raises PatternError on failure.
  244. """
  245. tokens = cls.__tokenize(statement)
  246. token_index = 0
  247. actions, token_index = cls.__read_actions(tokens, token_index)
  248. expression, token_index = cls.__read_expression(tokens, token_index)
  249. return PatternStatement(name, actions, expression, statement)
  250. @classmethod
  251. def __tokenize(cls, statement: str) -> list[str]:
  252. """
  253. Converts a message filter statement into a list of tokens.
  254. """
  255. tokens: list[str] = []
  256. in_quote = False
  257. in_escape = False
  258. all_token_types = set([ 'sym', 'op', 'val' ])
  259. possible_token_types = set(all_token_types)
  260. current_token = ''
  261. for ch in statement:
  262. if in_quote:
  263. if in_escape:
  264. if ch == 'n':
  265. current_token += '\n'
  266. elif ch == 't':
  267. current_token += '\t'
  268. else:
  269. current_token += ch
  270. in_escape = False
  271. elif ch == '\\':
  272. in_escape = True
  273. elif ch == in_quote:
  274. current_token += ch
  275. tokens.append(current_token)
  276. current_token = ''
  277. possible_token_types |= all_token_types
  278. in_quote = False
  279. else:
  280. current_token += ch
  281. else:
  282. if ch in cls.STRING_QUOTE_CHARS:
  283. if len(current_token) > 0:
  284. tokens.append(current_token)
  285. current_token = ''
  286. possible_token_types |= all_token_types
  287. in_quote = ch
  288. current_token = ch
  289. elif ch == '\\':
  290. raise PatternError("Unexpected \\ outside quoted string")
  291. elif ch in cls.WHITESPACE_CHARS:
  292. if len(current_token) > 0:
  293. tokens.append(current_token)
  294. current_token = ''
  295. possible_token_types |= all_token_types
  296. else:
  297. possible_ch_types = set()
  298. if ch in cls.SYMBOL_CHARS:
  299. possible_ch_types.add('sym')
  300. if ch in cls.VALUE_CHARS:
  301. possible_ch_types.add('val')
  302. if ch in cls.OP_CHARS:
  303. possible_ch_types.add('op')
  304. if len(current_token) > 0 and \
  305. possible_ch_types.isdisjoint(possible_token_types):
  306. if len(current_token) > 0:
  307. tokens.append(current_token)
  308. current_token = ''
  309. possible_token_types |= all_token_types
  310. possible_token_types &= possible_ch_types
  311. current_token += ch
  312. if len(current_token) > 0:
  313. tokens.append(current_token)
  314. # Some symbols might be glommed onto other tokens. Split 'em up.
  315. prefixes_to_split = [ '!', '(', ',' ]
  316. suffixes_to_split = [ ')', ',' ]
  317. i = 0
  318. while i < len(tokens):
  319. token = tokens[i]
  320. mutated = False
  321. for prefix in prefixes_to_split:
  322. if token.startswith(prefix) and len(token) > len(prefix):
  323. tokens.insert(i, prefix)
  324. tokens[i + 1] = token[len(prefix):]
  325. i += 1
  326. mutated = True
  327. break
  328. if mutated:
  329. continue
  330. for suffix in suffixes_to_split:
  331. if token.endswith(suffix) and len(token) > len(suffix):
  332. tokens[i] = token[0:-len(suffix)]
  333. tokens.insert(i + 1, suffix)
  334. mutated = True
  335. break
  336. if mutated:
  337. continue
  338. i += 1
  339. return tokens
  340. @classmethod
  341. def __read_actions(cls,
  342. tokens: list[str],
  343. token_index: int) -> tuple[list[PatternAction], int]:
  344. """
  345. Reads the actions from a list of statement tokens. Returns a tuple
  346. containing a list of PatternActions and the token index this method
  347. left off at (the token after the "if").
  348. """
  349. actions: list[PatternAction] = []
  350. current_action_tokens = []
  351. while token_index < len(tokens):
  352. token = tokens[token_index]
  353. if token == 'if':
  354. if len(current_action_tokens) > 0:
  355. a = PatternAction(current_action_tokens[0], \
  356. current_action_tokens[1:])
  357. cls.__validate_action(a)
  358. actions.append(a)
  359. token_index += 1
  360. return (actions, token_index)
  361. elif token == ',':
  362. if len(current_action_tokens) < 1:
  363. raise PatternError('Unexpected ,')
  364. a = PatternAction(current_action_tokens[0], \
  365. current_action_tokens[1:])
  366. cls.__validate_action(a)
  367. actions.append(a)
  368. current_action_tokens = []
  369. else:
  370. current_action_tokens.append(token)
  371. token_index += 1
  372. raise PatternError('Unexpected end of line in action list')
  373. @classmethod
  374. def __validate_action(cls, action: PatternAction) -> None:
  375. args: list[str] = cls.ACTION_TO_ARGS.get(action.action)
  376. if args is None:
  377. raise PatternError(f'Unknown action "{action.action}"')
  378. if len(action.arguments) != len(args):
  379. if len(args) == 0:
  380. raise PatternError(f'Action "{action.action}" expects no ' + \
  381. f'arguments, got {len(action.arguments)}.')
  382. raise PatternError(f'Action "{action.action}" expects ' + \
  383. f'{len(args)} arguments, got {len(action.arguments)}.')
  384. for i, datatype in enumerate(args):
  385. action.arguments[i] = cls.__parse_value(action.arguments[i], datatype)
  386. @classmethod
  387. def __read_expression(cls,
  388. tokens: list[str],
  389. token_index: int,
  390. depth: int = 0,
  391. one_subexpression: bool = False) -> tuple[PatternExpression, int]:
  392. """
  393. Reads an expression from a list of statement tokens. Returns a tuple
  394. containing the PatternExpression and the token index it left off at.
  395. If one_subexpression is True then it will return after reading a
  396. single expression instead of joining multiples (for reading the
  397. subject of a NOT expression).
  398. """
  399. subexpressions = []
  400. last_compound_operator = None
  401. while token_index < len(tokens):
  402. if one_subexpression:
  403. if len(subexpressions) == 1:
  404. return (subexpressions[0], token_index)
  405. if len(subexpressions) > 1:
  406. raise PatternError('Too many subexpressions')
  407. compound_operator = None
  408. if tokens[token_index] == ')':
  409. if len(subexpressions) == 0:
  410. raise PatternError('No subexpressions')
  411. if len(subexpressions) == 1:
  412. return (subexpressions[0], token_index)
  413. return (PatternCompoundExpression(last_compound_operator,
  414. subexpressions), token_index)
  415. if tokens[token_index] in set(["and", "or"]):
  416. compound_operator = tokens[token_index]
  417. if last_compound_operator and \
  418. compound_operator != last_compound_operator:
  419. subexpressions = [
  420. PatternCompoundExpression(last_compound_operator,
  421. subexpressions),
  422. ]
  423. last_compound_operator = compound_operator
  424. token_index += 1
  425. if tokens[token_index] == '!':
  426. (exp, next_index) = cls.__read_expression(tokens, \
  427. token_index + 1, depth + 1, one_subexpression=True)
  428. subexpressions.append(PatternCompoundExpression('!', [exp]))
  429. token_index = next_index
  430. elif tokens[token_index] == '(':
  431. (exp, next_index) = cls.__read_expression(tokens,
  432. token_index + 1, depth + 1)
  433. if tokens[next_index] != ')':
  434. raise PatternError('Expected )')
  435. subexpressions.append(exp)
  436. token_index = next_index + 1
  437. else:
  438. (simple, next_index) = cls.__read_simple_expression(tokens,
  439. token_index, depth)
  440. subexpressions.append(simple)
  441. token_index = next_index
  442. if len(subexpressions) == 0:
  443. raise PatternError('No subexpressions')
  444. elif len(subexpressions) == 1:
  445. return (subexpressions[0], token_index)
  446. else:
  447. return (PatternCompoundExpression(last_compound_operator,
  448. subexpressions), token_index)
  449. @classmethod
  450. def __read_simple_expression(cls,
  451. tokens: list[str],
  452. token_index: int,
  453. depth: int = 0) -> tuple[PatternExpression, int]:
  454. """
  455. Reads a simple expression consisting of a field name, operator, and
  456. comparison value. Returns a tuple of the PatternSimpleExpression and
  457. the token index it left off at.
  458. """
  459. if depth > cls.MAX_EXPRESSION_NESTING:
  460. raise PatternError('Expression nests too deeply')
  461. if token_index >= len(tokens):
  462. raise PatternError('Expected field name, found EOL')
  463. field = tokens[token_index]
  464. token_index += 1
  465. datatype = cls.FIELD_TO_TYPE.get(field)
  466. if datatype is None:
  467. raise PatternError(f'No such field "{field}"')
  468. if token_index >= len(tokens):
  469. raise PatternError('Expected operator, found EOL')
  470. op = tokens[token_index]
  471. token_index += 1
  472. if op == '!':
  473. if token_index >= len(tokens):
  474. raise PatternError('Expected operator, found EOL')
  475. op = '!' + tokens[token_index]
  476. token_index += 1
  477. allowed_ops = cls.TYPE_TO_OPERATORS[datatype]
  478. if op not in allowed_ops:
  479. if op in cls.OPERATORS_ALL:
  480. raise PatternError(f'Operator {op} cannot be used with ' + \
  481. f'field "{field}"')
  482. raise PatternError(f'Unrecognized operator "{op}" - allowed: ' + \
  483. f'{sorted(list(allowed_ops))}')
  484. if token_index >= len(tokens):
  485. raise PatternError('Expected value, found EOL')
  486. value_str = tokens[token_index]
  487. try:
  488. value = cls.__parse_value(value_str, datatype, op)
  489. except ValueError as cause:
  490. raise PatternError(f'Bad value {value_str}') from cause
  491. token_index += 1
  492. exp = PatternSimpleExpression(field, op, value)
  493. return (exp, token_index)
  494. @classmethod
  495. def __parse_value(cls, value: str, datatype: str, op: str = None) -> Any:
  496. """
  497. Converts a value token to its Python value. Raises ValueError on failure.
  498. """
  499. if datatype == cls.TYPE_ID:
  500. if not is_user_id(value):
  501. raise ValueError(f'Illegal user id value: {value}')
  502. return value
  503. if datatype == cls.TYPE_MEMBER:
  504. return user_id_from_mention(value)
  505. if datatype == cls.TYPE_TEXT:
  506. s = str_from_quoted_str(value)
  507. if op in ('matches', '!matches'):
  508. try:
  509. return re.compile(s.lower())
  510. except re.error as e:
  511. raise ValueError(f'Invalid regex: {e}') from e
  512. if op in ('containsword', '!containsword'):
  513. try:
  514. return re.compile(f'\\b{re.escape(s.lower())}\\b')
  515. except re.error as e:
  516. raise ValueError(f'Invalid regex: {e}') from e
  517. return s
  518. if datatype == cls.TYPE_INT:
  519. return int(value)
  520. if datatype == cls.TYPE_FLOAT:
  521. return float(value)
  522. if datatype == cls.TYPE_TIMESPAN:
  523. return timedelta_from_str(value)
  524. raise ValueError(f'Unhandled datatype {datatype}')