ソースを参照

Llllllots of cleanup after switching to VSCode

master
Rocketsoup 4年前
コミット
192da063ec
12個のファイルの変更375行の追加301行の削除
  1. 6
    0
      .pylintrc
  2. 86
    97
      cogs/basecog.py
  3. 14
    15
      cogs/configcog.py
  4. 6
    4
      cogs/crosspostcog.py
  5. 29
    10
      cogs/generalcog.py
  6. 9
    5
      cogs/joinraidcog.py
  7. 133
    81
      cogs/patterncog.py
  8. 10
    25
      cogs/urlspamcog.py
  9. 20
    26
      rbutils.py
  10. 12
    9
      rocketbot.py
  11. 31
    15
      rscollections.py
  12. 19
    14
      storage.py

+ 6
- 0
.pylintrc ファイルの表示

@@ -0,0 +1,6 @@
1
+[MESSAGES CONTROL]
2
+disable=bad-indentation, invalid-name
3
+
4
+[FORMAT]
5
+
6
+indent-string=' '

+ 86
- 97
cogs/basecog.py ファイルの表示

@@ -1,12 +1,11 @@
1
+from datetime import datetime, timedelta
1 2
 from discord import Guild, Member, Message, PartialEmoji, RawReactionActionEvent, TextChannel
3
+from discord.abc import GuildChannel
2 4
 from discord.ext import commands
3
-from datetime import datetime, timedelta
4 5
 
5
-from abc import ABC, abstractmethod
6 6
 from config import CONFIG
7 7
 from rscollections import AgeBoundDict
8 8
 from storage import ConfigKey, Storage
9
-import json
10 9
 
11 10
 class BotMessageReaction:
12 11
 	"""
@@ -136,18 +135,23 @@ class BotMessage:
136 135
 		return self.__message is not None
137 136
 
138 137
 	def message_id(self):
138
+		'Returns the Message id or None if not sent.'
139 139
 		return self.__message.id if self.__message else None
140 140
 
141
-	def message_sent_at(self):
141
+	def message_sent_at(self) -> datetime:
142
+		'Returns when the message was sent or None if not sent.'
142 143
 		return self.__message.created_at if self.__message else None
143 144
 
145
+	def has_reactions(self) -> bool:
146
+		return len(self.__reactions) > 0
147
+
144 148
 	async def set_text(self, new_text: str) -> None:
145 149
 		"""
146 150
 		Replaces the text of this message. If the message has been sent, it will
147 151
 		be updated.
148 152
 		"""
149 153
 		self.text = new_text
150
-		await self.__update_if_sent()
154
+		await self.update_if_sent()
151 155
 
152 156
 	async def set_reactions(self, reactions: list) -> None:
153 157
 		"""
@@ -158,7 +162,7 @@ class BotMessage:
158 162
 			# No change
159 163
 			return
160 164
 		self.__reactions = reactions.copy() if reactions is not None else []
161
-		await self.__update_if_sent()
165
+		await self.update_if_sent()
162 166
 
163 167
 	async def add_reaction(self, reaction: BotMessageReaction) -> None:
164 168
 		"""
@@ -175,8 +179,7 @@ class BotMessage:
175 179
 		will be updated.
176 180
 		"""
177 181
 		found = False
178
-		for i in range(len(self.__reactions)):
179
-			existing = self.__reactions[i]
182
+		for i, existing in enumerate(self.__reactions):
180 183
 			if existing.emoji == reaction.emoji:
181 184
 				if reaction == self.__reactions[i]:
182 185
 					# No change
@@ -186,34 +189,47 @@ class BotMessage:
186 189
 				break
187 190
 		if not found:
188 191
 			self.__reactions.append(reaction)
189
-		await self.__update_if_sent()
192
+		await self.update_if_sent()
190 193
 
191 194
 	async def remove_reaction(self, reaction_or_emoji) -> None:
192 195
 		"""
193 196
 		Removes a reaction. Can pass either a BotMessageReaction or just the
194 197
 		emoji string. If the message has been sent, it will be updated.
195 198
 		"""
196
-		for i in range(len(self.__reactions)):
197
-			existing = self.__reactions[i]
199
+		for i, existing in enumerate(self.__reactions):
198 200
 			if (isinstance(reaction_or_emoji, str) and existing.emoji == reaction_or_emoji) or \
199
-				(isinstance(reaction_or_emoji, BotMessageReaction) and existing.emoji == reaction_or_emoji.emoji):
201
+				(isinstance(reaction_or_emoji, BotMessageReaction) and \
202
+					existing.emoji == reaction_or_emoji.emoji):
200 203
 				self.__reactions.pop(i)
201
-				await self.__update_if_sent()
204
+				await self.update_if_sent()
202 205
 				return
203 206
 
204 207
 	def reaction_for_emoji(self, emoji) -> BotMessageReaction:
208
+		"""
209
+		Finds the BotMessageReaction for the given emoji or None if not found.
210
+		Accepts either a PartialEmoji or str.
211
+		"""
205 212
 		for reaction in self.__reactions:
206 213
 			if isinstance(emoji, PartialEmoji) and reaction.emoji == emoji.name:
207 214
 				return reaction
208
-			elif isinstance(emoji, str) and reaction.emoji == emoji:
215
+			if isinstance(emoji, str) and reaction.emoji == emoji:
209 216
 				return reaction
210 217
 		return None
211 218
 
212
-	async def __update_if_sent(self) -> None:
219
+	async def update_if_sent(self) -> None:
220
+		"""
221
+		Updates the text and/or reactions on a message if it was sent to
222
+		the guild, otherwise does nothing. Does not need to be called by
223
+		BaseCog subclasses.
224
+		"""
213 225
 		if self.__message:
214
-			await self._update()
226
+			await self.update()
215 227
 
216
-	async def _update(self) -> None:
228
+	async def update(self) -> None:
229
+		"""
230
+		Sends or updates an already sent message based on BotMessage state.
231
+		Does not need to be called by BaseCog subclasses.
232
+		"""
217 233
 		content: str = self.__formatted_message()
218 234
 		if self.__message:
219 235
 			if content != self.__posted_text:
@@ -278,6 +294,12 @@ class BotMessage:
278 294
 		return s
279 295
 
280 296
 class CogSetting:
297
+	"""
298
+	Describes a configuration setting for a guild that can be edited by the
299
+	mods of those guilds. BaseCog can generate "get" and "set" commands
300
+	automatically, reducing the boilerplate of generating commands manually.
301
+	Offers simple validation rules.
302
+	"""
281 303
 	def __init__(self,
282 304
 			name: str,
283 305
 			datatype,
@@ -287,6 +309,23 @@ class CogSetting:
287 309
 			min_value = None,
288 310
 			max_value = None,
289 311
 			enum_values: set = None):
312
+		"""
313
+		Params:
314
+		- name         Setting identifier. Must follow variable naming
315
+		               conventions.
316
+		- datatype     Datatype of the setting. E.g. int, float, str
317
+		- brief        Description of the setting, starting with lower case.
318
+		               Will be inserted into phrases like "Sets <brief>" and
319
+					   "Gets <brief".
320
+		- description  Long-form description. Min, max, and enum values will be
321
+		               appended to the end, so does not need to include these.
322
+		- usage        Description of the value argument in a set command, e.g.
323
+		               "<maxcount:int>"
324
+		- min_value    Smallest allowable value. Must be of the same datatype as
325
+		               the value. None for no minimum.
326
+		- max_value    Largest allowable value. None for no maximum.
327
+		- enum_values  Set of allowed values. None if unconstrained.
328
+		"""
290 329
 		self.name = name
291 330
 		self.datatype = datatype
292 331
 		self.brief = brief
@@ -308,6 +347,10 @@ class CogSetting:
308 347
 			self.usage = f'<{self.name}>'
309 348
 
310 349
 class BaseCog(commands.Cog):
350
+	"""
351
+	Superclass for all Rocketbot cogs. Provides lots of conveniences for
352
+	common tasks.
353
+	"""
311 354
 	def __init__(self, bot):
312 355
 		self.bot = bot
313 356
 		self.are_settings_setup = False
@@ -319,7 +362,8 @@ class BaseCog(commands.Cog):
319 362
 	def get_cog_default(cls, key: str):
320 363
 		"""
321 364
 		Convenience method for getting a cog configuration default from
322
-		`CONFIG['cogs'][<cogname>][<key>]`.
365
+		`CONFIG['cogs'][<cogname>][<key>]`. These values are used for
366
+		CogSettings when no guild-specific value is configured yet.
323 367
 		"""
324 368
 		cogs: dict = CONFIG['cog_defaults']
325 369
 		cog = cogs.get(cls.__name__)
@@ -337,6 +381,8 @@ class BaseCog(commands.Cog):
337 381
 		If the cog has a command group it will be detected automatically and
338 382
 		the commands added to that. Otherwise the commands will be added at
339 383
 		the top level.
384
+
385
+		Changes to settings can be detected by overriding `on_setting_updated`.
340 386
 		"""
341 387
 		self.settings.append(setting)
342 388
 
@@ -365,13 +411,22 @@ class BaseCog(commands.Cog):
365 411
 		"""
366 412
 		Manually sets a setting for the given guild. BaseCog creates "get" and
367 413
 		"set" commands for guild administrators to configure values themselves,
368
-		but this method can be used for hidden settings from code.
414
+		but this method can be used for hidden settings from code. A ValueError
415
+		will be raised if the new value does not pass validation specified in
416
+		the CogSetting.
369 417
 		"""
418
+		if setting.min_value is not None and new_value < setting.min_value:
419
+			raise ValueError(f'{setting.name} must be at least {setting.min_value}')
420
+		if setting.max_value is not None and new_value > setting.max_value:
421
+			raise ValueError(f'{setting.name} must be no more than {setting.max_value}')
422
+		if setting.enum_values and new_value not in setting.enum_values:
423
+			raise ValueError(f'{setting.name} must be one of {setting.enum_values}')
370 424
 		key = f'{cls.__name__}.{setting.name}'
371 425
 		Storage.set_config_value(guild, key, new_value)
372 426
 
373 427
 	@commands.Cog.listener()
374 428
 	async def on_ready(self):
429
+		'Event listener'
375 430
 		self.__set_up_setting_commands()
376 431
 
377 432
 	def __set_up_setting_commands(self):
@@ -406,9 +461,7 @@ class BaseCog(commands.Cog):
406 461
 		under the bot if `None`.
407 462
 		"""
408 463
 		# Manually constructing equivalent of:
409
-		# 	@commands.command(
410
-		# 		brief='Posts a test warning in the configured warning channel.'
411
-		# 	)
464
+		# 	@commands.command()
412 465
 		# 	@commands.has_permissions(ban_members=True)
413 466
 		# 	@commands.guild_only()
414 467
 		# 	async def getvar(self, context):
@@ -573,7 +626,6 @@ class BaseCog(commands.Cog):
573 626
 		"""
574 627
 		Subclass override point for being notified when a CogSetting is edited.
575 628
 		"""
576
-		pass
577 629
 
578 630
 	# Bot message handling
579 631
 
@@ -588,13 +640,17 @@ class BaseCog(commands.Cog):
588 640
 		return bm
589 641
 
590 642
 	async def post_message(self, message: BotMessage) -> bool:
643
+		"""
644
+		Posts a BotMessage to a guild. Returns whether it was successful. If
645
+		the caller wants to listen to reactions they should be added before
646
+		calling this method. Listen to reactions by overriding `on_mod_react`.
647
+		"""
591 648
 		message.source_cog = self
592
-		await message._update()
593
-		guild_messages = self.__bot_messages(message.guild)
594
-		if message.is_sent():
649
+		await message.update()
650
+		if message.has_reactions() and message.is_sent():
651
+			guild_messages = self.__bot_messages(message.guild)
595 652
 			guild_messages[message.message_id()] = message
596
-			return True
597
-		return False
653
+		return message.is_sent()
598 654
 
599 655
 	@commands.Cog.listener()
600 656
 	async def on_raw_reaction_add(self, payload: RawReactionActionEvent):
@@ -645,80 +701,13 @@ class BaseCog(commands.Cog):
645 701
 		Subclass override point for receiving mod reactions to bot messages sent
646 702
 		via `post_message()`.
647 703
 		"""
648
-		pass
649 704
 
650 705
 	# Helpers
651 706
 
652 707
 	@classmethod
653
-	async def validate_param(cls, context: commands.Context, param_name: str, value,
654
-		allowed_types: tuple = None,
655
-		min_value = None,
656
-		max_value = None) -> bool:
657
-		"""
658
-		Convenience method for validating a command parameter is of the expected
659
-		type and in the expected range. Bad values will cause a reply to be sent
660
-		to the original message and a False will be returned. If all checks
661
-		succeed, True will be returned.
662
-		"""
663
-		# TODO: Rework this to use BotMessage
664
-		if allowed_types is not None and not isinstance(value, allowed_types):
665
-			if len(allowed_types) == 1:
666
-				await context.message.reply(f'⚠️ `{param_name}` must be of type ' +
667
-					f'{allowed_types[0]}.', mention_author=False)
668
-			else:
669
-				await context.message.reply(f'⚠️ `{param_name}` must be of types ' +
670
-					f'{allowed_types}.', mention_author=False)
671
-			return False
672
-		if min_value is not None and value < min_value:
673
-			await context.message.reply(f'⚠️ `{param_name}` must be >= {min_value}.',
674
-				mention_author=False)
675
-			return False
676
-		if max_value is not None and value > max_value:
677
-			await context.message.reply(f'⚠️ `{param_name}` must be <= {max_value}.',
678
-				mention_author=False)
679
-		return True
680
-
681
-	@classmethod
682
-	async def warn(cls, guild: Guild, message: str) -> Message:
683
-		"""
684
-		DEPRECATED. Use post_message.
685
-
686
-		Sends a warning message to the configured warning channel for the
687
-		given guild. If no warning channel is configured no action is taken.
688
-		Returns the Message if successful or None if not.
689
-		"""
690
-		channel_id = Storage.get_config_value(guild, ConfigKey.WARNING_CHANNEL_ID)
691
-		if channel_id is None:
692
-			cls.log(guild, '\u0007No warning channel set! No warning issued.')
693
-			return None
694
-		channel: TextChannel = guild.get_channel(channel_id)
695
-		if channel is None:
696
-			cls.log(guild, '\u0007Configured warning channel does not exist!')
697
-			return None
698
-		mention: str = Storage.get_config_value(guild, ConfigKey.WARNING_MENTION)
699
-		text: str = message
700
-		if mention is not None:
701
-			text = f'{mention} {text}'
702
-		msg: Message = await channel.send(text)
703
-		return msg
704
-
705
-	@classmethod
706
-	async def update_warn(cls, warn_message: Message, new_text: str) -> None:
708
+	def log(cls, guild: Guild, message) -> None:
707 709
 		"""
708
-		DEPRECATED. Use post_message.
709
-
710
-		Updates the text of a previously posted `warn`. Includes configured
711
-		mentions if necessary.
710
+		Writes a message to the console. Intended for significant events only.
712 711
 		"""
713
-		text: str = new_text
714
-		mention: str = Storage.get_config_value(
715
-			warn_message.guild,
716
-			ConfigKey.WARNING_MENTION)
717
-		if mention is not None:
718
-			text = f'{mention} {text}'
719
-		await warn_message.edit(content=text)
720
-
721
-	@classmethod
722
-	def log(cls, guild: Guild, message) -> None:
723 712
 		now = datetime.now()
724 713
 		print(f'[{now.strftime("%Y-%m-%dT%H:%M:%S")}|{cls.__name__}|{guild.name}] {message}')

+ 14
- 15
cogs/configcog.py ファイルの表示

@@ -5,33 +5,30 @@ from config import CONFIG
5 5
 from storage import ConfigKey, Storage
6 6
 from cogs.basecog import BaseCog
7 7
 
8
-class ConfigCog(BaseCog):
8
+class ConfigCog(BaseCog, name='Configuration'):
9 9
 	"""
10 10
 	Cog for handling general bot configuration.
11 11
 	"""
12
-	def __init__(self, bot):
13
-		super().__init__(bot)
14
-
15 12
 	@commands.group(
16 13
 		brief='Manages general bot configuration'
17 14
 	)
18 15
 	@commands.has_permissions(ban_members=True)
19 16
 	@commands.guild_only()
20 17
 	async def config(self, context: commands.Context):
21
-		'Command group'
18
+		'General guild configuration command group'
22 19
 		if context.invoked_subcommand is None:
23 20
 			await context.send_help()
24 21
 
25 22
 	@config.command(
26
-		name='setwarningchannel',
27
-		brief='Sets where to post mod warnings',
23
+		brief='Sets the mod warning channel',
28 24
 		description='Run this command in the channel where bot messages ' +
29 25
 			'intended for server moderators should be sent. Other bot ' +
30 26
 			'messages may still be posted in the channel a command was ' +
31 27
 			'invoked in. If no output channel is set, mod-related messages ' +
32 28
 			'will not be posted!',
33 29
 	)
34
-	async def config_setwarningchannel(self, context: commands.Context) -> None:
30
+	async def setwarningchannel(self, context: commands.Context) -> None:
31
+		'Command handler'
35 32
 		guild: Guild = context.guild
36 33
 		channel: TextChannel = context.channel
37 34
 		Storage.set_config_value(guild, ConfigKey.WARNING_CHANNEL_ID,
@@ -41,10 +38,12 @@ class ConfigCog(BaseCog):
41 38
 			mention_author=False)
42 39
 
43 40
 	@config.command(
44
-		name='getwarningchannel',
45
-		brief='Shows the channel where mod warnings are posted',
41
+		brief='Shows the mod warning channel',
42
+		description='Shows the configured channel (if any) where mod ' + \
43
+			'warnings will be posted.',
46 44
 	)
47
-	async def config_getwarningchannel(self, context: commands.Context) -> None:
45
+	async def getwarningchannel(self, context: commands.Context) -> None:
46
+		'Command handler'
48 47
 		guild: Guild = context.guild
49 48
 		channel_id = Storage.get_config_value(guild, ConfigKey.WARNING_CHANNEL_ID)
50 49
 		if channel_id is None:
@@ -58,7 +57,6 @@ class ConfigCog(BaseCog):
58 57
 				mention_author=False)
59 58
 
60 59
 	@config.command(
61
-		name='setwarningmention',
62 60
 		brief='Sets a user/role to mention in warning messages',
63 61
 		usage='<@user|@role>',
64 62
 		description='Configures an role or other prefix to include at the ' +
@@ -66,9 +64,10 @@ class ConfigCog(BaseCog):
66 64
 			'attention of certain users, be sure to specify a properly ' +
67 65
 			'formed @ tag, not just the name of the user/role.'
68 66
 	)
69
-	async def config_setwarningmention(self,
67
+	async def setwarningmention(self,
70 68
 			context: commands.Context,
71 69
 			mention: str = None) -> None:
70
+		'Command handler'
72 71
 		guild: Guild = context.guild
73 72
 		Storage.set_config_value(guild, ConfigKey.WARNING_MENTION, mention)
74 73
 		if mention is None:
@@ -81,12 +80,12 @@ class ConfigCog(BaseCog):
81 80
 				mention_author=False)
82 81
 
83 82
 	@config.command(
84
-		name='getwarningmention',
85 83
 		brief='Shows the user/role to mention in warning messages',
86 84
 		description='Shows the text, if any, that will be prefixed on any ' +
87 85
 			'warning messages.'
88 86
 	)
89
-	async def config_getwarningmention(self, context: commands.Context) -> None:
87
+	async def getwarningmention(self, context: commands.Context) -> None:
88
+		'Command handler'
90 89
 		guild: Guild = context.guild
91 90
 		mention: str = Storage.get_config_value(guild, ConfigKey.WARNING_MENTION)
92 91
 		if mention is None:

+ 6
- 4
cogs/crosspostcog.py ファイルの表示

@@ -1,5 +1,5 @@
1 1
 from datetime import datetime, timedelta
2
-from discord import Guild, Member, Message
2
+from discord import Member, Message
3 3
 from discord.ext import commands
4 4
 
5 5
 from cogs.basecog import BaseCog, BotMessage, BotMessageReaction, CogSetting
@@ -8,6 +8,9 @@ from rscollections import AgeBoundList, SizeBoundDict
8 8
 from storage import Storage
9 9
 
10 10
 class SpamContext:
11
+	"""
12
+	Data about a set of duplicate messages from a user.
13
+	"""
11 14
 	def __init__(self, member, message_hash):
12 15
 		self.member = member
13 16
 		self.message_hash = message_hash
@@ -20,7 +23,7 @@ class SpamContext:
20 23
 		self.deleted_messages = set()  # of Message
21 24
 		self.unique_channels = set()  # of TextChannel
22 25
 
23
-class CrossPostCog(BaseCog):
26
+class CrossPostCog(BaseCog, name='Crosspost Detection'):
24 27
 	"""
25 28
 	Detects a user posting the same text in multiple channels in a short period
26 29
 	of time: a common pattern for spammers. Repeated posts in the same channel
@@ -120,7 +123,6 @@ class CrossPostCog(BaseCog):
120 123
 				continue
121 124
 			key = f'{message.author.id}|{message_hash}'
122 125
 			context = spam_lookup.get(key)
123
-			is_new = context is None
124 126
 			if context is None:
125 127
 				context = SpamContext(message.author, message_hash)
126 128
 				spam_lookup[key] = context
@@ -139,7 +141,6 @@ class CrossPostCog(BaseCog):
139 141
 		channel_count = len(context.unique_channels)
140 142
 		if channel_count >= ban_count:
141 143
 			if not context.is_banned:
142
-				count = len(context.spam_messages)
143 144
 				await context.member.ban(
144 145
 					reason='Rocketbot: Posted same message in ' + \
145 146
 						f'{channel_count} channels. Banned by ' + \
@@ -227,6 +228,7 @@ class CrossPostCog(BaseCog):
227 228
 
228 229
 	@commands.Cog.listener()
229 230
 	async def on_message(self, message: Message):
231
+		'Event handler'
230 232
 		if message.author is None or \
231 233
 				message.author.bot or \
232 234
 				message.channel is None or \

+ 29
- 10
cogs/generalcog.py ファイルの表示

@@ -1,14 +1,18 @@
1
+import re
1 2
 from datetime import datetime, timedelta
2
-from discord import Member, Message
3
+from discord import Message
3 4
 from discord.ext import commands
4
-import re
5 5
 
6 6
 from cogs.basecog import BaseCog, BotMessage
7 7
 from config import CONFIG
8 8
 from rbutils import parse_timedelta, describe_timedelta
9 9
 from storage import ConfigKey, Storage
10 10
 
11
-class GeneralCog(BaseCog):
11
+class GeneralCog(BaseCog, name='General'):
12
+	"""
13
+	Cog for handling high-level bot functionality and commands. Should be the
14
+	first cog added to the bot.
15
+	"""
12 16
 	def __init__(self, bot: commands.Bot):
13 17
 		super().__init__(bot)
14 18
 		self.is_connected = False
@@ -16,20 +20,26 @@ class GeneralCog(BaseCog):
16 20
 
17 21
 	@commands.Cog.listener()
18 22
 	async def on_connect(self):
23
+		'Event handler'
19 24
 		print('on_connect')
20 25
 		self.is_connected = True
21 26
 
22 27
 	@commands.Cog.listener()
23 28
 	async def on_ready(self):
29
+		'Event handler'
24 30
 		print('on_ready')
25 31
 		self.is_ready = True
26 32
 
27 33
 	@commands.command(
28
-		brief='Posts a test warning in the configured warning channel.'
34
+		brief='Posts a test warning',
35
+		description='Tests whether a warning channel is configured for this ' + \
36
+			'guild by posting a test warning. If a mod mention is ' + \
37
+			'configured, that user/role will be tagged in the test warning.',
29 38
 	)
30 39
 	@commands.has_permissions(ban_members=True)
31 40
 	@commands.guild_only()
32 41
 	async def testwarn(self, context):
42
+		'Command handler'
33 43
 		if Storage.get_config_value(context.guild, ConfigKey.WARNING_CHANNEL_ID) is None:
34 44
 			await context.message.reply(
35 45
 				f'{CONFIG["warning_emoji"]} No warning channel set!',
@@ -43,30 +53,39 @@ class GeneralCog(BaseCog):
43 53
 
44 54
 	@commands.command(
45 55
 		brief='Simple test reply',
56
+		description='Replies to the command message. Useful to ensure the ' + \
57
+			'bot is working properly.',
46 58
 	)
47 59
 	async def hello(self, context):
60
+		'Command handler'
48 61
 		await context.message.reply(
49 62
 			f'Hey, {context.author.name}!',
50 63
 		 	mention_author=False)
51 64
 
52 65
 	@commands.command(
53
-		brief='Shuts down the bot (admin only)',
66
+		brief='Shuts down the bot',
67
+		description='Causes the bot script to terminate. Only usable by a ' + \
68
+			'user with server admin permissions.',
54 69
 	)
55 70
 	@commands.has_permissions(administrator=True)
56 71
 	@commands.guild_only()
57
-	async def shutdown(self, context):
72
+	async def shutdown(self, context: commands.Context):
73
+		'Command handler'
74
+		await context.message.add_reaction('👋')
58 75
 		await self.bot.close()
59 76
 
60 77
 	@commands.command(
61 78
 		brief='Mass deletes messages',
62
-		description='Deletes recent messages by the given user. The age is ' +
63
-			'a duration, such as "30s", "5m", "1h30m". Messages far back in ' +
64
-			'the scrollback might not be deleted by this command.',
65
-		usage='<user> <age>'
79
+		description='Deletes recent messages by the given user. The user ' +
80
+			'can be either an @ mention or a numeric user ID. The age is ' +
81
+			'a duration, such as "30s", "5m", "1h30m". Only the most ' +
82
+			'recent 100 messages in each channel are searched.',
83
+		usage='<user:id|mention> <age:timespan>'
66 84
 	)
67 85
 	@commands.has_permissions(manage_messages=True)
68 86
 	@commands.guild_only()
69 87
 	async def deletemessages(self, context, user: str, age: str) -> None:
88
+		'Command handler'
70 89
 		member_id = self.__parse_member_id(user)
71 90
 		if member_id is None:
72 91
 			await context.message.reply(

+ 9
- 5
cogs/joinraidcog.py ファイルの表示

@@ -1,7 +1,7 @@
1
+import weakref
1 2
 from datetime import datetime, timedelta
2 3
 from discord import Guild, Member
3 4
 from discord.ext import commands
4
-import weakref
5 5
 
6 6
 from cogs.basecog import BaseCog, BotMessage, BotMessageReaction, CogSetting
7 7
 from config import CONFIG
@@ -9,6 +9,9 @@ from rscollections import AgeBoundList
9 9
 from storage import Storage
10 10
 
11 11
 class JoinRaidContext:
12
+	"""
13
+	Data about a join raid.
14
+	"""
12 15
 	def __init__(self, join_members: list):
13 16
 		self.join_members = list(join_members)
14 17
 		self.kicked_members = set()
@@ -16,9 +19,10 @@ class JoinRaidContext:
16 19
 		self.warning_message_ref = None
17 20
 
18 21
 	def last_join_time(self) -> datetime:
22
+		'Returns when the most recent member join was, in UTC'
19 23
 		return self.join_members[-1].joined_at
20 24
 
21
-class JoinRaidCog(BaseCog):
25
+class JoinRaidCog(BaseCog, name='Join Raids'):
22 26
 	"""
23 27
 	Cog for monitoring member joins and detecting potential bot raids.
24 28
 	"""
@@ -40,7 +44,7 @@ class JoinRaidCog(BaseCog):
40 44
 			min_value=1.0,
41 45
 			max_value=900.0)
42 46
 
43
-	STATE_KEY_RECENT_JOINS = "JoinRaidCog.recent_joins"	
47
+	STATE_KEY_RECENT_JOINS = "JoinRaidCog.recent_joins"
44 48
 	STATE_KEY_LAST_RAID = "JoinRaidCog.last_raid"
45 49
 
46 50
 	def __init__(self, bot):
@@ -55,7 +59,7 @@ class JoinRaidCog(BaseCog):
55 59
 	@commands.has_permissions(ban_members=True)
56 60
 	@commands.guild_only()
57 61
 	async def joinraid(self, context: commands.Context):
58
-		'Command group'
62
+		'Join raid detection command group'
59 63
 		if context.invoked_subcommand is None:
60 64
 			await context.send_help()
61 65
 
@@ -64,7 +68,7 @@ class JoinRaidCog(BaseCog):
64 68
 			reaction: BotMessageReaction,
65 69
 			reacted_by: Member) -> None:
66 70
 		guild: Guild = bot_message.guild
67
-		raid: JoinRaidRecord = bot_message.context
71
+		raid: JoinRaidContext = bot_message.context
68 72
 		if reaction.emoji == CONFIG['kick_emoji']:
69 73
 			to_kick = set(raid.join_members) - raid.kicked_members
70 74
 			for member in to_kick:

+ 133
- 81
cogs/patterncog.py ファイルの表示

@@ -1,8 +1,7 @@
1
+import re
1 2
 from abc import ABC, abstractmethod
2 3
 from discord import Guild, Member, Message
3 4
 from discord.ext import commands
4
-from datetime import timedelta
5
-import re
6 5
 
7 6
 from cogs.basecog import BaseCog, BotMessage, BotMessageReaction
8 7
 from config import CONFIG
@@ -10,42 +9,58 @@ from rbutils import parse_timedelta
10 9
 from storage import Storage
11 10
 
12 11
 class PatternAction:
13
-	def __init__(self, type: str, args: list):
14
-		self.type = type
12
+	"""
13
+	Describes one action to take on a matched message or its author.
14
+	"""
15
+	def __init__(self, action: str, args: list):
16
+		self.action = action
15 17
 		self.arguments = list(args)
16 18
 
17 19
 	def __str__(self) -> str:
18 20
 		arg_str = ', '.join(self.arguments)
19
-		return f'{self.type}({arg_str})'
21
+		return f'{self.action}({arg_str})'
20 22
 
21 23
 class PatternExpression(ABC):
24
+	"""
25
+	Abstract message matching expression.
26
+	"""
22 27
 	def __init__(self):
23 28
 		pass
24 29
 
25 30
 	@abstractmethod
26 31
 	def matches(self, message: Message) -> bool:
32
+		"""
33
+		Whether a message matches this expression.
34
+		"""
27 35
 		return False
28 36
 
29 37
 class PatternSimpleExpression(PatternExpression):
38
+	"""
39
+	Message matching expression with a simple "<field> <operator> <value>"
40
+	structure.
41
+	"""
30 42
 	def __init__(self, field: str, operator: str, value):
43
+		super().__init__()
31 44
 		self.field = field
32 45
 		self.operator = operator
33 46
 		self.value = value
34 47
 
35
-	def matches(self, message: Message) -> bool:
36
-		field_value = None
48
+	def __field_value(self, message: Message):
37 49
 		if self.field == 'content':
38
-			field_value = message.content
39
-		elif self.field == 'author':
40
-			field_value = str(message.author.id)
41
-		elif self.field == 'author.id':
42
-			field_value = str(message.author.id)
43
-		elif self.field == 'author.joinage':
44
-			field_value = message.created_at - message.author.joined_at
45
-		elif self.field == 'author.name':
46
-			field_value = message.author.name
50
+			return message.content
51
+		if self.field == 'author':
52
+			return str(message.author.id)
53
+		if self.field == 'author.id':
54
+			return str(message.author.id)
55
+		if self.field == 'author.joinage':
56
+			return message.created_at - message.author.joined_at
57
+		if self.field == 'author.name':
58
+			return message.author.name
47 59
 		else:
48 60
 			raise ValueError(f'Bad field name {self.field}')
61
+
62
+	def matches(self, message: Message) -> bool:
63
+		field_value = self.__field_value(message)
49 64
 		if self.operator == '==':
50 65
 			if isinstance(field_value, str) and isinstance(self.value, str):
51 66
 				return field_value.lower() == self.value.lower()
@@ -78,35 +93,42 @@ class PatternSimpleExpression(PatternExpression):
78 93
 		return f'({self.field} {self.operator} {self.value})'
79 94
 
80 95
 class PatternCompoundExpression(PatternExpression):
96
+	"""
97
+	Message matching expression that combines several child expressions with
98
+	a boolean operator.
99
+	"""
81 100
 	def __init__(self, operator: str, operands: list):
101
+		super().__init__()
82 102
 		self.operator = operator
83 103
 		self.operands = list(operands)
84 104
 
85 105
 	def matches(self, message: Message) -> bool:
86 106
 		if self.operator == '!':
87 107
 			return not self.operands[0].matches(message)
88
-		elif self.operator == 'and':
108
+		if self.operator == 'and':
89 109
 			for op in self.operands:
90 110
 				if not op.matches(message):
91 111
 					return False
92 112
 			return True
93
-		elif self.operator == 'or':
113
+		if self.operator == 'or':
94 114
 			for op in self.operands:
95 115
 				if op.matches(message):
96 116
 					return True
97 117
 			return False
98
-		else:
99
-			raise RuntimeError(f'Bad operator "{self.operator}"')
118
+		raise RuntimeError(f'Bad operator "{self.operator}"')
100 119
 
101 120
 	def __str__(self) -> str:
102 121
 		if self.operator == '!':
103 122
 			return f'(!( {self.operands[0]} ))'
104
-		else:
105
-			strs = map(str, self.operands)
106
-			joined = f' {self.operator} '.join(strs)
107
-			return f'( {joined} )'
123
+		strs = map(str, self.operands)
124
+		joined = f' {self.operator} '.join(strs)
125
+		return f'( {joined} )'
108 126
 
109 127
 class PatternStatement:
128
+	"""
129
+	A full message match statement. If a message matches the given expression,
130
+	the given actions should be performed.
131
+	"""
110 132
 	def __init__(self, name: str, actions: list, expression: PatternExpression, original: str):
111 133
 		self.name = name
112 134
 		self.actions = list(actions)  # PatternAction[]
@@ -114,6 +136,10 @@ class PatternStatement:
114 136
 		self.original = original
115 137
 
116 138
 class PatternContext:
139
+	"""
140
+	Data about a message that has matched a configured statement and what
141
+	actions have been carried out.
142
+	"""
117 143
 	def __init__(self, message: Message, statement: PatternStatement):
118 144
 		self.message = message
119 145
 		self.statement = statement
@@ -121,20 +147,11 @@ class PatternContext:
121 147
 		self.is_kicked = False
122 148
 		self.is_banned = False
123 149
 
124
-class PatternCog(BaseCog):
125
-	def __init__(self, bot):
126
-		super().__init__(bot)
127
-
128
-	# def __patterns(self, guild: Guild) -> list:
129
-	# 	patterns = Storage.get_state_value(guild, 'pattern_patterns')
130
-	# 	if patterns is None:
131
-	# 		patterns_encoded = Storage.get_config_value(guild, 'pattern_patterns')
132
-	# 		if patterns_encoded:
133
-	# 			patterns = []
134
-	# 			for pe in patterns_encoded:
135
-	# 				patterns.append(Pattern.decode(pe))
136
-	# 			Storage.set_state_value(guild, 'pattern_patterns', patterns)
137
-	# 	return patterns
150
+class PatternCog(BaseCog, name='Pattern Matching'):
151
+	"""
152
+	Highly flexible cog for performing various actions on messages that match
153
+	various critera. Patterns can be defined by mods for each guild.
154
+	"""
138 155
 
139 156
 	def __get_patterns(self, guild: Guild) -> dict:
140 157
 		patterns = Storage.get_state_value(guild, 'PatternCog.patterns')
@@ -148,12 +165,13 @@ class PatternCog(BaseCog):
148 165
 					try:
149 166
 						ps = PatternCompiler.parse_statement(name, statement)
150 167
 						patterns[name] = ps
151
-					except RuntimeError as e:
152
-						self.log(guild, f'Error parsing saved statement "{name}". Skipping: {statement}')
168
+					except Exception as e:
169
+						self.log(guild, f'Error parsing saved statement "{name}". Skipping: {statement}. Error: {e}')
153 170
 			Storage.set_state_value(guild, 'PatternCog.patterns', patterns)
154 171
 		return patterns
155 172
 
156
-	def __save_patterns(self, guild: Guild, patterns: dict) -> None:
173
+	@classmethod
174
+	def __save_patterns(cls, guild: Guild, patterns: dict) -> None:
157 175
 		to_save = []
158 176
 		for name, statement in patterns.items():
159 177
 			to_save.append({
@@ -164,6 +182,7 @@ class PatternCog(BaseCog):
164 182
 
165 183
 	@commands.Cog.listener()
166 184
 	async def on_message(self, message: Message) -> None:
185
+		'Event listener'
167 186
 		if message.author is None or \
168 187
 				message.author.bot or \
169 188
 				message.channel is None or \
@@ -176,7 +195,7 @@ class PatternCog(BaseCog):
176 195
 			return
177 196
 
178 197
 		patterns = self.__get_patterns(message.guild)
179
-		for name, statement in patterns.items():
198
+		for _, statement in patterns.items():
180 199
 			if statement.expression.matches(message):
181 200
 				await self.__trigger_actions(message, statement)
182 201
 				break
@@ -188,7 +207,7 @@ class PatternCog(BaseCog):
188 207
 		action_descriptions = []
189 208
 		self.log(message.guild, f'Message from {message.author.name} matched pattern "{statement.name}"')
190 209
 		for action in statement.actions:
191
-			if action.type == 'ban':
210
+			if action.action == 'ban':
192 211
 				await message.author.ban(
193 212
 					reason=f'Rocketbot: Message matched custom pattern named "{statement.name}"',
194 213
 					delete_message_days=0)
@@ -196,21 +215,21 @@ class PatternCog(BaseCog):
196 215
 				context.is_kicked = True
197 216
 				action_descriptions.append('Author banned')
198 217
 				self.log(message.guild, f'{message.author.name} banned')
199
-			elif action.type == 'delete':
218
+			elif action.action == 'delete':
200 219
 				await message.delete()
201 220
 				context.is_deleted = True
202 221
 				action_descriptions.append('Message deleted')
203 222
 				self.log(message.guild, f'{message.author.name}\'s message deleted')
204
-			elif action.type == 'kick':
223
+			elif action.action == 'kick':
205 224
 				await message.author.kick(
206 225
 					reason=f'Rocketbot: Message matched custom pattern named "{statement.name}"')
207 226
 				context.is_kicked = True
208 227
 				action_descriptions.append('Author kicked')
209 228
 				self.log(message.guild, f'{message.author.name} kicked')
210
-			elif action.type == 'modwarn':
229
+			elif action.action == 'modwarn':
211 230
 				should_alert_mods = True
212 231
 				action_descriptions.append('Mods alerted')
213
-			elif action.type == 'reply':
232
+			elif action.action == 'reply':
214 233
 				await message.reply(
215 234
 					f'{action.arguments[0]}',
216 235
 					mention_author=False)
@@ -240,13 +259,13 @@ class PatternCog(BaseCog):
240 259
 			context.is_deleted = True
241 260
 		elif reaction.emoji == CONFIG['kick_emoji']:
242 261
 			await context.message.author.kick(
243
-				reason=f'Rocketbot: Message matched custom pattern named ' + \
244
-					'"{statement.name}". Kicked by {reacted_by.name}.')
262
+				reason='Rocketbot: Message matched custom pattern named ' + \
263
+					f'"{context.statement.name}". Kicked by {reacted_by.name}.')
245 264
 			context.is_kicked = True
246 265
 		elif reaction.emoji == CONFIG['ban_emoji']:
247 266
 			await context.message.author.ban(
248
-				reason=f'Rocketbot: Message matched custom pattern named ' + \
249
-					'"{statement.name}". Banned by {reacted_by.name}.',
267
+				reason='Rocketbot: Message matched custom pattern named ' + \
268
+					f'"{context.statement.name}". Banned by {reacted_by.name}.',
250 269
 					delete_message_days=1)
251 270
 			context.is_banned = True
252 271
 		await bot_message.set_reactions(BotMessageReaction.standard_set(
@@ -260,7 +279,7 @@ class PatternCog(BaseCog):
260 279
 	@commands.has_permissions(ban_members=True)
261 280
 	@commands.guild_only()
262 281
 	async def pattern(self, context: commands.Context):
263
-		'Message pattern matching'
282
+		'Message pattern matching command group'
264 283
 		if context.invoked_subcommand is None:
265 284
 			await context.send_help()
266 285
 
@@ -273,6 +292,7 @@ class PatternCog(BaseCog):
273 292
 		ignore_extra=True
274 293
 	)
275 294
 	async def add(self, context: commands.Context, name: str):
295
+		'Command handler'
276 296
 		pattern_str = PatternCompiler.expression_str_from_context(context, name)
277 297
 		try:
278 298
 			statement = PatternCompiler.parse_statement(name, pattern_str)
@@ -292,6 +312,7 @@ class PatternCog(BaseCog):
292 312
 		usage='<pattern_name>'
293 313
 	)
294 314
 	async def remove(self, context: commands.Context, name: str):
315
+		'Command handler'
295 316
 		patterns = self.__get_patterns(context.guild)
296 317
 		if patterns.get(name) is not None:
297 318
 			del patterns[name]
@@ -308,6 +329,7 @@ class PatternCog(BaseCog):
308 329
 		brief='Lists all patterns'
309 330
 	)
310 331
 	async def list(self, context: commands.Context) -> None:
332
+		'Command handler'
311 333
 		patterns = self.__get_patterns(context.guild)
312 334
 		if len(patterns) == 0:
313 335
 			await context.message.reply('No patterns defined.', mention_author=False)
@@ -318,6 +340,9 @@ class PatternCog(BaseCog):
318 340
 		await context.message.reply(msg, mention_author=False)
319 341
 
320 342
 class PatternCompiler:
343
+	"""
344
+	Parses a user-provided message filter statement into a PatternStatement.
345
+	"""
321 346
 	TYPE_ID = 'id'
322 347
 	TYPE_MEMBER = 'Member'
323 348
 	TYPE_TEXT = 'text'
@@ -364,6 +389,9 @@ class PatternCompiler:
364 389
 
365 390
 	@classmethod
366 391
 	def expression_str_from_context(cls, context: commands.Context, name: str) -> str:
392
+		"""
393
+		Extracts the statement string from an "add" command context.
394
+		"""
367 395
 		pattern_str = context.message.content
368 396
 		command_chain = [ name ]
369 397
 		cmd = context.command
@@ -380,6 +408,9 @@ class PatternCompiler:
380 408
 
381 409
 	@classmethod
382 410
 	def parse_statement(cls, name: str, statement: str) -> PatternStatement:
411
+		"""
412
+		Parses a user-provided message filter statement into a PatternStatement.
413
+		"""
383 414
 		tokens = cls.tokenize(statement)
384 415
 		token_index = 0
385 416
 		actions, token_index = cls.read_actions(tokens, token_index)
@@ -388,6 +419,9 @@ class PatternCompiler:
388 419
 
389 420
 	@classmethod
390 421
 	def tokenize(cls, statement: str) -> list:
422
+		"""
423
+		Converts a message filter statement into a list of tokens.
424
+		"""
391 425
 		tokens = []
392 426
 		in_quote = False
393 427
 		in_escape = False
@@ -476,6 +510,11 @@ class PatternCompiler:
476 510
 
477 511
 	@classmethod
478 512
 	def read_actions(cls, tokens: list, token_index: int) -> tuple:
513
+		"""
514
+		Reads the actions from a list of statement tokens. Returns a tuple
515
+		containing a list of PatternActions and the token index this method
516
+		left off at (the token after the "if").
517
+		"""
479 518
 		actions = []
480 519
 		current_action_tokens = []
481 520
 		while token_index < len(tokens):
@@ -501,43 +540,47 @@ class PatternCompiler:
501 540
 
502 541
 	@classmethod
503 542
 	def __validate_action(cls, action: PatternAction) -> None:
504
-		args = cls.ACTION_TO_ARGS.get(action.type)
543
+		args = cls.ACTION_TO_ARGS.get(action.action)
505 544
 		if args is None:
506
-			raise RuntimeError(f'Unknown action "{action.type}"')
545
+			raise RuntimeError(f'Unknown action "{action.action}"')
507 546
 		if len(action.arguments) != len(args):
508
-			arg_list = ', '.join(args)
509 547
 			if len(args) == 0:
510
-				raise RuntimeError(f'Action "{action.type}" expects no arguments, got {len(action.arguments)}.')
548
+				raise RuntimeError(f'Action "{action.action}" expects no arguments, ' + \
549
+					f'got {len(action.arguments)}.')
511 550
 			else:
512
-				raise RuntimeError(f'Action "{action.type}" expects {len(args)} arguments, got {len(action.arguments)}.')
513
-		for i in range(len(args)):
514
-			datatype = args[i]
551
+				raise RuntimeError(f'Action "{action.action}" expects {len(args)} ' + \
552
+					f'arguments, got {len(action.arguments)}.')
553
+		for i, datatype in enumerate(args):
515 554
 			action.arguments[i] = cls.parse_value(action.arguments[i], datatype)
516 555
 
517 556
 	@classmethod
518
-	def read_expression(cls, tokens: list, token_index: int, depth: int = 0, one_subexpression: bool = False) -> tuple:
519
-		# field op value
520
-		# (field op value)
521
-		# !(field op value)
522
-		# field op value and field op value
523
-		# (field op value and field op value) or field op value
524
-		indent = '\t' * depth
557
+	def read_expression(cls,
558
+			tokens: list,
559
+			token_index: int,
560
+			depth: int = 0,
561
+			one_subexpression: bool = False) -> tuple:
562
+		"""
563
+		Reads an expression from a list of statement tokens. Returns a tuple
564
+		containing the PatternExpression and the token index it left off at.
565
+		If one_subexpression is True then it will return after reading a
566
+		single expression instead of joining multiples (for readong the
567
+		subject of a NOT expression).
568
+		"""
525 569
 		subexpressions = []
526 570
 		last_compound_operator = None
527 571
 		while token_index < len(tokens):
528 572
 			if one_subexpression:
529 573
 				if len(subexpressions) == 1:
530 574
 					return (subexpressions[0], token_index)
531
-				elif len(subexpressions) > 1:
575
+				if len(subexpressions) > 1:
532 576
 					raise RuntimeError('Too many subexpressions')
533 577
 			compound_operator = None
534 578
 			if tokens[token_index] == ')':
535 579
 				if len(subexpressions) == 0:
536 580
 					raise RuntimeError('No subexpressions')
537
-				elif len(subexpressions) == 1:
581
+				if len(subexpressions) == 1:
538 582
 					return (subexpressions[0], token_index)
539
-				else:
540
-					return (PatternCompoundExpression(last_compound_operator, subexpressions), token_index)
583
+				return (PatternCompoundExpression(last_compound_operator, subexpressions), token_index)
541 584
 			if tokens[token_index] in set(["and", "or"]):
542 585
 				compound_operator = tokens[token_index]
543 586
 				if last_compound_operator and compound_operator != last_compound_operator:
@@ -547,7 +590,8 @@ class PatternCompiler:
547 590
 					last_compound_operator = compound_operator
548 591
 				token_index += 1
549 592
 			if tokens[token_index] == '!':
550
-				(exp, next_index) = cls.read_expression(tokens, token_index + 1, depth + 1, one_subexpression=True)
593
+				(exp, next_index) = cls.read_expression(tokens, token_index + 1, \
594
+						depth + 1, one_subexpression=True)
551 595
 				subexpressions.append(PatternCompoundExpression('!', [exp]))
552 596
 				token_index = next_index
553 597
 			elif tokens[token_index] == '(':
@@ -569,8 +613,13 @@ class PatternCompiler:
569 613
 
570 614
 	@classmethod
571 615
 	def read_simple_expression(cls, tokens: list, token_index: int, depth: int = 0) -> tuple:
572
-		indent = '\t' * depth
573
-
616
+		"""
617
+		Reads a simple expression consisting of a field name, operator, and
618
+		comparison value. Returns a tuple of the PatternSimpleExpression and
619
+		the token index it left off at.
620
+		"""
621
+		if depth > 8:
622
+			raise RuntimeError('Expression nests too deeply')
574 623
 		if token_index >= len(tokens):
575 624
 			raise RuntimeError('Expected field name, found EOL')
576 625
 		field = tokens[token_index]
@@ -609,20 +658,23 @@ class PatternCompiler:
609 658
 		return (exp, token_index)
610 659
 
611 660
 	@classmethod
612
-	def parse_value(cls, value: str, type: str):
613
-		if type == cls.TYPE_ID:
661
+	def parse_value(cls, value: str, datatype: str):
662
+		"""
663
+		Converts a value token to its Python value.
664
+		"""
665
+		if datatype == cls.TYPE_ID:
614 666
 			p = re.compile('^[0-9]+$')
615 667
 			if p.match(value) is None:
616 668
 				raise ValueError(f'Illegal id value "{value}"')
617 669
 			# Store it as a str so it can be larger than an int
618 670
 			return value
619
-		if type == cls.TYPE_MEMBER:
671
+		if datatype == cls.TYPE_MEMBER:
620 672
 			p = re.compile('^<@!?([0-9]+)>$')
621 673
 			m = p.match(value)
622 674
 			if m is None:
623
-				raise ValueError(f'Illegal member value. Must be an @ mention.')
675
+				raise ValueError('Illegal member value. Must be an @ mention.')
624 676
 			return m.group(1)
625
-		if type == cls.TYPE_TEXT:
677
+		if datatype == cls.TYPE_TEXT:
626 678
 			# Must be quoted.
627 679
 			if len(value) < 2 or \
628 680
 					value[0:1] not in cls.STRING_QUOTE_CHARS or \
@@ -630,10 +682,10 @@ class PatternCompiler:
630 682
 					value[0:1] != value[-1:]:
631 683
 				raise ValueError(f'Not a quoted string value: {value}')
632 684
 			return value[1:-1]
633
-		if type == cls.TYPE_INT:
685
+		if datatype == cls.TYPE_INT:
634 686
 			return int(value)
635
-		if type == cls.TYPE_FLOAT:
687
+		if datatype == cls.TYPE_FLOAT:
636 688
 			return float(value)
637
-		if type == cls.TYPE_TIMESPAN:
689
+		if datatype == cls.TYPE_TIMESPAN:
638 690
 			return parse_timedelta(value)
639 691
 		raise ValueError(f'Unhandled datatype {datatype}')

+ 10
- 25
cogs/urlspamcog.py ファイルの表示

@@ -1,20 +1,23 @@
1
-from discord import Guild, Member, Message
2
-from discord.ext import commands
3 1
 import re
4 2
 from datetime import timedelta
3
+from discord import Member, Message
4
+from discord.ext import commands
5 5
 
6 6
 from cogs.basecog import BaseCog, BotMessage, BotMessageReaction, CogSetting
7 7
 from config import CONFIG
8
-from storage import Storage
8
+from rbutils import describe_timedelta
9 9
 
10 10
 class URLSpamContext:
11
+	"""
12
+	Data about a suspected spam message containing a URL.
13
+	"""
11 14
 	def __init__(self, spam_message: Message):
12 15
 		self.spam_message = spam_message
13 16
 		self.is_deleted = False
14 17
 		self.is_kicked = False
15 18
 		self.is_banned = False
16 19
 
17
-class URLSpamCog(BaseCog):
20
+class URLSpamCog(BaseCog, name='URL Spam'):
18 21
 	"""
19 22
 	Detects users posting URLs who just joined recently: a common spam pattern.
20 23
 	Can be configured to take immediate action or just warn the mods.
@@ -50,12 +53,13 @@ class URLSpamCog(BaseCog):
50 53
 	@commands.has_permissions(ban_members=True)
51 54
 	@commands.guild_only()
52 55
 	async def urlspam(self, context: commands.Context):
53
-		'Command group'
56
+		'URL spam command group'
54 57
 		if context.invoked_subcommand is None:
55 58
 			await context.send_help()
56 59
 
57 60
 	@commands.Cog.listener()
58 61
 	async def on_message(self, message: Message):
62
+		'Event listener'
59 63
 		if message.author is None or \
60 64
 				message.author.bot or \
61 65
 				message.guild is None or \
@@ -73,7 +77,7 @@ class URLSpamCog(BaseCog):
73 77
 		if not self.__contains_url(message.content):
74 78
 			return
75 79
 		join_age = message.created_at - message.author.joined_at
76
-		join_age_str = self.__format_timedelta(join_age)
80
+		join_age_str = describe_timedelta(join_age)
77 81
 		if join_age < min_join_age:
78 82
 			context = URLSpamContext(message)
79 83
 			needs_attention = False
@@ -165,22 +169,3 @@ class URLSpamCog(BaseCog):
165 169
 	def __contains_url(cls, text: str) -> bool:
166 170
 		p = re.compile(r'http(?:s)?://[^\s]+')
167 171
 		return p.search(text) is not None
168
-
169
-	@classmethod
170
-	def __format_timedelta(cls, timespan: timedelta) -> str:
171
-		parts = []
172
-		d = timespan.days
173
-		h = timespan.seconds // 3600
174
-		m = (timespan.seconds // 60) % 60
175
-		s = timespan.seconds % 60
176
-		if d > 0:
177
-			parts.append(f'{d} days')
178
-		if d > 0 or h > 0:
179
-			parts.append(f'{h} hours')
180
-		if d > 0 or h > 0 or m > 0:
181
-			parts.append(f'{m} minutes')
182
-		parts.append(f'{s} seconds')
183
-		# Limit the precision to the two most significant elements
184
-		while len(parts) > 2:
185
-			parts.pop(-1)
186
-		return ' '.join(parts)

+ 20
- 26
rbutils.py ファイルの表示

@@ -1,5 +1,5 @@
1
-from datetime import timedelta
2 1
 import re
2
+from datetime import timedelta
3 3
 
4 4
 def parse_timedelta(s: str) -> timedelta:
5 5
 	"""
@@ -32,28 +32,22 @@ def parse_timedelta(s: str) -> timedelta:
32 32
 	return timedelta(days=days, hours=hours, minutes=minutes, seconds=seconds)
33 33
 
34 34
 def describe_timedelta(td: timedelta, max_components: int = 2) -> str:
35
-	values = [
36
-		td.days,
37
-		td.seconds // 3600,
38
-		(td.seconds // 60) % 60,
39
-		td.seconds % 60,
40
-	]
41
-	units = [
42
-		'day' if values[0] == 1 else 'days',
43
-		'hour' if values[1] == 1 else 'hours',
44
-		'minute' if values[2] == 1 else 'minutes',
45
-		'second' if values[3] == 1 else 'seconds',
46
-	]
47
-	while len(values) > 1 and values[0] == 0:
48
-		values.pop(0)
49
-		units.pop(0)
50
-	if len(values) > max_components:
51
-		values = values[0:max_components]
52
-		units = units[0:max_components]
53
-	while len(values) > 1 and values[-1] == 0:
54
-		values.pop(-1)
55
-		units.pop(-1)
56
-	tokens = []
57
-	for i in range(len(values)):
58
-		tokens.append(f'{values[i]} {units[i]}')
59
-	return ' '.join(tokens)
35
+	"""
36
+	Formats a human-readable description of a time span. E.g. "3 days 2 hours".
37
+	"""
38
+	d = td.days
39
+	h = td.seconds // 3600
40
+	m = (td.seconds // 60) % 60
41
+	s = td.seconds % 60
42
+	components = []
43
+	if d != 0:
44
+		components.append('1 day' if d == 1 else f'{d} days')
45
+	if h != 0:
46
+		components.append('1 hour' if h == 1 else f'{h} hours')
47
+	if m != 0:
48
+		components.append('1 minute' if m == 1 else f'{m} minutes')
49
+	if s != 0 or len(components) == 0:
50
+		components.append('1 second' if s == 1 else f'{s} seconds')
51
+	if len(components) > max_components:
52
+		components = components[0:max_components]
53
+	return ' '.join(components)

+ 12
- 9
rocketbot.py ファイルの表示

@@ -18,21 +18,24 @@ from cogs.urlspamcog import URLSpamCog
18 18
 
19 19
 CURRENT_CONFIG_VERSION = 3
20 20
 if (CONFIG.get('__config_version') or 0) < CURRENT_CONFIG_VERSION:
21
-    # If you're getting this error, it means something changed in config.py's
22
-    # format. Consult config.py.sample and compare it to your own config.py.
23
-    # Rename/move any values as needed. When satisfied, update "__config_version"
24
-    # to the value in config.py.sample.
25
-    raise RuntimeError('config.py format may be outdated. Review ' +
26
-        'config.py.sample, update the "__config_version" field to ' +
27
-        f'{CURRENT_CONFIG_VERSION}, and try again.')
21
+	# If you're getting this error, it means something changed in config.py's
22
+	# format. Consult config.py.sample and compare it to your own config.py.
23
+	# Rename/move any values as needed. When satisfied, update "__config_version"
24
+	# to the value in config.py.sample.
25
+	raise RuntimeError('config.py format may be outdated. Review ' +
26
+		'config.py.sample, update the "__config_version" field to ' +
27
+		f'{CURRENT_CONFIG_VERSION}, and try again.')
28 28
 
29 29
 class Rocketbot(commands.Bot):
30
+	"""
31
+	Bot subclass
32
+	"""
30 33
 	def __init__(self, command_prefix, **kwargs):
31 34
 		super().__init__(command_prefix, **kwargs)
32 35
 
33 36
 intents = Intents.default()
34
-intents.messages = True
35
-intents.members = True # To get join/leave events
37
+intents.messages = True  # pylint: disable=assigning-non-slot
38
+intents.members = True  # pylint: disable=assigning-non-slot
36 39
 bot = Rocketbot(command_prefix=CONFIG['command_prefix'], intents=intents)
37 40
 
38 41
 # Core

+ 31
- 15
rscollections.py ファイルの表示

@@ -1,9 +1,9 @@
1
-from abc import ABC, abstractmethod
2
-
3 1
 """
4 2
 Subclasses of list, set, and dict with special behaviors.
5 3
 """
6 4
 
5
+from abc import ABC, abstractmethod
6
+
7 7
 # Abstract collections
8 8
 
9 9
 class AbstractMutableList(list, ABC):
@@ -247,7 +247,10 @@ class SizeBoundList(AbstractMutableList):
247 247
 	however elements will only be discarded following the next mutating
248 248
 	operation. Call `self.purge_old_elements()` to force resizing.
249 249
 	"""
250
-	def __init__(self, max_element_count: int, element_age, *args, **kwargs):
250
+	def __init__(self,
251
+			max_element_count: int,
252
+			element_age,
253
+			*args, **kwargs):
251 254
 		super().__init__(*args, **kwargs)
252 255
 		self.element_age = element_age
253 256
 		self.max_element_count = max_element_count
@@ -272,8 +275,7 @@ class SizeBoundList(AbstractMutableList):
272 275
 		while len(self) > self.max_element_count:
273 276
 			oldest_age = None
274 277
 			oldest_index = -1
275
-			for i in range(len(self)):
276
-				elem = self[i]
278
+			for i, elem in enumerate(self):
277 279
 				age = self.element_age(i, elem)
278 280
 				if oldest_age is None or age < oldest_age:
279 281
 					oldest_age = age
@@ -281,8 +283,8 @@ class SizeBoundList(AbstractMutableList):
281 283
 			self.pop(oldest_index)
282 284
 		self.is_culling = False
283 285
 
284
-	def copy():
285
-		return SizeBoundList(max_element_count, element_age, super())
286
+	def copy(self):
287
+		return SizeBoundList(self.max_element_count, self.element_age, super())
286 288
 
287 289
 class SizeBoundSet(AbstractMutableSet):
288 290
 	"""
@@ -303,7 +305,10 @@ class SizeBoundSet(AbstractMutableSet):
303 305
 	however elements will only be discarded following the next mutating
304 306
 	operation. Call `self.purge_old_elements()` to force resizing.
305 307
 	"""
306
-	def __init__(self, max_element_count: int, element_age, *args, **kwargs):
308
+	def __init__(self,
309
+			max_element_count: int,
310
+			element_age,
311
+			*args, **kwargs):
307 312
 		super().__init__(*args, **kwargs)
308 313
 		self.element_age = element_age
309 314
 		self.max_element_count = max_element_count
@@ -335,8 +340,8 @@ class SizeBoundSet(AbstractMutableSet):
335 340
 			self.remove(oldest_elem)
336 341
 		self.is_culling = False
337 342
 
338
-	def copy():
339
-		return SizeBoundSet(max_element_count, element_age, super())
343
+	def copy(self):
344
+		return SizeBoundSet(self.max_element_count, self.element_age, super())
340 345
 
341 346
 class SizeBoundDict(AbstractMutableDict):
342 347
 	"""
@@ -357,7 +362,10 @@ class SizeBoundDict(AbstractMutableDict):
357 362
 	however elements will only be discarded following the next mutating
358 363
 	operation. Call `self.purge_old_elements()` to force resizing.
359 364
 	"""
360
-	def __init__(self, max_element_count = 10, element_age = (lambda key, value: float(key)), *args, **kwargs):
365
+	def __init__(self,
366
+			max_element_count: int,
367
+			element_age,
368
+			*args, **kwargs):
361 369
 		super().__init__(*args, **kwargs)
362 370
 		self.element_age = element_age
363 371
 		self.max_element_count = max_element_count
@@ -389,8 +397,8 @@ class SizeBoundDict(AbstractMutableDict):
389 397
 			del self[oldest_key]
390 398
 		self.is_culling = False
391 399
 
392
-	def copy():
393
-		return SizeBoundDict(max_element_count, element_age, super())
400
+	def copy(self):
401
+		return SizeBoundDict(self.max_element_count, self.element_age, super())
394 402
 
395 403
 # Collections with limited age of elements
396 404
 
@@ -436,8 +444,7 @@ class AgeBoundList(AbstractMutableList):
436 444
 		min_age = None
437 445
 		max_age = None
438 446
 		ages = {}
439
-		for i in range(len(self)):
440
-			elem = self[i]
447
+		for i, elem in enumerate(self):
441 448
 			age = self.element_age(i, elem)
442 449
 			ages[i] = age
443 450
 			if min_age is None or age < min_age:
@@ -453,6 +460,9 @@ class AgeBoundList(AbstractMutableList):
453 460
 				del self[i]
454 461
 		self.is_culling = False
455 462
 
463
+	def copy(self):
464
+		return AgeBoundList(self.max_age, self.element_age, super())
465
+
456 466
 class AgeBoundSet(AbstractMutableSet):
457 467
 	"""
458 468
 	Subclass of `set` that enforces a maximum "age" of elements.
@@ -511,6 +521,9 @@ class AgeBoundSet(AbstractMutableSet):
511 521
 				self.remove(elem)
512 522
 		self.is_culling = False
513 523
 
524
+	def copy(self):
525
+		return AgeBoundSet(self.max_age, self.element_age, super())
526
+
514 527
 class AgeBoundDict(AbstractMutableDict):
515 528
 	"""
516 529
 	Subclass of `dict` that enforces a maximum "age" of elements.
@@ -568,3 +581,6 @@ class AgeBoundDict(AbstractMutableDict):
568 581
 			if age < cutoff:
569 582
 				del self[key]
570 583
 		self.is_culling = False
584
+
585
+	def copy(self):
586
+		return AgeBoundDict(self.max_age, self.element_age, super())

+ 19
- 14
storage.py ファイルの表示

@@ -1,11 +1,16 @@
1
+"""
2
+Handles storage of persisted and non-persisted data for the bot.
3
+"""
1 4
 import json
2 5
 from os.path import exists
3
-
4 6
 from discord import Guild
5 7
 
6 8
 from config import CONFIG
7 9
 
8 10
 class ConfigKey:
11
+	"""
12
+	Common keys in persisted guild storage.
13
+	"""
9 14
 	WARNING_CHANNEL_ID = 'warning_channel_id'
10 15
 	WARNING_MENTION = 'warning_mention'
11 16
 
@@ -47,16 +52,16 @@ class Storage:
47 52
 		cls.set_state_values(guild, { key: value })
48 53
 
49 54
 	@classmethod
50
-	def set_state_values(cls, guild: Guild, vars: dict) -> None:
55
+	def set_state_values(cls, guild: Guild, values: dict) -> None:
51 56
 		"""
52 57
 		Merges in a set of key-value pairs into the transient state for the
53 58
 		given guild. Any pairs with a value of `None` will be removed from the
54 59
 		transient state.
55 60
 		"""
56
-		if vars is None or len(vars) == 0:
61
+		if values is None or len(values) == 0:
57 62
 			return
58 63
 		state: dict = cls.get_state(guild)
59
-		for key, value in vars.items():
64
+		for key, value in values.items():
60 65
 			if value is None:
61 66
 				del state[key]
62 67
 			else:
@@ -105,21 +110,21 @@ class Storage:
105 110
 		cls.set_config_values(guild, { key: value })
106 111
 
107 112
 	@classmethod
108
-	def set_config_values(cls, guild: Guild, vars: dict) -> None:
113
+	def set_config_values(cls, guild: Guild, values: dict) -> None:
109 114
 		"""
110
-		Merges the given `vars` dict with the saved config for the given guild
111
-		and writes it to disk. `vars` must be JSON-encodable or a `ValueError`
115
+		Merges the given `values` dict with the saved config for the given guild
116
+		and writes it to disk. `values` must be JSON-encodable or a `ValueError`
112 117
 		will be raised. Keys with associated values of `None` will be removed
113 118
 		from the persisted config.
114 119
 		"""
115
-		if vars is None or len(vars) == 0:
120
+		if values is None or len(values) == 0:
116 121
 			return
117 122
 		config: dict = cls.get_config(guild)
118 123
 		try:
119
-			json.dumps(vars)
120
-		except:
121
-			raise ValueError(f'vars not JSON encodable - {vars}')
122
-		for key, value in vars.items():
124
+			json.dumps(values)
125
+		except Exception as e:
126
+			raise ValueError(f'values not JSON encodable - {values}') from e
127
+		for key, value in values.items():
123 128
 			if value is None:
124 129
 				del config[key]
125 130
 			else:
@@ -134,7 +139,7 @@ class Storage:
134 139
 		path: str = cls.__guild_config_path(guild)
135 140
 		cls.__trace(f'Saving config for guild {guild.id} to {path}')
136 141
 		cls.__trace(f'config = {config}')
137
-		with open(path, 'w') as file:
142
+		with open(path, 'w', encoding='utf8') as file:
138 143
 			# Pretty printing to make more legible for debugging
139 144
 			# Sorting keys to help with diffs
140 145
 			json.dump(config, file, indent='\t', sort_keys=True)
@@ -151,7 +156,7 @@ class Storage:
151 156
 			cls.__trace(f'No config on disk for guild {guild.id}. Returning None.')
152 157
 			return None
153 158
 		cls.__trace(f'Loading config from disk for guild {guild.id}')
154
-		with open(path, 'r') as file:
159
+		with open(path, 'r', encoding='utf8') as file:
155 160
 			config = json.load(file)
156 161
 		cls.__trace('State loaded')
157 162
 		return config

読み込み中…
キャンセル
保存