Experimental Discord bot written in Python
Du kan inte välja fler än 25 ämnen Ämnen måste starta med en bokstav eller siffra, kan innehålla bindestreck ('-') och vara max 35 tecken långa.

patterncog.py 23KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714
  1. """
  2. Cog for matching messages against guild-configurable criteria and taking
  3. automated actions on them.
  4. """
  5. import re
  6. from abc import ABCMeta, abstractmethod
  7. from discord import Guild, Member, Message, utils as discordutils
  8. from discord.ext import commands
  9. from config import CONFIG
  10. from rocketbot.cogs.basecog import BaseCog, BotMessage, BotMessageReaction
  11. from rocketbot.cogsetting import CogSetting
  12. from rocketbot.storage import Storage
  13. from rocketbot.utils import is_user_id, str_from_quoted_str, timedelta_from_str, \
  14. user_id_from_mention
  15. class PatternAction:
  16. """
  17. Describes one action to take on a matched message or its author.
  18. """
  19. def __init__(self, action: str, args: list):
  20. self.action = action
  21. self.arguments = list(args)
  22. def __str__(self) -> str:
  23. arg_str = ', '.join(self.arguments)
  24. return f'{self.action}({arg_str})'
  25. class PatternExpression(metaclass=ABCMeta):
  26. """
  27. Abstract message matching expression.
  28. """
  29. def __init__(self):
  30. pass
  31. @abstractmethod
  32. def matches(self, message: Message) -> bool:
  33. """
  34. Whether a message matches this expression.
  35. """
  36. return False
  37. class PatternSimpleExpression(PatternExpression):
  38. """
  39. Message matching expression with a simple "<field> <operator> <value>"
  40. structure.
  41. """
  42. def __init__(self, field: str, operator: str, value):
  43. super().__init__()
  44. self.field = field
  45. self.operator = operator
  46. self.value = value
  47. def __field_value(self, message: Message):
  48. if self.field in ('content.markdown', 'content'):
  49. return message.content
  50. if self.field == 'content.plain':
  51. return discordutils.remove_markdown(message.clean_content)
  52. if self.field == 'author':
  53. return str(message.author.id)
  54. if self.field == 'author.id':
  55. return str(message.author.id)
  56. if self.field == 'author.joinage':
  57. return message.created_at - message.author.joined_at
  58. if self.field == 'author.name':
  59. return message.author.name
  60. else:
  61. raise ValueError(f'Bad field name {self.field}')
  62. def matches(self, message: Message) -> bool:
  63. field_value = self.__field_value(message)
  64. if self.operator == '==':
  65. if isinstance(field_value, str) and isinstance(self.value, str):
  66. return field_value.lower() == self.value.lower()
  67. return field_value == self.value
  68. if self.operator == '!=':
  69. if isinstance(field_value, str) and isinstance(self.value, str):
  70. return field_value.lower() != self.value.lower()
  71. return field_value != self.value
  72. if self.operator == '<':
  73. return field_value < self.value
  74. if self.operator == '>':
  75. return field_value > self.value
  76. if self.operator == '<=':
  77. return field_value <= self.value
  78. if self.operator == '>=':
  79. return field_value >= self.value
  80. if self.operator == 'contains':
  81. return self.value.lower() in field_value.lower()
  82. if self.operator == '!contains':
  83. return self.value.lower() not in field_value.lower()
  84. if self.operator == 'matches':
  85. p = re.compile(self.value.lower())
  86. return p.match(field_value.lower()) is not None
  87. if self.operator == '!matches':
  88. p = re.compile(self.value.lower())
  89. return p.match(field_value.lower()) is None
  90. raise ValueError(f'Bad operator {self.operator}')
  91. def __str__(self) -> str:
  92. return f'({self.field} {self.operator} {self.value})'
  93. class PatternCompoundExpression(PatternExpression):
  94. """
  95. Message matching expression that combines several child expressions with
  96. a boolean operator.
  97. """
  98. def __init__(self, operator: str, operands: list):
  99. super().__init__()
  100. self.operator = operator
  101. self.operands = list(operands)
  102. def matches(self, message: Message) -> bool:
  103. if self.operator == '!':
  104. return not self.operands[0].matches(message)
  105. if self.operator == 'and':
  106. for op in self.operands:
  107. if not op.matches(message):
  108. return False
  109. return True
  110. if self.operator == 'or':
  111. for op in self.operands:
  112. if op.matches(message):
  113. return True
  114. return False
  115. raise ValueError(f'Bad operator "{self.operator}"')
  116. def __str__(self) -> str:
  117. if self.operator == '!':
  118. return f'(!( {self.operands[0]} ))'
  119. strs = map(str, self.operands)
  120. joined = f' {self.operator} '.join(strs)
  121. return f'( {joined} )'
  122. class PatternStatement:
  123. """
  124. A full message match statement. If a message matches the given expression,
  125. the given actions should be performed.
  126. """
  127. def __init__(self,
  128. name: str,
  129. actions: list,
  130. expression: PatternExpression,
  131. original: str):
  132. self.name = name
  133. self.actions = list(actions) # PatternAction[]
  134. self.expression = expression
  135. self.original = original
  136. def to_json(self) -> dict:
  137. """
  138. Returns a JSON representation of this statement.
  139. """
  140. return {
  141. 'name': self.name,
  142. 'statement': self.original,
  143. }
  144. @classmethod
  145. def from_json(cls, json: dict):
  146. """
  147. Gets a PatternStatement from its JSON representation.
  148. """
  149. return PatternCompiler.parse_statement(json['name'], json['statement'])
  150. class PatternContext:
  151. """
  152. Data about a message that has matched a configured statement and what
  153. actions have been carried out.
  154. """
  155. def __init__(self, message: Message, statement: PatternStatement):
  156. self.message = message
  157. self.statement = statement
  158. self.is_deleted = False
  159. self.is_kicked = False
  160. self.is_banned = False
  161. class PatternCog(BaseCog, name='Pattern Matching'):
  162. """
  163. Highly flexible cog for performing various actions on messages that match
  164. various critera. Patterns can be defined by mods for each guild.
  165. """
  166. SETTING_PATTERNS = CogSetting('patterns', None)
  167. def __get_patterns(self, guild: Guild) -> dict[str, PatternStatement]:
  168. """
  169. Returns a name -> PatternStatement lookup for the guild.
  170. """
  171. patterns: dict[str, PatternStatement] = Storage.get_state_value(guild,
  172. 'PatternCog.patterns')
  173. if patterns is None:
  174. jsons: list[dict] = self.get_guild_setting(guild, self.SETTING_PATTERNS)
  175. pattern_list: list[PatternStatement] = []
  176. for json in jsons:
  177. try:
  178. pattern_list.append(PatternStatement.from_json(json))
  179. except PatternError as e:
  180. self.log(guild, f'Error decoding pattern "{json["name"]}": {e}')
  181. patterns = { p.name:p for p in pattern_list}
  182. Storage.set_state_value(guild, 'PatternCog.patterns', patterns)
  183. return patterns
  184. @classmethod
  185. def __save_patterns(cls, guild: Guild, patterns: dict[str, PatternStatement]) -> None:
  186. to_save: list[dict] = list(map(PatternStatement.to_json, patterns.values()))
  187. cls.set_guild_setting(guild, cls.SETTING_PATTERNS, to_save)
  188. @commands.Cog.listener()
  189. async def on_message(self, message: Message) -> None:
  190. 'Event listener'
  191. if message.author is None or \
  192. message.author.bot or \
  193. message.channel is None or \
  194. message.guild is None or \
  195. message.content is None or \
  196. message.content == '':
  197. return
  198. if message.author.permissions_in(message.channel).ban_members:
  199. # Ignore mods
  200. return
  201. patterns = self.__get_patterns(message.guild)
  202. for statement in patterns.values():
  203. if statement.expression.matches(message):
  204. await self.__trigger_actions(message, statement)
  205. break
  206. async def __trigger_actions(self, message: Message, statement: PatternStatement) -> None:
  207. context = PatternContext(message, statement)
  208. should_alert_mods = False
  209. action_descriptions = []
  210. self.log(message.guild, f'Message from {message.author.name} matched pattern "{statement.name}"')
  211. for action in statement.actions:
  212. if action.action == 'ban':
  213. await message.author.ban(
  214. reason=f'Rocketbot: Message matched custom pattern named "{statement.name}"',
  215. delete_message_days=0)
  216. context.is_banned = True
  217. context.is_kicked = True
  218. action_descriptions.append('Author banned')
  219. self.log(message.guild, f'{message.author.name} banned')
  220. elif action.action == 'delete':
  221. await message.delete()
  222. context.is_deleted = True
  223. action_descriptions.append('Message deleted')
  224. self.log(message.guild, f'{message.author.name}\'s message deleted')
  225. elif action.action == 'kick':
  226. await message.author.kick(
  227. reason=f'Rocketbot: Message matched custom pattern named "{statement.name}"')
  228. context.is_kicked = True
  229. action_descriptions.append('Author kicked')
  230. self.log(message.guild, f'{message.author.name} kicked')
  231. elif action.action == 'modwarn':
  232. should_alert_mods = True
  233. action_descriptions.append('Mods alerted')
  234. elif action.action == 'reply':
  235. await message.reply(
  236. f'{action.arguments[0]}',
  237. mention_author=False)
  238. action_descriptions.append('Autoreplied')
  239. self.log(message.guild, f'autoreplied to {message.author.name}')
  240. bm = BotMessage(
  241. message.guild,
  242. f'User {message.author.name} tripped custom pattern ' + \
  243. f'`{statement.name}`.\n\nAutomatic actions taken:\n• ' + \
  244. ('\n• '.join(action_descriptions)),
  245. type=BotMessage.TYPE_MOD_WARNING if should_alert_mods else BotMessage.TYPE_INFO,
  246. context=context)
  247. bm.quote = discordutils.remove_markdown(message.clean_content)
  248. await bm.set_reactions(BotMessageReaction.standard_set(
  249. did_delete=context.is_deleted,
  250. did_kick=context.is_kicked,
  251. did_ban=context.is_banned))
  252. await self.post_message(bm)
  253. async def on_mod_react(self,
  254. bot_message: BotMessage,
  255. reaction: BotMessageReaction,
  256. reacted_by: Member) -> None:
  257. context: PatternContext = bot_message.context
  258. if reaction.emoji == CONFIG['trash_emoji']:
  259. await context.message.delete()
  260. context.is_deleted = True
  261. elif reaction.emoji == CONFIG['kick_emoji']:
  262. await context.message.author.kick(
  263. reason='Rocketbot: Message matched custom pattern named ' + \
  264. f'"{context.statement.name}". Kicked by {reacted_by.name}.')
  265. context.is_kicked = True
  266. elif reaction.emoji == CONFIG['ban_emoji']:
  267. await context.message.author.ban(
  268. reason='Rocketbot: Message matched custom pattern named ' + \
  269. f'"{context.statement.name}". Banned by {reacted_by.name}.',
  270. delete_message_days=1)
  271. context.is_banned = True
  272. await bot_message.set_reactions(BotMessageReaction.standard_set(
  273. did_delete=context.is_deleted,
  274. did_kick=context.is_kicked,
  275. did_ban=context.is_banned))
  276. @commands.group(
  277. brief='Manages message pattern matching',
  278. )
  279. @commands.has_permissions(ban_members=True)
  280. @commands.guild_only()
  281. async def pattern(self, context: commands.Context):
  282. 'Message pattern matching command group'
  283. if context.invoked_subcommand is None:
  284. await context.send_help()
  285. @pattern.command(
  286. brief='Adds a custom pattern',
  287. description='Adds a custom pattern. Patterns use a simplified ' + \
  288. 'expression language. Full documentation found here: ' + \
  289. 'https://git.rixafrix.com/ialbert/python-app-rocketbot/src/branch/master/patterns.md',
  290. usage='<pattern_name> <expression...>',
  291. ignore_extra=True
  292. )
  293. async def add(self, context: commands.Context, name: str):
  294. 'Command handler'
  295. pattern_str = PatternCompiler.expression_str_from_context(context, name)
  296. try:
  297. statement = PatternCompiler.parse_statement(name, pattern_str)
  298. patterns = self.__get_patterns(context.guild)
  299. patterns[name] = statement
  300. self.__save_patterns(context.guild, patterns)
  301. await context.message.reply(
  302. f'{CONFIG["success_emoji"]} Pattern `{name}` added.',
  303. mention_author=False)
  304. except PatternError as e:
  305. await context.message.reply(
  306. f'{CONFIG["failure_emoji"]} Error parsing statement. {e}',
  307. mention_author=False)
  308. @pattern.command(
  309. brief='Removes a custom pattern',
  310. usage='<pattern_name>'
  311. )
  312. async def remove(self, context: commands.Context, name: str):
  313. 'Command handler'
  314. patterns = self.__get_patterns(context.guild)
  315. if patterns.get(name) is not None:
  316. del patterns[name]
  317. self.__save_patterns(context.guild, patterns)
  318. await context.message.reply(
  319. f'{CONFIG["success_emoji"]} Pattern `{name}` deleted.',
  320. mention_author=False)
  321. else:
  322. await context.message.reply(
  323. f'{CONFIG["failure_emoji"]} No pattern named `{name}`.',
  324. mention_author=False)
  325. @pattern.command(
  326. brief='Lists all patterns'
  327. )
  328. async def list(self, context: commands.Context) -> None:
  329. 'Command handler'
  330. patterns = self.__get_patterns(context.guild)
  331. if len(patterns) == 0:
  332. await context.message.reply('No patterns defined.', mention_author=False)
  333. return
  334. msg = ''
  335. for name, statement in sorted(patterns.items()):
  336. msg += f'Pattern `{name}`:\n```\n{statement.original}\n```\n'
  337. await context.message.reply(msg, mention_author=False)
  338. class PatternError(RuntimeError):
  339. """
  340. Error thrown when parsing a pattern statement.
  341. """
  342. class PatternCompiler:
  343. """
  344. Parses a user-provided message filter statement into a PatternStatement.
  345. """
  346. TYPE_ID = 'id'
  347. TYPE_MEMBER = 'Member'
  348. TYPE_TEXT = 'text'
  349. TYPE_INT = 'int'
  350. TYPE_FLOAT = 'float'
  351. TYPE_TIMESPAN = 'timespan'
  352. FIELD_TO_TYPE = {
  353. 'content.plain': TYPE_TEXT,
  354. 'content.markdown': TYPE_TEXT,
  355. 'author': TYPE_MEMBER,
  356. 'author.id': TYPE_ID,
  357. 'author.name': TYPE_TEXT,
  358. 'author.joinage': TYPE_TIMESPAN,
  359. 'content': TYPE_TEXT, # deprecated, use content.markdown or content.plain
  360. }
  361. ACTION_TO_ARGS = {
  362. 'ban': [],
  363. 'delete': [],
  364. 'kick': [],
  365. 'modwarn': [],
  366. 'reply': [ TYPE_TEXT ],
  367. }
  368. OPERATORS_IDENTITY = set([ '==', '!=' ])
  369. OPERATORS_COMPARISON = set([ '<', '>', '<=', '>=' ])
  370. OPERATORS_NUMERIC = OPERATORS_IDENTITY | OPERATORS_COMPARISON
  371. OPERATORS_TEXT = OPERATORS_IDENTITY | set([ 'contains', '!contains', 'matches', '!matches' ])
  372. OPERATORS_ALL = OPERATORS_IDENTITY | OPERATORS_COMPARISON | OPERATORS_TEXT
  373. TYPE_TO_OPERATORS = {
  374. TYPE_ID: OPERATORS_IDENTITY,
  375. TYPE_MEMBER: OPERATORS_IDENTITY,
  376. TYPE_TEXT: OPERATORS_TEXT,
  377. TYPE_INT: OPERATORS_NUMERIC,
  378. TYPE_FLOAT: OPERATORS_NUMERIC,
  379. TYPE_TIMESPAN: OPERATORS_NUMERIC,
  380. }
  381. WHITESPACE_CHARS = ' \t\n\r'
  382. STRING_QUOTE_CHARS = '\'"'
  383. SYMBOL_CHARS = 'abcdefghijklmnopqrstuvwxyz.'
  384. VALUE_CHARS = '0123456789dhms<@!>'
  385. OP_CHARS = '<=>!(),'
  386. @classmethod
  387. def expression_str_from_context(cls, context: commands.Context, name: str) -> str:
  388. """
  389. Extracts the statement string from an "add" command context.
  390. """
  391. pattern_str = context.message.content
  392. command_chain = [ name ]
  393. cmd = context.command
  394. while cmd:
  395. command_chain.insert(0, cmd.name)
  396. cmd = cmd.parent
  397. command_chain[0] = f'{context.prefix}{command_chain[0]}'
  398. for cmd in command_chain:
  399. if pattern_str.startswith(cmd):
  400. pattern_str = pattern_str[len(cmd):].lstrip()
  401. elif pattern_str.startswith(f'"{cmd}"'):
  402. pattern_str = pattern_str[len(cmd) + 2:].lstrip()
  403. return pattern_str
  404. @classmethod
  405. def parse_statement(cls, name: str, statement: str) -> PatternStatement:
  406. """
  407. Parses a user-provided message filter statement into a PatternStatement.
  408. """
  409. tokens = cls.tokenize(statement)
  410. token_index = 0
  411. actions, token_index = cls.read_actions(tokens, token_index)
  412. expression, token_index = cls.read_expression(tokens, token_index)
  413. return PatternStatement(name, actions, expression, statement)
  414. @classmethod
  415. def tokenize(cls, statement: str) -> list:
  416. """
  417. Converts a message filter statement into a list of tokens.
  418. """
  419. tokens = []
  420. in_quote = False
  421. in_escape = False
  422. all_token_types = set([ 'sym', 'op', 'val' ])
  423. possible_token_types = set(all_token_types)
  424. current_token = ''
  425. for ch in statement:
  426. if in_quote:
  427. if in_escape:
  428. if ch == 'n':
  429. current_token += '\n'
  430. elif ch == 't':
  431. current_token += '\t'
  432. else:
  433. current_token += ch
  434. in_escape = False
  435. elif ch == '\\':
  436. in_escape = True
  437. elif ch == in_quote:
  438. current_token += ch
  439. tokens.append(current_token)
  440. current_token = ''
  441. possible_token_types |= all_token_types
  442. in_quote = False
  443. else:
  444. current_token += ch
  445. else:
  446. if ch in cls.STRING_QUOTE_CHARS:
  447. if len(current_token) > 0:
  448. tokens.append(current_token)
  449. current_token = ''
  450. possible_token_types |= all_token_types
  451. in_quote = ch
  452. current_token = ch
  453. elif ch == '\\':
  454. raise PatternError("Unexpected \\ outside quoted string")
  455. elif ch in cls.WHITESPACE_CHARS:
  456. if len(current_token) > 0:
  457. tokens.append(current_token)
  458. current_token = ''
  459. possible_token_types |= all_token_types
  460. else:
  461. possible_ch_types = set()
  462. if ch in cls.SYMBOL_CHARS:
  463. possible_ch_types.add('sym')
  464. if ch in cls.VALUE_CHARS:
  465. possible_ch_types.add('val')
  466. if ch in cls.OP_CHARS:
  467. possible_ch_types.add('op')
  468. if len(current_token) > 0 and possible_ch_types.isdisjoint(possible_token_types):
  469. if len(current_token) > 0:
  470. tokens.append(current_token)
  471. current_token = ''
  472. possible_token_types |= all_token_types
  473. possible_token_types &= possible_ch_types
  474. current_token += ch
  475. if len(current_token) > 0:
  476. tokens.append(current_token)
  477. # Some symbols might be glommed onto other tokens. Split 'em up.
  478. prefixes_to_split = [ '!', '(', ',' ]
  479. suffixes_to_split = [ ')', ',' ]
  480. i = 0
  481. while i < len(tokens):
  482. token = tokens[i]
  483. mutated = False
  484. for prefix in prefixes_to_split:
  485. if token.startswith(prefix) and len(token) > len(prefix):
  486. tokens.insert(i, prefix)
  487. tokens[i + 1] = token[len(prefix):]
  488. i += 1
  489. mutated = True
  490. break
  491. if mutated:
  492. continue
  493. for suffix in suffixes_to_split:
  494. if token.endswith(suffix) and len(token) > len(suffix):
  495. tokens[i] = token[0:-len(suffix)]
  496. tokens.insert(i + 1, suffix)
  497. mutated = True
  498. break
  499. if mutated:
  500. continue
  501. i += 1
  502. return tokens
  503. @classmethod
  504. def read_actions(cls, tokens: list, token_index: int) -> tuple:
  505. """
  506. Reads the actions from a list of statement tokens. Returns a tuple
  507. containing a list of PatternActions and the token index this method
  508. left off at (the token after the "if").
  509. """
  510. actions = []
  511. current_action_tokens = []
  512. while token_index < len(tokens):
  513. token = tokens[token_index]
  514. if token == 'if':
  515. if len(current_action_tokens) > 0:
  516. a = PatternAction(current_action_tokens[0], current_action_tokens[1:])
  517. cls.__validate_action(a)
  518. actions.append(a)
  519. token_index += 1
  520. return (actions, token_index)
  521. elif token == ',':
  522. if len(current_action_tokens) < 1:
  523. raise PatternError('Unexpected ,')
  524. a = PatternAction(current_action_tokens[0], current_action_tokens[1:])
  525. cls.__validate_action(a)
  526. actions.append(a)
  527. current_action_tokens = []
  528. else:
  529. current_action_tokens.append(token)
  530. token_index += 1
  531. raise PatternError('Unexpected end of line in action list')
  532. @classmethod
  533. def __validate_action(cls, action: PatternAction) -> None:
  534. args = cls.ACTION_TO_ARGS.get(action.action)
  535. if args is None:
  536. raise PatternError(f'Unknown action "{action.action}"')
  537. if len(action.arguments) != len(args):
  538. if len(args) == 0:
  539. raise PatternError(f'Action "{action.action}" expects no arguments, ' + \
  540. f'got {len(action.arguments)}.')
  541. else:
  542. raise PatternError(f'Action "{action.action}" expects {len(args)} ' + \
  543. f'arguments, got {len(action.arguments)}.')
  544. for i, datatype in enumerate(args):
  545. action.arguments[i] = cls.parse_value(action.arguments[i], datatype)
  546. @classmethod
  547. def read_expression(cls,
  548. tokens: list,
  549. token_index: int,
  550. depth: int = 0,
  551. one_subexpression: bool = False) -> tuple:
  552. """
  553. Reads an expression from a list of statement tokens. Returns a tuple
  554. containing the PatternExpression and the token index it left off at.
  555. If one_subexpression is True then it will return after reading a
  556. single expression instead of joining multiples (for readong the
  557. subject of a NOT expression).
  558. """
  559. subexpressions = []
  560. last_compound_operator = None
  561. while token_index < len(tokens):
  562. if one_subexpression:
  563. if len(subexpressions) == 1:
  564. return (subexpressions[0], token_index)
  565. if len(subexpressions) > 1:
  566. raise PatternError('Too many subexpressions')
  567. compound_operator = None
  568. if tokens[token_index] == ')':
  569. if len(subexpressions) == 0:
  570. raise PatternError('No subexpressions')
  571. if len(subexpressions) == 1:
  572. return (subexpressions[0], token_index)
  573. return (PatternCompoundExpression(last_compound_operator, subexpressions), token_index)
  574. if tokens[token_index] in set(["and", "or"]):
  575. compound_operator = tokens[token_index]
  576. if last_compound_operator and compound_operator != last_compound_operator:
  577. subexpressions = [ PatternCompoundExpression(last_compound_operator, subexpressions) ]
  578. last_compound_operator = compound_operator
  579. else:
  580. last_compound_operator = compound_operator
  581. token_index += 1
  582. if tokens[token_index] == '!':
  583. (exp, next_index) = cls.read_expression(tokens, token_index + 1, \
  584. depth + 1, one_subexpression=True)
  585. subexpressions.append(PatternCompoundExpression('!', [exp]))
  586. token_index = next_index
  587. elif tokens[token_index] == '(':
  588. (exp, next_index) = cls.read_expression(tokens, token_index + 1, depth + 1)
  589. if tokens[next_index] != ')':
  590. raise PatternError('Expected )')
  591. subexpressions.append(exp)
  592. token_index = next_index + 1
  593. else:
  594. (simple, next_index) = cls.read_simple_expression(tokens, token_index, depth)
  595. subexpressions.append(simple)
  596. token_index = next_index
  597. if len(subexpressions) == 0:
  598. raise PatternError('No subexpressions')
  599. elif len(subexpressions) == 1:
  600. return (subexpressions[0], token_index)
  601. else:
  602. return (PatternCompoundExpression(last_compound_operator, subexpressions), token_index)
  603. @classmethod
  604. def read_simple_expression(cls, tokens: list, token_index: int, depth: int = 0) -> tuple:
  605. """
  606. Reads a simple expression consisting of a field name, operator, and
  607. comparison value. Returns a tuple of the PatternSimpleExpression and
  608. the token index it left off at.
  609. """
  610. if depth > 8:
  611. raise PatternError('Expression nests too deeply')
  612. if token_index >= len(tokens):
  613. raise PatternError('Expected field name, found EOL')
  614. field = tokens[token_index]
  615. token_index += 1
  616. datatype = cls.FIELD_TO_TYPE.get(field)
  617. if datatype is None:
  618. raise PatternError(f'No such field "{field}"')
  619. if token_index >= len(tokens):
  620. raise PatternError('Expected operator, found EOL')
  621. op = tokens[token_index]
  622. token_index += 1
  623. if op == '!':
  624. if token_index >= len(tokens):
  625. raise PatternError('Expected operator, found EOL')
  626. op = '!' + tokens[token_index]
  627. token_index += 1
  628. allowed_ops = cls.TYPE_TO_OPERATORS[datatype]
  629. if op not in allowed_ops:
  630. if op in cls.OPERATORS_ALL:
  631. raise PatternError(f'Operator {op} cannot be used with field "{field}"')
  632. raise PatternError(f'Unrecognized operator "{op}" - allowed: {list(allowed_ops)}')
  633. if token_index >= len(tokens):
  634. raise PatternError('Expected value, found EOL')
  635. value = tokens[token_index]
  636. try:
  637. value = cls.parse_value(value, datatype)
  638. except ValueError as cause:
  639. raise PatternError(f'Bad value {value}') from cause
  640. token_index += 1
  641. exp = PatternSimpleExpression(field, op, value)
  642. return (exp, token_index)
  643. @classmethod
  644. def parse_value(cls, value: str, datatype: str):
  645. """
  646. Converts a value token to its Python value. Raises ValueError on failure.
  647. """
  648. if datatype == cls.TYPE_ID:
  649. if not is_user_id(value):
  650. raise ValueError(f'Illegal user id value: {value}')
  651. return value
  652. if datatype == cls.TYPE_MEMBER:
  653. return user_id_from_mention(value)
  654. if datatype == cls.TYPE_TEXT:
  655. return str_from_quoted_str(value)
  656. if datatype == cls.TYPE_INT:
  657. return int(value)
  658. if datatype == cls.TYPE_FLOAT:
  659. return float(value)
  660. if datatype == cls.TYPE_TIMESPAN:
  661. return timedelta_from_str(value)
  662. raise ValueError(f'Unhandled datatype {datatype}')