Experimental Discord bot written in Python
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

patterncog.py 23KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703
  1. """
  2. Cog for matching messages against guild-configurable criteria and taking
  3. automated actions on them.
  4. """
  5. import re
  6. from abc import ABCMeta, abstractmethod
  7. from discord import Guild, Member, Message
  8. from discord.ext import commands
  9. from config import CONFIG
  10. from rocketbot.cogs.basecog import BaseCog, BotMessage, BotMessageReaction
  11. from rocketbot.storage import Storage
  12. from rocketbot.utils import parse_timedelta
  13. class PatternAction:
  14. """
  15. Describes one action to take on a matched message or its author.
  16. """
  17. def __init__(self, action: str, args: list):
  18. self.action = action
  19. self.arguments = list(args)
  20. def __str__(self) -> str:
  21. arg_str = ', '.join(self.arguments)
  22. return f'{self.action}({arg_str})'
  23. class PatternExpression(metaclass=ABCMeta):
  24. """
  25. Abstract message matching expression.
  26. """
  27. def __init__(self):
  28. pass
  29. @abstractmethod
  30. def matches(self, message: Message) -> bool:
  31. """
  32. Whether a message matches this expression.
  33. """
  34. return False
  35. class PatternSimpleExpression(PatternExpression):
  36. """
  37. Message matching expression with a simple "<field> <operator> <value>"
  38. structure.
  39. """
  40. def __init__(self, field: str, operator: str, value):
  41. super().__init__()
  42. self.field = field
  43. self.operator = operator
  44. self.value = value
  45. def __field_value(self, message: Message):
  46. if self.field == 'content':
  47. return message.content
  48. if self.field == 'author':
  49. return str(message.author.id)
  50. if self.field == 'author.id':
  51. return str(message.author.id)
  52. if self.field == 'author.joinage':
  53. return message.created_at - message.author.joined_at
  54. if self.field == 'author.name':
  55. return message.author.name
  56. else:
  57. raise ValueError(f'Bad field name {self.field}')
  58. def matches(self, message: Message) -> bool:
  59. field_value = self.__field_value(message)
  60. if self.operator == '==':
  61. if isinstance(field_value, str) and isinstance(self.value, str):
  62. return field_value.lower() == self.value.lower()
  63. return field_value == self.value
  64. if self.operator == '!=':
  65. if isinstance(field_value, str) and isinstance(self.value, str):
  66. return field_value.lower() != self.value.lower()
  67. return field_value != self.value
  68. if self.operator == '<':
  69. return field_value < self.value
  70. if self.operator == '>':
  71. return field_value > self.value
  72. if self.operator == '<=':
  73. return field_value <= self.value
  74. if self.operator == '>=':
  75. return field_value >= self.value
  76. if self.operator == 'contains':
  77. return self.value.lower() in field_value.lower()
  78. if self.operator == '!contains':
  79. return self.value.lower() not in field_value.lower()
  80. if self.operator == 'matches':
  81. p = re.compile(self.value.lower())
  82. return p.match(field_value.lower()) is not None
  83. if self.operator == '!matches':
  84. p = re.compile(self.value.lower())
  85. return p.match(field_value.lower()) is None
  86. raise ValueError(f'Bad operator {self.operator}')
  87. def __str__(self) -> str:
  88. return f'({self.field} {self.operator} {self.value})'
  89. class PatternCompoundExpression(PatternExpression):
  90. """
  91. Message matching expression that combines several child expressions with
  92. a boolean operator.
  93. """
  94. def __init__(self, operator: str, operands: list):
  95. super().__init__()
  96. self.operator = operator
  97. self.operands = list(operands)
  98. def matches(self, message: Message) -> bool:
  99. if self.operator == '!':
  100. return not self.operands[0].matches(message)
  101. if self.operator == 'and':
  102. for op in self.operands:
  103. if not op.matches(message):
  104. return False
  105. return True
  106. if self.operator == 'or':
  107. for op in self.operands:
  108. if op.matches(message):
  109. return True
  110. return False
  111. raise RuntimeError(f'Bad operator "{self.operator}"')
  112. def __str__(self) -> str:
  113. if self.operator == '!':
  114. return f'(!( {self.operands[0]} ))'
  115. strs = map(str, self.operands)
  116. joined = f' {self.operator} '.join(strs)
  117. return f'( {joined} )'
  118. class PatternStatement:
  119. """
  120. A full message match statement. If a message matches the given expression,
  121. the given actions should be performed.
  122. """
  123. def __init__(self, name: str, actions: list, expression: PatternExpression, original: str):
  124. self.name = name
  125. self.actions = list(actions) # PatternAction[]
  126. self.expression = expression
  127. self.original = original
  128. class PatternContext:
  129. """
  130. Data about a message that has matched a configured statement and what
  131. actions have been carried out.
  132. """
  133. def __init__(self, message: Message, statement: PatternStatement):
  134. self.message = message
  135. self.statement = statement
  136. self.is_deleted = False
  137. self.is_kicked = False
  138. self.is_banned = False
  139. class PatternCog(BaseCog, name='Pattern Matching'):
  140. """
  141. Highly flexible cog for performing various actions on messages that match
  142. various critera. Patterns can be defined by mods for each guild.
  143. """
  144. def __get_patterns(self, guild: Guild) -> dict:
  145. patterns = Storage.get_state_value(guild, 'PatternCog.patterns')
  146. if patterns is None:
  147. patterns = {}
  148. patterns_encoded = Storage.get_config_value(guild, 'PatternCog.patterns')
  149. if patterns_encoded:
  150. for pe in patterns_encoded:
  151. name = pe.get('name')
  152. statement = pe.get('statement')
  153. try:
  154. ps = PatternCompiler.parse_statement(name, statement)
  155. patterns[name] = ps
  156. except PatternError as e:
  157. self.log(guild, 'Error parsing saved statement ' + \
  158. f'"{name}": "{e}" Statement: {statement}')
  159. Storage.set_state_value(guild, 'PatternCog.patterns', patterns)
  160. return patterns
  161. @classmethod
  162. def __save_patterns(cls, guild: Guild, patterns: dict) -> None:
  163. to_save = []
  164. for name, statement in patterns.items():
  165. to_save.append({
  166. 'name': name,
  167. 'statement': statement.original,
  168. })
  169. Storage.set_config_value(guild, 'PatternCog.patterns', to_save)
  170. @commands.Cog.listener()
  171. async def on_message(self, message: Message) -> None:
  172. 'Event listener'
  173. if message.author is None or \
  174. message.author.bot or \
  175. message.channel is None or \
  176. message.guild is None or \
  177. message.content is None or \
  178. message.content == '':
  179. return
  180. if message.author.permissions_in(message.channel).ban_members:
  181. # Ignore mods
  182. return
  183. patterns = self.__get_patterns(message.guild)
  184. for _, statement in patterns.items():
  185. if statement.expression.matches(message):
  186. await self.__trigger_actions(message, statement)
  187. break
  188. async def __trigger_actions(self, message: Message, statement: PatternStatement) -> None:
  189. context = PatternContext(message, statement)
  190. should_alert_mods = False
  191. action_descriptions = []
  192. self.log(message.guild, f'Message from {message.author.name} matched pattern "{statement.name}"')
  193. for action in statement.actions:
  194. if action.action == 'ban':
  195. await message.author.ban(
  196. reason=f'Rocketbot: Message matched custom pattern named "{statement.name}"',
  197. delete_message_days=0)
  198. context.is_banned = True
  199. context.is_kicked = True
  200. action_descriptions.append('Author banned')
  201. self.log(message.guild, f'{message.author.name} banned')
  202. elif action.action == 'delete':
  203. await message.delete()
  204. context.is_deleted = True
  205. action_descriptions.append('Message deleted')
  206. self.log(message.guild, f'{message.author.name}\'s message deleted')
  207. elif action.action == 'kick':
  208. await message.author.kick(
  209. reason=f'Rocketbot: Message matched custom pattern named "{statement.name}"')
  210. context.is_kicked = True
  211. action_descriptions.append('Author kicked')
  212. self.log(message.guild, f'{message.author.name} kicked')
  213. elif action.action == 'modwarn':
  214. should_alert_mods = True
  215. action_descriptions.append('Mods alerted')
  216. elif action.action == 'reply':
  217. await message.reply(
  218. f'{action.arguments[0]}',
  219. mention_author=False)
  220. action_descriptions.append('Autoreplied')
  221. self.log(message.guild, f'{message.author.name} autoreplied to')
  222. bm = BotMessage(
  223. message.guild,
  224. f'User {message.author.name} tripped custom pattern ' + \
  225. f'`{statement.name}`.\n\nAutomatic actions taken:\n• ' + \
  226. ('\n• '.join(action_descriptions)),
  227. type=BotMessage.TYPE_MOD_WARNING if should_alert_mods else BotMessage.TYPE_INFO,
  228. context=context)
  229. bm.quote = message.content
  230. await bm.set_reactions(BotMessageReaction.standard_set(
  231. did_delete=context.is_deleted,
  232. did_kick=context.is_kicked,
  233. did_ban=context.is_banned))
  234. await self.post_message(bm)
  235. async def on_mod_react(self,
  236. bot_message: BotMessage,
  237. reaction: BotMessageReaction,
  238. reacted_by: Member) -> None:
  239. context: PatternContext = bot_message.context
  240. if reaction.emoji == CONFIG['trash_emoji']:
  241. await context.message.delete()
  242. context.is_deleted = True
  243. elif reaction.emoji == CONFIG['kick_emoji']:
  244. await context.message.author.kick(
  245. reason='Rocketbot: Message matched custom pattern named ' + \
  246. f'"{context.statement.name}". Kicked by {reacted_by.name}.')
  247. context.is_kicked = True
  248. elif reaction.emoji == CONFIG['ban_emoji']:
  249. await context.message.author.ban(
  250. reason='Rocketbot: Message matched custom pattern named ' + \
  251. f'"{context.statement.name}". Banned by {reacted_by.name}.',
  252. delete_message_days=1)
  253. context.is_banned = True
  254. await bot_message.set_reactions(BotMessageReaction.standard_set(
  255. did_delete=context.is_deleted,
  256. did_kick=context.is_kicked,
  257. did_ban=context.is_banned))
  258. @commands.group(
  259. brief='Manages message pattern matching',
  260. )
  261. @commands.has_permissions(ban_members=True)
  262. @commands.guild_only()
  263. async def pattern(self, context: commands.Context):
  264. 'Message pattern matching command group'
  265. if context.invoked_subcommand is None:
  266. await context.send_help()
  267. @pattern.command(
  268. brief='Adds a custom pattern',
  269. description='Adds a custom pattern. Patterns use a simplified ' + \
  270. 'expression language. Full documentation found here: ' + \
  271. 'https://git.rixafrix.com/ialbert/python-app-rocketbot/src/branch/master/patterns.md',
  272. usage='<pattern_name> <expression...>',
  273. ignore_extra=True
  274. )
  275. async def add(self, context: commands.Context, name: str):
  276. 'Command handler'
  277. pattern_str = PatternCompiler.expression_str_from_context(context, name)
  278. try:
  279. statement = PatternCompiler.parse_statement(name, pattern_str)
  280. patterns = self.__get_patterns(context.guild)
  281. patterns[name] = statement
  282. self.__save_patterns(context.guild, patterns)
  283. await context.message.reply(
  284. f'{CONFIG["success_emoji"]} Pattern `{name}` added.',
  285. mention_author=False)
  286. except PatternError as e:
  287. await context.message.reply(
  288. f'{CONFIG["failure_emoji"]} Error parsing statement. {e}',
  289. mention_author=False)
  290. @pattern.command(
  291. brief='Removes a custom pattern',
  292. usage='<pattern_name>'
  293. )
  294. async def remove(self, context: commands.Context, name: str):
  295. 'Command handler'
  296. patterns = self.__get_patterns(context.guild)
  297. if patterns.get(name) is not None:
  298. del patterns[name]
  299. self.__save_patterns(context.guild, patterns)
  300. await context.message.reply(
  301. f'{CONFIG["success_emoji"]} Pattern `{name}` deleted.',
  302. mention_author=False)
  303. else:
  304. await context.message.reply(
  305. f'{CONFIG["failure_emoji"]} No pattern named `{name}`.',
  306. mention_author=False)
  307. @pattern.command(
  308. brief='Lists all patterns'
  309. )
  310. async def list(self, context: commands.Context) -> None:
  311. 'Command handler'
  312. patterns = self.__get_patterns(context.guild)
  313. if len(patterns) == 0:
  314. await context.message.reply('No patterns defined.', mention_author=False)
  315. return
  316. msg = ''
  317. for name, statement in sorted(patterns.items()):
  318. msg += f'Pattern `{name}`:\n```\n{statement.original}\n```\n'
  319. await context.message.reply(msg, mention_author=False)
  320. class PatternError(RuntimeError):
  321. """
  322. Error thrown when parsing a pattern statement.
  323. """
  324. class PatternCompiler:
  325. """
  326. Parses a user-provided message filter statement into a PatternStatement.
  327. """
  328. TYPE_ID = 'id'
  329. TYPE_MEMBER = 'Member'
  330. TYPE_TEXT = 'text'
  331. TYPE_INT = 'int'
  332. TYPE_FLOAT = 'float'
  333. TYPE_TIMESPAN = 'timespan'
  334. FIELD_TO_TYPE = {
  335. 'content': TYPE_TEXT,
  336. 'author': TYPE_MEMBER,
  337. 'author.id': TYPE_ID,
  338. 'author.name': TYPE_TEXT,
  339. 'author.joinage': TYPE_TIMESPAN,
  340. }
  341. ACTION_TO_ARGS = {
  342. 'ban': [],
  343. 'delete': [],
  344. 'kick': [],
  345. 'modwarn': [],
  346. 'reply': [ TYPE_TEXT ],
  347. }
  348. OPERATORS_IDENTITY = set([ '==', '!=' ])
  349. OPERATORS_COMPARISON = set([ '<', '>', '<=', '>=' ])
  350. OPERATORS_NUMERIC = OPERATORS_IDENTITY | OPERATORS_COMPARISON
  351. OPERATORS_TEXT = OPERATORS_IDENTITY | set([ 'contains', '!contains', 'matches', '!matches' ])
  352. OPERATORS_ALL = OPERATORS_IDENTITY | OPERATORS_COMPARISON | OPERATORS_TEXT
  353. TYPE_TO_OPERATORS = {
  354. TYPE_ID: OPERATORS_IDENTITY,
  355. TYPE_MEMBER: OPERATORS_IDENTITY,
  356. TYPE_TEXT: OPERATORS_TEXT,
  357. TYPE_INT: OPERATORS_NUMERIC,
  358. TYPE_FLOAT: OPERATORS_NUMERIC,
  359. TYPE_TIMESPAN: OPERATORS_NUMERIC,
  360. }
  361. WHITESPACE_CHARS = ' \t\n\r'
  362. STRING_QUOTE_CHARS = '\'"'
  363. SYMBOL_CHARS = 'abcdefghijklmnopqrstuvwxyz.'
  364. VALUE_CHARS = '0123456789dhms<@!>'
  365. OP_CHARS = '<=>!(),'
  366. @classmethod
  367. def expression_str_from_context(cls, context: commands.Context, name: str) -> str:
  368. """
  369. Extracts the statement string from an "add" command context.
  370. """
  371. pattern_str = context.message.content
  372. command_chain = [ name ]
  373. cmd = context.command
  374. while cmd:
  375. command_chain.insert(0, cmd.name)
  376. cmd = cmd.parent
  377. command_chain[0] = f'{context.prefix}{command_chain[0]}'
  378. for cmd in command_chain:
  379. if pattern_str.startswith(cmd):
  380. pattern_str = pattern_str[len(cmd):].lstrip()
  381. elif pattern_str.startswith(f'"{cmd}"'):
  382. pattern_str = pattern_str[len(cmd) + 2:].lstrip()
  383. return pattern_str
  384. @classmethod
  385. def parse_statement(cls, name: str, statement: str) -> PatternStatement:
  386. """
  387. Parses a user-provided message filter statement into a PatternStatement.
  388. """
  389. tokens = cls.tokenize(statement)
  390. token_index = 0
  391. actions, token_index = cls.read_actions(tokens, token_index)
  392. expression, token_index = cls.read_expression(tokens, token_index)
  393. return PatternStatement(name, actions, expression, statement)
  394. @classmethod
  395. def tokenize(cls, statement: str) -> list:
  396. """
  397. Converts a message filter statement into a list of tokens.
  398. """
  399. tokens = []
  400. in_quote = False
  401. in_escape = False
  402. all_token_types = set([ 'sym', 'op', 'val' ])
  403. possible_token_types = set(all_token_types)
  404. current_token = ''
  405. for ch in statement:
  406. if in_quote:
  407. if in_escape:
  408. if ch == 'n':
  409. current_token += '\n'
  410. elif ch == 't':
  411. current_token += '\t'
  412. else:
  413. current_token += ch
  414. in_escape = False
  415. elif ch == '\\':
  416. in_escape = True
  417. elif ch == in_quote:
  418. current_token += ch
  419. tokens.append(current_token)
  420. current_token = ''
  421. possible_token_types |= all_token_types
  422. in_quote = False
  423. else:
  424. current_token += ch
  425. else:
  426. if ch in cls.STRING_QUOTE_CHARS:
  427. if len(current_token) > 0:
  428. tokens.append(current_token)
  429. current_token = ''
  430. possible_token_types |= all_token_types
  431. in_quote = ch
  432. current_token = ch
  433. elif ch == '\\':
  434. raise PatternError("Unexpected \\ outside quoted string")
  435. elif ch in cls.WHITESPACE_CHARS:
  436. if len(current_token) > 0:
  437. tokens.append(current_token)
  438. current_token = ''
  439. possible_token_types |= all_token_types
  440. else:
  441. possible_ch_types = set()
  442. if ch in cls.SYMBOL_CHARS:
  443. possible_ch_types.add('sym')
  444. if ch in cls.VALUE_CHARS:
  445. possible_ch_types.add('val')
  446. if ch in cls.OP_CHARS:
  447. possible_ch_types.add('op')
  448. if len(current_token) > 0 and possible_ch_types.isdisjoint(possible_token_types):
  449. if len(current_token) > 0:
  450. tokens.append(current_token)
  451. current_token = ''
  452. possible_token_types |= all_token_types
  453. possible_token_types &= possible_ch_types
  454. current_token += ch
  455. if len(current_token) > 0:
  456. tokens.append(current_token)
  457. # Some symbols might be glommed onto other tokens. Split 'em up.
  458. prefixes_to_split = [ '!', '(', ',' ]
  459. suffixes_to_split = [ ')', ',' ]
  460. i = 0
  461. while i < len(tokens):
  462. token = tokens[i]
  463. mutated = False
  464. for prefix in prefixes_to_split:
  465. if token.startswith(prefix) and len(token) > len(prefix):
  466. tokens.insert(i, prefix)
  467. tokens[i + 1] = token[len(prefix):]
  468. i += 1
  469. mutated = True
  470. break
  471. if mutated:
  472. continue
  473. for suffix in suffixes_to_split:
  474. if token.endswith(suffix) and len(token) > len(suffix):
  475. tokens[i] = token[0:-len(suffix)]
  476. tokens.insert(i + 1, suffix)
  477. mutated = True
  478. break
  479. if mutated:
  480. continue
  481. i += 1
  482. return tokens
  483. @classmethod
  484. def read_actions(cls, tokens: list, token_index: int) -> tuple:
  485. """
  486. Reads the actions from a list of statement tokens. Returns a tuple
  487. containing a list of PatternActions and the token index this method
  488. left off at (the token after the "if").
  489. """
  490. actions = []
  491. current_action_tokens = []
  492. while token_index < len(tokens):
  493. token = tokens[token_index]
  494. if token == 'if':
  495. if len(current_action_tokens) > 0:
  496. a = PatternAction(current_action_tokens[0], current_action_tokens[1:])
  497. cls.__validate_action(a)
  498. actions.append(a)
  499. token_index += 1
  500. return (actions, token_index)
  501. elif token == ',':
  502. if len(current_action_tokens) < 1:
  503. raise PatternError('Unexpected ,')
  504. a = PatternAction(current_action_tokens[0], current_action_tokens[1:])
  505. cls.__validate_action(a)
  506. actions.append(a)
  507. current_action_tokens = []
  508. else:
  509. current_action_tokens.append(token)
  510. token_index += 1
  511. raise PatternError('Unexpected end of line in action list')
  512. @classmethod
  513. def __validate_action(cls, action: PatternAction) -> None:
  514. args = cls.ACTION_TO_ARGS.get(action.action)
  515. if args is None:
  516. raise PatternError(f'Unknown action "{action.action}"')
  517. if len(action.arguments) != len(args):
  518. if len(args) == 0:
  519. raise PatternError(f'Action "{action.action}" expects no arguments, ' + \
  520. f'got {len(action.arguments)}.')
  521. else:
  522. raise PatternError(f'Action "{action.action}" expects {len(args)} ' + \
  523. f'arguments, got {len(action.arguments)}.')
  524. for i, datatype in enumerate(args):
  525. action.arguments[i] = cls.parse_value(action.arguments[i], datatype)
  526. @classmethod
  527. def read_expression(cls,
  528. tokens: list,
  529. token_index: int,
  530. depth: int = 0,
  531. one_subexpression: bool = False) -> tuple:
  532. """
  533. Reads an expression from a list of statement tokens. Returns a tuple
  534. containing the PatternExpression and the token index it left off at.
  535. If one_subexpression is True then it will return after reading a
  536. single expression instead of joining multiples (for readong the
  537. subject of a NOT expression).
  538. """
  539. subexpressions = []
  540. last_compound_operator = None
  541. while token_index < len(tokens):
  542. if one_subexpression:
  543. if len(subexpressions) == 1:
  544. return (subexpressions[0], token_index)
  545. if len(subexpressions) > 1:
  546. raise PatternError('Too many subexpressions')
  547. compound_operator = None
  548. if tokens[token_index] == ')':
  549. if len(subexpressions) == 0:
  550. raise PatternError('No subexpressions')
  551. if len(subexpressions) == 1:
  552. return (subexpressions[0], token_index)
  553. return (PatternCompoundExpression(last_compound_operator, subexpressions), token_index)
  554. if tokens[token_index] in set(["and", "or"]):
  555. compound_operator = tokens[token_index]
  556. if last_compound_operator and compound_operator != last_compound_operator:
  557. subexpressions = [ PatternCompoundExpression(last_compound_operator, subexpressions) ]
  558. last_compound_operator = compound_operator
  559. else:
  560. last_compound_operator = compound_operator
  561. token_index += 1
  562. if tokens[token_index] == '!':
  563. (exp, next_index) = cls.read_expression(tokens, token_index + 1, \
  564. depth + 1, one_subexpression=True)
  565. subexpressions.append(PatternCompoundExpression('!', [exp]))
  566. token_index = next_index
  567. elif tokens[token_index] == '(':
  568. (exp, next_index) = cls.read_expression(tokens, token_index + 1, depth + 1)
  569. if tokens[next_index] != ')':
  570. raise PatternError('Expected )')
  571. subexpressions.append(exp)
  572. token_index = next_index + 1
  573. else:
  574. (simple, next_index) = cls.read_simple_expression(tokens, token_index, depth)
  575. subexpressions.append(simple)
  576. token_index = next_index
  577. if len(subexpressions) == 0:
  578. raise PatternError('No subexpressions')
  579. elif len(subexpressions) == 1:
  580. return (subexpressions[0], token_index)
  581. else:
  582. return (PatternCompoundExpression(last_compound_operator, subexpressions), token_index)
  583. @classmethod
  584. def read_simple_expression(cls, tokens: list, token_index: int, depth: int = 0) -> tuple:
  585. """
  586. Reads a simple expression consisting of a field name, operator, and
  587. comparison value. Returns a tuple of the PatternSimpleExpression and
  588. the token index it left off at.
  589. """
  590. if depth > 8:
  591. raise PatternError('Expression nests too deeply')
  592. if token_index >= len(tokens):
  593. raise PatternError('Expected field name, found EOL')
  594. field = tokens[token_index]
  595. token_index += 1
  596. datatype = cls.FIELD_TO_TYPE.get(field)
  597. if datatype is None:
  598. raise PatternError(f'No such field "{field}"')
  599. if token_index >= len(tokens):
  600. raise PatternError('Expected operator, found EOL')
  601. op = tokens[token_index]
  602. token_index += 1
  603. if op == '!':
  604. if token_index >= len(tokens):
  605. raise PatternError('Expected operator, found EOL')
  606. op = '!' + tokens[token_index]
  607. token_index += 1
  608. allowed_ops = cls.TYPE_TO_OPERATORS[datatype]
  609. if op not in allowed_ops:
  610. if op in cls.OPERATORS_ALL:
  611. raise PatternError(f'Operator {op} cannot be used with field "{field}"')
  612. raise PatternError(f'Unrecognized operator "{op}" - allowed: {list(allowed_ops)}')
  613. if token_index >= len(tokens):
  614. raise PatternError('Expected value, found EOL')
  615. value = tokens[token_index]
  616. try:
  617. value = cls.parse_value(value, datatype)
  618. except ValueError as cause:
  619. raise PatternError(f'Bad value {value}') from cause
  620. token_index += 1
  621. exp = PatternSimpleExpression(field, op, value)
  622. return (exp, token_index)
  623. @classmethod
  624. def parse_value(cls, value: str, datatype: str):
  625. """
  626. Converts a value token to its Python value. Raises ValueError on failure.
  627. """
  628. if datatype == cls.TYPE_ID:
  629. p = re.compile('^[0-9]+$')
  630. if p.match(value) is None:
  631. raise ValueError(f'Illegal id value "{value}"')
  632. # Store it as a str so it can be larger than an int
  633. return value
  634. if datatype == cls.TYPE_MEMBER:
  635. p = re.compile('^<@!?([0-9]+)>$')
  636. m = p.match(value)
  637. if m is None:
  638. raise ValueError('Illegal member value. Must be an @ mention.')
  639. return m.group(1)
  640. if datatype == cls.TYPE_TEXT:
  641. # Must be quoted.
  642. if len(value) < 2 or \
  643. value[0:1] not in cls.STRING_QUOTE_CHARS or \
  644. value[-1:] not in cls.STRING_QUOTE_CHARS or \
  645. value[0:1] != value[-1:]:
  646. raise ValueError(f'Not a quoted string value: {value}')
  647. return value[1:-1]
  648. if datatype == cls.TYPE_INT:
  649. return int(value)
  650. if datatype == cls.TYPE_FLOAT:
  651. return float(value)
  652. if datatype == cls.TYPE_TIMESPAN:
  653. return parse_timedelta(value)
  654. raise ValueError(f'Unhandled datatype {datatype}')