Experimental Discord bot written in Python
Du kan inte välja fler än 25 ämnen Ämnen måste starta med en bokstav eller siffra, kan innehålla bindestreck ('-') och vara max 35 tecken långa.

pattern.py 17KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576
  1. """
  2. Statements that match messages based on an expression and have a list of actions
  3. to take on them.
  4. """
  5. import re
  6. from abc import ABCMeta, abstractmethod
  7. from typing import Any
  8. from discord import Message, utils as discordutils
  9. from discord.ext.commands import Context
  10. from rocketbot.utils import is_user_id, str_from_quoted_str, timedelta_from_str, \
  11. user_id_from_mention
  12. class PatternError(RuntimeError):
  13. """
  14. Error thrown when parsing a pattern statement.
  15. """
  16. class PatternDeprecationError(PatternError):
  17. """
  18. Error raised by PatternStatement.check_deprecated_syntax.
  19. """
  20. class PatternAction:
  21. """
  22. Describes one action to take on a matched message or its author.
  23. """
  24. def __init__(self, action: str, args: list[Any]):
  25. self.action = action
  26. self.arguments = list(args)
  27. def __str__(self) -> str:
  28. arg_str = ', '.join(self.arguments)
  29. return f'{self.action}({arg_str})'
  30. class PatternExpression(metaclass=ABCMeta):
  31. """
  32. Abstract message matching expression.
  33. """
  34. def __init__(self):
  35. pass
  36. @abstractmethod
  37. def matches(self, message: Message) -> bool:
  38. """
  39. Whether a message matches this expression.
  40. """
  41. return False
  42. class PatternSimpleExpression(PatternExpression):
  43. """
  44. Message matching expression with a simple "<field> <operator> <value>"
  45. structure.
  46. """
  47. def __init__(self, field: str, operator: str, value: Any):
  48. super().__init__()
  49. self.field = field
  50. self.operator = operator
  51. self.value = value
  52. def __field_value(self, message: Message) -> Any:
  53. if self.field in ('content.markdown', 'content'):
  54. return message.content
  55. if self.field == 'content.plain':
  56. return discordutils.remove_markdown(message.clean_content)
  57. if self.field == 'author':
  58. return str(message.author.id)
  59. if self.field == 'author.id':
  60. return str(message.author.id)
  61. if self.field == 'author.joinage':
  62. return message.created_at - message.author.joined_at
  63. if self.field == 'author.name':
  64. return message.author.name
  65. else:
  66. raise ValueError(f'Bad field name {self.field}')
  67. def matches(self, message: Message) -> bool:
  68. field_value = self.__field_value(message)
  69. if self.operator == '==':
  70. if isinstance(field_value, str) and isinstance(self.value, str):
  71. return field_value.lower() == self.value.lower()
  72. return field_value == self.value
  73. if self.operator == '!=':
  74. if isinstance(field_value, str) and isinstance(self.value, str):
  75. return field_value.lower() != self.value.lower()
  76. return field_value != self.value
  77. if self.operator == '<':
  78. return field_value < self.value
  79. if self.operator == '>':
  80. return field_value > self.value
  81. if self.operator == '<=':
  82. return field_value <= self.value
  83. if self.operator == '>=':
  84. return field_value >= self.value
  85. if self.operator == 'contains':
  86. return self.value.lower() in field_value.lower()
  87. if self.operator == '!contains':
  88. return self.value.lower() not in field_value.lower()
  89. if self.operator in ('matches', 'containsword'):
  90. return self.value.search(field_value.lower()) is not None
  91. if self.operator in ('!matches', '!containsword'):
  92. return self.value.search(field_value.lower()) is None
  93. raise ValueError(f'Bad operator {self.operator}')
  94. def __str__(self) -> str:
  95. return f'({self.field} {self.operator} {self.value})'
  96. class PatternCompoundExpression(PatternExpression):
  97. """
  98. Message matching expression that combines several child expressions with
  99. a boolean operator.
  100. """
  101. def __init__(self, operator: str, operands: list[PatternExpression]):
  102. super().__init__()
  103. self.operator = operator
  104. self.operands = list(operands)
  105. def matches(self, message: Message) -> bool:
  106. if self.operator == '!':
  107. return not self.operands[0].matches(message)
  108. if self.operator == 'and':
  109. for op in self.operands:
  110. if not op.matches(message):
  111. return False
  112. return True
  113. if self.operator == 'or':
  114. for op in self.operands:
  115. if op.matches(message):
  116. return True
  117. return False
  118. raise ValueError(f'Bad operator "{self.operator}"')
  119. def __str__(self) -> str:
  120. if self.operator == '!':
  121. return f'(!( {self.operands[0]} ))'
  122. strs = map(str, self.operands)
  123. joined = f' {self.operator} '.join(strs)
  124. return f'( {joined} )'
  125. class PatternStatement:
  126. """
  127. A full message match statement. If a message matches the given expression,
  128. the given actions should be performed.
  129. """
  130. DEFAULT_PRIORITY = 100
  131. def __init__(self,
  132. name: str,
  133. actions: list[PatternAction],
  134. expression: PatternExpression,
  135. original: str,
  136. priority: int = DEFAULT_PRIORITY):
  137. self.name = name
  138. self.actions = list(actions) # PatternAction[]
  139. self.expression = expression
  140. self.original = original
  141. self.priority = priority
  142. def check_deprecations(self) -> None:
  143. """
  144. Tests whether this statement uses any deprecated syntax. Will raise a
  145. PatternDeprecationError if one is found.
  146. """
  147. self.__check_deprecations(self.expression)
  148. @classmethod
  149. def __check_deprecations(cls, expression: PatternExpression) -> None:
  150. if isinstance(expression, PatternSimpleExpression):
  151. s: PatternSimpleExpression = expression
  152. if s.field in PatternCompiler.DEPRECATED_FIELDS:
  153. raise PatternDeprecationError(f'"{s.field}" field is deprecated')
  154. elif isinstance(expression, PatternCompoundExpression):
  155. c: PatternCompoundExpression = expression
  156. for oper in c.operands:
  157. cls.__check_deprecations(oper)
  158. def to_json(self) -> dict:
  159. """
  160. Returns a JSON representation of this statement.
  161. """
  162. return {
  163. 'name': self.name,
  164. 'priority': self.priority,
  165. 'statement': self.original,
  166. }
  167. @classmethod
  168. def from_json(cls, json: dict):
  169. """
  170. Gets a PatternStatement from its JSON representation.
  171. """
  172. ps = PatternCompiler.parse_statement(json['name'], json['statement'])
  173. ps.priority = json.get('priority', cls.DEFAULT_PRIORITY)
  174. return ps
  175. class PatternCompiler:
  176. """
  177. Parses a user-provided message filter statement into a PatternStatement.
  178. """
  179. TYPE_FLOAT = 'float'
  180. TYPE_ID = 'id'
  181. TYPE_INT = 'int'
  182. TYPE_MEMBER = 'Member'
  183. TYPE_REGEX = 'regex'
  184. TYPE_TEXT = 'text'
  185. TYPE_TIMESPAN = 'timespan'
  186. FIELD_TO_TYPE: dict[str, str] = {
  187. 'content.plain': TYPE_TEXT,
  188. 'content.markdown': TYPE_TEXT,
  189. 'author': TYPE_MEMBER,
  190. 'author.id': TYPE_ID,
  191. 'author.name': TYPE_TEXT,
  192. 'author.joinage': TYPE_TIMESPAN,
  193. 'content': TYPE_TEXT, # deprecated, use content.markdown or content.plain
  194. }
  195. DEPRECATED_FIELDS: set[str] = set([ 'content' ])
  196. ACTION_TO_ARGS: dict[str, list[str]] = {
  197. 'ban': [],
  198. 'delete': [],
  199. 'kick': [],
  200. 'modinfo': [],
  201. 'modwarn': [],
  202. 'reply': [ TYPE_TEXT ],
  203. }
  204. OPERATORS_IDENTITY: set[str] = set([ '==', '!=' ])
  205. OPERATORS_COMPARISON: set[str] = set([ '<', '>', '<=', '>=' ])
  206. OPERATORS_NUMERIC = OPERATORS_IDENTITY | OPERATORS_COMPARISON
  207. OPERATORS_TEXT = OPERATORS_IDENTITY | set([
  208. 'contains', '!contains',
  209. 'containsword', '!containsword',
  210. 'matches', '!matches',
  211. ])
  212. OPERATORS_ALL = OPERATORS_IDENTITY | OPERATORS_COMPARISON | OPERATORS_TEXT
  213. TYPE_TO_OPERATORS: dict[str, set[str]] = {
  214. TYPE_ID: OPERATORS_IDENTITY,
  215. TYPE_MEMBER: OPERATORS_IDENTITY,
  216. TYPE_TEXT: OPERATORS_TEXT,
  217. TYPE_INT: OPERATORS_NUMERIC,
  218. TYPE_FLOAT: OPERATORS_NUMERIC,
  219. TYPE_TIMESPAN: OPERATORS_NUMERIC,
  220. }
  221. WHITESPACE_CHARS = ' \t\n\r'
  222. STRING_QUOTE_CHARS = '\'"'
  223. SYMBOL_CHARS = 'abcdefghijklmnopqrstuvwxyz.'
  224. VALUE_CHARS = '0123456789dhms<@!>'
  225. OP_CHARS = '<=>!(),'
  226. MAX_EXPRESSION_NESTING = 8
  227. @classmethod
  228. def expression_str_from_context(cls, context: Context, name: str) -> str:
  229. """
  230. Extracts the statement string from an "add" command context.
  231. """
  232. pattern_str = context.message.content
  233. command_chain = [ name ]
  234. cmd = context.command
  235. while cmd:
  236. command_chain.insert(0, cmd.name)
  237. cmd = cmd.parent
  238. command_chain[0] = f'{context.prefix}{command_chain[0]}'
  239. for cmd in command_chain:
  240. if pattern_str.startswith(cmd):
  241. pattern_str = pattern_str[len(cmd):].lstrip()
  242. elif pattern_str.startswith(f'"{cmd}"'):
  243. pattern_str = pattern_str[len(cmd) + 2:].lstrip()
  244. return pattern_str
  245. @classmethod
  246. def parse_statement(cls, name: str, statement: str) -> PatternStatement:
  247. """
  248. Parses a user-provided message filter statement into a PatternStatement.
  249. Raises PatternError on failure.
  250. """
  251. tokens = cls.__tokenize(statement)
  252. token_index = 0
  253. actions, token_index = cls.__read_actions(tokens, token_index)
  254. expression, token_index = cls.__read_expression(tokens, token_index)
  255. return PatternStatement(name, actions, expression, statement)
  256. @classmethod
  257. def __tokenize(cls, statement: str) -> list[str]:
  258. """
  259. Converts a message filter statement into a list of tokens.
  260. """
  261. tokens: list[str] = []
  262. in_quote = False
  263. in_escape = False
  264. all_token_types = set([ 'sym', 'op', 'val' ])
  265. possible_token_types = set(all_token_types)
  266. current_token = ''
  267. for ch in statement:
  268. if in_quote:
  269. if in_escape:
  270. if ch == 'n':
  271. current_token += '\n'
  272. elif ch == 't':
  273. current_token += '\t'
  274. else:
  275. current_token += ch
  276. in_escape = False
  277. elif ch == '\\':
  278. in_escape = True
  279. elif ch == in_quote:
  280. current_token += ch
  281. tokens.append(current_token)
  282. current_token = ''
  283. possible_token_types |= all_token_types
  284. in_quote = False
  285. else:
  286. current_token += ch
  287. else:
  288. if ch in cls.STRING_QUOTE_CHARS:
  289. if len(current_token) > 0:
  290. tokens.append(current_token)
  291. current_token = ''
  292. possible_token_types |= all_token_types
  293. in_quote = ch
  294. current_token = ch
  295. elif ch == '\\':
  296. raise PatternError("Unexpected \\ outside quoted string")
  297. elif ch in cls.WHITESPACE_CHARS:
  298. if len(current_token) > 0:
  299. tokens.append(current_token)
  300. current_token = ''
  301. possible_token_types |= all_token_types
  302. else:
  303. possible_ch_types = set()
  304. if ch in cls.SYMBOL_CHARS:
  305. possible_ch_types.add('sym')
  306. if ch in cls.VALUE_CHARS:
  307. possible_ch_types.add('val')
  308. if ch in cls.OP_CHARS:
  309. possible_ch_types.add('op')
  310. if len(current_token) > 0 and \
  311. possible_ch_types.isdisjoint(possible_token_types):
  312. if len(current_token) > 0:
  313. tokens.append(current_token)
  314. current_token = ''
  315. possible_token_types |= all_token_types
  316. possible_token_types &= possible_ch_types
  317. current_token += ch
  318. if len(current_token) > 0:
  319. tokens.append(current_token)
  320. # Some symbols might be glommed onto other tokens. Split 'em up.
  321. prefixes_to_split = [ '!', '(', ',' ]
  322. suffixes_to_split = [ ')', ',' ]
  323. i = 0
  324. while i < len(tokens):
  325. token = tokens[i]
  326. mutated = False
  327. for prefix in prefixes_to_split:
  328. if token.startswith(prefix) and len(token) > len(prefix):
  329. tokens.insert(i, prefix)
  330. tokens[i + 1] = token[len(prefix):]
  331. i += 1
  332. mutated = True
  333. break
  334. if mutated:
  335. continue
  336. for suffix in suffixes_to_split:
  337. if token.endswith(suffix) and len(token) > len(suffix):
  338. tokens[i] = token[0:-len(suffix)]
  339. tokens.insert(i + 1, suffix)
  340. mutated = True
  341. break
  342. if mutated:
  343. continue
  344. i += 1
  345. return tokens
  346. @classmethod
  347. def __read_actions(cls,
  348. tokens: list[str],
  349. token_index: int) -> tuple[list[PatternAction], int]:
  350. """
  351. Reads the actions from a list of statement tokens. Returns a tuple
  352. containing a list of PatternActions and the token index this method
  353. left off at (the token after the "if").
  354. """
  355. actions: list[PatternAction] = []
  356. current_action_tokens = []
  357. while token_index < len(tokens):
  358. token = tokens[token_index]
  359. if token == 'if':
  360. if len(current_action_tokens) > 0:
  361. a = PatternAction(current_action_tokens[0], \
  362. current_action_tokens[1:])
  363. cls.__validate_action(a)
  364. actions.append(a)
  365. token_index += 1
  366. return (actions, token_index)
  367. elif token == ',':
  368. if len(current_action_tokens) < 1:
  369. raise PatternError('Unexpected ,')
  370. a = PatternAction(current_action_tokens[0], \
  371. current_action_tokens[1:])
  372. cls.__validate_action(a)
  373. actions.append(a)
  374. current_action_tokens = []
  375. else:
  376. current_action_tokens.append(token)
  377. token_index += 1
  378. raise PatternError('Unexpected end of line in action list')
  379. @classmethod
  380. def __validate_action(cls, action: PatternAction) -> None:
  381. args: list[str] = cls.ACTION_TO_ARGS.get(action.action)
  382. if args is None:
  383. raise PatternError(f'Unknown action "{action.action}"')
  384. if len(action.arguments) != len(args):
  385. if len(args) == 0:
  386. raise PatternError(f'Action "{action.action}" expects no ' + \
  387. f'arguments, got {len(action.arguments)}.')
  388. raise PatternError(f'Action "{action.action}" expects ' + \
  389. f'{len(args)} arguments, got {len(action.arguments)}.')
  390. for i, datatype in enumerate(args):
  391. action.arguments[i] = cls.__parse_value(action.arguments[i], datatype)
  392. @classmethod
  393. def __read_expression(cls,
  394. tokens: list[str],
  395. token_index: int,
  396. depth: int = 0,
  397. one_subexpression: bool = False) -> tuple[PatternExpression, int]:
  398. """
  399. Reads an expression from a list of statement tokens. Returns a tuple
  400. containing the PatternExpression and the token index it left off at.
  401. If one_subexpression is True then it will return after reading a
  402. single expression instead of joining multiples (for reading the
  403. subject of a NOT expression).
  404. """
  405. subexpressions = []
  406. last_compound_operator = None
  407. while token_index < len(tokens):
  408. if one_subexpression:
  409. if len(subexpressions) == 1:
  410. return (subexpressions[0], token_index)
  411. if len(subexpressions) > 1:
  412. raise PatternError('Too many subexpressions')
  413. compound_operator = None
  414. if tokens[token_index] == ')':
  415. if len(subexpressions) == 0:
  416. raise PatternError('No subexpressions')
  417. if len(subexpressions) == 1:
  418. return (subexpressions[0], token_index)
  419. return (PatternCompoundExpression(last_compound_operator,
  420. subexpressions), token_index)
  421. if tokens[token_index] in set(["and", "or"]):
  422. compound_operator = tokens[token_index]
  423. if last_compound_operator and \
  424. compound_operator != last_compound_operator:
  425. subexpressions = [
  426. PatternCompoundExpression(last_compound_operator,
  427. subexpressions),
  428. ]
  429. last_compound_operator = compound_operator
  430. token_index += 1
  431. if tokens[token_index] == '!':
  432. (exp, next_index) = cls.__read_expression(tokens, \
  433. token_index + 1, depth + 1, one_subexpression=True)
  434. subexpressions.append(PatternCompoundExpression('!', [exp]))
  435. token_index = next_index
  436. elif tokens[token_index] == '(':
  437. (exp, next_index) = cls.__read_expression(tokens,
  438. token_index + 1, depth + 1)
  439. if tokens[next_index] != ')':
  440. raise PatternError('Expected )')
  441. subexpressions.append(exp)
  442. token_index = next_index + 1
  443. else:
  444. (simple, next_index) = cls.__read_simple_expression(tokens,
  445. token_index, depth)
  446. subexpressions.append(simple)
  447. token_index = next_index
  448. if len(subexpressions) == 0:
  449. raise PatternError('No subexpressions')
  450. elif len(subexpressions) == 1:
  451. return (subexpressions[0], token_index)
  452. else:
  453. return (PatternCompoundExpression(last_compound_operator,
  454. subexpressions), token_index)
  455. @classmethod
  456. def __read_simple_expression(cls,
  457. tokens: list[str],
  458. token_index: int,
  459. depth: int = 0) -> tuple[PatternExpression, int]:
  460. """
  461. Reads a simple expression consisting of a field name, operator, and
  462. comparison value. Returns a tuple of the PatternSimpleExpression and
  463. the token index it left off at.
  464. """
  465. if depth > cls.MAX_EXPRESSION_NESTING:
  466. raise PatternError('Expression nests too deeply')
  467. if token_index >= len(tokens):
  468. raise PatternError('Expected field name, found EOL')
  469. field = tokens[token_index]
  470. token_index += 1
  471. datatype = cls.FIELD_TO_TYPE.get(field)
  472. if datatype is None:
  473. raise PatternError(f'No such field "{field}"')
  474. if token_index >= len(tokens):
  475. raise PatternError('Expected operator, found EOL')
  476. op = tokens[token_index]
  477. token_index += 1
  478. if op == '!':
  479. if token_index >= len(tokens):
  480. raise PatternError('Expected operator, found EOL')
  481. op = '!' + tokens[token_index]
  482. token_index += 1
  483. allowed_ops = cls.TYPE_TO_OPERATORS[datatype]
  484. if op not in allowed_ops:
  485. if op in cls.OPERATORS_ALL:
  486. raise PatternError(f'Operator {op} cannot be used with ' + \
  487. f'field "{field}"')
  488. raise PatternError(f'Unrecognized operator "{op}" - allowed: ' + \
  489. f'{sorted(list(allowed_ops))}')
  490. if token_index >= len(tokens):
  491. raise PatternError('Expected value, found EOL')
  492. value_str = tokens[token_index]
  493. try:
  494. value = cls.__parse_value(value_str, datatype, op)
  495. except ValueError as cause:
  496. raise PatternError(f'Bad value {value_str}') from cause
  497. token_index += 1
  498. exp = PatternSimpleExpression(field, op, value)
  499. return (exp, token_index)
  500. @classmethod
  501. def __parse_value(cls, value: str, datatype: str, op: str = None) -> Any:
  502. """
  503. Converts a value token to its Python value. Raises ValueError on failure.
  504. """
  505. if datatype == cls.TYPE_ID:
  506. if not is_user_id(value):
  507. raise ValueError(f'Illegal user id value: {value}')
  508. return value
  509. if datatype == cls.TYPE_MEMBER:
  510. return user_id_from_mention(value)
  511. if datatype == cls.TYPE_TEXT:
  512. s = str_from_quoted_str(value)
  513. if op in ('matches', '!matches'):
  514. try:
  515. return re.compile(s.lower())
  516. except re.error as e:
  517. raise ValueError(f'Invalid regex: {e}') from e
  518. if op in ('containsword', '!containsword'):
  519. try:
  520. return re.compile(f'\\b{re.escape(s.lower())}\\b')
  521. except re.error as e:
  522. raise ValueError(f'Invalid regex: {e}') from e
  523. return s
  524. if datatype == cls.TYPE_INT:
  525. return int(value)
  526. if datatype == cls.TYPE_FLOAT:
  527. return float(value)
  528. if datatype == cls.TYPE_TIMESPAN:
  529. return timedelta_from_str(value)
  530. raise ValueError(f'Unhandled datatype {datatype}')