|
|
@@ -1,75 +1,166 @@
|
|
1
|
|
-from discord import Guild, Message
|
|
|
1
|
+from abc import ABC, abstractmethod
|
|
|
2
|
+from discord import Guild, Member, Message
|
|
2
|
3
|
from discord.ext import commands
|
|
3
|
4
|
from datetime import timedelta
|
|
|
5
|
+import re
|
|
4
|
6
|
|
|
5
|
|
-from cogs.basecog import BaseCog, BotMessage
|
|
|
7
|
+from cogs.basecog import BaseCog, BotMessage, BotMessageReaction
|
|
|
8
|
+from config import CONFIG
|
|
6
|
9
|
from storage import Storage
|
|
7
|
10
|
|
|
8
|
|
-class Criterion:
|
|
9
|
|
- def __init__(self, type, **kwargs):
|
|
|
11
|
+class PatternAction:
|
|
|
12
|
+ def __init__(self, type: str, args: list):
|
|
10
|
13
|
self.type = type
|
|
11
|
|
- if type == 'contains':
|
|
12
|
|
- text = kwargs['text']
|
|
13
|
|
- self.text = text
|
|
14
|
|
- self.test = lambda m : text.lower() in m.content.lower()
|
|
15
|
|
- elif type == 'joinage':
|
|
16
|
|
- min = kwargs['min']
|
|
17
|
|
- self.min = min
|
|
18
|
|
- self.test = lambda m : m.created_at - m.author.joined_at < min
|
|
19
|
|
- else:
|
|
20
|
|
- raise RuntimeError(f'Unknown criterion type "{type}"')
|
|
|
14
|
+ self.arguments = list(args)
|
|
|
15
|
+
|
|
|
16
|
+ def __str__(self) -> str:
|
|
|
17
|
+ arg_str = ', '.join(self.arguments)
|
|
|
18
|
+ return f'{self.type}({arg_str})'
|
|
21
|
19
|
|
|
|
20
|
+class PatternExpression(ABC):
|
|
|
21
|
+ def __init__(self):
|
|
|
22
|
+ pass
|
|
|
23
|
+
|
|
|
24
|
+ @abstractmethod
|
|
22
|
25
|
def matches(self, message: Message) -> bool:
|
|
23
|
|
- return self.test(message)
|
|
|
26
|
+ return False
|
|
24
|
27
|
|
|
25
|
|
- @classmethod
|
|
26
|
|
- def decode(cls, val: dict):
|
|
27
|
|
- type = val['type']
|
|
28
|
|
- if type == 'contains':
|
|
29
|
|
- return Criterion(type, text=val['text'])
|
|
30
|
|
- elif type == 'joinage':
|
|
31
|
|
- return Criterion(type, min=timedelta(seconds=val['min']))
|
|
32
|
|
-
|
|
33
|
|
-class Pattern:
|
|
34
|
|
- def __init__(self, criteria: list, action: str, must_match_all: bool = True):
|
|
35
|
|
- self.criteria = criteria
|
|
36
|
|
- self.action = action
|
|
37
|
|
- self.must_match_all = must_match_all
|
|
|
28
|
+class PatternSimpleExpression(PatternExpression):
|
|
|
29
|
+ def __init__(self, field: str, operator: str, value):
|
|
|
30
|
+ self.field = field
|
|
|
31
|
+ self.operator = operator
|
|
|
32
|
+ self.value = value
|
|
38
|
33
|
|
|
39
|
34
|
def matches(self, message: Message) -> bool:
|
|
40
|
|
- for criterion in self.criteria:
|
|
41
|
|
- crit_matches = criterion.matches(message)
|
|
42
|
|
- if crit_matches and not self.must_match_all:
|
|
43
|
|
- return True
|
|
44
|
|
- if not crit_matches and self.must_match_all:
|
|
45
|
|
- return False
|
|
46
|
|
- return self.must_match_all
|
|
|
35
|
+ field_value = None
|
|
|
36
|
+ if self.field == 'content':
|
|
|
37
|
+ field_value = message.content
|
|
|
38
|
+ elif self.field == 'author':
|
|
|
39
|
+ field_value = str(message.author.id)
|
|
|
40
|
+ elif self.field == 'author.id':
|
|
|
41
|
+ field_value = str(message.author.id)
|
|
|
42
|
+ elif self.field == 'author.joinage':
|
|
|
43
|
+ field_value = message.created_at - message.author.joined_at
|
|
|
44
|
+ elif self.field == 'author.name':
|
|
|
45
|
+ field_value = message.author.name
|
|
|
46
|
+ else:
|
|
|
47
|
+ raise ValueError(f'Bad field name {self.field}')
|
|
|
48
|
+ if self.operator == '==':
|
|
|
49
|
+ if isinstance(field_value, str) and isinstance(self.value, str):
|
|
|
50
|
+ return field_value.lower() == self.value.lower()
|
|
|
51
|
+ return field_value == self.value
|
|
|
52
|
+ if self.operator == '!=':
|
|
|
53
|
+ if isinstance(field_value, str) and isinstance(self.value, str):
|
|
|
54
|
+ return field_value.lower() != self.value.lower()
|
|
|
55
|
+ return field_value != self.value
|
|
|
56
|
+ if self.operator == '<':
|
|
|
57
|
+ return field_value < self.value
|
|
|
58
|
+ if self.operator == '>':
|
|
|
59
|
+ return field_value > self.value
|
|
|
60
|
+ if self.operator == '<=':
|
|
|
61
|
+ return field_value <= self.value
|
|
|
62
|
+ if self.operator == '>=':
|
|
|
63
|
+ return field_value >= self.value
|
|
|
64
|
+ if self.operator == 'contains':
|
|
|
65
|
+ return self.value.lower() in field_value.lower()
|
|
|
66
|
+ if self.operator == '!contains':
|
|
|
67
|
+ return self.value.lower() not in field_value.lower()
|
|
|
68
|
+ if self.operator == 'matches':
|
|
|
69
|
+ p = re.compile(self.value.lower())
|
|
|
70
|
+ return p.match(field_value.lower()) is not None
|
|
|
71
|
+ if self.operator == '!matches':
|
|
|
72
|
+ p = re.compile(self.value.lower())
|
|
|
73
|
+ return p.match(field_value.lower()) is None
|
|
|
74
|
+ raise ValueError(f'Bad operator {self.operator}')
|
|
47
|
75
|
|
|
48
|
|
- @classmethod
|
|
49
|
|
- def decode(cls, val: dict):
|
|
50
|
|
- match_all = val.get('must_match_all')
|
|
51
|
|
- action = val.get('action')
|
|
52
|
|
- encoded_criteria = val.get('criteria')
|
|
53
|
|
- criteria = []
|
|
54
|
|
- for ec in encoded_criteria:
|
|
55
|
|
- criteria.append(Criterion.decode(ec))
|
|
56
|
|
- return Pattern(criteria, action, match_all if isinstance(match_all, bool) else True)
|
|
|
76
|
+ def __str__(self) -> str:
|
|
|
77
|
+ return f'({self.field} {self.operator} {self.value})'
|
|
|
78
|
+
|
|
|
79
|
+class PatternCompoundExpression(PatternExpression):
|
|
|
80
|
+ def __init__(self, operator: str, operands: list):
|
|
|
81
|
+ self.operator = operator
|
|
|
82
|
+ self.operands = list(operands)
|
|
|
83
|
+
|
|
|
84
|
+ def matches(self, message: Message) -> bool:
|
|
|
85
|
+ if self.operator == '!':
|
|
|
86
|
+ return not self.operands[0].matches(message)
|
|
|
87
|
+ elif self.operator == 'and':
|
|
|
88
|
+ for op in self.operands:
|
|
|
89
|
+ if not op.matches(message):
|
|
|
90
|
+ return False
|
|
|
91
|
+ return True
|
|
|
92
|
+ elif self.operator == 'or':
|
|
|
93
|
+ for op in self.operands:
|
|
|
94
|
+ if op.matches(message):
|
|
|
95
|
+ return True
|
|
|
96
|
+ return False
|
|
|
97
|
+ else:
|
|
|
98
|
+ raise RuntimeError(f'Bad operator "{self.operator}"')
|
|
|
99
|
+
|
|
|
100
|
+ def __str__(self) -> str:
|
|
|
101
|
+ if self.operator == '!':
|
|
|
102
|
+ return f'(!( {self.operands[0]} ))'
|
|
|
103
|
+ else:
|
|
|
104
|
+ strs = map(str, self.operands)
|
|
|
105
|
+ joined = f' {self.operator} '.join(strs)
|
|
|
106
|
+ return f'( {joined} )'
|
|
|
107
|
+
|
|
|
108
|
+class PatternStatement:
|
|
|
109
|
+ def __init__(self, name: str, actions: list, expression: PatternExpression, original: str):
|
|
|
110
|
+ self.name = name
|
|
|
111
|
+ self.actions = list(actions) # PatternAction[]
|
|
|
112
|
+ self.expression = expression
|
|
|
113
|
+ self.original = original
|
|
|
114
|
+
|
|
|
115
|
+class PatternContext:
|
|
|
116
|
+ def __init__(self, message: Message, statement: PatternStatement):
|
|
|
117
|
+ self.message = message
|
|
|
118
|
+ self.statement = statement
|
|
|
119
|
+ self.is_deleted = False
|
|
|
120
|
+ self.is_kicked = False
|
|
|
121
|
+ self.is_banned = False
|
|
57
|
122
|
|
|
58
|
123
|
class PatternCog(BaseCog):
|
|
59
|
124
|
def __init__(self, bot):
|
|
60
|
125
|
super().__init__(bot)
|
|
61
|
126
|
|
|
62
|
|
- def __patterns(self, guild: Guild) -> list:
|
|
63
|
|
- patterns = Storage.get_state_value(guild, 'pattern_patterns')
|
|
|
127
|
+ # def __patterns(self, guild: Guild) -> list:
|
|
|
128
|
+ # patterns = Storage.get_state_value(guild, 'pattern_patterns')
|
|
|
129
|
+ # if patterns is None:
|
|
|
130
|
+ # patterns_encoded = Storage.get_config_value(guild, 'pattern_patterns')
|
|
|
131
|
+ # if patterns_encoded:
|
|
|
132
|
+ # patterns = []
|
|
|
133
|
+ # for pe in patterns_encoded:
|
|
|
134
|
+ # patterns.append(Pattern.decode(pe))
|
|
|
135
|
+ # Storage.set_state_value(guild, 'pattern_patterns', patterns)
|
|
|
136
|
+ # return patterns
|
|
|
137
|
+
|
|
|
138
|
+ def __get_patterns(self, guild: Guild) -> dict:
|
|
|
139
|
+ patterns = Storage.get_state_value(guild, 'PatternsCog.patterns')
|
|
64
|
140
|
if patterns is None:
|
|
65
|
|
- patterns_encoded = Storage.get_config_value(guild, 'pattern_patterns')
|
|
|
141
|
+ patterns = {}
|
|
|
142
|
+ patterns_encoded = Storage.get_config_value(guild, 'PatternsCog.patterns')
|
|
66
|
143
|
if patterns_encoded:
|
|
67
|
|
- patterns = []
|
|
68
|
144
|
for pe in patterns_encoded:
|
|
69
|
|
- patterns.append(Pattern.decode(pe))
|
|
70
|
|
- Storage.set_state_value(guild, 'pattern_patterns', patterns)
|
|
|
145
|
+ name = pe.get('name')
|
|
|
146
|
+ statement = pe.get('statement')
|
|
|
147
|
+ try:
|
|
|
148
|
+ ps = PatternCompiler.parse_statement(name, statement)
|
|
|
149
|
+ patterns[name] = ps
|
|
|
150
|
+ except RuntimeError as e:
|
|
|
151
|
+ self.log(guild, f'Error parsing saved statement "{name}". Skipping: {statement}')
|
|
|
152
|
+ Storage.set_state_value(guild, 'PatternsCog.patterns', patterns)
|
|
71
|
153
|
return patterns
|
|
72
|
154
|
|
|
|
155
|
+ def __save_patterns(self, guild: Guild, patterns: dict) -> None:
|
|
|
156
|
+ to_save = []
|
|
|
157
|
+ for name, statement in patterns.items():
|
|
|
158
|
+ to_save.append({
|
|
|
159
|
+ 'name': name,
|
|
|
160
|
+ 'statement': statement.original,
|
|
|
161
|
+ })
|
|
|
162
|
+ Storage.set_config_value(guild, 'PatternsCog.patterns', to_save)
|
|
|
163
|
+
|
|
73
|
164
|
@commands.Cog.listener()
|
|
74
|
165
|
async def on_message(self, message: Message) -> None:
|
|
75
|
166
|
if message.author is None or \
|
|
|
@@ -82,78 +173,85 @@ class PatternCog(BaseCog):
|
|
82
|
173
|
if message.author.permissions_in(message.channel).ban_members:
|
|
83
|
174
|
# Ignore mods
|
|
84
|
175
|
return
|
|
85
|
|
- patterns = self.__patterns(message.guild)
|
|
86
|
|
- for pattern in patterns:
|
|
87
|
|
- if pattern.matches(message):
|
|
88
|
|
- text = None
|
|
89
|
|
- if pattern.action == 'delete':
|
|
90
|
|
- await message.delete()
|
|
91
|
|
- text = f'Message from {message.author.mention} matched ' + \
|
|
92
|
|
- 'banned pattern. Deleted.'
|
|
93
|
|
- self.log(message.guild, 'Message matched pattern. Deleted.')
|
|
94
|
|
- elif pattern.action == 'kick':
|
|
95
|
|
- await message.delete()
|
|
96
|
|
- await message.author.kick(reason='Rocketbot: Message matched banned pattern')
|
|
97
|
|
- text = f'Message from {message.author.mention} matched ' + \
|
|
98
|
|
- 'banned pattern. Message deleted and user kicked.'
|
|
99
|
|
- self.log(message.guild,
|
|
100
|
|
- '\u0007Message matched pattern. Kicked ' + \
|
|
101
|
|
- f'{message.author.name} ({message.author.id}).')
|
|
102
|
|
- elif pattern.action == 'ban':
|
|
103
|
|
- await message.delete()
|
|
104
|
|
- await message.author.ban(reason='Rocketbot: Message matched banned pattern')
|
|
105
|
|
- text = f'Message from {message.author.mention} matched ' + \
|
|
106
|
|
- 'banned pattern. Message deleted and user banned.'
|
|
107
|
|
- self.log(message.guild,
|
|
108
|
|
- '\u0007Message matched pattern. Banned ' + \
|
|
109
|
|
- f'{message.author_name} ({message.author.id}).')
|
|
110
|
|
- if text:
|
|
111
|
|
- m = BotMessage(message.guild,
|
|
112
|
|
- text = msg,
|
|
113
|
|
- type = BotMessage.TYPE_MOD_WARNING)
|
|
114
|
|
- m.quote = message.content
|
|
115
|
|
- await self.post_message(m)
|
|
|
176
|
+
|
|
|
177
|
+ patterns = self.__get_patterns(message.guild)
|
|
|
178
|
+ for name, statement in patterns.items():
|
|
|
179
|
+ if statement.expression.matches(message):
|
|
|
180
|
+ await self.__trigger_actions(message, statement)
|
|
116
|
181
|
break
|
|
117
|
182
|
|
|
118
|
|
- """
|
|
119
|
|
- Expression language samples:
|
|
120
|
|
-
|
|
121
|
|
- content contains "poop"
|
|
122
|
|
- content contains "poop" and content contains "tinkle"
|
|
123
|
|
- joinage < 600s
|
|
124
|
|
- (content contains "this" and content contains "that") or content contains "whatever"
|
|
125
|
|
-
|
|
126
|
|
- <field> <op> <value>
|
|
127
|
|
-
|
|
128
|
|
- Fields:
|
|
129
|
|
- content
|
|
130
|
|
- author.id
|
|
131
|
|
- author.name
|
|
132
|
|
- author.joinage
|
|
133
|
|
-
|
|
134
|
|
- Ops:
|
|
135
|
|
- ==
|
|
136
|
|
- !=
|
|
137
|
|
- <
|
|
138
|
|
- >
|
|
139
|
|
- <=
|
|
140
|
|
- >=
|
|
141
|
|
- contains, !contains -- plain strings
|
|
142
|
|
- matches, !matches -- regexes
|
|
143
|
|
-
|
|
144
|
|
- Value types:
|
|
145
|
|
- timedelta (600, 600s, 10m, 5m30s)
|
|
146
|
|
- number
|
|
147
|
|
- string
|
|
148
|
|
- regex
|
|
149
|
|
- mention
|
|
150
|
|
-
|
|
151
|
|
- Evaluation
|
|
152
|
|
- and
|
|
153
|
|
- or
|
|
154
|
|
- ( )
|
|
155
|
|
- !( )
|
|
156
|
|
- """
|
|
|
183
|
+
|
|
|
184
|
+ async def __trigger_actions(self, message: Message, statement: PatternStatement) -> None:
|
|
|
185
|
+ context = PatternContext(message, statement)
|
|
|
186
|
+ should_alert_mods = False
|
|
|
187
|
+ action_descriptions = []
|
|
|
188
|
+ self.log(message.guild, f'Message from {message.author.name} matched pattern "{statement.name}"')
|
|
|
189
|
+ for action in statement.actions:
|
|
|
190
|
+ if action.type == 'ban':
|
|
|
191
|
+ await message.author.ban(
|
|
|
192
|
+ reason=f'Rocketbot: Message matched custom pattern named "{statement.name}"',
|
|
|
193
|
+ delete_message_days=0)
|
|
|
194
|
+ context.is_banned = True
|
|
|
195
|
+ context.is_kicked = True
|
|
|
196
|
+ action_descriptions.append('Author banned')
|
|
|
197
|
+ self.log(message.guild, f'{message.author.name} banned')
|
|
|
198
|
+ elif action.type == 'delete':
|
|
|
199
|
+ await message.delete()
|
|
|
200
|
+ context.is_deleted = True
|
|
|
201
|
+ action_descriptions.append('Message deleted')
|
|
|
202
|
+ self.log(message.guild, f'{message.author.name}\'s message deleted')
|
|
|
203
|
+ elif action.type == 'kick':
|
|
|
204
|
+ await message.author.kick(
|
|
|
205
|
+ reason=f'Rocketbot: Message matched custom pattern named "{statement.name}"')
|
|
|
206
|
+ context.is_kicked = True
|
|
|
207
|
+ action_descriptions.append('Author kicked')
|
|
|
208
|
+ self.log(message.guild, f'{message.author.name} kicked')
|
|
|
209
|
+ elif action.type == 'modwarn':
|
|
|
210
|
+ should_alert_mods = True
|
|
|
211
|
+ action_descriptions.append('Mods alerted')
|
|
|
212
|
+ elif action.type == 'reply':
|
|
|
213
|
+ await message.reply(
|
|
|
214
|
+ f'{action.arguments[0]}',
|
|
|
215
|
+ mention_author=False)
|
|
|
216
|
+ action_descriptions.append('Autoreplied')
|
|
|
217
|
+ self.log(message.guild, f'{message.author.name} autoreplied to')
|
|
|
218
|
+ bm = BotMessage(
|
|
|
219
|
+ message.guild,
|
|
|
220
|
+ f'User {message.author.name} tripped custom pattern ' + \
|
|
|
221
|
+ f'`{statement.name}`.\n\nAutomatic actions taken:\n• ' + \
|
|
|
222
|
+ ('\n• '.join(action_descriptions)),
|
|
|
223
|
+ type=BotMessage.TYPE_MOD_WARNING if should_alert_mods else BotMessage.TYPE_INFO,
|
|
|
224
|
+ context=context)
|
|
|
225
|
+ bm.quote = message.content
|
|
|
226
|
+ await bm.set_reactions(BotMessageReaction.standard_set(
|
|
|
227
|
+ did_delete=context.is_deleted,
|
|
|
228
|
+ did_kick=context.is_kicked,
|
|
|
229
|
+ did_ban=context.is_banned))
|
|
|
230
|
+ await self.post_message(bm)
|
|
|
231
|
+
|
|
|
232
|
+ async def on_mod_react(self,
|
|
|
233
|
+ bot_message: BotMessage,
|
|
|
234
|
+ reaction: BotMessageReaction,
|
|
|
235
|
+ reacted_by: Member) -> None:
|
|
|
236
|
+ context: PatternContext = bot_message.context
|
|
|
237
|
+ if reaction.emoji == CONFIG['trash_emoji']:
|
|
|
238
|
+ await context.message.delete()
|
|
|
239
|
+ context.is_deleted = True
|
|
|
240
|
+ elif reaction.emoji == CONFIG['kick_emoji']:
|
|
|
241
|
+ await context.message.author.kick(
|
|
|
242
|
+ reason=f'Rocketbot: Message matched custom pattern named ' + \
|
|
|
243
|
+ '"{statement.name}". Kicked by {reacted_by.name}.')
|
|
|
244
|
+ context.is_kicked = True
|
|
|
245
|
+ elif reaction.emoji == CONFIG['ban_emoji']:
|
|
|
246
|
+ await context.message.author.ban(
|
|
|
247
|
+ reason=f'Rocketbot: Message matched custom pattern named ' + \
|
|
|
248
|
+ '"{statement.name}". Banned by {reacted_by.name}.',
|
|
|
249
|
+ delete_message_days=1)
|
|
|
250
|
+ context.is_banned = True
|
|
|
251
|
+ await bot_message.set_reactions(BotMessageReaction.standard_set(
|
|
|
252
|
+ did_delete=context.is_deleted,
|
|
|
253
|
+ did_kick=context.is_kicked,
|
|
|
254
|
+ did_ban=context.is_banned))
|
|
157
|
255
|
|
|
158
|
256
|
@commands.group(
|
|
159
|
257
|
brief='Manages message pattern matching',
|
|
|
@@ -165,10 +263,391 @@ class PatternCog(BaseCog):
|
|
165
|
263
|
if context.invoked_subcommand is None:
|
|
166
|
264
|
await context.send_help()
|
|
167
|
265
|
|
|
168
|
|
- @patterns.command()
|
|
169
|
|
- async def addpattern(self, context: commands.Context, name: str, expression: str, *args):
|
|
170
|
|
- print(f'Pattern name: {name}')
|
|
|
266
|
+ @patterns.command(
|
|
|
267
|
+ brief='Adds a custom pattern',
|
|
|
268
|
+ description='Adds a custom pattern. Patterns use a simplified expression language. Full documentation found here: https://git.rixafrix.com/ialbert/python-app-rocketbot/src/branch/master/patterns.md',
|
|
|
269
|
+ usage='<pattern_name> <expression...>',
|
|
|
270
|
+ ignore_extra=True
|
|
|
271
|
+ )
|
|
|
272
|
+ async def add(self, context: commands.Context, name: str):
|
|
|
273
|
+ pattern_str = PatternCompiler.expression_str_from_context(context, name)
|
|
|
274
|
+ try:
|
|
|
275
|
+ statement = PatternCompiler.parse_statement(name, pattern_str)
|
|
|
276
|
+ patterns = self.__get_patterns(context.guild)
|
|
|
277
|
+ patterns[name] = statement
|
|
|
278
|
+ self.__save_patterns(context.guild, patterns)
|
|
|
279
|
+ await context.message.reply(
|
|
|
280
|
+ f'{CONFIG["success_emoji"]} Pattern `{name}` added.',
|
|
|
281
|
+ mention_author=False)
|
|
|
282
|
+ except Exception as e:
|
|
|
283
|
+ await context.message.reply(
|
|
|
284
|
+ f'{CONFIG["failure_emoji"]} Error parsing statement. {e}',
|
|
|
285
|
+ mention_author=False)
|
|
|
286
|
+
|
|
|
287
|
+ @patterns.command(
|
|
|
288
|
+ brief='Removes a custom pattern',
|
|
|
289
|
+ usage='<pattern_name>'
|
|
|
290
|
+ )
|
|
|
291
|
+ async def remove(self, context: commands.Context, name: str):
|
|
|
292
|
+ patterns = self.__get_patterns(context.guild)
|
|
|
293
|
+ if patterns.get(name) is not None:
|
|
|
294
|
+ del patterns[name]
|
|
|
295
|
+ self.__save_patterns(context.guild, patterns)
|
|
|
296
|
+ await context.message.reply(
|
|
|
297
|
+ f'{CONFIG["success_emoji"]} Pattern `{name}` deleted.',
|
|
|
298
|
+ mention_author=False)
|
|
|
299
|
+ else:
|
|
|
300
|
+ await context.message.reply(
|
|
|
301
|
+ f'{CONFIG["failure_emoji"]} No pattern named `{name}`.',
|
|
|
302
|
+ mention_author=False)
|
|
|
303
|
+
|
|
|
304
|
+ @patterns.command(
|
|
|
305
|
+ brief='Lists all patterns'
|
|
|
306
|
+ )
|
|
|
307
|
+ async def list(self, context: commands.Context) -> None:
|
|
|
308
|
+ patterns = self.__get_patterns(context.guild)
|
|
|
309
|
+ if len(patterns) == 0:
|
|
|
310
|
+ await context.message.reply('No patterns defined.', mention_author=False)
|
|
|
311
|
+ return
|
|
|
312
|
+ msg = ''
|
|
|
313
|
+ for name, statement in sorted(patterns.items()):
|
|
|
314
|
+ msg += f'Pattern `{name}`:\n```\n{statement.original}\n```\n'
|
|
|
315
|
+ await context.message.reply(msg, mention_author=False)
|
|
|
316
|
+
|
|
|
317
|
+class PatternCompiler:
|
|
|
318
|
+ TYPE_ID = 'id'
|
|
|
319
|
+ TYPE_MEMBER = 'Member'
|
|
|
320
|
+ TYPE_TEXT = 'text'
|
|
|
321
|
+ TYPE_INT = 'int'
|
|
|
322
|
+ TYPE_FLOAT = 'float'
|
|
|
323
|
+ TYPE_TIMESPAN = 'timespan'
|
|
|
324
|
+
|
|
|
325
|
+ FIELD_TO_TYPE = {
|
|
|
326
|
+ 'content': TYPE_TEXT,
|
|
|
327
|
+ 'author': TYPE_MEMBER,
|
|
|
328
|
+ 'author.id': TYPE_ID,
|
|
|
329
|
+ 'author.name': TYPE_TEXT,
|
|
|
330
|
+ 'author.joinage': TYPE_TIMESPAN,
|
|
|
331
|
+ }
|
|
|
332
|
+
|
|
|
333
|
+ ACTION_TO_ARGS = {
|
|
|
334
|
+ 'ban': [],
|
|
|
335
|
+ 'delete': [],
|
|
|
336
|
+ 'kick': [],
|
|
|
337
|
+ 'modwarn': [],
|
|
|
338
|
+ 'reply': [ TYPE_TEXT ],
|
|
|
339
|
+ }
|
|
|
340
|
+
|
|
|
341
|
+ OPERATORS_IDENTITY = set([ '==', '!=' ])
|
|
|
342
|
+ OPERATORS_COMPARISON = set([ '<', '>', '<=', '>=' ])
|
|
|
343
|
+ OPERATORS_NUMERIC = OPERATORS_IDENTITY | OPERATORS_COMPARISON
|
|
|
344
|
+ OPERATORS_TEXT = OPERATORS_IDENTITY | set([ 'contains', '!contains', 'matches', '!matches' ])
|
|
|
345
|
+ OPERATORS_ALL = OPERATORS_IDENTITY | OPERATORS_COMPARISON | OPERATORS_TEXT
|
|
|
346
|
+
|
|
|
347
|
+ TYPE_TO_OPERATORS = {
|
|
|
348
|
+ TYPE_ID: OPERATORS_IDENTITY,
|
|
|
349
|
+ TYPE_MEMBER: OPERATORS_IDENTITY,
|
|
|
350
|
+ TYPE_TEXT: OPERATORS_TEXT,
|
|
|
351
|
+ TYPE_INT: OPERATORS_NUMERIC,
|
|
|
352
|
+ TYPE_FLOAT: OPERATORS_NUMERIC,
|
|
|
353
|
+ TYPE_TIMESPAN: OPERATORS_NUMERIC,
|
|
|
354
|
+ }
|
|
|
355
|
+
|
|
|
356
|
+ WHITESPACE_CHARS = ' \t\n\r'
|
|
|
357
|
+ STRING_QUOTE_CHARS = '\'"'
|
|
|
358
|
+ SYMBOL_CHARS = 'abcdefghijklmnopqrstuvwxyz.'
|
|
|
359
|
+ VALUE_CHARS = '0123456789dhms<@!>'
|
|
|
360
|
+ OP_CHARS = '<=>!(),'
|
|
|
361
|
+
|
|
|
362
|
+ @classmethod
|
|
|
363
|
+ def expression_str_from_context(cls, context: commands.Context, name: str) -> str:
|
|
|
364
|
+ pattern_str = context.message.content
|
|
|
365
|
+ command_chain = [ name ]
|
|
|
366
|
+ cmd = context.command
|
|
|
367
|
+ while cmd:
|
|
|
368
|
+ command_chain.insert(0, cmd.name)
|
|
|
369
|
+ cmd = cmd.parent
|
|
|
370
|
+ command_chain[0] = f'{context.prefix}{command_chain[0]}'
|
|
|
371
|
+ for cmd in command_chain:
|
|
|
372
|
+ if pattern_str.startswith(cmd):
|
|
|
373
|
+ pattern_str = pattern_str[len(cmd):].lstrip()
|
|
|
374
|
+ return pattern_str
|
|
|
375
|
+
|
|
|
376
|
+ @classmethod
|
|
|
377
|
+ def parse_statement(cls, name: str, statement: str) -> PatternStatement:
|
|
|
378
|
+ tokens = cls.tokenize(statement)
|
|
|
379
|
+ token_index = 0
|
|
|
380
|
+ actions, token_index = cls.read_actions(tokens, token_index)
|
|
|
381
|
+ expression, token_index = cls.read_expression(tokens, token_index)
|
|
|
382
|
+ return PatternStatement(name, actions, expression, statement)
|
|
|
383
|
+
|
|
|
384
|
+ @classmethod
|
|
|
385
|
+ def tokenize(cls, statement: str) -> list:
|
|
171
|
386
|
tokens = []
|
|
172
|
|
- tokens.append(expression)
|
|
173
|
|
- tokens += args
|
|
174
|
|
- print('Expression: ' + (' '.join(tokens)))
|
|
|
387
|
+ in_quote = False
|
|
|
388
|
+ in_escape = False
|
|
|
389
|
+ all_token_types = set([ 'sym', 'op', 'val' ])
|
|
|
390
|
+ possible_token_types = set(all_token_types)
|
|
|
391
|
+ current_token = ''
|
|
|
392
|
+ for ch in statement:
|
|
|
393
|
+ if in_quote:
|
|
|
394
|
+ if in_escape:
|
|
|
395
|
+ if ch == 'n':
|
|
|
396
|
+ current_token += '\n'
|
|
|
397
|
+ elif ch == 't':
|
|
|
398
|
+ current_token += '\t'
|
|
|
399
|
+ else:
|
|
|
400
|
+ current_token += ch
|
|
|
401
|
+ in_escape = False
|
|
|
402
|
+ elif ch == '\\':
|
|
|
403
|
+ in_escape = True
|
|
|
404
|
+ elif ch == in_quote:
|
|
|
405
|
+ current_token += ch
|
|
|
406
|
+ tokens.append(current_token)
|
|
|
407
|
+ current_token = ''
|
|
|
408
|
+ possible_token_types |= all_token_types
|
|
|
409
|
+ in_quote = False
|
|
|
410
|
+ else:
|
|
|
411
|
+ current_token += ch
|
|
|
412
|
+ else:
|
|
|
413
|
+ if ch in cls.STRING_QUOTE_CHARS:
|
|
|
414
|
+ if len(current_token) > 0:
|
|
|
415
|
+ tokens.append(current_token)
|
|
|
416
|
+ current_token = ''
|
|
|
417
|
+ possible_token_types |= all_token_types
|
|
|
418
|
+ in_quote = ch
|
|
|
419
|
+ current_token = ch
|
|
|
420
|
+ elif ch == '\\':
|
|
|
421
|
+ raise RuntimeError("Unexpected \\")
|
|
|
422
|
+ elif ch in cls.WHITESPACE_CHARS:
|
|
|
423
|
+ if len(current_token) > 0:
|
|
|
424
|
+ tokens.append(current_token)
|
|
|
425
|
+ current_token = ''
|
|
|
426
|
+ possible_token_types |= all_token_types
|
|
|
427
|
+ else:
|
|
|
428
|
+ possible_ch_types = set()
|
|
|
429
|
+ if ch in cls.SYMBOL_CHARS:
|
|
|
430
|
+ possible_ch_types.add('sym')
|
|
|
431
|
+ if ch in cls.VALUE_CHARS:
|
|
|
432
|
+ possible_ch_types.add('val')
|
|
|
433
|
+ if ch in cls.OP_CHARS:
|
|
|
434
|
+ possible_ch_types.add('op')
|
|
|
435
|
+ if len(current_token) > 0 and possible_ch_types.isdisjoint(possible_token_types):
|
|
|
436
|
+ if len(current_token) > 0:
|
|
|
437
|
+ tokens.append(current_token)
|
|
|
438
|
+ current_token = ''
|
|
|
439
|
+ possible_token_types |= all_token_types
|
|
|
440
|
+ possible_token_types &= possible_ch_types
|
|
|
441
|
+ current_token += ch
|
|
|
442
|
+ if len(current_token) > 0:
|
|
|
443
|
+ tokens.append(current_token)
|
|
|
444
|
+
|
|
|
445
|
+ # Some symbols might be glommed onto other tokens. Split 'em up.
|
|
|
446
|
+ prefixes_to_split = [ '!', '(', ',' ]
|
|
|
447
|
+ suffixes_to_split = [ ')', ',' ]
|
|
|
448
|
+ i = 0
|
|
|
449
|
+ while i < len(tokens):
|
|
|
450
|
+ token = tokens[i]
|
|
|
451
|
+ mutated = False
|
|
|
452
|
+ for prefix in prefixes_to_split:
|
|
|
453
|
+ if token.startswith(prefix) and len(token) > len(prefix):
|
|
|
454
|
+ tokens.insert(i, prefix)
|
|
|
455
|
+ tokens[i + 1] = token[len(prefix):]
|
|
|
456
|
+ i += 1
|
|
|
457
|
+ mutated = True
|
|
|
458
|
+ break
|
|
|
459
|
+ if mutated:
|
|
|
460
|
+ continue
|
|
|
461
|
+ for suffix in suffixes_to_split:
|
|
|
462
|
+ if token.endswith(suffix) and len(token) > len(suffix):
|
|
|
463
|
+ tokens[i] = token[0:-len(suffix)]
|
|
|
464
|
+ tokens.insert(i + 1, suffix)
|
|
|
465
|
+ mutated = True
|
|
|
466
|
+ break
|
|
|
467
|
+ if mutated:
|
|
|
468
|
+ continue
|
|
|
469
|
+ i += 1
|
|
|
470
|
+ return tokens
|
|
|
471
|
+
|
|
|
472
|
+ @classmethod
|
|
|
473
|
+ def read_actions(cls, tokens: list, token_index: int) -> tuple:
|
|
|
474
|
+ actions = []
|
|
|
475
|
+ current_action_tokens = []
|
|
|
476
|
+ while token_index < len(tokens):
|
|
|
477
|
+ token = tokens[token_index]
|
|
|
478
|
+ if token == 'if':
|
|
|
479
|
+ if len(current_action_tokens) > 0:
|
|
|
480
|
+ a = PatternAction(current_action_tokens[0], current_action_tokens[1:])
|
|
|
481
|
+ cls.__validate_action(a)
|
|
|
482
|
+ actions.append(a)
|
|
|
483
|
+ token_index += 1
|
|
|
484
|
+ return (actions, token_index)
|
|
|
485
|
+ elif token == ',':
|
|
|
486
|
+ if len(current_action_tokens) < 1:
|
|
|
487
|
+ raise RuntimeError('Unexpected ,')
|
|
|
488
|
+ a = PatternAction(current_action_tokens[0], current_action_tokens[1:])
|
|
|
489
|
+ cls.__validate_action(a)
|
|
|
490
|
+ actions.append(a)
|
|
|
491
|
+ current_action_tokens = []
|
|
|
492
|
+ else:
|
|
|
493
|
+ current_action_tokens.append(token)
|
|
|
494
|
+ token_index += 1
|
|
|
495
|
+ raise RuntimeError('Unexpected end of line')
|
|
|
496
|
+
|
|
|
497
|
+ @classmethod
|
|
|
498
|
+ def __validate_action(cls, action: PatternAction) -> None:
|
|
|
499
|
+ args = cls.ACTION_TO_ARGS.get(action.type)
|
|
|
500
|
+ if args is None:
|
|
|
501
|
+ raise RuntimeError(f'Unknown action "{action.type}"')
|
|
|
502
|
+ if len(action.arguments) != len(args):
|
|
|
503
|
+ arg_list = ', '.join(args)
|
|
|
504
|
+ if len(args) == 0:
|
|
|
505
|
+ raise RuntimeError(f'Action "{action.type}" expects no arguments, got {len(action.arguments)}.')
|
|
|
506
|
+ else:
|
|
|
507
|
+ raise RuntimeError(f'Action "{action.type}" expects {len(args)} arguments, got {len(action.arguments)}.')
|
|
|
508
|
+ for i in range(len(args)):
|
|
|
509
|
+ datatype = args[i]
|
|
|
510
|
+ action.arguments[i] = cls.parse_value(action.arguments[i], datatype)
|
|
|
511
|
+
|
|
|
512
|
+ @classmethod
|
|
|
513
|
+ def read_expression(cls, tokens: list, token_index: int, depth: int = 0, one_subexpression: bool = False) -> tuple:
|
|
|
514
|
+ # field op value
|
|
|
515
|
+ # (field op value)
|
|
|
516
|
+ # !(field op value)
|
|
|
517
|
+ # field op value and field op value
|
|
|
518
|
+ # (field op value and field op value) or field op value
|
|
|
519
|
+ indent = '\t' * depth
|
|
|
520
|
+ subexpressions = []
|
|
|
521
|
+ last_compound_operator = None
|
|
|
522
|
+ while token_index < len(tokens):
|
|
|
523
|
+ if one_subexpression:
|
|
|
524
|
+ if len(subexpressions) == 1:
|
|
|
525
|
+ return (subexpressions[0], token_index)
|
|
|
526
|
+ elif len(subexpressions) > 1:
|
|
|
527
|
+ raise RuntimeError('Too many subexpressions')
|
|
|
528
|
+ compound_operator = None
|
|
|
529
|
+ if tokens[token_index] == ')':
|
|
|
530
|
+ if len(subexpressions) == 0:
|
|
|
531
|
+ raise RuntimeError('No subexpressions')
|
|
|
532
|
+ elif len(subexpressions) == 1:
|
|
|
533
|
+ return (subexpressions[0], token_index)
|
|
|
534
|
+ else:
|
|
|
535
|
+ return (PatternCompoundExpression(last_compound_operator, subexpressions), token_index)
|
|
|
536
|
+ if tokens[token_index] in set(["and", "or"]):
|
|
|
537
|
+ compound_operator = tokens[token_index]
|
|
|
538
|
+ if last_compound_operator and compound_operator != last_compound_operator:
|
|
|
539
|
+ subexpressions = [ PatternCompoundExpression(last_compound_operator, subexpressions) ]
|
|
|
540
|
+ last_compound_operator = compound_operator
|
|
|
541
|
+ else:
|
|
|
542
|
+ last_compound_operator = compound_operator
|
|
|
543
|
+ token_index += 1
|
|
|
544
|
+ if tokens[token_index] == '!':
|
|
|
545
|
+ (exp, next_index) = cls.read_expression(tokens, token_index + 1, depth + 1, one_subexpression=True)
|
|
|
546
|
+ subexpressions.append(PatternCompoundExpression('!', [exp]))
|
|
|
547
|
+ token_index = next_index
|
|
|
548
|
+ elif tokens[token_index] == '(':
|
|
|
549
|
+ (exp, next_index) = cls.read_expression(tokens, token_index + 1, depth + 1)
|
|
|
550
|
+ if tokens[next_index] != ')':
|
|
|
551
|
+ raise RuntimeError('Expected )')
|
|
|
552
|
+ subexpressions.append(exp)
|
|
|
553
|
+ token_index = next_index + 1
|
|
|
554
|
+ else:
|
|
|
555
|
+ (simple, next_index) = cls.read_simple_expression(tokens, token_index, depth)
|
|
|
556
|
+ subexpressions.append(simple)
|
|
|
557
|
+ token_index = next_index
|
|
|
558
|
+ if len(subexpressions) == 0:
|
|
|
559
|
+ raise RuntimeError('No subexpressions')
|
|
|
560
|
+ elif len(subexpressions) == 1:
|
|
|
561
|
+ return (subexpressions[0], token_index)
|
|
|
562
|
+ else:
|
|
|
563
|
+ return (PatternCompoundExpression(last_compound_operator, subexpressions), token_index)
|
|
|
564
|
+
|
|
|
565
|
+ @classmethod
|
|
|
566
|
+ def read_simple_expression(cls, tokens: list, token_index: int, depth: int = 0) -> tuple:
|
|
|
567
|
+ indent = '\t' * depth
|
|
|
568
|
+
|
|
|
569
|
+ if token_index >= len(tokens):
|
|
|
570
|
+ raise RuntimeError('Expected field name, found EOL')
|
|
|
571
|
+ field = tokens[token_index]
|
|
|
572
|
+ token_index += 1
|
|
|
573
|
+
|
|
|
574
|
+ datatype = cls.FIELD_TO_TYPE.get(field)
|
|
|
575
|
+ if datatype is None:
|
|
|
576
|
+ raise RuntimeError(f'No such field "{field}"')
|
|
|
577
|
+
|
|
|
578
|
+ if token_index >= len(tokens):
|
|
|
579
|
+ raise RuntimeError('Expected operator, found EOL')
|
|
|
580
|
+ op = tokens[token_index]
|
|
|
581
|
+ token_index += 1
|
|
|
582
|
+
|
|
|
583
|
+ if op == '!':
|
|
|
584
|
+ if token_index >= len(tokens):
|
|
|
585
|
+ raise RuntimeError('Expected operator, found EOL')
|
|
|
586
|
+ op = '!' + tokens[token_index]
|
|
|
587
|
+ token_index += 1
|
|
|
588
|
+
|
|
|
589
|
+ allowed_ops = cls.TYPE_TO_OPERATORS[datatype]
|
|
|
590
|
+ if op not in allowed_ops:
|
|
|
591
|
+ if op in cls.OPERATORS_ALL:
|
|
|
592
|
+ raise RuntimeError(f'Operator {op} cannot be used with field "{field}"')
|
|
|
593
|
+ else:
|
|
|
594
|
+ raise RuntimeError(f'Unrecognized operator "{op}" - allowed: {list(allowed_ops)}')
|
|
|
595
|
+
|
|
|
596
|
+ if token_index >= len(tokens):
|
|
|
597
|
+ raise RuntimeError('Expected value, found EOL')
|
|
|
598
|
+ value = tokens[token_index]
|
|
|
599
|
+
|
|
|
600
|
+ value = cls.parse_value(value, datatype)
|
|
|
601
|
+
|
|
|
602
|
+ token_index += 1
|
|
|
603
|
+ exp = PatternSimpleExpression(field, op, value)
|
|
|
604
|
+ return (exp, token_index)
|
|
|
605
|
+
|
|
|
606
|
+ @classmethod
|
|
|
607
|
+ def parse_value(cls, value: str, type: str):
|
|
|
608
|
+ if type == cls.TYPE_ID:
|
|
|
609
|
+ p = re.compile('^[0-9]+$')
|
|
|
610
|
+ if p.match(value) is None:
|
|
|
611
|
+ raise ValueError(f'Illegal id value "{value}"')
|
|
|
612
|
+ # Store it as a str so it can be larger than an int
|
|
|
613
|
+ return value
|
|
|
614
|
+ if type == cls.TYPE_MEMBER:
|
|
|
615
|
+ p = re.compile('^<@!?([0-9]+)>$')
|
|
|
616
|
+ m = p.match(value)
|
|
|
617
|
+ if m is None:
|
|
|
618
|
+ raise ValueError(f'Illegal member value. Must be an @ mention.')
|
|
|
619
|
+ return m.group(1)
|
|
|
620
|
+ if type == cls.TYPE_TEXT:
|
|
|
621
|
+ # Must be quoted.
|
|
|
622
|
+ if len(value) < 2 or \
|
|
|
623
|
+ value[0:1] not in cls.STRING_QUOTE_CHARS or \
|
|
|
624
|
+ value[-1:] not in cls.STRING_QUOTE_CHARS or \
|
|
|
625
|
+ value[0:1] != value[-1:]:
|
|
|
626
|
+ raise ValueError(f'Not a quoted string value: {value}')
|
|
|
627
|
+ return value[1:-1]
|
|
|
628
|
+ if type == cls.TYPE_INT:
|
|
|
629
|
+ return int(value)
|
|
|
630
|
+ if type == cls.TYPE_FLOAT:
|
|
|
631
|
+ return float(value)
|
|
|
632
|
+ if type == cls.TYPE_TIMESPAN:
|
|
|
633
|
+ p = re.compile('^(?:[0-9]+[dhms])+$')
|
|
|
634
|
+ if p.match(value) is None:
|
|
|
635
|
+ raise RuntimeError("Illegal timespan value \"{value}\". Must be like \"100d\", \"5m30s\", etc.")
|
|
|
636
|
+ p = re.compile('([0-9]+)([dhms])')
|
|
|
637
|
+ days = 0
|
|
|
638
|
+ hours = 0
|
|
|
639
|
+ minutes = 0
|
|
|
640
|
+ seconds = 0
|
|
|
641
|
+ for m in p.finditer(value):
|
|
|
642
|
+ scalar = int(m.group(1))
|
|
|
643
|
+ unit = m.group(2)
|
|
|
644
|
+ if unit == 'd':
|
|
|
645
|
+ days = scalar
|
|
|
646
|
+ elif unit == 'h':
|
|
|
647
|
+ hours = scalar
|
|
|
648
|
+ elif unit == 'm':
|
|
|
649
|
+ minutes = scalar
|
|
|
650
|
+ elif unit == 's':
|
|
|
651
|
+ seconds = scalar
|
|
|
652
|
+ return timedelta(days=days, hours=hours, minutes=minutes, seconds=seconds)
|
|
|
653
|
+ raise ValueError(f'Unhandled datatype {datatype}')
|