소스 검색

Pattern cog converted to slash commands with some minimal autocomplete

pull/13/head
Rocketsoup 2 달 전
부모
커밋
eb786293e7
2개의 변경된 파일300개의 추가작업 그리고 154개의 파일을 삭제
  1. 166
    73
      rocketbot/cogs/patterncog.py
  2. 134
    81
      rocketbot/pattern.py

+ 166
- 73
rocketbot/cogs/patterncog.py 파일 보기

@@ -2,11 +2,12 @@
2 2
 Cog for matching messages against guild-configurable criteria and taking
3 3
 automated actions on them.
4 4
 """
5
+import re
5 6
 from datetime import datetime
6
-from typing import Optional, Literal
7
+from typing import Optional
7 8
 
8 9
 from discord import Guild, Member, Message, utils as discordutils, Permissions, Interaction
9
-from discord.app_commands import Group, rename
10
+from discord.app_commands import Choice, Group, autocomplete, rename
10 11
 from discord.ext import commands
11 12
 
12 13
 from config import CONFIG
@@ -16,6 +17,7 @@ from rocketbot.cogsetting import CogSetting
16 17
 from rocketbot.pattern import PatternCompiler, PatternDeprecationError, \
17 18
 	PatternError, PatternStatement
18 19
 from rocketbot.storage import Storage
20
+from rocketbot.utils import dump_stacktrace
19 21
 
20 22
 class PatternContext:
21 23
 	"""
@@ -29,6 +31,59 @@ class PatternContext:
29 31
 		self.is_kicked = False
30 32
 		self.is_banned = False
31 33
 
34
+async def pattern_name_autocomplete(interaction: Interaction, current: str) -> list[Choice[str]]:
35
+	choices: list[Choice[str]] = []
36
+	try:
37
+		if interaction.guild is None:
38
+			return []
39
+		patterns: dict[str, PatternStatement] = PatternCog.shared.get_patterns(interaction.guild)
40
+		current_normal = current.lower().strip()
41
+		for name in sorted(patterns.keys()):
42
+			if len(current_normal) == 0 or current_normal.startswith(name.lower()):
43
+				choices.append(Choice(name=name, value=name))
44
+	except BaseException as e:
45
+		dump_stacktrace(e)
46
+	return choices
47
+
48
+async def action_autocomplete(interaction: Interaction, current: str) -> list[Choice[str]]:
49
+	# FIXME: WORK IN PROGRESS
50
+	print(f'autocomplete action - current = "{current}"')
51
+	regex = re.compile('^(.*?)([a-zA-Z]+)$')
52
+	match: Optional[re.Match[str]] = regex.match(current)
53
+	initial: str = ''
54
+	stub: str = current
55
+	if match:
56
+		initial = match.group(1).strip()
57
+		stub = match.group(2)
58
+	if PatternCompiler.ACTION_TO_ARGS.get(stub, None) is not None:
59
+		# Matches perfectly. Suggest another instead of completing the current.
60
+		initial = current.strip() + ', '
61
+		stub = ''
62
+	print(f'initial = "{initial}", stub = "{stub}"')
63
+
64
+	options: list[Choice[str]] = []
65
+	for action in sorted(PatternCompiler.ACTION_TO_ARGS.keys()):
66
+		if len(stub) == 0 or action.startswith(stub.lower()):
67
+			arg_types = PatternCompiler.ACTION_TO_ARGS[action]
68
+			arg_type_strs = []
69
+			for arg_type in arg_types:
70
+				if arg_type == PatternCompiler.TYPE_TEXT:
71
+					arg_type_strs.append('"message"')
72
+				else:
73
+					raise ValueError(f'Argument type {arg_type} not yet supported')
74
+			suffix = '' if len(arg_type_strs) == 0 else ' ' + (' '.join(arg_type_strs))
75
+			options.append(Choice(name=action, value=f'{initial.strip()} {action}{suffix}'))
76
+	return options
77
+
78
+async def priority_autocomplete(interaction: Interaction, current: str) -> list[Choice[str]]:
79
+	return [
80
+		Choice(name='very low (50)', value=50),
81
+		Choice(name='low (75)', value=75),
82
+		Choice(name='normal (100)', value=100),
83
+		Choice(name='high (125)', value=125),
84
+		Choice(name='very high (150)', value=150),
85
+	]
86
+
32 87
 class PatternCog(BaseCog, name='Pattern Matching'):
33 88
 	"""
34 89
 	Highly flexible cog for performing various actions on messages that match
@@ -37,6 +92,8 @@ class PatternCog(BaseCog, name='Pattern Matching'):
37 92
 
38 93
 	SETTING_PATTERNS = CogSetting('patterns', None)
39 94
 
95
+	shared: Optional['PatternCog'] = None
96
+
40 97
 	def __init__(self, bot: Rocketbot):
41 98
 		super().__init__(
42 99
 			bot,
@@ -44,8 +101,9 @@ class PatternCog(BaseCog, name='Pattern Matching'):
44 101
 			name='patterns',
45 102
 			short_description='Manages message pattern matching.',
46 103
 		)
104
+		PatternCog.shared = self
47 105
 
48
-	def __get_patterns(self, guild: Guild) -> dict[str, PatternStatement]:
106
+	def get_patterns(self, guild: Guild) -> dict[str, PatternStatement]:
49 107
 		"""
50 108
 		Returns a name -> PatternStatement lookup for the guild.
51 109
 		"""
@@ -104,7 +162,7 @@ class PatternCog(BaseCog, name='Pattern Matching'):
104 162
 			# Ignore mods
105 163
 			return
106 164
 
107
-		patterns = self.__get_patterns(message.guild)
165
+		patterns = self.get_patterns(message.guild)
108 166
 		for statement in sorted(patterns.values(), key=lambda s : s.priority, reverse=True):
109 167
 			other_fields = {
110 168
 				'last_matched': self.__get_last_matched(message.guild, statement.name),
@@ -199,117 +257,152 @@ class PatternCog(BaseCog, name='Pattern Matching'):
199 257
 			did_kick=context.is_kicked,
200 258
 			did_ban=context.is_banned))
201 259
 
202
-	spattern = Group(
260
+	pattern = Group(
203 261
 		name='pattern',
204 262
 		description='Manages message pattern matching.',
205 263
 		guild_only=True,
206 264
 		default_permissions=Permissions(Permissions.manage_messages.flag),
207 265
 	)
208
-	@spattern.command()
266
+
267
+	@pattern.command()
209 268
 	@rename(expression='if')
210
-	async def test(
269
+	@autocomplete(
270
+		name=pattern_name_autocomplete,
271
+		# actions=action_autocomplete
272
+	)
273
+	async def add(
211 274
 			self,
212 275
 			interaction: Interaction,
276
+			name: str,
213 277
 			actions: str,
214 278
 			expression: str
215 279
 	) -> None:
216
-		vals = [ actions, expression ]
217
-		arg_list = ''
218
-		for arg in vals:
219
-			if arg is not None:
220
-				arg_list += f'- "`{arg}`"\n'
221
-		await interaction.response.send_message(
222
-			"Got /pattern test call with arguments\n" + arg_list,
223
-			ephemeral=True,
224
-		)
280
+		"""
281
+		Adds a custom pattern.
225 282
 
226
-	@commands.group(
227
-		brief='Manages message pattern matching',
228
-	)
229
-	@commands.has_permissions(ban_members=True)
230
-	@commands.guild_only()
231
-	async def pattern(self, context: commands.Context):
232
-		"""Message pattern matching command group"""
233
-		if context.invoked_subcommand is None:
234
-			await context.send_help()
283
+		Adds a custom pattern. Patterns use a simplified
284
+		expression language. Full documentation found here:
285
+		https://git.rixafrix.com/ialbert/python-app-rocketbot/src/branch/main/docs/patterns.md
235 286
 
236
-	@pattern.command(
237
-		brief='Adds a custom pattern',
238
-		description='Adds a custom pattern. Patterns use a simplified ' + \
239
-			'expression language. Full documentation found here: ' + \
240
-			'https://git.rixafrix.com/ialbert/python-app-rocketbot/src/' + \
241
-			'branch/main/docs/patterns.md',
242
-		usage='<pattern_name> <expression...>',
243
-		ignore_extra=True
244
-	)
245
-	async def add(self, context: commands.Context, name: str):
246
-		"""Command handler"""
247
-		pattern_str = PatternCompiler.expression_str_from_context(context, name)
287
+		Parameters
288
+		----------
289
+		interaction : Interaction
290
+		name : str
291
+			A name for the pattern.
292
+		actions : str
293
+			One or more actions to take when a message matches the expression.
294
+		expression : str
295
+			Criteria for matching chat messages.
296
+		"""
297
+		pattern_str = f'{actions} if {expression}'
298
+		guild = interaction.guild
248 299
 		try:
249 300
 			statement = PatternCompiler.parse_statement(name, pattern_str)
250 301
 			statement.check_deprecations()
251
-			patterns = self.__get_patterns(context.guild)
302
+			patterns = self.get_patterns(guild)
252 303
 			patterns[name] = statement
253
-			self.__save_patterns(context.guild, patterns)
254
-			await context.message.reply(
304
+			self.__save_patterns(guild, patterns)
305
+			await interaction.response.send_message(
255 306
 				f'{CONFIG["success_emoji"]} Pattern `{name}` added.',
256
-				mention_author=False)
307
+				ephemeral=True,
308
+			)
257 309
 		except PatternError as e:
258
-			await context.message.reply(
310
+			await interaction.response.send_message(
259 311
 				f'{CONFIG["failure_emoji"]} Error parsing statement. {e}',
260
-				mention_author=False)
312
+				ephemeral=True,
313
+			)
261 314
 
262 315
 	@pattern.command(
263
-		brief='Removes a custom pattern',
264
-		usage='<pattern_name>'
316
+		description='Removes a custom pattern',
317
+		extras={
318
+			'usage': '<pattern_name>',
319
+		},
265 320
 	)
266
-	async def remove(self, context: commands.Context, name: str):
267
-		"""Command handler"""
268
-		patterns = self.__get_patterns(context.guild)
321
+	@autocomplete(name=pattern_name_autocomplete)
322
+	async def remove(self, interaction: Interaction, name: str):
323
+		"""
324
+		Command handler
325
+
326
+		Parameters
327
+		----------
328
+		interaction: Interaction
329
+		name: str
330
+			Name of the pattern to remove.
331
+		"""
332
+		guild = interaction.guild
333
+		patterns = self.get_patterns(guild)
269 334
 		if patterns.get(name) is not None:
270 335
 			del patterns[name]
271
-			self.__save_patterns(context.guild, patterns)
272
-			await context.message.reply(
336
+			self.__save_patterns(guild, patterns)
337
+			await interaction.response.send_message(
273 338
 				f'{CONFIG["success_emoji"]} Pattern `{name}` deleted.',
274
-				mention_author=False)
339
+				ephemeral=True,
340
+			)
275 341
 		else:
276
-			await context.message.reply(
342
+			await interaction.response.send_message(
277 343
 				f'{CONFIG["failure_emoji"]} No pattern named `{name}`.',
278
-				mention_author=False)
344
+				ephemeral=True,
345
+			)
279 346
 
280 347
 	@pattern.command(
281
-		brief='Lists all patterns'
348
+		description='Lists all patterns',
282 349
 	)
283
-	async def list(self, context: commands.Context) -> None:
284
-		"""Command handler"""
285
-		patterns = self.__get_patterns(context.guild)
350
+	async def list(self, interaction: Interaction) -> None:
351
+		"""
352
+		Command handler
353
+
354
+		Parameters
355
+		----------
356
+		interaction: Interaction
357
+		"""
358
+		guild = interaction.guild
359
+		patterns = self.get_patterns(guild)
286 360
 		if len(patterns) == 0:
287
-			await context.message.reply('No patterns defined.', mention_author=False)
361
+			await interaction.response.send_message(
362
+				'No patterns defined.',
363
+				ephemeral=True,
364
+			)
288 365
 			return
289 366
 		msg = ''
290 367
 		for name, statement in sorted(patterns.items()):
291 368
 			msg += f'Pattern `{name}` (priority={statement.priority}):\n```\n{statement.original}\n```\n'
292
-		await context.message.reply(msg, mention_author=False)
369
+		await interaction.response.send_message(msg, ephemeral=True)
293 370
 
294 371
 	@pattern.command(
295
-		brief='Sets a pattern\'s priority level',
296
-		description='Sets the priority for a pattern. Messages are checked ' +
297
-			'against patterns with the highest priority first. Patterns with ' +
298
-			'the same priority may be checked in arbitrary order. Default ' +
299
-			'priority is 100.',
372
+		description='Sets a pattern\'s priority level',
373
+		extras={
374
+			'long_description': 'Sets the priority for a pattern. Messages are checked ' +
375
+				'against patterns with the highest priority first. Patterns with ' +
376
+				'the same priority may be checked in arbitrary order. Default ' +
377
+				'priority is 100.',
378
+		},
300 379
 	)
301
-	async def setpriority(self, context: commands.Context, name: str, priority: int) -> None:
302
-		"""Command handler"""
303
-		patterns = self.__get_patterns(context.guild)
380
+	@autocomplete(name=pattern_name_autocomplete, priority=priority_autocomplete)
381
+	async def setpriority(self, interaction: Interaction, name: str, priority: int) -> None:
382
+		"""
383
+		Command handler
384
+
385
+		Parameters
386
+		----------
387
+		interaction: Interaction
388
+		name: str
389
+			A name for the pattern
390
+		priority: int
391
+			Priority for evaluating the pattern. Default is 100. Higher values match first.
392
+		"""
393
+		guild = interaction.guild
394
+		patterns = self.get_patterns(guild)
304 395
 		statement = patterns.get(name)
305 396
 		if statement is None:
306
-			await context.message.reply(
397
+			await interaction.response.send_message(
307 398
 				f'{CONFIG["failure_emoji"]} No such pattern `{name}`',
308
-				mention_author=False)
399
+				ephemeral=True,
400
+			)
309 401
 			return
310 402
 		statement.priority = priority
311
-		self.__save_patterns(context.guild, patterns)
312
-		await context.message.reply(
403
+		self.__save_patterns(guild, patterns)
404
+		await interaction.response.send_message(
313 405
 			f'{CONFIG["success_emoji"]} Priority for pattern `{name}` ' + \
314
-			f'updated to `{priority}`.',
315
-			mention_author=False)
406
+				f'updated to `{priority}`.',
407
+			ephemeral=True,
408
+		)

+ 134
- 81
rocketbot/pattern.py 파일 보기

@@ -5,7 +5,7 @@ to take on them.
5 5
 import re
6 6
 from abc import ABCMeta, abstractmethod
7 7
 from datetime import datetime, timezone
8
-from typing import Any, Union
8
+from typing import Any, Union, Literal
9 9
 
10 10
 from discord import Message, utils as discordutils
11 11
 from discord.ext.commands import Context
@@ -13,6 +13,11 @@ from discord.ext.commands import Context
13 13
 from rocketbot.utils import is_user_id, str_from_quoted_str, timedelta_from_str, \
14 14
 	user_id_from_mention
15 15
 
16
+PatternField = Literal['content.markdown', 'content', 'content.plain', 'author', 'author.id', 'author.joinage', 'author.name', 'lastmatched']
17
+PatternComparisonOperator = Literal['==', '!=', '<', '>', '<=', '>=', 'contains', '!contains', 'matches', '!matches', 'containsword', '!containsword']
18
+PatternBooleanOperator = Literal['!', 'and', 'or']
19
+PatternActionType = Literal['ban', 'delete', 'kick', 'modinfo', 'modwarn', 'reply']
20
+
16 21
 class PatternError(RuntimeError):
17 22
 	"""
18 23
 	Error thrown when parsing a pattern statement.
@@ -27,6 +32,14 @@ class PatternAction:
27 32
 	"""
28 33
 	Describes one action to take on a matched message or its author.
29 34
 	"""
35
+
36
+	TYPE_BAN: PatternActionType = 'ban'
37
+	TYPE_DELETE: PatternActionType = 'delete'
38
+	TYPE_KICK: PatternActionType = 'kick'
39
+	TYPE_INFORM_MODS: PatternActionType = 'modinfo'
40
+	TYPE_WARN_MODS: PatternActionType = 'modwarn'
41
+	TYPE_REPLY: PatternActionType = 'reply'
42
+
30 43
 	def __init__(self, action: str, args: list[Any]):
31 44
 		self.action = action
32 45
 		self.arguments = list(args)
@@ -55,56 +68,81 @@ class PatternSimpleExpression(PatternExpression):
55 68
 	Message matching expression with a simple "<field> <operator> <value>"
56 69
 	structure.
57 70
 	"""
58
-	def __init__(self, field: str, operator: str, value: Any):
71
+
72
+	FIELD_CONTENT_MARKDOWN: PatternField = 'content.markdown'
73
+	FIELD_CONTENT_PLAIN: PatternField = 'content.plain'
74
+	FIELD_AUTHOR_ID: PatternField = 'author.id'
75
+	FIELD_AUTHOR_JOINAGE: PatternField = 'author.joinage'
76
+	FIELD_AUTHOR_NAME: PatternField = 'author.name'
77
+	FIELD_LAST_MATCHED: PatternField = 'lastmatched'
78
+
79
+	# Less preferred but recognized field aliases
80
+	ALIAS_FIELD_CONTENT_MARKDOWN: PatternField = 'content'
81
+	ALIAS_FIELD_AUTHOR_ID: PatternField = 'author'
82
+
83
+	OP_EQUALS: PatternComparisonOperator = '=='
84
+	OP_NOT_EQUALS: PatternComparisonOperator = '!='
85
+	OP_LESS_THAN: PatternComparisonOperator = '<'
86
+	OP_GREATER_THAN: PatternComparisonOperator = '>'
87
+	OP_LESS_THAN_OR_EQUALS: PatternComparisonOperator = '<='
88
+	OP_GREATER_THAN_OR_EQUALS: PatternComparisonOperator = '>='
89
+	OP_CONTAINS: PatternComparisonOperator = 'contains'
90
+	OP_NOT_CONTAINS: PatternComparisonOperator = '!contains'
91
+	OP_MATCHES: PatternComparisonOperator = 'matches'
92
+	OP_NOT_MATCHES: PatternComparisonOperator = '!matches'
93
+	OP_CONTAINS_WORD: PatternComparisonOperator = 'containsword'
94
+	OP_NOT_CONTAINS_WORD: PatternComparisonOperator = '!containsword'
95
+
96
+	def __init__(self, field: PatternField, operator: PatternComparisonOperator, value: Any):
59 97
 		super().__init__()
60
-		self.field: str = field
61
-		self.operator: str = operator
98
+		self.field: PatternField = field
99
+		self.operator: PatternComparisonOperator = operator
62 100
 		self.value: Any = value
63 101
 
64 102
 	def __field_value(self, message: Message, other_fields: dict[str, Any]) -> Any:
65
-		if self.field in ('content.markdown', 'content'):
103
+		cls = PatternSimpleExpression
104
+		if self.field in (cls.FIELD_CONTENT_MARKDOWN, cls.ALIAS_FIELD_CONTENT_MARKDOWN):
66 105
 			return message.content
67
-		if self.field == 'content.plain':
106
+		if self.field == cls.FIELD_CONTENT_PLAIN:
68 107
 			return discordutils.remove_markdown(message.clean_content)
69
-		if self.field == 'author':
108
+		if self.field in (cls.FIELD_AUTHOR_ID, cls.ALIAS_FIELD_AUTHOR_ID):
70 109
 			return str(message.author.id)
71
-		if self.field == 'author.id':
72
-			return str(message.author.id)
73
-		if self.field == 'author.joinage':
110
+		if self.field == cls.FIELD_AUTHOR_JOINAGE:
74 111
 			return message.created_at - message.author.joined_at
75
-		if self.field == 'author.name':
112
+		if self.field == cls.FIELD_AUTHOR_NAME:
76 113
 			return message.author.name
77
-		if self.field == 'lastmatched':
114
+		if self.field == cls.FIELD_LAST_MATCHED:
78 115
 			long_ago = datetime(year=1900, month=1, day=1, hour=0, minute=0, second=0, tzinfo=timezone.utc)
79 116
 			last_matched = other_fields.get('last_matched') or long_ago
80 117
 			return message.created_at - last_matched
81
-		raise ValueError(f'Bad field name {self.field}')
118
+		raise ValueError(f'Bad field name "{self.field}"')
82 119
 
83 120
 	def matches(self, message: Message, other_fields: dict[str, Any]) -> bool:
121
+		cls = PatternSimpleExpression
84 122
 		field_value = self.__field_value(message, other_fields)
85
-		if self.operator == '==':
123
+		if self.operator == cls.OP_EQUALS:
86 124
 			if isinstance(field_value, str) and isinstance(self.value, str):
87 125
 				return field_value.lower() == self.value.lower()
88 126
 			return field_value == self.value
89
-		if self.operator == '!=':
127
+		if self.operator == cls.OP_NOT_EQUALS:
90 128
 			if isinstance(field_value, str) and isinstance(self.value, str):
91 129
 				return field_value.lower() != self.value.lower()
92 130
 			return field_value != self.value
93
-		if self.operator == '<':
131
+		if self.operator == cls.OP_LESS_THAN:
94 132
 			return field_value < self.value
95
-		if self.operator == '>':
133
+		if self.operator == cls.OP_GREATER_THAN:
96 134
 			return field_value > self.value
97
-		if self.operator == '<=':
135
+		if self.operator == cls.OP_LESS_THAN_OR_EQUALS:
98 136
 			return field_value <= self.value
99
-		if self.operator == '>=':
137
+		if self.operator == cls.OP_GREATER_THAN_OR_EQUALS:
100 138
 			return field_value >= self.value
101
-		if self.operator == 'contains':
139
+		if self.operator == cls.OP_CONTAINS:
102 140
 			return self.value.lower() in field_value.lower()
103
-		if self.operator == '!contains':
141
+		if self.operator == cls.OP_NOT_CONTAINS:
104 142
 			return self.value.lower() not in field_value.lower()
105
-		if self.operator in ('matches', 'containsword'):
143
+		if self.operator in (cls.OP_MATCHES, cls.OP_CONTAINS_WORD):
106 144
 			return self.value.search(field_value.lower()) is not None
107
-		if self.operator in ('!matches', '!containsword'):
145
+		if self.operator in (cls.OP_NOT_MATCHES, cls.OP_NOT_CONTAINS_WORD):
108 146
 			return self.value.search(field_value.lower()) is None
109 147
 		raise ValueError(f'Bad operator {self.operator}')
110 148
 
@@ -116,20 +154,24 @@ class PatternCompoundExpression(PatternExpression):
116 154
 	Message matching expression that combines several child expressions with
117 155
 	a boolean operator.
118 156
 	"""
119
-	def __init__(self, operator: str, operands: list[PatternExpression]):
157
+	OP_NOT = '!'
158
+	OP_AND = 'and'
159
+	OP_OR = 'or'
160
+
161
+	def __init__(self, operator: PatternBooleanOperator, operands: list[PatternExpression]):
120 162
 		super().__init__()
121
-		self.operator = operator
163
+		self.operator: PatternBooleanOperator = operator
122 164
 		self.operands = list(operands)
123 165
 
124 166
 	def matches(self, message: Message, other_fields: dict[str, Any]) -> bool:
125
-		if self.operator == '!':
167
+		if self.operator == PatternCompoundExpression.OP_NOT:
126 168
 			return not self.operands[0].matches(message, other_fields)
127
-		if self.operator == 'and':
169
+		if self.operator == PatternCompoundExpression.OP_AND:
128 170
 			for op in self.operands:
129 171
 				if not op.matches(message, other_fields):
130 172
 					return False
131 173
 			return True
132
-		if self.operator == 'or':
174
+		if self.operator == PatternCompoundExpression.OP_OR:
133 175
 			for op in self.operands:
134 176
 				if op.matches(message, other_fields):
135 177
 					return True
@@ -137,7 +179,7 @@ class PatternCompoundExpression(PatternExpression):
137 179
 		raise ValueError(f'Bad operator "{self.operator}"')
138 180
 
139 181
 	def __str__(self) -> str:
140
-		if self.operator == '!':
182
+		if self.operator == PatternCompoundExpression.OP_NOT:
141 183
 			return f'(!( {self.operands[0]} ))'
142 184
 		strs = map(str, self.operands)
143 185
 		joined = f' {self.operator} '.join(strs)
@@ -204,52 +246,63 @@ class PatternCompiler:
204 246
 	"""
205 247
 	Parses a user-provided message filter statement into a PatternStatement.
206 248
 	"""
207
-	TYPE_FLOAT: str = 'float'
208
-	TYPE_ID: str = 'id'
209
-	TYPE_INT: str = 'int'
210
-	TYPE_MEMBER: str = 'Member'
211
-	TYPE_REGEX: str = 'regex'
212
-	TYPE_TEXT: str = 'text'
213
-	TYPE_TIMESPAN: str = 'timespan'
214
-
215
-	FIELD_TO_TYPE: dict[str, str] = {
216
-		'author': TYPE_MEMBER,
217
-		'author.id': TYPE_ID,
218
-		'author.joinage': TYPE_TIMESPAN,
219
-		'author.name': TYPE_TEXT,
220
-		'content': TYPE_TEXT, # deprecated, use content.markdown or content.plain
221
-		'content.markdown': TYPE_TEXT,
222
-		'content.plain': TYPE_TEXT,
223
-		'lastmatched': TYPE_TIMESPAN,
249
+	DATATYPE_FLOAT: str = 'float'
250
+	DATATYPE_ID: str = 'id'
251
+	DATATYPE_INT: str = 'int'
252
+	DATATYPE_MEMBER: str = 'Member'
253
+	DATATYPE_REGEX: str = 'regex'
254
+	DATATYPE_TEXT: str = 'text'
255
+	DATATYPE_TIMESPAN: str = 'timespan'
256
+
257
+	FIELD_TO_DATATYPE: dict[PatternField, str] = {
258
+		PatternSimpleExpression.ALIAS_FIELD_AUTHOR_ID: DATATYPE_MEMBER,
259
+		PatternSimpleExpression.FIELD_AUTHOR_ID: DATATYPE_ID,
260
+		PatternSimpleExpression.FIELD_AUTHOR_JOINAGE: DATATYPE_TIMESPAN,
261
+		PatternSimpleExpression.FIELD_AUTHOR_NAME: DATATYPE_TEXT,
262
+		PatternSimpleExpression.ALIAS_FIELD_CONTENT_MARKDOWN: DATATYPE_TEXT, # deprecated, use content.markdown or content.plain
263
+		PatternSimpleExpression.FIELD_CONTENT_MARKDOWN: DATATYPE_TEXT,
264
+		PatternSimpleExpression.FIELD_CONTENT_PLAIN: DATATYPE_TEXT,
265
+		PatternSimpleExpression.FIELD_LAST_MATCHED: DATATYPE_TIMESPAN,
224 266
 	}
225
-	DEPRECATED_FIELDS: set[str] = { 'content' }
226
-
227
-	ACTION_TO_ARGS: dict[str, list[str]] = {
228
-		'ban': [],
229
-		'delete': [],
230
-		'kick': [],
231
-		'modinfo': [],
232
-		'modwarn': [],
233
-		'reply': [ TYPE_TEXT ],
267
+	DEPRECATED_FIELDS: set[PatternField] = { 'content' }
268
+
269
+	ACTION_TO_ARGS: dict[PatternActionType, list[str]] = {
270
+		PatternAction.TYPE_BAN: [],
271
+		PatternAction.TYPE_DELETE: [],
272
+		PatternAction.TYPE_KICK: [],
273
+		PatternAction.TYPE_INFORM_MODS: [],
274
+		PatternAction.TYPE_WARN_MODS: [],
275
+		PatternAction.TYPE_REPLY: [ DATATYPE_TEXT ],
234 276
 	}
235 277
 
236
-	OPERATORS_IDENTITY: set[str] = { '==', '!=' }
237
-	OPERATORS_COMPARISON: set[str] = { '<', '>', '<=', '>=' }
238
-	OPERATORS_NUMERIC: set[str] = OPERATORS_IDENTITY | OPERATORS_COMPARISON
239
-	OPERATORS_TEXT: set[str] = OPERATORS_IDENTITY | {
240
-		'contains', '!contains',
241
-		'containsword', '!containsword',
242
-		'matches', '!matches',
278
+	OPERATORS_IDENTITY: set[PatternComparisonOperator] = {
279
+		PatternSimpleExpression.OP_EQUALS,
280
+		PatternSimpleExpression.OP_NOT_EQUALS,
281
+	}
282
+	OPERATORS_COMPARISON: set[PatternComparisonOperator] = {
283
+		PatternSimpleExpression.OP_LESS_THAN,
284
+		PatternSimpleExpression.OP_GREATER_THAN,
285
+		PatternSimpleExpression.OP_LESS_THAN_OR_EQUALS,
286
+		PatternSimpleExpression.OP_GREATER_THAN_OR_EQUALS,
287
+	}
288
+	OPERATORS_NUMERIC: set[PatternComparisonOperator] = OPERATORS_IDENTITY | OPERATORS_COMPARISON
289
+	OPERATORS_TEXT: set[PatternComparisonOperator] = OPERATORS_IDENTITY | {
290
+		PatternSimpleExpression.OP_CONTAINS,
291
+		PatternSimpleExpression.OP_NOT_CONTAINS,
292
+		PatternSimpleExpression.OP_CONTAINS_WORD,
293
+		PatternSimpleExpression.OP_NOT_CONTAINS_WORD,
294
+		PatternSimpleExpression.OP_MATCHES,
295
+		PatternSimpleExpression.OP_NOT_MATCHES,
243 296
 	}
244 297
 	OPERATORS_ALL: set[str] = OPERATORS_IDENTITY | OPERATORS_COMPARISON | OPERATORS_TEXT
245 298
 
246
-	TYPE_TO_OPERATORS: dict[str, set[str]] = {
247
-		TYPE_ID: OPERATORS_IDENTITY,
248
-		TYPE_MEMBER: OPERATORS_IDENTITY,
249
-		TYPE_TEXT: OPERATORS_TEXT,
250
-		TYPE_INT: OPERATORS_NUMERIC,
251
-		TYPE_FLOAT: OPERATORS_NUMERIC,
252
-		TYPE_TIMESPAN: OPERATORS_NUMERIC,
299
+	DATATYPE_TO_OPERATORS: dict[str, set[PatternComparisonOperator]] = {
300
+		DATATYPE_ID: OPERATORS_IDENTITY,
301
+		DATATYPE_MEMBER: OPERATORS_IDENTITY,
302
+		DATATYPE_TEXT: OPERATORS_TEXT,
303
+		DATATYPE_INT: OPERATORS_NUMERIC,
304
+		DATATYPE_FLOAT: OPERATORS_NUMERIC,
305
+		DATATYPE_TIMESPAN: OPERATORS_NUMERIC,
253 306
 	}
254 307
 
255 308
 	WHITESPACE_CHARS: str = ' \t\n\r'
@@ -458,7 +511,7 @@ class PatternCompiler:
458 511
 					return subexpressions[0], token_index
459 512
 				return (PatternCompoundExpression(last_compound_operator,
460 513
 					subexpressions), token_index)
461
-			if tokens[token_index] in { "and", "or" }:
514
+			if tokens[token_index] in { PatternCompoundExpression.OP_AND, PatternCompoundExpression.OP_OR }:
462 515
 				compound_operator = tokens[token_index]
463 516
 				if last_compound_operator and \
464 517
 						compound_operator != last_compound_operator:
@@ -468,7 +521,7 @@ class PatternCompiler:
468 521
 					]
469 522
 				last_compound_operator = compound_operator
470 523
 				token_index += 1
471
-			if tokens[token_index] == '!':
524
+			if tokens[token_index] == PatternCompoundExpression.OP_NOT:
472 525
 				(exp, next_index) = cls.__read_expression(tokens,
473 526
 						token_index + 1, depth + 1, one_subexpression=True)
474 527
 				subexpressions.append(PatternCompoundExpression('!', [exp]))
@@ -507,10 +560,10 @@ class PatternCompiler:
507 560
 			raise PatternError('Expression nests too deeply')
508 561
 		if token_index >= len(tokens):
509 562
 			raise PatternError('Expected field name, found EOL')
510
-		field = tokens[token_index]
563
+		field: PatternField = tokens[token_index]
511 564
 		token_index += 1
512 565
 
513
-		datatype = cls.FIELD_TO_TYPE.get(field)
566
+		datatype = cls.FIELD_TO_DATATYPE.get(field, None)
514 567
 		if datatype is None:
515 568
 			raise PatternError(f'No such field "{field}"')
516 569
 
@@ -519,13 +572,13 @@ class PatternCompiler:
519 572
 		op = tokens[token_index]
520 573
 		token_index += 1
521 574
 
522
-		if op == '!':
575
+		if op == PatternCompoundExpression.OP_NOT:
523 576
 			if token_index >= len(tokens):
524 577
 				raise PatternError('Expected operator, found EOL')
525 578
 			op = '!' + tokens[token_index]
526 579
 			token_index += 1
527 580
 
528
-		allowed_ops = cls.TYPE_TO_OPERATORS[datatype]
581
+		allowed_ops = cls.DATATYPE_TO_OPERATORS[datatype]
529 582
 		if op not in allowed_ops:
530 583
 			if op in cls.OPERATORS_ALL:
531 584
 				raise PatternError(f'Operator {op} cannot be used with ' + \
@@ -551,13 +604,13 @@ class PatternCompiler:
551 604
 		"""
552 605
 		Converts a value token to its Python value. Raises ValueError on failure.
553 606
 		"""
554
-		if datatype == cls.TYPE_ID:
607
+		if datatype == cls.DATATYPE_ID:
555 608
 			if not is_user_id(value):
556 609
 				raise ValueError(f'Illegal user id value: {value}')
557 610
 			return value
558
-		if datatype == cls.TYPE_MEMBER:
611
+		if datatype == cls.DATATYPE_MEMBER:
559 612
 			return user_id_from_mention(value)
560
-		if datatype == cls.TYPE_TEXT:
613
+		if datatype == cls.DATATYPE_TEXT:
561 614
 			s = str_from_quoted_str(value)
562 615
 			if op in ('matches', '!matches'):
563 616
 				try:
@@ -570,10 +623,10 @@ class PatternCompiler:
570 623
 				except re.error as e:
571 624
 					raise ValueError(f'Invalid regex: {e}') from e
572 625
 			return s
573
-		if datatype == cls.TYPE_INT:
626
+		if datatype == cls.DATATYPE_INT:
574 627
 			return int(value)
575
-		if datatype == cls.TYPE_FLOAT:
628
+		if datatype == cls.DATATYPE_FLOAT:
576 629
 			return float(value)
577
-		if datatype == cls.TYPE_TIMESPAN:
630
+		if datatype == cls.DATATYPE_TIMESPAN:
578 631
 			return timedelta_from_str(value)
579 632
 		raise ValueError(f'Unhandled datatype {datatype}')

Loading…
취소
저장