ソースを参照

Annotating pattern, utils

master
Rocketsoup 4年前
コミット
a0a2e5834f
2個のファイルの変更63行の追加61行の削除
  1. 36
    36
      rocketbot/pattern.py
  2. 27
    25
      rocketbot/utils.py

+ 36
- 36
rocketbot/pattern.py ファイルの表示

@@ -43,7 +43,7 @@ class PatternExpression(metaclass=ABCMeta):
43 43
 		pass
44 44
 
45 45
 	@abstractmethod
46
-	def matches(self, message: Message, other_fields: dict[str: Any]) -> bool:
46
+	def matches(self, message: Message, other_fields: dict[str, Any]) -> bool:
47 47
 		"""
48 48
 		Whether a message matches this expression. other_fields are additional
49 49
 		fields that can be queried not contained in the message itself.
@@ -57,9 +57,9 @@ class PatternSimpleExpression(PatternExpression):
57 57
 	"""
58 58
 	def __init__(self, field: str, operator: str, value: Any):
59 59
 		super().__init__()
60
-		self.field = field
61
-		self.operator = operator
62
-		self.value = value
60
+		self.field: str = field
61
+		self.operator: str = operator
62
+		self.value: Any = value
63 63
 
64 64
 	def __field_value(self, message: Message, other_fields: dict[str: Any]) -> Any:
65 65
 		if self.field in ('content.markdown', 'content'):
@@ -149,7 +149,7 @@ class PatternStatement:
149 149
 	the given actions should be performed.
150 150
 	"""
151 151
 
152
-	DEFAULT_PRIORITY = 100
152
+	DEFAULT_PRIORITY: int = 100
153 153
 
154 154
 	def __init__(self,
155 155
 			name: str,
@@ -157,11 +157,11 @@ class PatternStatement:
157 157
 			expression: PatternExpression,
158 158
 			original: str,
159 159
 			priority: int = DEFAULT_PRIORITY):
160
-		self.name = name
161
-		self.actions = list(actions)  # PatternAction[]
162
-		self.expression = expression
163
-		self.original = original
164
-		self.priority = priority
160
+		self.name: str = name
161
+		self.actions: list[PatternAction] = list(actions)  # PatternAction[]
162
+		self.expression: PatternExpression = expression
163
+		self.original: str = original
164
+		self.priority: int = priority
165 165
 
166 166
 	def check_deprecations(self) -> None:
167 167
 		"""
@@ -181,7 +181,7 @@ class PatternStatement:
181 181
 			for oper in c.operands:
182 182
 				cls.__check_deprecations(oper)
183 183
 
184
-	def to_json(self) -> dict:
184
+	def to_json(self) -> dict[str, Any]:
185 185
 		"""
186 186
 		Returns a JSON representation of this statement.
187 187
 		"""
@@ -192,7 +192,7 @@ class PatternStatement:
192 192
 		}
193 193
 
194 194
 	@classmethod
195
-	def from_json(cls, json: dict):
195
+	def from_json(cls, json: dict[str, Any]):
196 196
 		"""
197 197
 		Gets a PatternStatement from its JSON representation.
198 198
 		"""
@@ -204,13 +204,13 @@ class PatternCompiler:
204 204
 	"""
205 205
 	Parses a user-provided message filter statement into a PatternStatement.
206 206
 	"""
207
-	TYPE_FLOAT = 'float'
208
-	TYPE_ID = 'id'
209
-	TYPE_INT = 'int'
210
-	TYPE_MEMBER = 'Member'
211
-	TYPE_REGEX = 'regex'
212
-	TYPE_TEXT = 'text'
213
-	TYPE_TIMESPAN = 'timespan'
207
+	TYPE_FLOAT: str = 'float'
208
+	TYPE_ID: str = 'id'
209
+	TYPE_INT: str = 'int'
210
+	TYPE_MEMBER: str = 'Member'
211
+	TYPE_REGEX: str = 'regex'
212
+	TYPE_TEXT: str = 'text'
213
+	TYPE_TIMESPAN: str = 'timespan'
214 214
 
215 215
 	FIELD_TO_TYPE: dict[str, str] = {
216 216
 		'author': TYPE_MEMBER,
@@ -235,13 +235,13 @@ class PatternCompiler:
235 235
 
236 236
 	OPERATORS_IDENTITY: set[str] = set([ '==', '!=' ])
237 237
 	OPERATORS_COMPARISON: set[str] = set([ '<', '>', '<=', '>=' ])
238
-	OPERATORS_NUMERIC = OPERATORS_IDENTITY | OPERATORS_COMPARISON
239
-	OPERATORS_TEXT = OPERATORS_IDENTITY | set([
238
+	OPERATORS_NUMERIC: set[str] = OPERATORS_IDENTITY | OPERATORS_COMPARISON
239
+	OPERATORS_TEXT: set[str] = OPERATORS_IDENTITY | set([
240 240
 		'contains', '!contains',
241 241
 		'containsword', '!containsword',
242 242
 		'matches', '!matches',
243 243
 	])
244
-	OPERATORS_ALL = OPERATORS_IDENTITY | OPERATORS_COMPARISON | OPERATORS_TEXT
244
+	OPERATORS_ALL: set[str] = OPERATORS_IDENTITY | OPERATORS_COMPARISON | OPERATORS_TEXT
245 245
 
246 246
 	TYPE_TO_OPERATORS: dict[str, set[str]] = {
247 247
 		TYPE_ID: OPERATORS_IDENTITY,
@@ -252,20 +252,20 @@ class PatternCompiler:
252 252
 		TYPE_TIMESPAN: OPERATORS_NUMERIC,
253 253
 	}
254 254
 
255
-	WHITESPACE_CHARS = ' \t\n\r'
256
-	STRING_QUOTE_CHARS = '\'"'
257
-	SYMBOL_CHARS = 'abcdefghijklmnopqrstuvwxyz.'
258
-	VALUE_CHARS = '0123456789dhms<@!>'
259
-	OP_CHARS = '<=>!(),'
255
+	WHITESPACE_CHARS: str = ' \t\n\r'
256
+	STRING_QUOTE_CHARS: str = '\'"'
257
+	SYMBOL_CHARS: str = 'abcdefghijklmnopqrstuvwxyz.'
258
+	VALUE_CHARS: str = '0123456789dhms<@!>'
259
+	OP_CHARS: str = '<=>!(),'
260 260
 
261
-	MAX_EXPRESSION_NESTING = 8
261
+	MAX_EXPRESSION_NESTING: int = 8
262 262
 
263 263
 	@classmethod
264 264
 	def expression_str_from_context(cls, context: Context, name: str) -> str:
265 265
 		"""
266 266
 		Extracts the statement string from an "add" command context.
267 267
 		"""
268
-		pattern_str = context.message.content
268
+		pattern_str: str = context.message.content
269 269
 		command_chain = [ name ]
270 270
 		cmd = context.command
271 271
 		while cmd:
@@ -285,8 +285,8 @@ class PatternCompiler:
285 285
 		Parses a user-provided message filter statement into a PatternStatement.
286 286
 		Raises PatternError on failure.
287 287
 		"""
288
-		tokens = cls.__tokenize(statement)
289
-		token_index = 0
288
+		tokens: list[str] = cls.__tokenize(statement)
289
+		token_index: int = 0
290 290
 		actions, token_index = cls.__read_actions(tokens, token_index)
291 291
 		expression, token_index = cls.__read_expression(tokens, token_index)
292 292
 		return PatternStatement(name, actions, expression, statement)
@@ -297,11 +297,11 @@ class PatternCompiler:
297 297
 		Converts a message filter statement into a list of tokens.
298 298
 		"""
299 299
 		tokens: list[str] = []
300
-		in_quote = False
301
-		in_escape = False
302
-		all_token_types = set([ 'sym', 'op', 'val' ])
303
-		possible_token_types = set(all_token_types)
304
-		current_token = ''
300
+		in_quote: bool = False
301
+		in_escape: bool = False
302
+		all_token_types: set[str] = set([ 'sym', 'op', 'val' ])
303
+		possible_token_types: set[str] = set(all_token_types)
304
+		current_token: str = ''
305 305
 		for ch in statement:
306 306
 			if in_quote:
307 307
 				if in_escape:

+ 27
- 25
rocketbot/utils.py ファイルの表示

@@ -3,6 +3,8 @@ General utility functions.
3 3
 """
4 4
 import re
5 5
 from datetime import datetime, timedelta
6
+from typing import Any, Union
7
+
6 8
 from discord import Guild
7 9
 from discord.ext.commands import Cog, Group
8 10
 
@@ -15,14 +17,14 @@ def timedelta_from_str(s: str) -> timedelta:
15 17
 	"1h30m"
16 18
 	"73d18h22m52s"
17 19
 	"""
18
-	p = re.compile('^(?:[0-9]+[dhms])+$')
20
+	p: re.Pattern = re.compile('^(?:[0-9]+[dhms])+$')
19 21
 	if p.match(s) is None:
20 22
 		raise ValueError("Illegal timespan value '{s}'.")
21 23
 	p = re.compile('([0-9]+)([dhms])')
22
-	days = 0
23
-	hours = 0
24
-	minutes = 0
25
-	seconds = 0
24
+	days: int = 0
25
+	hours: int = 0
26
+	minutes: int = 0
27
+	seconds: int = 0
26 28
 	for m in p.finditer(s):
27 29
 		scalar = int(m.group(1))
28 30
 		unit = m.group(2)
@@ -40,11 +42,11 @@ def str_from_timedelta(td: timedelta) -> str:
40 42
 	"""
41 43
 	Encodes a timedelta as a str. E.g. "3d2h"
42 44
 	"""
43
-	d = td.days
44
-	h = td.seconds // 3600
45
-	m = (td.seconds // 60) % 60
46
-	s = td.seconds % 60
47
-	components = []
45
+	d: int = td.days
46
+	h: int = td.seconds // 3600
47
+	m: int = (td.seconds // 60) % 60
48
+	s: int = td.seconds % 60
49
+	components: list[str] = []
48 50
 	if d != 0:
49 51
 		components.append(f'{d}d')
50 52
 	if h != 0:
@@ -59,11 +61,11 @@ def describe_timedelta(td: timedelta, max_components: int = 2) -> str:
59 61
 	"""
60 62
 	Formats a human-readable description of a time span. E.g. "3 days 2 hours".
61 63
 	"""
62
-	d = td.days
63
-	h = td.seconds // 3600
64
-	m = (td.seconds // 60) % 60
65
-	s = td.seconds % 60
66
-	components = []
64
+	d: int = td.days
65
+	h: int = td.seconds // 3600
66
+	m: int = (td.seconds // 60) % 60
67
+	s: int = td.seconds % 60
68
+	components: list[str] = []
67 69
 	if d != 0:
68 70
 		components.append('1 day' if d == 1 else f'{d} days')
69 71
 	if h != 0:
@@ -84,20 +86,20 @@ def first_command_group(cog: Cog) -> Group:
84 86
 			return member
85 87
 	return None
86 88
 
87
-def bot_log(guild: Guild, cog_class, message: str) -> None:
89
+def bot_log(guild: Guild, cog_class, message: Any) -> None:
88 90
 	'Logs a message to stdout with time, cog, and guild info.'
89
-	now = datetime.now() # local
91
+	now: datetime = datetime.now() # local
90 92
 	s = f'[{now.strftime("%Y-%m-%dT%H:%M:%S")}|'
91 93
 	s += f'{cog_class.__name__}|' if cog_class else '-|'
92 94
 	s += f'{guild.name}] ' if guild else '-] '
93
-	s += message
95
+	s += str(message)
94 96
 	print(s)
95 97
 
96
-__QUOTE_CHARS = '\'"'
97
-__ID_REGEX = re.compile('^[0-9]{17,20}$')
98
-__MENTION_REGEX = re.compile('^<@[!&]([0-9]{17,20})>$')
99
-__USER_MENTION_REGEX = re.compile('^<@!([0-9]{17,20})>$')
100
-__ROLE_MENTION_REGEX = re.compile('^<@&([0-9]{17,20})>$')
98
+__QUOTE_CHARS: str = '\'"'
99
+__ID_REGEX: re.Pattern = re.compile('^[0-9]{17,20}$')
100
+__MENTION_REGEX: re.Pattern = re.compile('^<@[!&]([0-9]{17,20})>$')
101
+__USER_MENTION_REGEX: re.Pattern = re.compile('^<@!([0-9]{17,20})>$')
102
+__ROLE_MENTION_REGEX: re.Pattern = re.compile('^<@&([0-9]{17,20})>$')
101 103
 
102 104
 def is_user_id(val: str) -> bool:
103 105
 	'Tests if a string is in user/role ID format.'
@@ -122,11 +124,11 @@ def user_id_from_mention(mention: str) -> str:
122 124
 		return m.group(1)
123 125
 	raise ValueError(f'"{mention}" is not an @ user mention')
124 126
 
125
-def mention_from_user_id(user_id: str) -> str:
127
+def mention_from_user_id(user_id: Union[str, int]) -> str:
126 128
 	'Returns a markdown user mention from a user id.'
127 129
 	return f'<@!{user_id}>'
128 130
 
129
-def mention_from_role_id(role_id: str) -> str:
131
+def mention_from_role_id(role_id: Union[str, int]) -> str:
130 132
 	'Returns a markdown role mention from a role id.'
131 133
 	return f'<@&{role_id}>'
132 134
 

読み込み中…
キャンセル
保存