|
|
@@ -2,166 +2,15 @@
|
|
2
|
2
|
Cog for matching messages against guild-configurable criteria and taking
|
|
3
|
3
|
automated actions on them.
|
|
4
|
4
|
"""
|
|
5
|
|
-import re
|
|
6
|
|
-from abc import ABCMeta, abstractmethod
|
|
7
|
5
|
from discord import Guild, Member, Message, utils as discordutils
|
|
8
|
6
|
from discord.ext import commands
|
|
9
|
7
|
|
|
10
|
8
|
from config import CONFIG
|
|
11
|
9
|
from rocketbot.cogs.basecog import BaseCog, BotMessage, BotMessageReaction
|
|
12
|
10
|
from rocketbot.cogsetting import CogSetting
|
|
|
11
|
+from rocketbot.pattern import PatternCompiler, PatternDeprecationError, \
|
|
|
12
|
+ PatternError, PatternStatement
|
|
13
|
13
|
from rocketbot.storage import Storage
|
|
14
|
|
-from rocketbot.utils import is_user_id, str_from_quoted_str, timedelta_from_str, \
|
|
15
|
|
- user_id_from_mention
|
|
16
|
|
-
|
|
17
|
|
-class PatternAction:
|
|
18
|
|
- """
|
|
19
|
|
- Describes one action to take on a matched message or its author.
|
|
20
|
|
- """
|
|
21
|
|
- def __init__(self, action: str, args: list):
|
|
22
|
|
- self.action = action
|
|
23
|
|
- self.arguments = list(args)
|
|
24
|
|
-
|
|
25
|
|
- def __str__(self) -> str:
|
|
26
|
|
- arg_str = ', '.join(self.arguments)
|
|
27
|
|
- return f'{self.action}({arg_str})'
|
|
28
|
|
-
|
|
29
|
|
-class PatternExpression(metaclass=ABCMeta):
|
|
30
|
|
- """
|
|
31
|
|
- Abstract message matching expression.
|
|
32
|
|
- """
|
|
33
|
|
- def __init__(self):
|
|
34
|
|
- pass
|
|
35
|
|
-
|
|
36
|
|
- @abstractmethod
|
|
37
|
|
- def matches(self, message: Message) -> bool:
|
|
38
|
|
- """
|
|
39
|
|
- Whether a message matches this expression.
|
|
40
|
|
- """
|
|
41
|
|
- return False
|
|
42
|
|
-
|
|
43
|
|
-class PatternSimpleExpression(PatternExpression):
|
|
44
|
|
- """
|
|
45
|
|
- Message matching expression with a simple "<field> <operator> <value>"
|
|
46
|
|
- structure.
|
|
47
|
|
- """
|
|
48
|
|
- def __init__(self, field: str, operator: str, value):
|
|
49
|
|
- super().__init__()
|
|
50
|
|
- self.field = field
|
|
51
|
|
- self.operator = operator
|
|
52
|
|
- self.value = value
|
|
53
|
|
-
|
|
54
|
|
- def __field_value(self, message: Message):
|
|
55
|
|
- if self.field in ('content.markdown', 'content'):
|
|
56
|
|
- return message.content
|
|
57
|
|
- if self.field == 'content.plain':
|
|
58
|
|
- return discordutils.remove_markdown(message.clean_content)
|
|
59
|
|
- if self.field == 'author':
|
|
60
|
|
- return str(message.author.id)
|
|
61
|
|
- if self.field == 'author.id':
|
|
62
|
|
- return str(message.author.id)
|
|
63
|
|
- if self.field == 'author.joinage':
|
|
64
|
|
- return message.created_at - message.author.joined_at
|
|
65
|
|
- if self.field == 'author.name':
|
|
66
|
|
- return message.author.name
|
|
67
|
|
- else:
|
|
68
|
|
- raise ValueError(f'Bad field name {self.field}')
|
|
69
|
|
-
|
|
70
|
|
- def matches(self, message: Message) -> bool:
|
|
71
|
|
- field_value = self.__field_value(message)
|
|
72
|
|
- if self.operator == '==':
|
|
73
|
|
- if isinstance(field_value, str) and isinstance(self.value, str):
|
|
74
|
|
- return field_value.lower() == self.value.lower()
|
|
75
|
|
- return field_value == self.value
|
|
76
|
|
- if self.operator == '!=':
|
|
77
|
|
- if isinstance(field_value, str) and isinstance(self.value, str):
|
|
78
|
|
- return field_value.lower() != self.value.lower()
|
|
79
|
|
- return field_value != self.value
|
|
80
|
|
- if self.operator == '<':
|
|
81
|
|
- return field_value < self.value
|
|
82
|
|
- if self.operator == '>':
|
|
83
|
|
- return field_value > self.value
|
|
84
|
|
- if self.operator == '<=':
|
|
85
|
|
- return field_value <= self.value
|
|
86
|
|
- if self.operator == '>=':
|
|
87
|
|
- return field_value >= self.value
|
|
88
|
|
- if self.operator == 'contains':
|
|
89
|
|
- return self.value.lower() in field_value.lower()
|
|
90
|
|
- if self.operator == '!contains':
|
|
91
|
|
- return self.value.lower() not in field_value.lower()
|
|
92
|
|
- if self.operator == 'matches':
|
|
93
|
|
- p = re.compile(self.value.lower())
|
|
94
|
|
- return p.match(field_value.lower()) is not None
|
|
95
|
|
- if self.operator == '!matches':
|
|
96
|
|
- p = re.compile(self.value.lower())
|
|
97
|
|
- return p.match(field_value.lower()) is None
|
|
98
|
|
- raise ValueError(f'Bad operator {self.operator}')
|
|
99
|
|
-
|
|
100
|
|
- def __str__(self) -> str:
|
|
101
|
|
- return f'({self.field} {self.operator} {self.value})'
|
|
102
|
|
-
|
|
103
|
|
-class PatternCompoundExpression(PatternExpression):
|
|
104
|
|
- """
|
|
105
|
|
- Message matching expression that combines several child expressions with
|
|
106
|
|
- a boolean operator.
|
|
107
|
|
- """
|
|
108
|
|
- def __init__(self, operator: str, operands: list):
|
|
109
|
|
- super().__init__()
|
|
110
|
|
- self.operator = operator
|
|
111
|
|
- self.operands = list(operands)
|
|
112
|
|
-
|
|
113
|
|
- def matches(self, message: Message) -> bool:
|
|
114
|
|
- if self.operator == '!':
|
|
115
|
|
- return not self.operands[0].matches(message)
|
|
116
|
|
- if self.operator == 'and':
|
|
117
|
|
- for op in self.operands:
|
|
118
|
|
- if not op.matches(message):
|
|
119
|
|
- return False
|
|
120
|
|
- return True
|
|
121
|
|
- if self.operator == 'or':
|
|
122
|
|
- for op in self.operands:
|
|
123
|
|
- if op.matches(message):
|
|
124
|
|
- return True
|
|
125
|
|
- return False
|
|
126
|
|
- raise ValueError(f'Bad operator "{self.operator}"')
|
|
127
|
|
-
|
|
128
|
|
- def __str__(self) -> str:
|
|
129
|
|
- if self.operator == '!':
|
|
130
|
|
- return f'(!( {self.operands[0]} ))'
|
|
131
|
|
- strs = map(str, self.operands)
|
|
132
|
|
- joined = f' {self.operator} '.join(strs)
|
|
133
|
|
- return f'( {joined} )'
|
|
134
|
|
-
|
|
135
|
|
-class PatternStatement:
|
|
136
|
|
- """
|
|
137
|
|
- A full message match statement. If a message matches the given expression,
|
|
138
|
|
- the given actions should be performed.
|
|
139
|
|
- """
|
|
140
|
|
- def __init__(self,
|
|
141
|
|
- name: str,
|
|
142
|
|
- actions: list,
|
|
143
|
|
- expression: PatternExpression,
|
|
144
|
|
- original: str):
|
|
145
|
|
- self.name = name
|
|
146
|
|
- self.actions = list(actions) # PatternAction[]
|
|
147
|
|
- self.expression = expression
|
|
148
|
|
- self.original = original
|
|
149
|
|
-
|
|
150
|
|
- def to_json(self) -> dict:
|
|
151
|
|
- """
|
|
152
|
|
- Returns a JSON representation of this statement.
|
|
153
|
|
- """
|
|
154
|
|
- return {
|
|
155
|
|
- 'name': self.name,
|
|
156
|
|
- 'statement': self.original,
|
|
157
|
|
- }
|
|
158
|
|
-
|
|
159
|
|
- @classmethod
|
|
160
|
|
- def from_json(cls, json: dict):
|
|
161
|
|
- """
|
|
162
|
|
- Gets a PatternStatement from its JSON representation.
|
|
163
|
|
- """
|
|
164
|
|
- return PatternCompiler.parse_statement(json['name'], json['statement'])
|
|
165
|
14
|
|
|
166
|
15
|
class PatternContext:
|
|
167
|
16
|
"""
|
|
|
@@ -194,7 +43,12 @@ class PatternCog(BaseCog, name='Pattern Matching'):
|
|
194
|
43
|
pattern_list: list[PatternStatement] = []
|
|
195
|
44
|
for json in jsons:
|
|
196
|
45
|
try:
|
|
197
|
|
- pattern_list.append(PatternStatement.from_json(json))
|
|
|
46
|
+ ps = PatternStatement.from_json(json)
|
|
|
47
|
+ pattern_list.append(ps)
|
|
|
48
|
+ try:
|
|
|
49
|
+ ps.check_deprecations()
|
|
|
50
|
+ except PatternDeprecationError as e:
|
|
|
51
|
+ self.log(guild, f'Pattern {ps.name}: {e}')
|
|
198
|
52
|
except PatternError as e:
|
|
199
|
53
|
self.log(guild, f'Error decoding pattern "{json["name"]}": {e}')
|
|
200
|
54
|
patterns = { p.name:p for p in pattern_list}
|
|
|
@@ -202,7 +56,9 @@ class PatternCog(BaseCog, name='Pattern Matching'):
|
|
202
|
56
|
return patterns
|
|
203
|
57
|
|
|
204
|
58
|
@classmethod
|
|
205
|
|
- def __save_patterns(cls, guild: Guild, patterns: dict[str, PatternStatement]) -> None:
|
|
|
59
|
+ def __save_patterns(cls,
|
|
|
60
|
+ guild: Guild,
|
|
|
61
|
+ patterns: dict[str, PatternStatement]) -> None:
|
|
206
|
62
|
to_save: list[dict] = list(map(PatternStatement.to_json, patterns.values()))
|
|
207
|
63
|
cls.set_guild_setting(guild, cls.SETTING_PATTERNS, to_save)
|
|
208
|
64
|
|
|
|
@@ -226,15 +82,20 @@ class PatternCog(BaseCog, name='Pattern Matching'):
|
|
226
|
82
|
await self.__trigger_actions(message, statement)
|
|
227
|
83
|
break
|
|
228
|
84
|
|
|
229
|
|
- async def __trigger_actions(self, message: Message, statement: PatternStatement) -> None:
|
|
|
85
|
+ async def __trigger_actions(self,
|
|
|
86
|
+ message: Message,
|
|
|
87
|
+ statement: PatternStatement) -> None:
|
|
230
|
88
|
context = PatternContext(message, statement)
|
|
231
|
|
- should_alert_mods = False
|
|
|
89
|
+ should_post_message = False
|
|
|
90
|
+ message_type: int = BotMessage.TYPE_DEFAULT
|
|
232
|
91
|
action_descriptions = []
|
|
233
|
|
- self.log(message.guild, f'Message from {message.author.name} matched pattern "{statement.name}"')
|
|
|
92
|
+ self.log(message.guild, f'Message from {message.author.name} matched ' + \
|
|
|
93
|
+ f'pattern "{statement.name}"')
|
|
234
|
94
|
for action in statement.actions:
|
|
235
|
95
|
if action.action == 'ban':
|
|
236
|
96
|
await message.author.ban(
|
|
237
|
|
- reason=f'Rocketbot: Message matched custom pattern named "{statement.name}"',
|
|
|
97
|
+ reason='Rocketbot: Message matched custom pattern named ' + \
|
|
|
98
|
+ f'"{statement.name}"',
|
|
238
|
99
|
delete_message_days=0)
|
|
239
|
100
|
context.is_banned = True
|
|
240
|
101
|
context.is_kicked = True
|
|
|
@@ -247,12 +108,18 @@ class PatternCog(BaseCog, name='Pattern Matching'):
|
|
247
|
108
|
self.log(message.guild, f'{message.author.name}\'s message deleted')
|
|
248
|
109
|
elif action.action == 'kick':
|
|
249
|
110
|
await message.author.kick(
|
|
250
|
|
- reason=f'Rocketbot: Message matched custom pattern named "{statement.name}"')
|
|
|
111
|
+ reason='Rocketbot: Message matched custom pattern named ' + \
|
|
|
112
|
+ f'"{statement.name}"')
|
|
251
|
113
|
context.is_kicked = True
|
|
252
|
114
|
action_descriptions.append('Author kicked')
|
|
253
|
115
|
self.log(message.guild, f'{message.author.name} kicked')
|
|
|
116
|
+ elif action.action == 'modinfo':
|
|
|
117
|
+ should_post_message = True
|
|
|
118
|
+ message_type = BotMessage.TYPE_INFO
|
|
|
119
|
+ action_descriptions.append('Message logged')
|
|
254
|
120
|
elif action.action == 'modwarn':
|
|
255
|
|
- should_alert_mods = True
|
|
|
121
|
+ should_post_message = True
|
|
|
122
|
+ message_type = BotMessage.TYPE_MOD_WARNING
|
|
256
|
123
|
action_descriptions.append('Mods alerted')
|
|
257
|
124
|
elif action.action == 'reply':
|
|
258
|
125
|
await message.reply(
|
|
|
@@ -260,19 +127,20 @@ class PatternCog(BaseCog, name='Pattern Matching'):
|
|
260
|
127
|
mention_author=False)
|
|
261
|
128
|
action_descriptions.append('Autoreplied')
|
|
262
|
129
|
self.log(message.guild, f'autoreplied to {message.author.name}')
|
|
263
|
|
- bm = BotMessage(
|
|
264
|
|
- message.guild,
|
|
265
|
|
- f'User {message.author.name} tripped custom pattern ' + \
|
|
266
|
|
- f'`{statement.name}`.\n\nAutomatic actions taken:\n• ' + \
|
|
267
|
|
- ('\n• '.join(action_descriptions)),
|
|
268
|
|
- type=BotMessage.TYPE_MOD_WARNING if should_alert_mods else BotMessage.TYPE_INFO,
|
|
269
|
|
- context=context)
|
|
270
|
|
- bm.quote = discordutils.remove_markdown(message.clean_content)
|
|
271
|
|
- await bm.set_reactions(BotMessageReaction.standard_set(
|
|
272
|
|
- did_delete=context.is_deleted,
|
|
273
|
|
- did_kick=context.is_kicked,
|
|
274
|
|
- did_ban=context.is_banned))
|
|
275
|
|
- await self.post_message(bm)
|
|
|
130
|
+ if should_post_message:
|
|
|
131
|
+ bm = BotMessage(
|
|
|
132
|
+ message.guild,
|
|
|
133
|
+ f'User {message.author.name} tripped custom pattern ' + \
|
|
|
134
|
+ f'`{statement.name}`.\n\nAutomatic actions taken:\n• ' + \
|
|
|
135
|
+ ('\n• '.join(action_descriptions)),
|
|
|
136
|
+ type=message_type,
|
|
|
137
|
+ context=context)
|
|
|
138
|
+ bm.quote = discordutils.remove_markdown(message.clean_content)
|
|
|
139
|
+ await bm.set_reactions(BotMessageReaction.standard_set(
|
|
|
140
|
+ did_delete=context.is_deleted,
|
|
|
141
|
+ did_kick=context.is_kicked,
|
|
|
142
|
+ did_ban=context.is_banned))
|
|
|
143
|
+ await self.post_message(bm)
|
|
276
|
144
|
|
|
277
|
145
|
async def on_mod_react(self,
|
|
278
|
146
|
bot_message: BotMessage,
|
|
|
@@ -312,7 +180,8 @@ class PatternCog(BaseCog, name='Pattern Matching'):
|
|
312
|
180
|
brief='Adds a custom pattern',
|
|
313
|
181
|
description='Adds a custom pattern. Patterns use a simplified ' + \
|
|
314
|
182
|
'expression language. Full documentation found here: ' + \
|
|
315
|
|
- 'https://git.rixafrix.com/ialbert/python-app-rocketbot/src/branch/master/patterns.md',
|
|
|
183
|
+ 'https://git.rixafrix.com/ialbert/python-app-rocketbot/src/' + \
|
|
|
184
|
+ 'branch/master/patterns.md',
|
|
316
|
185
|
usage='<pattern_name> <expression...>',
|
|
317
|
186
|
ignore_extra=True
|
|
318
|
187
|
)
|
|
|
@@ -321,6 +190,7 @@ class PatternCog(BaseCog, name='Pattern Matching'):
|
|
321
|
190
|
pattern_str = PatternCompiler.expression_str_from_context(context, name)
|
|
322
|
191
|
try:
|
|
323
|
192
|
statement = PatternCompiler.parse_statement(name, pattern_str)
|
|
|
193
|
+ statement.check_deprecations()
|
|
324
|
194
|
patterns = self.__get_patterns(context.guild)
|
|
325
|
195
|
patterns[name] = statement
|
|
326
|
196
|
self.__save_patterns(context.guild, patterns)
|
|
|
@@ -363,352 +233,3 @@ class PatternCog(BaseCog, name='Pattern Matching'):
|
|
363
|
233
|
for name, statement in sorted(patterns.items()):
|
|
364
|
234
|
msg += f'Pattern `{name}`:\n```\n{statement.original}\n```\n'
|
|
365
|
235
|
await context.message.reply(msg, mention_author=False)
|
|
366
|
|
-
|
|
367
|
|
-class PatternError(RuntimeError):
|
|
368
|
|
- """
|
|
369
|
|
- Error thrown when parsing a pattern statement.
|
|
370
|
|
- """
|
|
371
|
|
-
|
|
372
|
|
-class PatternCompiler:
|
|
373
|
|
- """
|
|
374
|
|
- Parses a user-provided message filter statement into a PatternStatement.
|
|
375
|
|
- """
|
|
376
|
|
- TYPE_ID = 'id'
|
|
377
|
|
- TYPE_MEMBER = 'Member'
|
|
378
|
|
- TYPE_TEXT = 'text'
|
|
379
|
|
- TYPE_INT = 'int'
|
|
380
|
|
- TYPE_FLOAT = 'float'
|
|
381
|
|
- TYPE_TIMESPAN = 'timespan'
|
|
382
|
|
-
|
|
383
|
|
- FIELD_TO_TYPE = {
|
|
384
|
|
- 'content.plain': TYPE_TEXT,
|
|
385
|
|
- 'content.markdown': TYPE_TEXT,
|
|
386
|
|
- 'author': TYPE_MEMBER,
|
|
387
|
|
- 'author.id': TYPE_ID,
|
|
388
|
|
- 'author.name': TYPE_TEXT,
|
|
389
|
|
- 'author.joinage': TYPE_TIMESPAN,
|
|
390
|
|
-
|
|
391
|
|
- 'content': TYPE_TEXT, # deprecated, use content.markdown or content.plain
|
|
392
|
|
- }
|
|
393
|
|
-
|
|
394
|
|
- ACTION_TO_ARGS = {
|
|
395
|
|
- 'ban': [],
|
|
396
|
|
- 'delete': [],
|
|
397
|
|
- 'kick': [],
|
|
398
|
|
- 'modwarn': [],
|
|
399
|
|
- 'reply': [ TYPE_TEXT ],
|
|
400
|
|
- }
|
|
401
|
|
-
|
|
402
|
|
- OPERATORS_IDENTITY = set([ '==', '!=' ])
|
|
403
|
|
- OPERATORS_COMPARISON = set([ '<', '>', '<=', '>=' ])
|
|
404
|
|
- OPERATORS_NUMERIC = OPERATORS_IDENTITY | OPERATORS_COMPARISON
|
|
405
|
|
- OPERATORS_TEXT = OPERATORS_IDENTITY | set([ 'contains', '!contains', 'matches', '!matches' ])
|
|
406
|
|
- OPERATORS_ALL = OPERATORS_IDENTITY | OPERATORS_COMPARISON | OPERATORS_TEXT
|
|
407
|
|
-
|
|
408
|
|
- TYPE_TO_OPERATORS = {
|
|
409
|
|
- TYPE_ID: OPERATORS_IDENTITY,
|
|
410
|
|
- TYPE_MEMBER: OPERATORS_IDENTITY,
|
|
411
|
|
- TYPE_TEXT: OPERATORS_TEXT,
|
|
412
|
|
- TYPE_INT: OPERATORS_NUMERIC,
|
|
413
|
|
- TYPE_FLOAT: OPERATORS_NUMERIC,
|
|
414
|
|
- TYPE_TIMESPAN: OPERATORS_NUMERIC,
|
|
415
|
|
- }
|
|
416
|
|
-
|
|
417
|
|
- WHITESPACE_CHARS = ' \t\n\r'
|
|
418
|
|
- STRING_QUOTE_CHARS = '\'"'
|
|
419
|
|
- SYMBOL_CHARS = 'abcdefghijklmnopqrstuvwxyz.'
|
|
420
|
|
- VALUE_CHARS = '0123456789dhms<@!>'
|
|
421
|
|
- OP_CHARS = '<=>!(),'
|
|
422
|
|
-
|
|
423
|
|
- @classmethod
|
|
424
|
|
- def expression_str_from_context(cls, context: commands.Context, name: str) -> str:
|
|
425
|
|
- """
|
|
426
|
|
- Extracts the statement string from an "add" command context.
|
|
427
|
|
- """
|
|
428
|
|
- pattern_str = context.message.content
|
|
429
|
|
- command_chain = [ name ]
|
|
430
|
|
- cmd = context.command
|
|
431
|
|
- while cmd:
|
|
432
|
|
- command_chain.insert(0, cmd.name)
|
|
433
|
|
- cmd = cmd.parent
|
|
434
|
|
- command_chain[0] = f'{context.prefix}{command_chain[0]}'
|
|
435
|
|
- for cmd in command_chain:
|
|
436
|
|
- if pattern_str.startswith(cmd):
|
|
437
|
|
- pattern_str = pattern_str[len(cmd):].lstrip()
|
|
438
|
|
- elif pattern_str.startswith(f'"{cmd}"'):
|
|
439
|
|
- pattern_str = pattern_str[len(cmd) + 2:].lstrip()
|
|
440
|
|
- return pattern_str
|
|
441
|
|
-
|
|
442
|
|
- @classmethod
|
|
443
|
|
- def parse_statement(cls, name: str, statement: str) -> PatternStatement:
|
|
444
|
|
- """
|
|
445
|
|
- Parses a user-provided message filter statement into a PatternStatement.
|
|
446
|
|
- """
|
|
447
|
|
- tokens = cls.tokenize(statement)
|
|
448
|
|
- token_index = 0
|
|
449
|
|
- actions, token_index = cls.read_actions(tokens, token_index)
|
|
450
|
|
- expression, token_index = cls.read_expression(tokens, token_index)
|
|
451
|
|
- return PatternStatement(name, actions, expression, statement)
|
|
452
|
|
-
|
|
453
|
|
- @classmethod
|
|
454
|
|
- def tokenize(cls, statement: str) -> list:
|
|
455
|
|
- """
|
|
456
|
|
- Converts a message filter statement into a list of tokens.
|
|
457
|
|
- """
|
|
458
|
|
- tokens = []
|
|
459
|
|
- in_quote = False
|
|
460
|
|
- in_escape = False
|
|
461
|
|
- all_token_types = set([ 'sym', 'op', 'val' ])
|
|
462
|
|
- possible_token_types = set(all_token_types)
|
|
463
|
|
- current_token = ''
|
|
464
|
|
- for ch in statement:
|
|
465
|
|
- if in_quote:
|
|
466
|
|
- if in_escape:
|
|
467
|
|
- if ch == 'n':
|
|
468
|
|
- current_token += '\n'
|
|
469
|
|
- elif ch == 't':
|
|
470
|
|
- current_token += '\t'
|
|
471
|
|
- else:
|
|
472
|
|
- current_token += ch
|
|
473
|
|
- in_escape = False
|
|
474
|
|
- elif ch == '\\':
|
|
475
|
|
- in_escape = True
|
|
476
|
|
- elif ch == in_quote:
|
|
477
|
|
- current_token += ch
|
|
478
|
|
- tokens.append(current_token)
|
|
479
|
|
- current_token = ''
|
|
480
|
|
- possible_token_types |= all_token_types
|
|
481
|
|
- in_quote = False
|
|
482
|
|
- else:
|
|
483
|
|
- current_token += ch
|
|
484
|
|
- else:
|
|
485
|
|
- if ch in cls.STRING_QUOTE_CHARS:
|
|
486
|
|
- if len(current_token) > 0:
|
|
487
|
|
- tokens.append(current_token)
|
|
488
|
|
- current_token = ''
|
|
489
|
|
- possible_token_types |= all_token_types
|
|
490
|
|
- in_quote = ch
|
|
491
|
|
- current_token = ch
|
|
492
|
|
- elif ch == '\\':
|
|
493
|
|
- raise PatternError("Unexpected \\ outside quoted string")
|
|
494
|
|
- elif ch in cls.WHITESPACE_CHARS:
|
|
495
|
|
- if len(current_token) > 0:
|
|
496
|
|
- tokens.append(current_token)
|
|
497
|
|
- current_token = ''
|
|
498
|
|
- possible_token_types |= all_token_types
|
|
499
|
|
- else:
|
|
500
|
|
- possible_ch_types = set()
|
|
501
|
|
- if ch in cls.SYMBOL_CHARS:
|
|
502
|
|
- possible_ch_types.add('sym')
|
|
503
|
|
- if ch in cls.VALUE_CHARS:
|
|
504
|
|
- possible_ch_types.add('val')
|
|
505
|
|
- if ch in cls.OP_CHARS:
|
|
506
|
|
- possible_ch_types.add('op')
|
|
507
|
|
- if len(current_token) > 0 and possible_ch_types.isdisjoint(possible_token_types):
|
|
508
|
|
- if len(current_token) > 0:
|
|
509
|
|
- tokens.append(current_token)
|
|
510
|
|
- current_token = ''
|
|
511
|
|
- possible_token_types |= all_token_types
|
|
512
|
|
- possible_token_types &= possible_ch_types
|
|
513
|
|
- current_token += ch
|
|
514
|
|
- if len(current_token) > 0:
|
|
515
|
|
- tokens.append(current_token)
|
|
516
|
|
-
|
|
517
|
|
- # Some symbols might be glommed onto other tokens. Split 'em up.
|
|
518
|
|
- prefixes_to_split = [ '!', '(', ',' ]
|
|
519
|
|
- suffixes_to_split = [ ')', ',' ]
|
|
520
|
|
- i = 0
|
|
521
|
|
- while i < len(tokens):
|
|
522
|
|
- token = tokens[i]
|
|
523
|
|
- mutated = False
|
|
524
|
|
- for prefix in prefixes_to_split:
|
|
525
|
|
- if token.startswith(prefix) and len(token) > len(prefix):
|
|
526
|
|
- tokens.insert(i, prefix)
|
|
527
|
|
- tokens[i + 1] = token[len(prefix):]
|
|
528
|
|
- i += 1
|
|
529
|
|
- mutated = True
|
|
530
|
|
- break
|
|
531
|
|
- if mutated:
|
|
532
|
|
- continue
|
|
533
|
|
- for suffix in suffixes_to_split:
|
|
534
|
|
- if token.endswith(suffix) and len(token) > len(suffix):
|
|
535
|
|
- tokens[i] = token[0:-len(suffix)]
|
|
536
|
|
- tokens.insert(i + 1, suffix)
|
|
537
|
|
- mutated = True
|
|
538
|
|
- break
|
|
539
|
|
- if mutated:
|
|
540
|
|
- continue
|
|
541
|
|
- i += 1
|
|
542
|
|
- return tokens
|
|
543
|
|
-
|
|
544
|
|
- @classmethod
|
|
545
|
|
- def read_actions(cls, tokens: list, token_index: int) -> tuple:
|
|
546
|
|
- """
|
|
547
|
|
- Reads the actions from a list of statement tokens. Returns a tuple
|
|
548
|
|
- containing a list of PatternActions and the token index this method
|
|
549
|
|
- left off at (the token after the "if").
|
|
550
|
|
- """
|
|
551
|
|
- actions = []
|
|
552
|
|
- current_action_tokens = []
|
|
553
|
|
- while token_index < len(tokens):
|
|
554
|
|
- token = tokens[token_index]
|
|
555
|
|
- if token == 'if':
|
|
556
|
|
- if len(current_action_tokens) > 0:
|
|
557
|
|
- a = PatternAction(current_action_tokens[0], current_action_tokens[1:])
|
|
558
|
|
- cls.__validate_action(a)
|
|
559
|
|
- actions.append(a)
|
|
560
|
|
- token_index += 1
|
|
561
|
|
- return (actions, token_index)
|
|
562
|
|
- elif token == ',':
|
|
563
|
|
- if len(current_action_tokens) < 1:
|
|
564
|
|
- raise PatternError('Unexpected ,')
|
|
565
|
|
- a = PatternAction(current_action_tokens[0], current_action_tokens[1:])
|
|
566
|
|
- cls.__validate_action(a)
|
|
567
|
|
- actions.append(a)
|
|
568
|
|
- current_action_tokens = []
|
|
569
|
|
- else:
|
|
570
|
|
- current_action_tokens.append(token)
|
|
571
|
|
- token_index += 1
|
|
572
|
|
- raise PatternError('Unexpected end of line in action list')
|
|
573
|
|
-
|
|
574
|
|
- @classmethod
|
|
575
|
|
- def __validate_action(cls, action: PatternAction) -> None:
|
|
576
|
|
- args = cls.ACTION_TO_ARGS.get(action.action)
|
|
577
|
|
- if args is None:
|
|
578
|
|
- raise PatternError(f'Unknown action "{action.action}"')
|
|
579
|
|
- if len(action.arguments) != len(args):
|
|
580
|
|
- if len(args) == 0:
|
|
581
|
|
- raise PatternError(f'Action "{action.action}" expects no arguments, ' + \
|
|
582
|
|
- f'got {len(action.arguments)}.')
|
|
583
|
|
- else:
|
|
584
|
|
- raise PatternError(f'Action "{action.action}" expects {len(args)} ' + \
|
|
585
|
|
- f'arguments, got {len(action.arguments)}.')
|
|
586
|
|
- for i, datatype in enumerate(args):
|
|
587
|
|
- action.arguments[i] = cls.parse_value(action.arguments[i], datatype)
|
|
588
|
|
-
|
|
589
|
|
- @classmethod
|
|
590
|
|
- def read_expression(cls,
|
|
591
|
|
- tokens: list,
|
|
592
|
|
- token_index: int,
|
|
593
|
|
- depth: int = 0,
|
|
594
|
|
- one_subexpression: bool = False) -> tuple:
|
|
595
|
|
- """
|
|
596
|
|
- Reads an expression from a list of statement tokens. Returns a tuple
|
|
597
|
|
- containing the PatternExpression and the token index it left off at.
|
|
598
|
|
- If one_subexpression is True then it will return after reading a
|
|
599
|
|
- single expression instead of joining multiples (for readong the
|
|
600
|
|
- subject of a NOT expression).
|
|
601
|
|
- """
|
|
602
|
|
- subexpressions = []
|
|
603
|
|
- last_compound_operator = None
|
|
604
|
|
- while token_index < len(tokens):
|
|
605
|
|
- if one_subexpression:
|
|
606
|
|
- if len(subexpressions) == 1:
|
|
607
|
|
- return (subexpressions[0], token_index)
|
|
608
|
|
- if len(subexpressions) > 1:
|
|
609
|
|
- raise PatternError('Too many subexpressions')
|
|
610
|
|
- compound_operator = None
|
|
611
|
|
- if tokens[token_index] == ')':
|
|
612
|
|
- if len(subexpressions) == 0:
|
|
613
|
|
- raise PatternError('No subexpressions')
|
|
614
|
|
- if len(subexpressions) == 1:
|
|
615
|
|
- return (subexpressions[0], token_index)
|
|
616
|
|
- return (PatternCompoundExpression(last_compound_operator, subexpressions), token_index)
|
|
617
|
|
- if tokens[token_index] in set(["and", "or"]):
|
|
618
|
|
- compound_operator = tokens[token_index]
|
|
619
|
|
- if last_compound_operator and compound_operator != last_compound_operator:
|
|
620
|
|
- subexpressions = [ PatternCompoundExpression(last_compound_operator, subexpressions) ]
|
|
621
|
|
- last_compound_operator = compound_operator
|
|
622
|
|
- else:
|
|
623
|
|
- last_compound_operator = compound_operator
|
|
624
|
|
- token_index += 1
|
|
625
|
|
- if tokens[token_index] == '!':
|
|
626
|
|
- (exp, next_index) = cls.read_expression(tokens, token_index + 1, \
|
|
627
|
|
- depth + 1, one_subexpression=True)
|
|
628
|
|
- subexpressions.append(PatternCompoundExpression('!', [exp]))
|
|
629
|
|
- token_index = next_index
|
|
630
|
|
- elif tokens[token_index] == '(':
|
|
631
|
|
- (exp, next_index) = cls.read_expression(tokens, token_index + 1, depth + 1)
|
|
632
|
|
- if tokens[next_index] != ')':
|
|
633
|
|
- raise PatternError('Expected )')
|
|
634
|
|
- subexpressions.append(exp)
|
|
635
|
|
- token_index = next_index + 1
|
|
636
|
|
- else:
|
|
637
|
|
- (simple, next_index) = cls.read_simple_expression(tokens, token_index, depth)
|
|
638
|
|
- subexpressions.append(simple)
|
|
639
|
|
- token_index = next_index
|
|
640
|
|
- if len(subexpressions) == 0:
|
|
641
|
|
- raise PatternError('No subexpressions')
|
|
642
|
|
- elif len(subexpressions) == 1:
|
|
643
|
|
- return (subexpressions[0], token_index)
|
|
644
|
|
- else:
|
|
645
|
|
- return (PatternCompoundExpression(last_compound_operator, subexpressions), token_index)
|
|
646
|
|
-
|
|
647
|
|
- @classmethod
|
|
648
|
|
- def read_simple_expression(cls, tokens: list, token_index: int, depth: int = 0) -> tuple:
|
|
649
|
|
- """
|
|
650
|
|
- Reads a simple expression consisting of a field name, operator, and
|
|
651
|
|
- comparison value. Returns a tuple of the PatternSimpleExpression and
|
|
652
|
|
- the token index it left off at.
|
|
653
|
|
- """
|
|
654
|
|
- if depth > 8:
|
|
655
|
|
- raise PatternError('Expression nests too deeply')
|
|
656
|
|
- if token_index >= len(tokens):
|
|
657
|
|
- raise PatternError('Expected field name, found EOL')
|
|
658
|
|
- field = tokens[token_index]
|
|
659
|
|
- token_index += 1
|
|
660
|
|
-
|
|
661
|
|
- datatype = cls.FIELD_TO_TYPE.get(field)
|
|
662
|
|
- if datatype is None:
|
|
663
|
|
- raise PatternError(f'No such field "{field}"')
|
|
664
|
|
-
|
|
665
|
|
- if token_index >= len(tokens):
|
|
666
|
|
- raise PatternError('Expected operator, found EOL')
|
|
667
|
|
- op = tokens[token_index]
|
|
668
|
|
- token_index += 1
|
|
669
|
|
-
|
|
670
|
|
- if op == '!':
|
|
671
|
|
- if token_index >= len(tokens):
|
|
672
|
|
- raise PatternError('Expected operator, found EOL')
|
|
673
|
|
- op = '!' + tokens[token_index]
|
|
674
|
|
- token_index += 1
|
|
675
|
|
-
|
|
676
|
|
- allowed_ops = cls.TYPE_TO_OPERATORS[datatype]
|
|
677
|
|
- if op not in allowed_ops:
|
|
678
|
|
- if op in cls.OPERATORS_ALL:
|
|
679
|
|
- raise PatternError(f'Operator {op} cannot be used with field "{field}"')
|
|
680
|
|
- raise PatternError(f'Unrecognized operator "{op}" - allowed: {list(allowed_ops)}')
|
|
681
|
|
-
|
|
682
|
|
- if token_index >= len(tokens):
|
|
683
|
|
- raise PatternError('Expected value, found EOL')
|
|
684
|
|
- value = tokens[token_index]
|
|
685
|
|
-
|
|
686
|
|
- try:
|
|
687
|
|
- value = cls.parse_value(value, datatype)
|
|
688
|
|
- except ValueError as cause:
|
|
689
|
|
- raise PatternError(f'Bad value {value}') from cause
|
|
690
|
|
-
|
|
691
|
|
- token_index += 1
|
|
692
|
|
- exp = PatternSimpleExpression(field, op, value)
|
|
693
|
|
- return (exp, token_index)
|
|
694
|
|
-
|
|
695
|
|
- @classmethod
|
|
696
|
|
- def parse_value(cls, value: str, datatype: str):
|
|
697
|
|
- """
|
|
698
|
|
- Converts a value token to its Python value. Raises ValueError on failure.
|
|
699
|
|
- """
|
|
700
|
|
- if datatype == cls.TYPE_ID:
|
|
701
|
|
- if not is_user_id(value):
|
|
702
|
|
- raise ValueError(f'Illegal user id value: {value}')
|
|
703
|
|
- return value
|
|
704
|
|
- if datatype == cls.TYPE_MEMBER:
|
|
705
|
|
- return user_id_from_mention(value)
|
|
706
|
|
- if datatype == cls.TYPE_TEXT:
|
|
707
|
|
- return str_from_quoted_str(value)
|
|
708
|
|
- if datatype == cls.TYPE_INT:
|
|
709
|
|
- return int(value)
|
|
710
|
|
- if datatype == cls.TYPE_FLOAT:
|
|
711
|
|
- return float(value)
|
|
712
|
|
- if datatype == cls.TYPE_TIMESPAN:
|
|
713
|
|
- return timedelta_from_str(value)
|
|
714
|
|
- raise ValueError(f'Unhandled datatype {datatype}')
|