Преглед изворни кода

Moving pattern stuff to separate file. Small improvements.

master
Rocketsoup пре 4 година
родитељ
комит
c6750d9b75
3 измењених фајлова са 631 додато и 537 уклоњено
  1. 18
    13
      patterns.md
  2. 45
    524
      rocketbot/cogs/patterncog.py
  3. 568
    0
      rocketbot/pattern.py

+ 18
- 13
patterns.md Прегледај датотеку

@@ -33,6 +33,7 @@ Available actions:
33 33
 * `ban` - Bans the user. The "reason" in the audit log will reference the pattern name.
34 34
 * `delete` - Deletes the message.
35 35
 * `kick` - Kicks the user. The "reason" in the audit log will reference the pattern name.
36
+* `modinfo` - Posts an informative message in the bot warning channel but does not tag the mods. Useful for logging a pattern that is mildly harmful but not worth getting immediate mod attention. Message will have reactions to delete, kick, or ban.
36 37
 * `modwarn` - Tags the mods in a warning message. The message will offer quick actions to manually delete the message, kick the user, and ban the user (assuming the other actions didn't already do one or more of these things)
37 38
 * `reply "message"` - Makes Rocketbot automatically reply to their message with the given text.
38 39
 
@@ -44,7 +45,7 @@ The simplest expression just consists of a message field, a comparison operator,
44 45
 content.plain contains "forbidden"
45 46
 ```
46 47
 
47
-The message will match if its `content.plain` `contains` the word `"forbidden"`.
48
+The message will match if its `content.plain` field `contains` the word `"forbidden"`.
48 49
 
49 50
 The available operators and type of value depends on the field being accessed.
50 51
 
@@ -54,25 +55,27 @@ The available operators and type of value depends on the field being accessed.
54 55
 * `content.markdown` - The raw markdown of the message. This contains all markdown characters, and mentions are of the `<@!0000000>` form. Available operators: `==`, `!=`, `contains`, `!contains`, `matches`, `!matches`. Comparison value must be a quoted string.
55 56
 * `author` - Who sent the message. Available operators: `==`, `!=`. Comparison value must be a user mention (an @ that Discord will tab-complete for you).
56 57
 * `author.id` - The numeric ID of the user who sent the message. Available operators: `==`, `!=`. Comparison value must be a numeric user ID.
57
-* `author.name` - The username of the author. Available operators: `==`, `!=`, `contains`, `!contains`, `matches`, `!matches`. Comparison value must be a quoted string.
58
+* `author.name` - The username of the author. Available operators: `==`, `!=`, `contains`, `!contains`, `containsword`, `!containsword`, `matches`, `!matches`. Comparison value must be a quoted string.
58 59
 * `author.joinage` - How much time has elapsed from when the author joined and when the message was sent. If the user has joined and left multiple times this is the most recent join time. Available operators: `==`, `!=`, `<`, `>`, `<=`, `>=`. Comparison value must be a timespan (see below)
59 60
 
60 61
 #### Operators
61 62
 
62
-* `==` - The values are equal
63
-* `!=` - The values are not equal
64
-* `<` - The field is less than the given value
65
-* `>` - The field is greater than the given value
66
-* `<=` - The field is less than or equal to the given value
67
-* `>=` - The field is greater than or equal to the given value
68
-* `contains` - The value is contained somewhere in the field value
69
-* `!contains` - The value is not contained anywhere in the field value
70
-* `matches` - The given regular expression matches the field value
71
-* `!matches` - The given regular expression does not match the field value
63
+* `==` - The values are equal.
64
+* `!=` - The values are not equal.
65
+* `<` - The field is less than the given value.
66
+* `>` - The field is greater than the given value.
67
+* `<=` - The field is less than or equal to the given value.
68
+* `>=` - The field is greater than or equal to the given value.
69
+* `contains` - The value is contained somewhere in the field value. Will match parts of words (e.g. "cat" will match "scatter").
70
+* `!contains` - The value is not contained anywhere in the field value.
71
+* `containsword` - The value is contained somewhere in the field value as a whole word (e.g. "cat" will not match "scatter").
72
+* `!containsword` - The value is not contained somewhere in the field value as a whole word.
73
+* `matches` - The given regular expression matches part of the field value.
74
+* `!matches` - The given regular expression does not match any part of the field value.
72 75
 
73 76
 #### Values
74 77
 
75
-Text values must be enclosed in double quote (`"`) characters.
78
+Text values must be enclosed in double quote (`"`) characters. To include a literal quote character, escape it with a backslash, e.g. `"string with \"quotes\" in it"`. Literal backslashes can be escaped with two backslashes, e.g. `"string with \\ backslash"`.
76 79
 
77 80
 Timespans consist of one or more pairs of a number and a unit letter ("d" for days, "h" for hours, "m" for minutes, "s" for seconds). Examples:
78 81
 
@@ -80,6 +83,8 @@ Timespans consist of one or more pairs of a number and a unit letter ("d" for da
80 83
 * `1h30m` - 1 hour, 30 minutes
81 84
 * `99d9h9m9s` - 99 days, 9 hours, 9 minutes, 9 seconds
82 85
 
86
+Regular expressions are provided in double quotes like a regular string. Backslashed character classes must be escaped, e.g. `"foo\\s+bar"` for "foo" and "bar" separated by whitespace.
87
+
83 88
 ### Compound Expressions
84 89
 
85 90
 Multiple expressions can be combined with "and" or "or". For example:

+ 45
- 524
rocketbot/cogs/patterncog.py Прегледај датотеку

@@ -2,166 +2,15 @@
2 2
 Cog for matching messages against guild-configurable criteria and taking
3 3
 automated actions on them.
4 4
 """
5
-import re
6
-from abc import ABCMeta, abstractmethod
7 5
 from discord import Guild, Member, Message, utils as discordutils
8 6
 from discord.ext import commands
9 7
 
10 8
 from config import CONFIG
11 9
 from rocketbot.cogs.basecog import BaseCog, BotMessage, BotMessageReaction
12 10
 from rocketbot.cogsetting import CogSetting
11
+from rocketbot.pattern import PatternCompiler, PatternDeprecationError, \
12
+	PatternError, PatternStatement
13 13
 from rocketbot.storage import Storage
14
-from rocketbot.utils import is_user_id, str_from_quoted_str, timedelta_from_str, \
15
-	user_id_from_mention
16
-
17
-class PatternAction:
18
-	"""
19
-	Describes one action to take on a matched message or its author.
20
-	"""
21
-	def __init__(self, action: str, args: list):
22
-		self.action = action
23
-		self.arguments = list(args)
24
-
25
-	def __str__(self) -> str:
26
-		arg_str = ', '.join(self.arguments)
27
-		return f'{self.action}({arg_str})'
28
-
29
-class PatternExpression(metaclass=ABCMeta):
30
-	"""
31
-	Abstract message matching expression.
32
-	"""
33
-	def __init__(self):
34
-		pass
35
-
36
-	@abstractmethod
37
-	def matches(self, message: Message) -> bool:
38
-		"""
39
-		Whether a message matches this expression.
40
-		"""
41
-		return False
42
-
43
-class PatternSimpleExpression(PatternExpression):
44
-	"""
45
-	Message matching expression with a simple "<field> <operator> <value>"
46
-	structure.
47
-	"""
48
-	def __init__(self, field: str, operator: str, value):
49
-		super().__init__()
50
-		self.field = field
51
-		self.operator = operator
52
-		self.value = value
53
-
54
-	def __field_value(self, message: Message):
55
-		if self.field in ('content.markdown', 'content'):
56
-			return message.content
57
-		if self.field == 'content.plain':
58
-			return discordutils.remove_markdown(message.clean_content)
59
-		if self.field == 'author':
60
-			return str(message.author.id)
61
-		if self.field == 'author.id':
62
-			return str(message.author.id)
63
-		if self.field == 'author.joinage':
64
-			return message.created_at - message.author.joined_at
65
-		if self.field == 'author.name':
66
-			return message.author.name
67
-		else:
68
-			raise ValueError(f'Bad field name {self.field}')
69
-
70
-	def matches(self, message: Message) -> bool:
71
-		field_value = self.__field_value(message)
72
-		if self.operator == '==':
73
-			if isinstance(field_value, str) and isinstance(self.value, str):
74
-				return field_value.lower() == self.value.lower()
75
-			return field_value == self.value
76
-		if self.operator == '!=':
77
-			if isinstance(field_value, str) and isinstance(self.value, str):
78
-				return field_value.lower() != self.value.lower()
79
-			return field_value != self.value
80
-		if self.operator == '<':
81
-			return field_value < self.value
82
-		if self.operator == '>':
83
-			return field_value > self.value
84
-		if self.operator == '<=':
85
-			return field_value <= self.value
86
-		if self.operator == '>=':
87
-			return field_value >= self.value
88
-		if self.operator == 'contains':
89
-			return self.value.lower() in field_value.lower()
90
-		if self.operator == '!contains':
91
-			return self.value.lower() not in field_value.lower()
92
-		if self.operator == 'matches':
93
-			p = re.compile(self.value.lower())
94
-			return p.match(field_value.lower()) is not None
95
-		if self.operator == '!matches':
96
-			p = re.compile(self.value.lower())
97
-			return p.match(field_value.lower()) is None
98
-		raise ValueError(f'Bad operator {self.operator}')
99
-
100
-	def __str__(self) -> str:
101
-		return f'({self.field} {self.operator} {self.value})'
102
-
103
-class PatternCompoundExpression(PatternExpression):
104
-	"""
105
-	Message matching expression that combines several child expressions with
106
-	a boolean operator.
107
-	"""
108
-	def __init__(self, operator: str, operands: list):
109
-		super().__init__()
110
-		self.operator = operator
111
-		self.operands = list(operands)
112
-
113
-	def matches(self, message: Message) -> bool:
114
-		if self.operator == '!':
115
-			return not self.operands[0].matches(message)
116
-		if self.operator == 'and':
117
-			for op in self.operands:
118
-				if not op.matches(message):
119
-					return False
120
-			return True
121
-		if self.operator == 'or':
122
-			for op in self.operands:
123
-				if op.matches(message):
124
-					return True
125
-			return False
126
-		raise ValueError(f'Bad operator "{self.operator}"')
127
-
128
-	def __str__(self) -> str:
129
-		if self.operator == '!':
130
-			return f'(!( {self.operands[0]} ))'
131
-		strs = map(str, self.operands)
132
-		joined = f' {self.operator} '.join(strs)
133
-		return f'( {joined} )'
134
-
135
-class PatternStatement:
136
-	"""
137
-	A full message match statement. If a message matches the given expression,
138
-	the given actions should be performed.
139
-	"""
140
-	def __init__(self,
141
-			name: str,
142
-			actions: list,
143
-			expression: PatternExpression,
144
-			original: str):
145
-		self.name = name
146
-		self.actions = list(actions)  # PatternAction[]
147
-		self.expression = expression
148
-		self.original = original
149
-
150
-	def to_json(self) -> dict:
151
-		"""
152
-		Returns a JSON representation of this statement.
153
-		"""
154
-		return {
155
-			'name': self.name,
156
-			'statement': self.original,
157
-		}
158
-
159
-	@classmethod
160
-	def from_json(cls, json: dict):
161
-		"""
162
-		Gets a PatternStatement from its JSON representation.
163
-		"""
164
-		return PatternCompiler.parse_statement(json['name'], json['statement'])
165 14
 
166 15
 class PatternContext:
167 16
 	"""
@@ -194,7 +43,12 @@ class PatternCog(BaseCog, name='Pattern Matching'):
194 43
 			pattern_list: list[PatternStatement] = []
195 44
 			for json in jsons:
196 45
 				try:
197
-					pattern_list.append(PatternStatement.from_json(json))
46
+					ps = PatternStatement.from_json(json)
47
+					pattern_list.append(ps)
48
+					try:
49
+						ps.check_deprecations()
50
+					except PatternDeprecationError as e:
51
+						self.log(guild, f'Pattern {ps.name}: {e}')
198 52
 				except PatternError as e:
199 53
 					self.log(guild, f'Error decoding pattern "{json["name"]}": {e}')
200 54
 			patterns = { p.name:p for p in pattern_list}
@@ -202,7 +56,9 @@ class PatternCog(BaseCog, name='Pattern Matching'):
202 56
 		return patterns
203 57
 
204 58
 	@classmethod
205
-	def __save_patterns(cls, guild: Guild, patterns: dict[str, PatternStatement]) -> None:
59
+	def __save_patterns(cls,
60
+			guild: Guild,
61
+			patterns: dict[str, PatternStatement]) -> None:
206 62
 		to_save: list[dict] = list(map(PatternStatement.to_json, patterns.values()))
207 63
 		cls.set_guild_setting(guild, cls.SETTING_PATTERNS, to_save)
208 64
 
@@ -226,15 +82,20 @@ class PatternCog(BaseCog, name='Pattern Matching'):
226 82
 				await self.__trigger_actions(message, statement)
227 83
 				break
228 84
 
229
-	async def __trigger_actions(self, message: Message, statement: PatternStatement) -> None:
85
+	async def __trigger_actions(self,
86
+			message: Message,
87
+			statement: PatternStatement) -> None:
230 88
 		context = PatternContext(message, statement)
231
-		should_alert_mods = False
89
+		should_post_message = False
90
+		message_type: int = BotMessage.TYPE_DEFAULT
232 91
 		action_descriptions = []
233
-		self.log(message.guild, f'Message from {message.author.name} matched pattern "{statement.name}"')
92
+		self.log(message.guild, f'Message from {message.author.name} matched ' + \
93
+			f'pattern "{statement.name}"')
234 94
 		for action in statement.actions:
235 95
 			if action.action == 'ban':
236 96
 				await message.author.ban(
237
-					reason=f'Rocketbot: Message matched custom pattern named "{statement.name}"',
97
+					reason='Rocketbot: Message matched custom pattern named ' + \
98
+						f'"{statement.name}"',
238 99
 					delete_message_days=0)
239 100
 				context.is_banned = True
240 101
 				context.is_kicked = True
@@ -247,12 +108,18 @@ class PatternCog(BaseCog, name='Pattern Matching'):
247 108
 				self.log(message.guild, f'{message.author.name}\'s message deleted')
248 109
 			elif action.action == 'kick':
249 110
 				await message.author.kick(
250
-					reason=f'Rocketbot: Message matched custom pattern named "{statement.name}"')
111
+					reason='Rocketbot: Message matched custom pattern named ' + \
112
+						f'"{statement.name}"')
251 113
 				context.is_kicked = True
252 114
 				action_descriptions.append('Author kicked')
253 115
 				self.log(message.guild, f'{message.author.name} kicked')
116
+			elif action.action == 'modinfo':
117
+				should_post_message = True
118
+				message_type = BotMessage.TYPE_INFO
119
+				action_descriptions.append('Message logged')
254 120
 			elif action.action == 'modwarn':
255
-				should_alert_mods = True
121
+				should_post_message = True
122
+				message_type = BotMessage.TYPE_MOD_WARNING
256 123
 				action_descriptions.append('Mods alerted')
257 124
 			elif action.action == 'reply':
258 125
 				await message.reply(
@@ -260,19 +127,20 @@ class PatternCog(BaseCog, name='Pattern Matching'):
260 127
 					mention_author=False)
261 128
 				action_descriptions.append('Autoreplied')
262 129
 				self.log(message.guild, f'autoreplied to {message.author.name}')
263
-		bm = BotMessage(
264
-			message.guild,
265
-			f'User {message.author.name} tripped custom pattern ' + \
266
-				f'`{statement.name}`.\n\nAutomatic actions taken:\n• ' + \
267
-				('\n• '.join(action_descriptions)),
268
-			type=BotMessage.TYPE_MOD_WARNING if should_alert_mods else BotMessage.TYPE_INFO,
269
-			context=context)
270
-		bm.quote = discordutils.remove_markdown(message.clean_content)
271
-		await bm.set_reactions(BotMessageReaction.standard_set(
272
-			did_delete=context.is_deleted,
273
-			did_kick=context.is_kicked,
274
-			did_ban=context.is_banned))
275
-		await self.post_message(bm)
130
+		if should_post_message:
131
+			bm = BotMessage(
132
+				message.guild,
133
+				f'User {message.author.name} tripped custom pattern ' + \
134
+					f'`{statement.name}`.\n\nAutomatic actions taken:\n• ' + \
135
+					('\n• '.join(action_descriptions)),
136
+				type=message_type,
137
+				context=context)
138
+			bm.quote = discordutils.remove_markdown(message.clean_content)
139
+			await bm.set_reactions(BotMessageReaction.standard_set(
140
+				did_delete=context.is_deleted,
141
+				did_kick=context.is_kicked,
142
+				did_ban=context.is_banned))
143
+			await self.post_message(bm)
276 144
 
277 145
 	async def on_mod_react(self,
278 146
 			bot_message: BotMessage,
@@ -312,7 +180,8 @@ class PatternCog(BaseCog, name='Pattern Matching'):
312 180
 		brief='Adds a custom pattern',
313 181
 		description='Adds a custom pattern. Patterns use a simplified ' + \
314 182
 			'expression language. Full documentation found here: ' + \
315
-			'https://git.rixafrix.com/ialbert/python-app-rocketbot/src/branch/master/patterns.md',
183
+			'https://git.rixafrix.com/ialbert/python-app-rocketbot/src/' + \
184
+			'branch/master/patterns.md',
316 185
 		usage='<pattern_name> <expression...>',
317 186
 		ignore_extra=True
318 187
 	)
@@ -321,6 +190,7 @@ class PatternCog(BaseCog, name='Pattern Matching'):
321 190
 		pattern_str = PatternCompiler.expression_str_from_context(context, name)
322 191
 		try:
323 192
 			statement = PatternCompiler.parse_statement(name, pattern_str)
193
+			statement.check_deprecations()
324 194
 			patterns = self.__get_patterns(context.guild)
325 195
 			patterns[name] = statement
326 196
 			self.__save_patterns(context.guild, patterns)
@@ -363,352 +233,3 @@ class PatternCog(BaseCog, name='Pattern Matching'):
363 233
 		for name, statement in sorted(patterns.items()):
364 234
 			msg += f'Pattern `{name}`:\n```\n{statement.original}\n```\n'
365 235
 		await context.message.reply(msg, mention_author=False)
366
-
367
-class PatternError(RuntimeError):
368
-	"""
369
-	Error thrown when parsing a pattern statement.
370
-	"""
371
-
372
-class PatternCompiler:
373
-	"""
374
-	Parses a user-provided message filter statement into a PatternStatement.
375
-	"""
376
-	TYPE_ID = 'id'
377
-	TYPE_MEMBER = 'Member'
378
-	TYPE_TEXT = 'text'
379
-	TYPE_INT = 'int'
380
-	TYPE_FLOAT = 'float'
381
-	TYPE_TIMESPAN = 'timespan'
382
-
383
-	FIELD_TO_TYPE = {
384
-		'content.plain': TYPE_TEXT,
385
-		'content.markdown': TYPE_TEXT,
386
-		'author': TYPE_MEMBER,
387
-		'author.id': TYPE_ID,
388
-		'author.name': TYPE_TEXT,
389
-		'author.joinage': TYPE_TIMESPAN,
390
-
391
-		'content': TYPE_TEXT, # deprecated, use content.markdown or content.plain
392
-	}
393
-
394
-	ACTION_TO_ARGS = {
395
-		'ban': [],
396
-		'delete': [],
397
-		'kick': [],
398
-		'modwarn': [],
399
-		'reply': [ TYPE_TEXT ],
400
-	}
401
-
402
-	OPERATORS_IDENTITY = set([ '==', '!=' ])
403
-	OPERATORS_COMPARISON = set([ '<', '>', '<=', '>=' ])
404
-	OPERATORS_NUMERIC = OPERATORS_IDENTITY | OPERATORS_COMPARISON
405
-	OPERATORS_TEXT = OPERATORS_IDENTITY | set([ 'contains', '!contains', 'matches', '!matches' ])
406
-	OPERATORS_ALL = OPERATORS_IDENTITY | OPERATORS_COMPARISON | OPERATORS_TEXT
407
-
408
-	TYPE_TO_OPERATORS = {
409
-		TYPE_ID: OPERATORS_IDENTITY,
410
-		TYPE_MEMBER: OPERATORS_IDENTITY,
411
-		TYPE_TEXT: OPERATORS_TEXT,
412
-		TYPE_INT: OPERATORS_NUMERIC,
413
-		TYPE_FLOAT: OPERATORS_NUMERIC,
414
-		TYPE_TIMESPAN: OPERATORS_NUMERIC,
415
-	}
416
-
417
-	WHITESPACE_CHARS = ' \t\n\r'
418
-	STRING_QUOTE_CHARS = '\'"'
419
-	SYMBOL_CHARS = 'abcdefghijklmnopqrstuvwxyz.'
420
-	VALUE_CHARS = '0123456789dhms<@!>'
421
-	OP_CHARS = '<=>!(),'
422
-
423
-	@classmethod
424
-	def expression_str_from_context(cls, context: commands.Context, name: str) -> str:
425
-		"""
426
-		Extracts the statement string from an "add" command context.
427
-		"""
428
-		pattern_str = context.message.content
429
-		command_chain = [ name ]
430
-		cmd = context.command
431
-		while cmd:
432
-			command_chain.insert(0, cmd.name)
433
-			cmd = cmd.parent
434
-		command_chain[0] = f'{context.prefix}{command_chain[0]}'
435
-		for cmd in command_chain:
436
-			if pattern_str.startswith(cmd):
437
-				pattern_str = pattern_str[len(cmd):].lstrip()
438
-			elif pattern_str.startswith(f'"{cmd}"'):
439
-				pattern_str = pattern_str[len(cmd) + 2:].lstrip()
440
-		return pattern_str
441
-
442
-	@classmethod
443
-	def parse_statement(cls, name: str, statement: str) -> PatternStatement:
444
-		"""
445
-		Parses a user-provided message filter statement into a PatternStatement.
446
-		"""
447
-		tokens = cls.tokenize(statement)
448
-		token_index = 0
449
-		actions, token_index = cls.read_actions(tokens, token_index)
450
-		expression, token_index = cls.read_expression(tokens, token_index)
451
-		return PatternStatement(name, actions, expression, statement)
452
-
453
-	@classmethod
454
-	def tokenize(cls, statement: str) -> list:
455
-		"""
456
-		Converts a message filter statement into a list of tokens.
457
-		"""
458
-		tokens = []
459
-		in_quote = False
460
-		in_escape = False
461
-		all_token_types = set([ 'sym', 'op', 'val' ])
462
-		possible_token_types = set(all_token_types)
463
-		current_token = ''
464
-		for ch in statement:
465
-			if in_quote:
466
-				if in_escape:
467
-					if ch == 'n':
468
-						current_token += '\n'
469
-					elif ch == 't':
470
-						current_token += '\t'
471
-					else:
472
-						current_token += ch
473
-					in_escape = False
474
-				elif ch == '\\':
475
-					in_escape = True
476
-				elif ch == in_quote:
477
-					current_token += ch
478
-					tokens.append(current_token)
479
-					current_token = ''
480
-					possible_token_types |= all_token_types
481
-					in_quote = False
482
-				else:
483
-					current_token += ch
484
-			else:
485
-				if ch in cls.STRING_QUOTE_CHARS:
486
-					if len(current_token) > 0:
487
-						tokens.append(current_token)
488
-						current_token = ''
489
-						possible_token_types |= all_token_types
490
-					in_quote = ch
491
-					current_token = ch
492
-				elif ch == '\\':
493
-					raise PatternError("Unexpected \\ outside quoted string")
494
-				elif ch in cls.WHITESPACE_CHARS:
495
-					if len(current_token) > 0:
496
-						tokens.append(current_token)
497
-					current_token = ''
498
-					possible_token_types |= all_token_types
499
-				else:
500
-					possible_ch_types = set()
501
-					if ch in cls.SYMBOL_CHARS:
502
-						possible_ch_types.add('sym')
503
-					if ch in cls.VALUE_CHARS:
504
-						possible_ch_types.add('val')
505
-					if ch in cls.OP_CHARS:
506
-						possible_ch_types.add('op')
507
-					if len(current_token) > 0 and possible_ch_types.isdisjoint(possible_token_types):
508
-						if len(current_token) > 0:
509
-							tokens.append(current_token)
510
-							current_token = ''
511
-							possible_token_types |= all_token_types
512
-					possible_token_types &= possible_ch_types
513
-					current_token += ch
514
-		if len(current_token) > 0:
515
-			tokens.append(current_token)
516
-
517
-		# Some symbols might be glommed onto other tokens. Split 'em up.
518
-		prefixes_to_split = [ '!', '(', ',' ]
519
-		suffixes_to_split = [ ')', ',' ]
520
-		i = 0
521
-		while i < len(tokens):
522
-			token = tokens[i]
523
-			mutated = False
524
-			for prefix in prefixes_to_split:
525
-				if token.startswith(prefix) and len(token) > len(prefix):
526
-					tokens.insert(i, prefix)
527
-					tokens[i + 1] = token[len(prefix):]
528
-					i += 1
529
-					mutated = True
530
-					break
531
-			if mutated:
532
-				continue
533
-			for suffix in suffixes_to_split:
534
-				if token.endswith(suffix) and len(token) > len(suffix):
535
-					tokens[i] = token[0:-len(suffix)]
536
-					tokens.insert(i + 1, suffix)
537
-					mutated = True
538
-					break
539
-			if mutated:
540
-				continue
541
-			i += 1
542
-		return tokens
543
-
544
-	@classmethod
545
-	def read_actions(cls, tokens: list, token_index: int) -> tuple:
546
-		"""
547
-		Reads the actions from a list of statement tokens. Returns a tuple
548
-		containing a list of PatternActions and the token index this method
549
-		left off at (the token after the "if").
550
-		"""
551
-		actions = []
552
-		current_action_tokens = []
553
-		while token_index < len(tokens):
554
-			token = tokens[token_index]
555
-			if token == 'if':
556
-				if len(current_action_tokens) > 0:
557
-					a = PatternAction(current_action_tokens[0], current_action_tokens[1:])
558
-					cls.__validate_action(a)
559
-					actions.append(a)
560
-				token_index += 1
561
-				return (actions, token_index)
562
-			elif token == ',':
563
-				if len(current_action_tokens) < 1:
564
-					raise PatternError('Unexpected ,')
565
-				a = PatternAction(current_action_tokens[0], current_action_tokens[1:])
566
-				cls.__validate_action(a)
567
-				actions.append(a)
568
-				current_action_tokens = []
569
-			else:
570
-				current_action_tokens.append(token)
571
-			token_index += 1
572
-		raise PatternError('Unexpected end of line in action list')
573
-
574
-	@classmethod
575
-	def __validate_action(cls, action: PatternAction) -> None:
576
-		args = cls.ACTION_TO_ARGS.get(action.action)
577
-		if args is None:
578
-			raise PatternError(f'Unknown action "{action.action}"')
579
-		if len(action.arguments) != len(args):
580
-			if len(args) == 0:
581
-				raise PatternError(f'Action "{action.action}" expects no arguments, ' + \
582
-					f'got {len(action.arguments)}.')
583
-			else:
584
-				raise PatternError(f'Action "{action.action}" expects {len(args)} ' + \
585
-					f'arguments, got {len(action.arguments)}.')
586
-		for i, datatype in enumerate(args):
587
-			action.arguments[i] = cls.parse_value(action.arguments[i], datatype)
588
-
589
-	@classmethod
590
-	def read_expression(cls,
591
-			tokens: list,
592
-			token_index: int,
593
-			depth: int = 0,
594
-			one_subexpression: bool = False) -> tuple:
595
-		"""
596
-		Reads an expression from a list of statement tokens. Returns a tuple
597
-		containing the PatternExpression and the token index it left off at.
598
-		If one_subexpression is True then it will return after reading a
599
-		single expression instead of joining multiples (for readong the
600
-		subject of a NOT expression).
601
-		"""
602
-		subexpressions = []
603
-		last_compound_operator = None
604
-		while token_index < len(tokens):
605
-			if one_subexpression:
606
-				if len(subexpressions) == 1:
607
-					return (subexpressions[0], token_index)
608
-				if len(subexpressions) > 1:
609
-					raise PatternError('Too many subexpressions')
610
-			compound_operator = None
611
-			if tokens[token_index] == ')':
612
-				if len(subexpressions) == 0:
613
-					raise PatternError('No subexpressions')
614
-				if len(subexpressions) == 1:
615
-					return (subexpressions[0], token_index)
616
-				return (PatternCompoundExpression(last_compound_operator, subexpressions), token_index)
617
-			if tokens[token_index] in set(["and", "or"]):
618
-				compound_operator = tokens[token_index]
619
-				if last_compound_operator and compound_operator != last_compound_operator:
620
-					subexpressions = [ PatternCompoundExpression(last_compound_operator, subexpressions) ]
621
-					last_compound_operator = compound_operator
622
-				else:
623
-					last_compound_operator = compound_operator
624
-				token_index += 1
625
-			if tokens[token_index] == '!':
626
-				(exp, next_index) = cls.read_expression(tokens, token_index + 1, \
627
-						depth + 1, one_subexpression=True)
628
-				subexpressions.append(PatternCompoundExpression('!', [exp]))
629
-				token_index = next_index
630
-			elif tokens[token_index] == '(':
631
-				(exp, next_index) = cls.read_expression(tokens, token_index + 1, depth + 1)
632
-				if tokens[next_index] != ')':
633
-					raise PatternError('Expected )')
634
-				subexpressions.append(exp)
635
-				token_index = next_index + 1
636
-			else:
637
-				(simple, next_index) = cls.read_simple_expression(tokens, token_index, depth)
638
-				subexpressions.append(simple)
639
-				token_index = next_index
640
-		if len(subexpressions) == 0:
641
-			raise PatternError('No subexpressions')
642
-		elif len(subexpressions) == 1:
643
-			return (subexpressions[0], token_index)
644
-		else:
645
-			return (PatternCompoundExpression(last_compound_operator, subexpressions), token_index)
646
-
647
-	@classmethod
648
-	def read_simple_expression(cls, tokens: list, token_index: int, depth: int = 0) -> tuple:
649
-		"""
650
-		Reads a simple expression consisting of a field name, operator, and
651
-		comparison value. Returns a tuple of the PatternSimpleExpression and
652
-		the token index it left off at.
653
-		"""
654
-		if depth > 8:
655
-			raise PatternError('Expression nests too deeply')
656
-		if token_index >= len(tokens):
657
-			raise PatternError('Expected field name, found EOL')
658
-		field = tokens[token_index]
659
-		token_index += 1
660
-
661
-		datatype = cls.FIELD_TO_TYPE.get(field)
662
-		if datatype is None:
663
-			raise PatternError(f'No such field "{field}"')
664
-
665
-		if token_index >= len(tokens):
666
-			raise PatternError('Expected operator, found EOL')
667
-		op = tokens[token_index]
668
-		token_index += 1
669
-
670
-		if op == '!':
671
-			if token_index >= len(tokens):
672
-				raise PatternError('Expected operator, found EOL')
673
-			op = '!' + tokens[token_index]
674
-			token_index += 1
675
-
676
-		allowed_ops = cls.TYPE_TO_OPERATORS[datatype]
677
-		if op not in allowed_ops:
678
-			if op in cls.OPERATORS_ALL:
679
-				raise PatternError(f'Operator {op} cannot be used with field "{field}"')
680
-			raise PatternError(f'Unrecognized operator "{op}" - allowed: {list(allowed_ops)}')
681
-
682
-		if token_index >= len(tokens):
683
-			raise PatternError('Expected value, found EOL')
684
-		value = tokens[token_index]
685
-
686
-		try:
687
-			value = cls.parse_value(value, datatype)
688
-		except ValueError as cause:
689
-			raise PatternError(f'Bad value {value}') from cause
690
-
691
-		token_index += 1
692
-		exp = PatternSimpleExpression(field, op, value)
693
-		return (exp, token_index)
694
-
695
-	@classmethod
696
-	def parse_value(cls, value: str, datatype: str):
697
-		"""
698
-		Converts a value token to its Python value. Raises ValueError on failure.
699
-		"""
700
-		if datatype == cls.TYPE_ID:
701
-			if not is_user_id(value):
702
-				raise ValueError(f'Illegal user id value: {value}')
703
-			return value
704
-		if datatype == cls.TYPE_MEMBER:
705
-			return user_id_from_mention(value)
706
-		if datatype == cls.TYPE_TEXT:
707
-			return str_from_quoted_str(value)
708
-		if datatype == cls.TYPE_INT:
709
-			return int(value)
710
-		if datatype == cls.TYPE_FLOAT:
711
-			return float(value)
712
-		if datatype == cls.TYPE_TIMESPAN:
713
-			return timedelta_from_str(value)
714
-		raise ValueError(f'Unhandled datatype {datatype}')

+ 568
- 0
rocketbot/pattern.py Прегледај датотеку

@@ -0,0 +1,568 @@
1
+"""
2
+Statements that match messages based on an expression and have a list of actions
3
+to take on them.
4
+"""
5
+import re
6
+from abc import ABCMeta, abstractmethod
7
+from typing import Any
8
+
9
+from discord import Message, utils as discordutils
10
+from discord.ext.commands import Context
11
+
12
+from rocketbot.utils import is_user_id, str_from_quoted_str, timedelta_from_str, \
13
+	user_id_from_mention
14
+
15
+class PatternError(RuntimeError):
16
+	"""
17
+	Error thrown when parsing a pattern statement.
18
+	"""
19
+
20
+class PatternDeprecationError(PatternError):
21
+	"""
22
+	Error raised by PatternStatement.check_deprecated_syntax.
23
+	"""
24
+
25
+class PatternAction:
26
+	"""
27
+	Describes one action to take on a matched message or its author.
28
+	"""
29
+	def __init__(self, action: str, args: list[Any]):
30
+		self.action = action
31
+		self.arguments = list(args)
32
+
33
+	def __str__(self) -> str:
34
+		arg_str = ', '.join(self.arguments)
35
+		return f'{self.action}({arg_str})'
36
+
37
+class PatternExpression(metaclass=ABCMeta):
38
+	"""
39
+	Abstract message matching expression.
40
+	"""
41
+	def __init__(self):
42
+		pass
43
+
44
+	@abstractmethod
45
+	def matches(self, message: Message) -> bool:
46
+		"""
47
+		Whether a message matches this expression.
48
+		"""
49
+		return False
50
+
51
+class PatternSimpleExpression(PatternExpression):
52
+	"""
53
+	Message matching expression with a simple "<field> <operator> <value>"
54
+	structure.
55
+	"""
56
+	def __init__(self, field: str, operator: str, value: Any):
57
+		super().__init__()
58
+		self.field = field
59
+		self.operator = operator
60
+		self.value = value
61
+
62
+	def __field_value(self, message: Message) -> Any:
63
+		if self.field in ('content.markdown', 'content'):
64
+			return message.content
65
+		if self.field == 'content.plain':
66
+			return discordutils.remove_markdown(message.clean_content)
67
+		if self.field == 'author':
68
+			return str(message.author.id)
69
+		if self.field == 'author.id':
70
+			return str(message.author.id)
71
+		if self.field == 'author.joinage':
72
+			return message.created_at - message.author.joined_at
73
+		if self.field == 'author.name':
74
+			return message.author.name
75
+		else:
76
+			raise ValueError(f'Bad field name {self.field}')
77
+
78
+	def matches(self, message: Message) -> bool:
79
+		field_value = self.__field_value(message)
80
+		if self.operator == '==':
81
+			if isinstance(field_value, str) and isinstance(self.value, str):
82
+				return field_value.lower() == self.value.lower()
83
+			return field_value == self.value
84
+		if self.operator == '!=':
85
+			if isinstance(field_value, str) and isinstance(self.value, str):
86
+				return field_value.lower() != self.value.lower()
87
+			return field_value != self.value
88
+		if self.operator == '<':
89
+			return field_value < self.value
90
+		if self.operator == '>':
91
+			return field_value > self.value
92
+		if self.operator == '<=':
93
+			return field_value <= self.value
94
+		if self.operator == '>=':
95
+			return field_value >= self.value
96
+		if self.operator == 'contains':
97
+			return self.value.lower() in field_value.lower()
98
+		if self.operator == '!contains':
99
+			return self.value.lower() not in field_value.lower()
100
+		if self.operator in ('matches', 'containsword'):
101
+			return self.value.search(field_value.lower()) is not None
102
+		if self.operator in ('!matches', '!containsword'):
103
+			return self.value.search(field_value.lower()) is None
104
+		raise ValueError(f'Bad operator {self.operator}')
105
+
106
+	def __str__(self) -> str:
107
+		return f'({self.field} {self.operator} {self.value})'
108
+
109
+class PatternCompoundExpression(PatternExpression):
110
+	"""
111
+	Message matching expression that combines several child expressions with
112
+	a boolean operator.
113
+	"""
114
+	def __init__(self, operator: str, operands: list[PatternExpression]):
115
+		super().__init__()
116
+		self.operator = operator
117
+		self.operands = list(operands)
118
+
119
+	def matches(self, message: Message) -> bool:
120
+		if self.operator == '!':
121
+			return not self.operands[0].matches(message)
122
+		if self.operator == 'and':
123
+			for op in self.operands:
124
+				if not op.matches(message):
125
+					return False
126
+			return True
127
+		if self.operator == 'or':
128
+			for op in self.operands:
129
+				if op.matches(message):
130
+					return True
131
+			return False
132
+		raise ValueError(f'Bad operator "{self.operator}"')
133
+
134
+	def __str__(self) -> str:
135
+		if self.operator == '!':
136
+			return f'(!( {self.operands[0]} ))'
137
+		strs = map(str, self.operands)
138
+		joined = f' {self.operator} '.join(strs)
139
+		return f'( {joined} )'
140
+
141
+class PatternStatement:
142
+	"""
143
+	A full message match statement. If a message matches the given expression,
144
+	the given actions should be performed.
145
+	"""
146
+	def __init__(self,
147
+			name: str,
148
+			actions: list[PatternAction],
149
+			expression: PatternExpression,
150
+			original: str):
151
+		self.name = name
152
+		self.actions = list(actions)  # PatternAction[]
153
+		self.expression = expression
154
+		self.original = original
155
+
156
+	def check_deprecations(self) -> None:
157
+		"""
158
+		Tests whether this statement uses any deprecated syntax. Will raise a
159
+		PatternDeprecationError if one is found.
160
+		"""
161
+		self.__check_deprecations(self.expression)
162
+
163
+	@classmethod
164
+	def __check_deprecations(cls, expression: PatternExpression) -> None:
165
+		if isinstance(expression, PatternSimpleExpression):
166
+			s: PatternSimpleExpression = expression
167
+			if s.field in PatternCompiler.DEPRECATED_FIELDS:
168
+				raise PatternDeprecationError(f'"{s.field}" field is deprecated')
169
+		elif isinstance(expression, PatternCompoundExpression):
170
+			c: PatternCompoundExpression = expression
171
+			for oper in c.operands:
172
+				cls.__check_deprecations(oper)
173
+
174
+	def to_json(self) -> dict:
175
+		"""
176
+		Returns a JSON representation of this statement.
177
+		"""
178
+		return {
179
+			'name': self.name,
180
+			'statement': self.original,
181
+		}
182
+
183
+	@classmethod
184
+	def from_json(cls, json: dict):
185
+		"""
186
+		Gets a PatternStatement from its JSON representation.
187
+		"""
188
+		return PatternCompiler.parse_statement(json['name'], json['statement'])
189
+
190
+class PatternCompiler:
191
+	"""
192
+	Parses a user-provided message filter statement into a PatternStatement.
193
+	"""
194
+	TYPE_FLOAT = 'float'
195
+	TYPE_ID = 'id'
196
+	TYPE_INT = 'int'
197
+	TYPE_MEMBER = 'Member'
198
+	TYPE_REGEX = 'regex'
199
+	TYPE_TEXT = 'text'
200
+	TYPE_TIMESPAN = 'timespan'
201
+
202
+	FIELD_TO_TYPE: dict[str, str] = {
203
+		'content.plain': TYPE_TEXT,
204
+		'content.markdown': TYPE_TEXT,
205
+		'author': TYPE_MEMBER,
206
+		'author.id': TYPE_ID,
207
+		'author.name': TYPE_TEXT,
208
+		'author.joinage': TYPE_TIMESPAN,
209
+
210
+		'content': TYPE_TEXT, # deprecated, use content.markdown or content.plain
211
+	}
212
+	DEPRECATED_FIELDS: set[str] = set([ 'content' ])
213
+
214
+	ACTION_TO_ARGS: dict[str, list[str]] = {
215
+		'ban': [],
216
+		'delete': [],
217
+		'kick': [],
218
+		'modinfo': [],
219
+		'modwarn': [],
220
+		'reply': [ TYPE_TEXT ],
221
+	}
222
+
223
+	OPERATORS_IDENTITY: set[str] = set([ '==', '!=' ])
224
+	OPERATORS_COMPARISON: set[str] = set([ '<', '>', '<=', '>=' ])
225
+	OPERATORS_NUMERIC = OPERATORS_IDENTITY | OPERATORS_COMPARISON
226
+	OPERATORS_TEXT = OPERATORS_IDENTITY | set([
227
+		'contains', '!contains',
228
+		'containsword', '!containsword',
229
+		'matches', '!matches',
230
+	])
231
+	OPERATORS_ALL = OPERATORS_IDENTITY | OPERATORS_COMPARISON | OPERATORS_TEXT
232
+
233
+	TYPE_TO_OPERATORS: dict[str, set[str]] = {
234
+		TYPE_ID: OPERATORS_IDENTITY,
235
+		TYPE_MEMBER: OPERATORS_IDENTITY,
236
+		TYPE_TEXT: OPERATORS_TEXT,
237
+		TYPE_INT: OPERATORS_NUMERIC,
238
+		TYPE_FLOAT: OPERATORS_NUMERIC,
239
+		TYPE_TIMESPAN: OPERATORS_NUMERIC,
240
+	}
241
+
242
+	WHITESPACE_CHARS = ' \t\n\r'
243
+	STRING_QUOTE_CHARS = '\'"'
244
+	SYMBOL_CHARS = 'abcdefghijklmnopqrstuvwxyz.'
245
+	VALUE_CHARS = '0123456789dhms<@!>'
246
+	OP_CHARS = '<=>!(),'
247
+
248
+	MAX_EXPRESSION_NESTING = 8
249
+
250
+	@classmethod
251
+	def expression_str_from_context(cls, context: Context, name: str) -> str:
252
+		"""
253
+		Extracts the statement string from an "add" command context.
254
+		"""
255
+		pattern_str = context.message.content
256
+		command_chain = [ name ]
257
+		cmd = context.command
258
+		while cmd:
259
+			command_chain.insert(0, cmd.name)
260
+			cmd = cmd.parent
261
+		command_chain[0] = f'{context.prefix}{command_chain[0]}'
262
+		for cmd in command_chain:
263
+			if pattern_str.startswith(cmd):
264
+				pattern_str = pattern_str[len(cmd):].lstrip()
265
+			elif pattern_str.startswith(f'"{cmd}"'):
266
+				pattern_str = pattern_str[len(cmd) + 2:].lstrip()
267
+		return pattern_str
268
+
269
+	@classmethod
270
+	def parse_statement(cls, name: str, statement: str) -> PatternStatement:
271
+		"""
272
+		Parses a user-provided message filter statement into a PatternStatement.
273
+		Raises PatternError on failure.
274
+		"""
275
+		tokens = cls.__tokenize(statement)
276
+		token_index = 0
277
+		actions, token_index = cls.__read_actions(tokens, token_index)
278
+		expression, token_index = cls.__read_expression(tokens, token_index)
279
+		return PatternStatement(name, actions, expression, statement)
280
+
281
+	@classmethod
282
+	def __tokenize(cls, statement: str) -> list[str]:
283
+		"""
284
+		Converts a message filter statement into a list of tokens.
285
+		"""
286
+		tokens: list[str] = []
287
+		in_quote = False
288
+		in_escape = False
289
+		all_token_types = set([ 'sym', 'op', 'val' ])
290
+		possible_token_types = set(all_token_types)
291
+		current_token = ''
292
+		for ch in statement:
293
+			if in_quote:
294
+				if in_escape:
295
+					if ch == 'n':
296
+						current_token += '\n'
297
+					elif ch == 't':
298
+						current_token += '\t'
299
+					else:
300
+						current_token += ch
301
+					in_escape = False
302
+				elif ch == '\\':
303
+					in_escape = True
304
+				elif ch == in_quote:
305
+					current_token += ch
306
+					tokens.append(current_token)
307
+					current_token = ''
308
+					possible_token_types |= all_token_types
309
+					in_quote = False
310
+				else:
311
+					current_token += ch
312
+			else:
313
+				if ch in cls.STRING_QUOTE_CHARS:
314
+					if len(current_token) > 0:
315
+						tokens.append(current_token)
316
+						current_token = ''
317
+						possible_token_types |= all_token_types
318
+					in_quote = ch
319
+					current_token = ch
320
+				elif ch == '\\':
321
+					raise PatternError("Unexpected \\ outside quoted string")
322
+				elif ch in cls.WHITESPACE_CHARS:
323
+					if len(current_token) > 0:
324
+						tokens.append(current_token)
325
+					current_token = ''
326
+					possible_token_types |= all_token_types
327
+				else:
328
+					possible_ch_types = set()
329
+					if ch in cls.SYMBOL_CHARS:
330
+						possible_ch_types.add('sym')
331
+					if ch in cls.VALUE_CHARS:
332
+						possible_ch_types.add('val')
333
+					if ch in cls.OP_CHARS:
334
+						possible_ch_types.add('op')
335
+					if len(current_token) > 0 and \
336
+							possible_ch_types.isdisjoint(possible_token_types):
337
+						if len(current_token) > 0:
338
+							tokens.append(current_token)
339
+							current_token = ''
340
+							possible_token_types |= all_token_types
341
+					possible_token_types &= possible_ch_types
342
+					current_token += ch
343
+		if len(current_token) > 0:
344
+			tokens.append(current_token)
345
+
346
+		# Some symbols might be glommed onto other tokens. Split 'em up.
347
+		prefixes_to_split = [ '!', '(', ',' ]
348
+		suffixes_to_split = [ ')', ',' ]
349
+		i = 0
350
+		while i < len(tokens):
351
+			token = tokens[i]
352
+			mutated = False
353
+			for prefix in prefixes_to_split:
354
+				if token.startswith(prefix) and len(token) > len(prefix):
355
+					tokens.insert(i, prefix)
356
+					tokens[i + 1] = token[len(prefix):]
357
+					i += 1
358
+					mutated = True
359
+					break
360
+			if mutated:
361
+				continue
362
+			for suffix in suffixes_to_split:
363
+				if token.endswith(suffix) and len(token) > len(suffix):
364
+					tokens[i] = token[0:-len(suffix)]
365
+					tokens.insert(i + 1, suffix)
366
+					mutated = True
367
+					break
368
+			if mutated:
369
+				continue
370
+			i += 1
371
+		return tokens
372
+
373
+	@classmethod
374
+	def __read_actions(cls,
375
+			tokens: list[str],
376
+			token_index: int) -> tuple[list[PatternAction], int]:
377
+		"""
378
+		Reads the actions from a list of statement tokens. Returns a tuple
379
+		containing a list of PatternActions and the token index this method
380
+		left off at (the token after the "if").
381
+		"""
382
+		actions: list[PatternAction] = []
383
+		current_action_tokens = []
384
+		while token_index < len(tokens):
385
+			token = tokens[token_index]
386
+			if token == 'if':
387
+				if len(current_action_tokens) > 0:
388
+					a = PatternAction(current_action_tokens[0], \
389
+						current_action_tokens[1:])
390
+					cls.__validate_action(a)
391
+					actions.append(a)
392
+				token_index += 1
393
+				return (actions, token_index)
394
+			elif token == ',':
395
+				if len(current_action_tokens) < 1:
396
+					raise PatternError('Unexpected ,')
397
+				a = PatternAction(current_action_tokens[0], \
398
+					current_action_tokens[1:])
399
+				cls.__validate_action(a)
400
+				actions.append(a)
401
+				current_action_tokens = []
402
+			else:
403
+				current_action_tokens.append(token)
404
+			token_index += 1
405
+		raise PatternError('Unexpected end of line in action list')
406
+
407
+	@classmethod
408
+	def __validate_action(cls, action: PatternAction) -> None:
409
+		args: list[str] = cls.ACTION_TO_ARGS.get(action.action)
410
+		if args is None:
411
+			raise PatternError(f'Unknown action "{action.action}"')
412
+		if len(action.arguments) != len(args):
413
+			if len(args) == 0:
414
+				raise PatternError(f'Action "{action.action}" expects no ' + \
415
+					f'arguments, got {len(action.arguments)}.')
416
+			raise PatternError(f'Action "{action.action}" expects ' + \
417
+				f'{len(args)} arguments, got {len(action.arguments)}.')
418
+		for i, datatype in enumerate(args):
419
+			action.arguments[i] = cls.__parse_value(action.arguments[i], datatype)
420
+
421
+	@classmethod
422
+	def __read_expression(cls,
423
+			tokens: list[str],
424
+			token_index: int,
425
+			depth: int = 0,
426
+			one_subexpression: bool = False) -> tuple[PatternExpression, int]:
427
+		"""
428
+		Reads an expression from a list of statement tokens. Returns a tuple
429
+		containing the PatternExpression and the token index it left off at.
430
+		If one_subexpression is True then it will return after reading a
431
+		single expression instead of joining multiples (for reading the
432
+		subject of a NOT expression).
433
+		"""
434
+		subexpressions = []
435
+		last_compound_operator = None
436
+		while token_index < len(tokens):
437
+			if one_subexpression:
438
+				if len(subexpressions) == 1:
439
+					return (subexpressions[0], token_index)
440
+				if len(subexpressions) > 1:
441
+					raise PatternError('Too many subexpressions')
442
+			compound_operator = None
443
+			if tokens[token_index] == ')':
444
+				if len(subexpressions) == 0:
445
+					raise PatternError('No subexpressions')
446
+				if len(subexpressions) == 1:
447
+					return (subexpressions[0], token_index)
448
+				return (PatternCompoundExpression(last_compound_operator,
449
+					subexpressions), token_index)
450
+			if tokens[token_index] in set(["and", "or"]):
451
+				compound_operator = tokens[token_index]
452
+				if last_compound_operator and \
453
+						compound_operator != last_compound_operator:
454
+					subexpressions = [
455
+						PatternCompoundExpression(last_compound_operator,
456
+							subexpressions),
457
+					]
458
+				last_compound_operator = compound_operator
459
+				token_index += 1
460
+			if tokens[token_index] == '!':
461
+				(exp, next_index) = cls.__read_expression(tokens, \
462
+						token_index + 1, depth + 1, one_subexpression=True)
463
+				subexpressions.append(PatternCompoundExpression('!', [exp]))
464
+				token_index = next_index
465
+			elif tokens[token_index] == '(':
466
+				(exp, next_index) = cls.__read_expression(tokens,
467
+					token_index + 1, depth + 1)
468
+				if tokens[next_index] != ')':
469
+					raise PatternError('Expected )')
470
+				subexpressions.append(exp)
471
+				token_index = next_index + 1
472
+			else:
473
+				(simple, next_index) = cls.__read_simple_expression(tokens,
474
+					token_index, depth)
475
+				subexpressions.append(simple)
476
+				token_index = next_index
477
+		if len(subexpressions) == 0:
478
+			raise PatternError('No subexpressions')
479
+		elif len(subexpressions) == 1:
480
+			return (subexpressions[0], token_index)
481
+		else:
482
+			return (PatternCompoundExpression(last_compound_operator,
483
+				subexpressions), token_index)
484
+
485
+	@classmethod
486
+	def __read_simple_expression(cls,
487
+			tokens: list[str],
488
+			token_index: int,
489
+			depth: int = 0) -> tuple[PatternExpression, int]:
490
+		"""
491
+		Reads a simple expression consisting of a field name, operator, and
492
+		comparison value. Returns a tuple of the PatternSimpleExpression and
493
+		the token index it left off at.
494
+		"""
495
+		if depth > cls.MAX_EXPRESSION_NESTING:
496
+			raise PatternError('Expression nests too deeply')
497
+		if token_index >= len(tokens):
498
+			raise PatternError('Expected field name, found EOL')
499
+		field = tokens[token_index]
500
+		token_index += 1
501
+
502
+		datatype = cls.FIELD_TO_TYPE.get(field)
503
+		if datatype is None:
504
+			raise PatternError(f'No such field "{field}"')
505
+
506
+		if token_index >= len(tokens):
507
+			raise PatternError('Expected operator, found EOL')
508
+		op = tokens[token_index]
509
+		token_index += 1
510
+
511
+		if op == '!':
512
+			if token_index >= len(tokens):
513
+				raise PatternError('Expected operator, found EOL')
514
+			op = '!' + tokens[token_index]
515
+			token_index += 1
516
+
517
+		allowed_ops = cls.TYPE_TO_OPERATORS[datatype]
518
+		if op not in allowed_ops:
519
+			if op in cls.OPERATORS_ALL:
520
+				raise PatternError(f'Operator {op} cannot be used with ' + \
521
+					f'field "{field}"')
522
+			raise PatternError(f'Unrecognized operator "{op}" - allowed: ' + \
523
+				f'{sorted(list(allowed_ops))}')
524
+
525
+		if token_index >= len(tokens):
526
+			raise PatternError('Expected value, found EOL')
527
+		value_str = tokens[token_index]
528
+
529
+		try:
530
+			value = cls.__parse_value(value_str, datatype, op)
531
+		except ValueError as cause:
532
+			raise PatternError(f'Bad value {value_str}') from cause
533
+
534
+		token_index += 1
535
+		exp = PatternSimpleExpression(field, op, value)
536
+		return (exp, token_index)
537
+
538
+	@classmethod
539
+	def __parse_value(cls, value: str, datatype: str, op: str = None) -> Any:
540
+		"""
541
+		Converts a value token to its Python value. Raises ValueError on failure.
542
+		"""
543
+		if datatype == cls.TYPE_ID:
544
+			if not is_user_id(value):
545
+				raise ValueError(f'Illegal user id value: {value}')
546
+			return value
547
+		if datatype == cls.TYPE_MEMBER:
548
+			return user_id_from_mention(value)
549
+		if datatype == cls.TYPE_TEXT:
550
+			s = str_from_quoted_str(value)
551
+			if op in ('matches', '!matches'):
552
+				try:
553
+					return re.compile(s.lower())
554
+				except re.error as e:
555
+					raise ValueError(f'Invalid regex: {e}') from e
556
+			if op in ('containsword', '!containsword'):
557
+				try:
558
+					return re.compile(f'\\b{re.escape(s.lower())}\\b')
559
+				except re.error as e:
560
+					raise ValueError(f'Invalid regex: {e}') from e
561
+			return s
562
+		if datatype == cls.TYPE_INT:
563
+			return int(value)
564
+		if datatype == cls.TYPE_FLOAT:
565
+			return float(value)
566
+		if datatype == cls.TYPE_TIMESPAN:
567
+			return timedelta_from_str(value)
568
+		raise ValueError(f'Unhandled datatype {datatype}')

Loading…
Откажи
Сачувај