Kaynağa Gözat

Pattern cog converted to slash commands with some minimal autocomplete

pull/13/head
Rocketsoup 2 ay önce
ebeveyn
işleme
eb786293e7
2 değiştirilmiş dosya ile 300 ekleme ve 154 silme
  1. 166
    73
      rocketbot/cogs/patterncog.py
  2. 134
    81
      rocketbot/pattern.py

+ 166
- 73
rocketbot/cogs/patterncog.py Dosyayı Görüntüle

2
 Cog for matching messages against guild-configurable criteria and taking
2
 Cog for matching messages against guild-configurable criteria and taking
3
 automated actions on them.
3
 automated actions on them.
4
 """
4
 """
5
+import re
5
 from datetime import datetime
6
 from datetime import datetime
6
-from typing import Optional, Literal
7
+from typing import Optional
7
 
8
 
8
 from discord import Guild, Member, Message, utils as discordutils, Permissions, Interaction
9
 from discord import Guild, Member, Message, utils as discordutils, Permissions, Interaction
9
-from discord.app_commands import Group, rename
10
+from discord.app_commands import Choice, Group, autocomplete, rename
10
 from discord.ext import commands
11
 from discord.ext import commands
11
 
12
 
12
 from config import CONFIG
13
 from config import CONFIG
16
 from rocketbot.pattern import PatternCompiler, PatternDeprecationError, \
17
 from rocketbot.pattern import PatternCompiler, PatternDeprecationError, \
17
 	PatternError, PatternStatement
18
 	PatternError, PatternStatement
18
 from rocketbot.storage import Storage
19
 from rocketbot.storage import Storage
20
+from rocketbot.utils import dump_stacktrace
19
 
21
 
20
 class PatternContext:
22
 class PatternContext:
21
 	"""
23
 	"""
29
 		self.is_kicked = False
31
 		self.is_kicked = False
30
 		self.is_banned = False
32
 		self.is_banned = False
31
 
33
 
34
+async def pattern_name_autocomplete(interaction: Interaction, current: str) -> list[Choice[str]]:
35
+	choices: list[Choice[str]] = []
36
+	try:
37
+		if interaction.guild is None:
38
+			return []
39
+		patterns: dict[str, PatternStatement] = PatternCog.shared.get_patterns(interaction.guild)
40
+		current_normal = current.lower().strip()
41
+		for name in sorted(patterns.keys()):
42
+			if len(current_normal) == 0 or current_normal.startswith(name.lower()):
43
+				choices.append(Choice(name=name, value=name))
44
+	except BaseException as e:
45
+		dump_stacktrace(e)
46
+	return choices
47
+
48
+async def action_autocomplete(interaction: Interaction, current: str) -> list[Choice[str]]:
49
+	# FIXME: WORK IN PROGRESS
50
+	print(f'autocomplete action - current = "{current}"')
51
+	regex = re.compile('^(.*?)([a-zA-Z]+)$')
52
+	match: Optional[re.Match[str]] = regex.match(current)
53
+	initial: str = ''
54
+	stub: str = current
55
+	if match:
56
+		initial = match.group(1).strip()
57
+		stub = match.group(2)
58
+	if PatternCompiler.ACTION_TO_ARGS.get(stub, None) is not None:
59
+		# Matches perfectly. Suggest another instead of completing the current.
60
+		initial = current.strip() + ', '
61
+		stub = ''
62
+	print(f'initial = "{initial}", stub = "{stub}"')
63
+
64
+	options: list[Choice[str]] = []
65
+	for action in sorted(PatternCompiler.ACTION_TO_ARGS.keys()):
66
+		if len(stub) == 0 or action.startswith(stub.lower()):
67
+			arg_types = PatternCompiler.ACTION_TO_ARGS[action]
68
+			arg_type_strs = []
69
+			for arg_type in arg_types:
70
+				if arg_type == PatternCompiler.TYPE_TEXT:
71
+					arg_type_strs.append('"message"')
72
+				else:
73
+					raise ValueError(f'Argument type {arg_type} not yet supported')
74
+			suffix = '' if len(arg_type_strs) == 0 else ' ' + (' '.join(arg_type_strs))
75
+			options.append(Choice(name=action, value=f'{initial.strip()} {action}{suffix}'))
76
+	return options
77
+
78
+async def priority_autocomplete(interaction: Interaction, current: str) -> list[Choice[str]]:
79
+	return [
80
+		Choice(name='very low (50)', value=50),
81
+		Choice(name='low (75)', value=75),
82
+		Choice(name='normal (100)', value=100),
83
+		Choice(name='high (125)', value=125),
84
+		Choice(name='very high (150)', value=150),
85
+	]
86
+
32
 class PatternCog(BaseCog, name='Pattern Matching'):
87
 class PatternCog(BaseCog, name='Pattern Matching'):
33
 	"""
88
 	"""
34
 	Highly flexible cog for performing various actions on messages that match
89
 	Highly flexible cog for performing various actions on messages that match
37
 
92
 
38
 	SETTING_PATTERNS = CogSetting('patterns', None)
93
 	SETTING_PATTERNS = CogSetting('patterns', None)
39
 
94
 
95
+	shared: Optional['PatternCog'] = None
96
+
40
 	def __init__(self, bot: Rocketbot):
97
 	def __init__(self, bot: Rocketbot):
41
 		super().__init__(
98
 		super().__init__(
42
 			bot,
99
 			bot,
44
 			name='patterns',
101
 			name='patterns',
45
 			short_description='Manages message pattern matching.',
102
 			short_description='Manages message pattern matching.',
46
 		)
103
 		)
104
+		PatternCog.shared = self
47
 
105
 
48
-	def __get_patterns(self, guild: Guild) -> dict[str, PatternStatement]:
106
+	def get_patterns(self, guild: Guild) -> dict[str, PatternStatement]:
49
 		"""
107
 		"""
50
 		Returns a name -> PatternStatement lookup for the guild.
108
 		Returns a name -> PatternStatement lookup for the guild.
51
 		"""
109
 		"""
104
 			# Ignore mods
162
 			# Ignore mods
105
 			return
163
 			return
106
 
164
 
107
-		patterns = self.__get_patterns(message.guild)
165
+		patterns = self.get_patterns(message.guild)
108
 		for statement in sorted(patterns.values(), key=lambda s : s.priority, reverse=True):
166
 		for statement in sorted(patterns.values(), key=lambda s : s.priority, reverse=True):
109
 			other_fields = {
167
 			other_fields = {
110
 				'last_matched': self.__get_last_matched(message.guild, statement.name),
168
 				'last_matched': self.__get_last_matched(message.guild, statement.name),
199
 			did_kick=context.is_kicked,
257
 			did_kick=context.is_kicked,
200
 			did_ban=context.is_banned))
258
 			did_ban=context.is_banned))
201
 
259
 
202
-	spattern = Group(
260
+	pattern = Group(
203
 		name='pattern',
261
 		name='pattern',
204
 		description='Manages message pattern matching.',
262
 		description='Manages message pattern matching.',
205
 		guild_only=True,
263
 		guild_only=True,
206
 		default_permissions=Permissions(Permissions.manage_messages.flag),
264
 		default_permissions=Permissions(Permissions.manage_messages.flag),
207
 	)
265
 	)
208
-	@spattern.command()
266
+
267
+	@pattern.command()
209
 	@rename(expression='if')
268
 	@rename(expression='if')
210
-	async def test(
269
+	@autocomplete(
270
+		name=pattern_name_autocomplete,
271
+		# actions=action_autocomplete
272
+	)
273
+	async def add(
211
 			self,
274
 			self,
212
 			interaction: Interaction,
275
 			interaction: Interaction,
276
+			name: str,
213
 			actions: str,
277
 			actions: str,
214
 			expression: str
278
 			expression: str
215
 	) -> None:
279
 	) -> None:
216
-		vals = [ actions, expression ]
217
-		arg_list = ''
218
-		for arg in vals:
219
-			if arg is not None:
220
-				arg_list += f'- "`{arg}`"\n'
221
-		await interaction.response.send_message(
222
-			"Got /pattern test call with arguments\n" + arg_list,
223
-			ephemeral=True,
224
-		)
280
+		"""
281
+		Adds a custom pattern.
225
 
282
 
226
-	@commands.group(
227
-		brief='Manages message pattern matching',
228
-	)
229
-	@commands.has_permissions(ban_members=True)
230
-	@commands.guild_only()
231
-	async def pattern(self, context: commands.Context):
232
-		"""Message pattern matching command group"""
233
-		if context.invoked_subcommand is None:
234
-			await context.send_help()
283
+		Adds a custom pattern. Patterns use a simplified
284
+		expression language. Full documentation found here:
285
+		https://git.rixafrix.com/ialbert/python-app-rocketbot/src/branch/main/docs/patterns.md
235
 
286
 
236
-	@pattern.command(
237
-		brief='Adds a custom pattern',
238
-		description='Adds a custom pattern. Patterns use a simplified ' + \
239
-			'expression language. Full documentation found here: ' + \
240
-			'https://git.rixafrix.com/ialbert/python-app-rocketbot/src/' + \
241
-			'branch/main/docs/patterns.md',
242
-		usage='<pattern_name> <expression...>',
243
-		ignore_extra=True
244
-	)
245
-	async def add(self, context: commands.Context, name: str):
246
-		"""Command handler"""
247
-		pattern_str = PatternCompiler.expression_str_from_context(context, name)
287
+		Parameters
288
+		----------
289
+		interaction : Interaction
290
+		name : str
291
+			A name for the pattern.
292
+		actions : str
293
+			One or more actions to take when a message matches the expression.
294
+		expression : str
295
+			Criteria for matching chat messages.
296
+		"""
297
+		pattern_str = f'{actions} if {expression}'
298
+		guild = interaction.guild
248
 		try:
299
 		try:
249
 			statement = PatternCompiler.parse_statement(name, pattern_str)
300
 			statement = PatternCompiler.parse_statement(name, pattern_str)
250
 			statement.check_deprecations()
301
 			statement.check_deprecations()
251
-			patterns = self.__get_patterns(context.guild)
302
+			patterns = self.get_patterns(guild)
252
 			patterns[name] = statement
303
 			patterns[name] = statement
253
-			self.__save_patterns(context.guild, patterns)
254
-			await context.message.reply(
304
+			self.__save_patterns(guild, patterns)
305
+			await interaction.response.send_message(
255
 				f'{CONFIG["success_emoji"]} Pattern `{name}` added.',
306
 				f'{CONFIG["success_emoji"]} Pattern `{name}` added.',
256
-				mention_author=False)
307
+				ephemeral=True,
308
+			)
257
 		except PatternError as e:
309
 		except PatternError as e:
258
-			await context.message.reply(
310
+			await interaction.response.send_message(
259
 				f'{CONFIG["failure_emoji"]} Error parsing statement. {e}',
311
 				f'{CONFIG["failure_emoji"]} Error parsing statement. {e}',
260
-				mention_author=False)
312
+				ephemeral=True,
313
+			)
261
 
314
 
262
 	@pattern.command(
315
 	@pattern.command(
263
-		brief='Removes a custom pattern',
264
-		usage='<pattern_name>'
316
+		description='Removes a custom pattern',
317
+		extras={
318
+			'usage': '<pattern_name>',
319
+		},
265
 	)
320
 	)
266
-	async def remove(self, context: commands.Context, name: str):
267
-		"""Command handler"""
268
-		patterns = self.__get_patterns(context.guild)
321
+	@autocomplete(name=pattern_name_autocomplete)
322
+	async def remove(self, interaction: Interaction, name: str):
323
+		"""
324
+		Command handler
325
+
326
+		Parameters
327
+		----------
328
+		interaction: Interaction
329
+		name: str
330
+			Name of the pattern to remove.
331
+		"""
332
+		guild = interaction.guild
333
+		patterns = self.get_patterns(guild)
269
 		if patterns.get(name) is not None:
334
 		if patterns.get(name) is not None:
270
 			del patterns[name]
335
 			del patterns[name]
271
-			self.__save_patterns(context.guild, patterns)
272
-			await context.message.reply(
336
+			self.__save_patterns(guild, patterns)
337
+			await interaction.response.send_message(
273
 				f'{CONFIG["success_emoji"]} Pattern `{name}` deleted.',
338
 				f'{CONFIG["success_emoji"]} Pattern `{name}` deleted.',
274
-				mention_author=False)
339
+				ephemeral=True,
340
+			)
275
 		else:
341
 		else:
276
-			await context.message.reply(
342
+			await interaction.response.send_message(
277
 				f'{CONFIG["failure_emoji"]} No pattern named `{name}`.',
343
 				f'{CONFIG["failure_emoji"]} No pattern named `{name}`.',
278
-				mention_author=False)
344
+				ephemeral=True,
345
+			)
279
 
346
 
280
 	@pattern.command(
347
 	@pattern.command(
281
-		brief='Lists all patterns'
348
+		description='Lists all patterns',
282
 	)
349
 	)
283
-	async def list(self, context: commands.Context) -> None:
284
-		"""Command handler"""
285
-		patterns = self.__get_patterns(context.guild)
350
+	async def list(self, interaction: Interaction) -> None:
351
+		"""
352
+		Command handler
353
+
354
+		Parameters
355
+		----------
356
+		interaction: Interaction
357
+		"""
358
+		guild = interaction.guild
359
+		patterns = self.get_patterns(guild)
286
 		if len(patterns) == 0:
360
 		if len(patterns) == 0:
287
-			await context.message.reply('No patterns defined.', mention_author=False)
361
+			await interaction.response.send_message(
362
+				'No patterns defined.',
363
+				ephemeral=True,
364
+			)
288
 			return
365
 			return
289
 		msg = ''
366
 		msg = ''
290
 		for name, statement in sorted(patterns.items()):
367
 		for name, statement in sorted(patterns.items()):
291
 			msg += f'Pattern `{name}` (priority={statement.priority}):\n```\n{statement.original}\n```\n'
368
 			msg += f'Pattern `{name}` (priority={statement.priority}):\n```\n{statement.original}\n```\n'
292
-		await context.message.reply(msg, mention_author=False)
369
+		await interaction.response.send_message(msg, ephemeral=True)
293
 
370
 
294
 	@pattern.command(
371
 	@pattern.command(
295
-		brief='Sets a pattern\'s priority level',
296
-		description='Sets the priority for a pattern. Messages are checked ' +
297
-			'against patterns with the highest priority first. Patterns with ' +
298
-			'the same priority may be checked in arbitrary order. Default ' +
299
-			'priority is 100.',
372
+		description='Sets a pattern\'s priority level',
373
+		extras={
374
+			'long_description': 'Sets the priority for a pattern. Messages are checked ' +
375
+				'against patterns with the highest priority first. Patterns with ' +
376
+				'the same priority may be checked in arbitrary order. Default ' +
377
+				'priority is 100.',
378
+		},
300
 	)
379
 	)
301
-	async def setpriority(self, context: commands.Context, name: str, priority: int) -> None:
302
-		"""Command handler"""
303
-		patterns = self.__get_patterns(context.guild)
380
+	@autocomplete(name=pattern_name_autocomplete, priority=priority_autocomplete)
381
+	async def setpriority(self, interaction: Interaction, name: str, priority: int) -> None:
382
+		"""
383
+		Command handler
384
+
385
+		Parameters
386
+		----------
387
+		interaction: Interaction
388
+		name: str
389
+			A name for the pattern
390
+		priority: int
391
+			Priority for evaluating the pattern. Default is 100. Higher values match first.
392
+		"""
393
+		guild = interaction.guild
394
+		patterns = self.get_patterns(guild)
304
 		statement = patterns.get(name)
395
 		statement = patterns.get(name)
305
 		if statement is None:
396
 		if statement is None:
306
-			await context.message.reply(
397
+			await interaction.response.send_message(
307
 				f'{CONFIG["failure_emoji"]} No such pattern `{name}`',
398
 				f'{CONFIG["failure_emoji"]} No such pattern `{name}`',
308
-				mention_author=False)
399
+				ephemeral=True,
400
+			)
309
 			return
401
 			return
310
 		statement.priority = priority
402
 		statement.priority = priority
311
-		self.__save_patterns(context.guild, patterns)
312
-		await context.message.reply(
403
+		self.__save_patterns(guild, patterns)
404
+		await interaction.response.send_message(
313
 			f'{CONFIG["success_emoji"]} Priority for pattern `{name}` ' + \
405
 			f'{CONFIG["success_emoji"]} Priority for pattern `{name}` ' + \
314
-			f'updated to `{priority}`.',
315
-			mention_author=False)
406
+				f'updated to `{priority}`.',
407
+			ephemeral=True,
408
+		)

+ 134
- 81
rocketbot/pattern.py Dosyayı Görüntüle

5
 import re
5
 import re
6
 from abc import ABCMeta, abstractmethod
6
 from abc import ABCMeta, abstractmethod
7
 from datetime import datetime, timezone
7
 from datetime import datetime, timezone
8
-from typing import Any, Union
8
+from typing import Any, Union, Literal
9
 
9
 
10
 from discord import Message, utils as discordutils
10
 from discord import Message, utils as discordutils
11
 from discord.ext.commands import Context
11
 from discord.ext.commands import Context
13
 from rocketbot.utils import is_user_id, str_from_quoted_str, timedelta_from_str, \
13
 from rocketbot.utils import is_user_id, str_from_quoted_str, timedelta_from_str, \
14
 	user_id_from_mention
14
 	user_id_from_mention
15
 
15
 
16
+PatternField = Literal['content.markdown', 'content', 'content.plain', 'author', 'author.id', 'author.joinage', 'author.name', 'lastmatched']
17
+PatternComparisonOperator = Literal['==', '!=', '<', '>', '<=', '>=', 'contains', '!contains', 'matches', '!matches', 'containsword', '!containsword']
18
+PatternBooleanOperator = Literal['!', 'and', 'or']
19
+PatternActionType = Literal['ban', 'delete', 'kick', 'modinfo', 'modwarn', 'reply']
20
+
16
 class PatternError(RuntimeError):
21
 class PatternError(RuntimeError):
17
 	"""
22
 	"""
18
 	Error thrown when parsing a pattern statement.
23
 	Error thrown when parsing a pattern statement.
27
 	"""
32
 	"""
28
 	Describes one action to take on a matched message or its author.
33
 	Describes one action to take on a matched message or its author.
29
 	"""
34
 	"""
35
+
36
+	TYPE_BAN: PatternActionType = 'ban'
37
+	TYPE_DELETE: PatternActionType = 'delete'
38
+	TYPE_KICK: PatternActionType = 'kick'
39
+	TYPE_INFORM_MODS: PatternActionType = 'modinfo'
40
+	TYPE_WARN_MODS: PatternActionType = 'modwarn'
41
+	TYPE_REPLY: PatternActionType = 'reply'
42
+
30
 	def __init__(self, action: str, args: list[Any]):
43
 	def __init__(self, action: str, args: list[Any]):
31
 		self.action = action
44
 		self.action = action
32
 		self.arguments = list(args)
45
 		self.arguments = list(args)
55
 	Message matching expression with a simple "<field> <operator> <value>"
68
 	Message matching expression with a simple "<field> <operator> <value>"
56
 	structure.
69
 	structure.
57
 	"""
70
 	"""
58
-	def __init__(self, field: str, operator: str, value: Any):
71
+
72
+	FIELD_CONTENT_MARKDOWN: PatternField = 'content.markdown'
73
+	FIELD_CONTENT_PLAIN: PatternField = 'content.plain'
74
+	FIELD_AUTHOR_ID: PatternField = 'author.id'
75
+	FIELD_AUTHOR_JOINAGE: PatternField = 'author.joinage'
76
+	FIELD_AUTHOR_NAME: PatternField = 'author.name'
77
+	FIELD_LAST_MATCHED: PatternField = 'lastmatched'
78
+
79
+	# Less preferred but recognized field aliases
80
+	ALIAS_FIELD_CONTENT_MARKDOWN: PatternField = 'content'
81
+	ALIAS_FIELD_AUTHOR_ID: PatternField = 'author'
82
+
83
+	OP_EQUALS: PatternComparisonOperator = '=='
84
+	OP_NOT_EQUALS: PatternComparisonOperator = '!='
85
+	OP_LESS_THAN: PatternComparisonOperator = '<'
86
+	OP_GREATER_THAN: PatternComparisonOperator = '>'
87
+	OP_LESS_THAN_OR_EQUALS: PatternComparisonOperator = '<='
88
+	OP_GREATER_THAN_OR_EQUALS: PatternComparisonOperator = '>='
89
+	OP_CONTAINS: PatternComparisonOperator = 'contains'
90
+	OP_NOT_CONTAINS: PatternComparisonOperator = '!contains'
91
+	OP_MATCHES: PatternComparisonOperator = 'matches'
92
+	OP_NOT_MATCHES: PatternComparisonOperator = '!matches'
93
+	OP_CONTAINS_WORD: PatternComparisonOperator = 'containsword'
94
+	OP_NOT_CONTAINS_WORD: PatternComparisonOperator = '!containsword'
95
+
96
+	def __init__(self, field: PatternField, operator: PatternComparisonOperator, value: Any):
59
 		super().__init__()
97
 		super().__init__()
60
-		self.field: str = field
61
-		self.operator: str = operator
98
+		self.field: PatternField = field
99
+		self.operator: PatternComparisonOperator = operator
62
 		self.value: Any = value
100
 		self.value: Any = value
63
 
101
 
64
 	def __field_value(self, message: Message, other_fields: dict[str, Any]) -> Any:
102
 	def __field_value(self, message: Message, other_fields: dict[str, Any]) -> Any:
65
-		if self.field in ('content.markdown', 'content'):
103
+		cls = PatternSimpleExpression
104
+		if self.field in (cls.FIELD_CONTENT_MARKDOWN, cls.ALIAS_FIELD_CONTENT_MARKDOWN):
66
 			return message.content
105
 			return message.content
67
-		if self.field == 'content.plain':
106
+		if self.field == cls.FIELD_CONTENT_PLAIN:
68
 			return discordutils.remove_markdown(message.clean_content)
107
 			return discordutils.remove_markdown(message.clean_content)
69
-		if self.field == 'author':
108
+		if self.field in (cls.FIELD_AUTHOR_ID, cls.ALIAS_FIELD_AUTHOR_ID):
70
 			return str(message.author.id)
109
 			return str(message.author.id)
71
-		if self.field == 'author.id':
72
-			return str(message.author.id)
73
-		if self.field == 'author.joinage':
110
+		if self.field == cls.FIELD_AUTHOR_JOINAGE:
74
 			return message.created_at - message.author.joined_at
111
 			return message.created_at - message.author.joined_at
75
-		if self.field == 'author.name':
112
+		if self.field == cls.FIELD_AUTHOR_NAME:
76
 			return message.author.name
113
 			return message.author.name
77
-		if self.field == 'lastmatched':
114
+		if self.field == cls.FIELD_LAST_MATCHED:
78
 			long_ago = datetime(year=1900, month=1, day=1, hour=0, minute=0, second=0, tzinfo=timezone.utc)
115
 			long_ago = datetime(year=1900, month=1, day=1, hour=0, minute=0, second=0, tzinfo=timezone.utc)
79
 			last_matched = other_fields.get('last_matched') or long_ago
116
 			last_matched = other_fields.get('last_matched') or long_ago
80
 			return message.created_at - last_matched
117
 			return message.created_at - last_matched
81
-		raise ValueError(f'Bad field name {self.field}')
118
+		raise ValueError(f'Bad field name "{self.field}"')
82
 
119
 
83
 	def matches(self, message: Message, other_fields: dict[str, Any]) -> bool:
120
 	def matches(self, message: Message, other_fields: dict[str, Any]) -> bool:
121
+		cls = PatternSimpleExpression
84
 		field_value = self.__field_value(message, other_fields)
122
 		field_value = self.__field_value(message, other_fields)
85
-		if self.operator == '==':
123
+		if self.operator == cls.OP_EQUALS:
86
 			if isinstance(field_value, str) and isinstance(self.value, str):
124
 			if isinstance(field_value, str) and isinstance(self.value, str):
87
 				return field_value.lower() == self.value.lower()
125
 				return field_value.lower() == self.value.lower()
88
 			return field_value == self.value
126
 			return field_value == self.value
89
-		if self.operator == '!=':
127
+		if self.operator == cls.OP_NOT_EQUALS:
90
 			if isinstance(field_value, str) and isinstance(self.value, str):
128
 			if isinstance(field_value, str) and isinstance(self.value, str):
91
 				return field_value.lower() != self.value.lower()
129
 				return field_value.lower() != self.value.lower()
92
 			return field_value != self.value
130
 			return field_value != self.value
93
-		if self.operator == '<':
131
+		if self.operator == cls.OP_LESS_THAN:
94
 			return field_value < self.value
132
 			return field_value < self.value
95
-		if self.operator == '>':
133
+		if self.operator == cls.OP_GREATER_THAN:
96
 			return field_value > self.value
134
 			return field_value > self.value
97
-		if self.operator == '<=':
135
+		if self.operator == cls.OP_LESS_THAN_OR_EQUALS:
98
 			return field_value <= self.value
136
 			return field_value <= self.value
99
-		if self.operator == '>=':
137
+		if self.operator == cls.OP_GREATER_THAN_OR_EQUALS:
100
 			return field_value >= self.value
138
 			return field_value >= self.value
101
-		if self.operator == 'contains':
139
+		if self.operator == cls.OP_CONTAINS:
102
 			return self.value.lower() in field_value.lower()
140
 			return self.value.lower() in field_value.lower()
103
-		if self.operator == '!contains':
141
+		if self.operator == cls.OP_NOT_CONTAINS:
104
 			return self.value.lower() not in field_value.lower()
142
 			return self.value.lower() not in field_value.lower()
105
-		if self.operator in ('matches', 'containsword'):
143
+		if self.operator in (cls.OP_MATCHES, cls.OP_CONTAINS_WORD):
106
 			return self.value.search(field_value.lower()) is not None
144
 			return self.value.search(field_value.lower()) is not None
107
-		if self.operator in ('!matches', '!containsword'):
145
+		if self.operator in (cls.OP_NOT_MATCHES, cls.OP_NOT_CONTAINS_WORD):
108
 			return self.value.search(field_value.lower()) is None
146
 			return self.value.search(field_value.lower()) is None
109
 		raise ValueError(f'Bad operator {self.operator}')
147
 		raise ValueError(f'Bad operator {self.operator}')
110
 
148
 
116
 	Message matching expression that combines several child expressions with
154
 	Message matching expression that combines several child expressions with
117
 	a boolean operator.
155
 	a boolean operator.
118
 	"""
156
 	"""
119
-	def __init__(self, operator: str, operands: list[PatternExpression]):
157
+	OP_NOT = '!'
158
+	OP_AND = 'and'
159
+	OP_OR = 'or'
160
+
161
+	def __init__(self, operator: PatternBooleanOperator, operands: list[PatternExpression]):
120
 		super().__init__()
162
 		super().__init__()
121
-		self.operator = operator
163
+		self.operator: PatternBooleanOperator = operator
122
 		self.operands = list(operands)
164
 		self.operands = list(operands)
123
 
165
 
124
 	def matches(self, message: Message, other_fields: dict[str, Any]) -> bool:
166
 	def matches(self, message: Message, other_fields: dict[str, Any]) -> bool:
125
-		if self.operator == '!':
167
+		if self.operator == PatternCompoundExpression.OP_NOT:
126
 			return not self.operands[0].matches(message, other_fields)
168
 			return not self.operands[0].matches(message, other_fields)
127
-		if self.operator == 'and':
169
+		if self.operator == PatternCompoundExpression.OP_AND:
128
 			for op in self.operands:
170
 			for op in self.operands:
129
 				if not op.matches(message, other_fields):
171
 				if not op.matches(message, other_fields):
130
 					return False
172
 					return False
131
 			return True
173
 			return True
132
-		if self.operator == 'or':
174
+		if self.operator == PatternCompoundExpression.OP_OR:
133
 			for op in self.operands:
175
 			for op in self.operands:
134
 				if op.matches(message, other_fields):
176
 				if op.matches(message, other_fields):
135
 					return True
177
 					return True
137
 		raise ValueError(f'Bad operator "{self.operator}"')
179
 		raise ValueError(f'Bad operator "{self.operator}"')
138
 
180
 
139
 	def __str__(self) -> str:
181
 	def __str__(self) -> str:
140
-		if self.operator == '!':
182
+		if self.operator == PatternCompoundExpression.OP_NOT:
141
 			return f'(!( {self.operands[0]} ))'
183
 			return f'(!( {self.operands[0]} ))'
142
 		strs = map(str, self.operands)
184
 		strs = map(str, self.operands)
143
 		joined = f' {self.operator} '.join(strs)
185
 		joined = f' {self.operator} '.join(strs)
204
 	"""
246
 	"""
205
 	Parses a user-provided message filter statement into a PatternStatement.
247
 	Parses a user-provided message filter statement into a PatternStatement.
206
 	"""
248
 	"""
207
-	TYPE_FLOAT: str = 'float'
208
-	TYPE_ID: str = 'id'
209
-	TYPE_INT: str = 'int'
210
-	TYPE_MEMBER: str = 'Member'
211
-	TYPE_REGEX: str = 'regex'
212
-	TYPE_TEXT: str = 'text'
213
-	TYPE_TIMESPAN: str = 'timespan'
214
-
215
-	FIELD_TO_TYPE: dict[str, str] = {
216
-		'author': TYPE_MEMBER,
217
-		'author.id': TYPE_ID,
218
-		'author.joinage': TYPE_TIMESPAN,
219
-		'author.name': TYPE_TEXT,
220
-		'content': TYPE_TEXT, # deprecated, use content.markdown or content.plain
221
-		'content.markdown': TYPE_TEXT,
222
-		'content.plain': TYPE_TEXT,
223
-		'lastmatched': TYPE_TIMESPAN,
249
+	DATATYPE_FLOAT: str = 'float'
250
+	DATATYPE_ID: str = 'id'
251
+	DATATYPE_INT: str = 'int'
252
+	DATATYPE_MEMBER: str = 'Member'
253
+	DATATYPE_REGEX: str = 'regex'
254
+	DATATYPE_TEXT: str = 'text'
255
+	DATATYPE_TIMESPAN: str = 'timespan'
256
+
257
+	FIELD_TO_DATATYPE: dict[PatternField, str] = {
258
+		PatternSimpleExpression.ALIAS_FIELD_AUTHOR_ID: DATATYPE_MEMBER,
259
+		PatternSimpleExpression.FIELD_AUTHOR_ID: DATATYPE_ID,
260
+		PatternSimpleExpression.FIELD_AUTHOR_JOINAGE: DATATYPE_TIMESPAN,
261
+		PatternSimpleExpression.FIELD_AUTHOR_NAME: DATATYPE_TEXT,
262
+		PatternSimpleExpression.ALIAS_FIELD_CONTENT_MARKDOWN: DATATYPE_TEXT, # deprecated, use content.markdown or content.plain
263
+		PatternSimpleExpression.FIELD_CONTENT_MARKDOWN: DATATYPE_TEXT,
264
+		PatternSimpleExpression.FIELD_CONTENT_PLAIN: DATATYPE_TEXT,
265
+		PatternSimpleExpression.FIELD_LAST_MATCHED: DATATYPE_TIMESPAN,
224
 	}
266
 	}
225
-	DEPRECATED_FIELDS: set[str] = { 'content' }
226
-
227
-	ACTION_TO_ARGS: dict[str, list[str]] = {
228
-		'ban': [],
229
-		'delete': [],
230
-		'kick': [],
231
-		'modinfo': [],
232
-		'modwarn': [],
233
-		'reply': [ TYPE_TEXT ],
267
+	DEPRECATED_FIELDS: set[PatternField] = { 'content' }
268
+
269
+	ACTION_TO_ARGS: dict[PatternActionType, list[str]] = {
270
+		PatternAction.TYPE_BAN: [],
271
+		PatternAction.TYPE_DELETE: [],
272
+		PatternAction.TYPE_KICK: [],
273
+		PatternAction.TYPE_INFORM_MODS: [],
274
+		PatternAction.TYPE_WARN_MODS: [],
275
+		PatternAction.TYPE_REPLY: [ DATATYPE_TEXT ],
234
 	}
276
 	}
235
 
277
 
236
-	OPERATORS_IDENTITY: set[str] = { '==', '!=' }
237
-	OPERATORS_COMPARISON: set[str] = { '<', '>', '<=', '>=' }
238
-	OPERATORS_NUMERIC: set[str] = OPERATORS_IDENTITY | OPERATORS_COMPARISON
239
-	OPERATORS_TEXT: set[str] = OPERATORS_IDENTITY | {
240
-		'contains', '!contains',
241
-		'containsword', '!containsword',
242
-		'matches', '!matches',
278
+	OPERATORS_IDENTITY: set[PatternComparisonOperator] = {
279
+		PatternSimpleExpression.OP_EQUALS,
280
+		PatternSimpleExpression.OP_NOT_EQUALS,
281
+	}
282
+	OPERATORS_COMPARISON: set[PatternComparisonOperator] = {
283
+		PatternSimpleExpression.OP_LESS_THAN,
284
+		PatternSimpleExpression.OP_GREATER_THAN,
285
+		PatternSimpleExpression.OP_LESS_THAN_OR_EQUALS,
286
+		PatternSimpleExpression.OP_GREATER_THAN_OR_EQUALS,
287
+	}
288
+	OPERATORS_NUMERIC: set[PatternComparisonOperator] = OPERATORS_IDENTITY | OPERATORS_COMPARISON
289
+	OPERATORS_TEXT: set[PatternComparisonOperator] = OPERATORS_IDENTITY | {
290
+		PatternSimpleExpression.OP_CONTAINS,
291
+		PatternSimpleExpression.OP_NOT_CONTAINS,
292
+		PatternSimpleExpression.OP_CONTAINS_WORD,
293
+		PatternSimpleExpression.OP_NOT_CONTAINS_WORD,
294
+		PatternSimpleExpression.OP_MATCHES,
295
+		PatternSimpleExpression.OP_NOT_MATCHES,
243
 	}
296
 	}
244
 	OPERATORS_ALL: set[str] = OPERATORS_IDENTITY | OPERATORS_COMPARISON | OPERATORS_TEXT
297
 	OPERATORS_ALL: set[str] = OPERATORS_IDENTITY | OPERATORS_COMPARISON | OPERATORS_TEXT
245
 
298
 
246
-	TYPE_TO_OPERATORS: dict[str, set[str]] = {
247
-		TYPE_ID: OPERATORS_IDENTITY,
248
-		TYPE_MEMBER: OPERATORS_IDENTITY,
249
-		TYPE_TEXT: OPERATORS_TEXT,
250
-		TYPE_INT: OPERATORS_NUMERIC,
251
-		TYPE_FLOAT: OPERATORS_NUMERIC,
252
-		TYPE_TIMESPAN: OPERATORS_NUMERIC,
299
+	DATATYPE_TO_OPERATORS: dict[str, set[PatternComparisonOperator]] = {
300
+		DATATYPE_ID: OPERATORS_IDENTITY,
301
+		DATATYPE_MEMBER: OPERATORS_IDENTITY,
302
+		DATATYPE_TEXT: OPERATORS_TEXT,
303
+		DATATYPE_INT: OPERATORS_NUMERIC,
304
+		DATATYPE_FLOAT: OPERATORS_NUMERIC,
305
+		DATATYPE_TIMESPAN: OPERATORS_NUMERIC,
253
 	}
306
 	}
254
 
307
 
255
 	WHITESPACE_CHARS: str = ' \t\n\r'
308
 	WHITESPACE_CHARS: str = ' \t\n\r'
458
 					return subexpressions[0], token_index
511
 					return subexpressions[0], token_index
459
 				return (PatternCompoundExpression(last_compound_operator,
512
 				return (PatternCompoundExpression(last_compound_operator,
460
 					subexpressions), token_index)
513
 					subexpressions), token_index)
461
-			if tokens[token_index] in { "and", "or" }:
514
+			if tokens[token_index] in { PatternCompoundExpression.OP_AND, PatternCompoundExpression.OP_OR }:
462
 				compound_operator = tokens[token_index]
515
 				compound_operator = tokens[token_index]
463
 				if last_compound_operator and \
516
 				if last_compound_operator and \
464
 						compound_operator != last_compound_operator:
517
 						compound_operator != last_compound_operator:
468
 					]
521
 					]
469
 				last_compound_operator = compound_operator
522
 				last_compound_operator = compound_operator
470
 				token_index += 1
523
 				token_index += 1
471
-			if tokens[token_index] == '!':
524
+			if tokens[token_index] == PatternCompoundExpression.OP_NOT:
472
 				(exp, next_index) = cls.__read_expression(tokens,
525
 				(exp, next_index) = cls.__read_expression(tokens,
473
 						token_index + 1, depth + 1, one_subexpression=True)
526
 						token_index + 1, depth + 1, one_subexpression=True)
474
 				subexpressions.append(PatternCompoundExpression('!', [exp]))
527
 				subexpressions.append(PatternCompoundExpression('!', [exp]))
507
 			raise PatternError('Expression nests too deeply')
560
 			raise PatternError('Expression nests too deeply')
508
 		if token_index >= len(tokens):
561
 		if token_index >= len(tokens):
509
 			raise PatternError('Expected field name, found EOL')
562
 			raise PatternError('Expected field name, found EOL')
510
-		field = tokens[token_index]
563
+		field: PatternField = tokens[token_index]
511
 		token_index += 1
564
 		token_index += 1
512
 
565
 
513
-		datatype = cls.FIELD_TO_TYPE.get(field)
566
+		datatype = cls.FIELD_TO_DATATYPE.get(field, None)
514
 		if datatype is None:
567
 		if datatype is None:
515
 			raise PatternError(f'No such field "{field}"')
568
 			raise PatternError(f'No such field "{field}"')
516
 
569
 
519
 		op = tokens[token_index]
572
 		op = tokens[token_index]
520
 		token_index += 1
573
 		token_index += 1
521
 
574
 
522
-		if op == '!':
575
+		if op == PatternCompoundExpression.OP_NOT:
523
 			if token_index >= len(tokens):
576
 			if token_index >= len(tokens):
524
 				raise PatternError('Expected operator, found EOL')
577
 				raise PatternError('Expected operator, found EOL')
525
 			op = '!' + tokens[token_index]
578
 			op = '!' + tokens[token_index]
526
 			token_index += 1
579
 			token_index += 1
527
 
580
 
528
-		allowed_ops = cls.TYPE_TO_OPERATORS[datatype]
581
+		allowed_ops = cls.DATATYPE_TO_OPERATORS[datatype]
529
 		if op not in allowed_ops:
582
 		if op not in allowed_ops:
530
 			if op in cls.OPERATORS_ALL:
583
 			if op in cls.OPERATORS_ALL:
531
 				raise PatternError(f'Operator {op} cannot be used with ' + \
584
 				raise PatternError(f'Operator {op} cannot be used with ' + \
551
 		"""
604
 		"""
552
 		Converts a value token to its Python value. Raises ValueError on failure.
605
 		Converts a value token to its Python value. Raises ValueError on failure.
553
 		"""
606
 		"""
554
-		if datatype == cls.TYPE_ID:
607
+		if datatype == cls.DATATYPE_ID:
555
 			if not is_user_id(value):
608
 			if not is_user_id(value):
556
 				raise ValueError(f'Illegal user id value: {value}')
609
 				raise ValueError(f'Illegal user id value: {value}')
557
 			return value
610
 			return value
558
-		if datatype == cls.TYPE_MEMBER:
611
+		if datatype == cls.DATATYPE_MEMBER:
559
 			return user_id_from_mention(value)
612
 			return user_id_from_mention(value)
560
-		if datatype == cls.TYPE_TEXT:
613
+		if datatype == cls.DATATYPE_TEXT:
561
 			s = str_from_quoted_str(value)
614
 			s = str_from_quoted_str(value)
562
 			if op in ('matches', '!matches'):
615
 			if op in ('matches', '!matches'):
563
 				try:
616
 				try:
570
 				except re.error as e:
623
 				except re.error as e:
571
 					raise ValueError(f'Invalid regex: {e}') from e
624
 					raise ValueError(f'Invalid regex: {e}') from e
572
 			return s
625
 			return s
573
-		if datatype == cls.TYPE_INT:
626
+		if datatype == cls.DATATYPE_INT:
574
 			return int(value)
627
 			return int(value)
575
-		if datatype == cls.TYPE_FLOAT:
628
+		if datatype == cls.DATATYPE_FLOAT:
576
 			return float(value)
629
 			return float(value)
577
-		if datatype == cls.TYPE_TIMESPAN:
630
+		if datatype == cls.DATATYPE_TIMESPAN:
578
 			return timedelta_from_str(value)
631
 			return timedelta_from_str(value)
579
 		raise ValueError(f'Unhandled datatype {datatype}')
632
 		raise ValueError(f'Unhandled datatype {datatype}')

Loading…
İptal
Kaydet