Pārlūkot izejas kodu

Llllllots of cleanup after switching to VSCode

master
Rocketsoup 4 gadus atpakaļ
vecāks
revīzija
192da063ec
12 mainītis faili ar 375 papildinājumiem un 301 dzēšanām
  1. 6
    0
      .pylintrc
  2. 86
    97
      cogs/basecog.py
  3. 14
    15
      cogs/configcog.py
  4. 6
    4
      cogs/crosspostcog.py
  5. 29
    10
      cogs/generalcog.py
  6. 9
    5
      cogs/joinraidcog.py
  7. 133
    81
      cogs/patterncog.py
  8. 10
    25
      cogs/urlspamcog.py
  9. 20
    26
      rbutils.py
  10. 12
    9
      rocketbot.py
  11. 31
    15
      rscollections.py
  12. 19
    14
      storage.py

+ 6
- 0
.pylintrc Parādīt failu

1
+[MESSAGES CONTROL]
2
+disable=bad-indentation, invalid-name
3
+
4
+[FORMAT]
5
+
6
+indent-string=' '

+ 86
- 97
cogs/basecog.py Parādīt failu

1
+from datetime import datetime, timedelta
1
 from discord import Guild, Member, Message, PartialEmoji, RawReactionActionEvent, TextChannel
2
 from discord import Guild, Member, Message, PartialEmoji, RawReactionActionEvent, TextChannel
3
+from discord.abc import GuildChannel
2
 from discord.ext import commands
4
 from discord.ext import commands
3
-from datetime import datetime, timedelta
4
 
5
 
5
-from abc import ABC, abstractmethod
6
 from config import CONFIG
6
 from config import CONFIG
7
 from rscollections import AgeBoundDict
7
 from rscollections import AgeBoundDict
8
 from storage import ConfigKey, Storage
8
 from storage import ConfigKey, Storage
9
-import json
10
 
9
 
11
 class BotMessageReaction:
10
 class BotMessageReaction:
12
 	"""
11
 	"""
136
 		return self.__message is not None
135
 		return self.__message is not None
137
 
136
 
138
 	def message_id(self):
137
 	def message_id(self):
138
+		'Returns the Message id or None if not sent.'
139
 		return self.__message.id if self.__message else None
139
 		return self.__message.id if self.__message else None
140
 
140
 
141
-	def message_sent_at(self):
141
+	def message_sent_at(self) -> datetime:
142
+		'Returns when the message was sent or None if not sent.'
142
 		return self.__message.created_at if self.__message else None
143
 		return self.__message.created_at if self.__message else None
143
 
144
 
145
+	def has_reactions(self) -> bool:
146
+		return len(self.__reactions) > 0
147
+
144
 	async def set_text(self, new_text: str) -> None:
148
 	async def set_text(self, new_text: str) -> None:
145
 		"""
149
 		"""
146
 		Replaces the text of this message. If the message has been sent, it will
150
 		Replaces the text of this message. If the message has been sent, it will
147
 		be updated.
151
 		be updated.
148
 		"""
152
 		"""
149
 		self.text = new_text
153
 		self.text = new_text
150
-		await self.__update_if_sent()
154
+		await self.update_if_sent()
151
 
155
 
152
 	async def set_reactions(self, reactions: list) -> None:
156
 	async def set_reactions(self, reactions: list) -> None:
153
 		"""
157
 		"""
158
 			# No change
162
 			# No change
159
 			return
163
 			return
160
 		self.__reactions = reactions.copy() if reactions is not None else []
164
 		self.__reactions = reactions.copy() if reactions is not None else []
161
-		await self.__update_if_sent()
165
+		await self.update_if_sent()
162
 
166
 
163
 	async def add_reaction(self, reaction: BotMessageReaction) -> None:
167
 	async def add_reaction(self, reaction: BotMessageReaction) -> None:
164
 		"""
168
 		"""
175
 		will be updated.
179
 		will be updated.
176
 		"""
180
 		"""
177
 		found = False
181
 		found = False
178
-		for i in range(len(self.__reactions)):
179
-			existing = self.__reactions[i]
182
+		for i, existing in enumerate(self.__reactions):
180
 			if existing.emoji == reaction.emoji:
183
 			if existing.emoji == reaction.emoji:
181
 				if reaction == self.__reactions[i]:
184
 				if reaction == self.__reactions[i]:
182
 					# No change
185
 					# No change
186
 				break
189
 				break
187
 		if not found:
190
 		if not found:
188
 			self.__reactions.append(reaction)
191
 			self.__reactions.append(reaction)
189
-		await self.__update_if_sent()
192
+		await self.update_if_sent()
190
 
193
 
191
 	async def remove_reaction(self, reaction_or_emoji) -> None:
194
 	async def remove_reaction(self, reaction_or_emoji) -> None:
192
 		"""
195
 		"""
193
 		Removes a reaction. Can pass either a BotMessageReaction or just the
196
 		Removes a reaction. Can pass either a BotMessageReaction or just the
194
 		emoji string. If the message has been sent, it will be updated.
197
 		emoji string. If the message has been sent, it will be updated.
195
 		"""
198
 		"""
196
-		for i in range(len(self.__reactions)):
197
-			existing = self.__reactions[i]
199
+		for i, existing in enumerate(self.__reactions):
198
 			if (isinstance(reaction_or_emoji, str) and existing.emoji == reaction_or_emoji) or \
200
 			if (isinstance(reaction_or_emoji, str) and existing.emoji == reaction_or_emoji) or \
199
-				(isinstance(reaction_or_emoji, BotMessageReaction) and existing.emoji == reaction_or_emoji.emoji):
201
+				(isinstance(reaction_or_emoji, BotMessageReaction) and \
202
+					existing.emoji == reaction_or_emoji.emoji):
200
 				self.__reactions.pop(i)
203
 				self.__reactions.pop(i)
201
-				await self.__update_if_sent()
204
+				await self.update_if_sent()
202
 				return
205
 				return
203
 
206
 
204
 	def reaction_for_emoji(self, emoji) -> BotMessageReaction:
207
 	def reaction_for_emoji(self, emoji) -> BotMessageReaction:
208
+		"""
209
+		Finds the BotMessageReaction for the given emoji or None if not found.
210
+		Accepts either a PartialEmoji or str.
211
+		"""
205
 		for reaction in self.__reactions:
212
 		for reaction in self.__reactions:
206
 			if isinstance(emoji, PartialEmoji) and reaction.emoji == emoji.name:
213
 			if isinstance(emoji, PartialEmoji) and reaction.emoji == emoji.name:
207
 				return reaction
214
 				return reaction
208
-			elif isinstance(emoji, str) and reaction.emoji == emoji:
215
+			if isinstance(emoji, str) and reaction.emoji == emoji:
209
 				return reaction
216
 				return reaction
210
 		return None
217
 		return None
211
 
218
 
212
-	async def __update_if_sent(self) -> None:
219
+	async def update_if_sent(self) -> None:
220
+		"""
221
+		Updates the text and/or reactions on a message if it was sent to
222
+		the guild, otherwise does nothing. Does not need to be called by
223
+		BaseCog subclasses.
224
+		"""
213
 		if self.__message:
225
 		if self.__message:
214
-			await self._update()
226
+			await self.update()
215
 
227
 
216
-	async def _update(self) -> None:
228
+	async def update(self) -> None:
229
+		"""
230
+		Sends or updates an already sent message based on BotMessage state.
231
+		Does not need to be called by BaseCog subclasses.
232
+		"""
217
 		content: str = self.__formatted_message()
233
 		content: str = self.__formatted_message()
218
 		if self.__message:
234
 		if self.__message:
219
 			if content != self.__posted_text:
235
 			if content != self.__posted_text:
278
 		return s
294
 		return s
279
 
295
 
280
 class CogSetting:
296
 class CogSetting:
297
+	"""
298
+	Describes a configuration setting for a guild that can be edited by the
299
+	mods of those guilds. BaseCog can generate "get" and "set" commands
300
+	automatically, reducing the boilerplate of generating commands manually.
301
+	Offers simple validation rules.
302
+	"""
281
 	def __init__(self,
303
 	def __init__(self,
282
 			name: str,
304
 			name: str,
283
 			datatype,
305
 			datatype,
287
 			min_value = None,
309
 			min_value = None,
288
 			max_value = None,
310
 			max_value = None,
289
 			enum_values: set = None):
311
 			enum_values: set = None):
312
+		"""
313
+		Params:
314
+		- name         Setting identifier. Must follow variable naming
315
+		               conventions.
316
+		- datatype     Datatype of the setting. E.g. int, float, str
317
+		- brief        Description of the setting, starting with lower case.
318
+		               Will be inserted into phrases like "Sets <brief>" and
319
+					   "Gets <brief".
320
+		- description  Long-form description. Min, max, and enum values will be
321
+		               appended to the end, so does not need to include these.
322
+		- usage        Description of the value argument in a set command, e.g.
323
+		               "<maxcount:int>"
324
+		- min_value    Smallest allowable value. Must be of the same datatype as
325
+		               the value. None for no minimum.
326
+		- max_value    Largest allowable value. None for no maximum.
327
+		- enum_values  Set of allowed values. None if unconstrained.
328
+		"""
290
 		self.name = name
329
 		self.name = name
291
 		self.datatype = datatype
330
 		self.datatype = datatype
292
 		self.brief = brief
331
 		self.brief = brief
308
 			self.usage = f'<{self.name}>'
347
 			self.usage = f'<{self.name}>'
309
 
348
 
310
 class BaseCog(commands.Cog):
349
 class BaseCog(commands.Cog):
350
+	"""
351
+	Superclass for all Rocketbot cogs. Provides lots of conveniences for
352
+	common tasks.
353
+	"""
311
 	def __init__(self, bot):
354
 	def __init__(self, bot):
312
 		self.bot = bot
355
 		self.bot = bot
313
 		self.are_settings_setup = False
356
 		self.are_settings_setup = False
319
 	def get_cog_default(cls, key: str):
362
 	def get_cog_default(cls, key: str):
320
 		"""
363
 		"""
321
 		Convenience method for getting a cog configuration default from
364
 		Convenience method for getting a cog configuration default from
322
-		`CONFIG['cogs'][<cogname>][<key>]`.
365
+		`CONFIG['cogs'][<cogname>][<key>]`. These values are used for
366
+		CogSettings when no guild-specific value is configured yet.
323
 		"""
367
 		"""
324
 		cogs: dict = CONFIG['cog_defaults']
368
 		cogs: dict = CONFIG['cog_defaults']
325
 		cog = cogs.get(cls.__name__)
369
 		cog = cogs.get(cls.__name__)
337
 		If the cog has a command group it will be detected automatically and
381
 		If the cog has a command group it will be detected automatically and
338
 		the commands added to that. Otherwise the commands will be added at
382
 		the commands added to that. Otherwise the commands will be added at
339
 		the top level.
383
 		the top level.
384
+
385
+		Changes to settings can be detected by overriding `on_setting_updated`.
340
 		"""
386
 		"""
341
 		self.settings.append(setting)
387
 		self.settings.append(setting)
342
 
388
 
365
 		"""
411
 		"""
366
 		Manually sets a setting for the given guild. BaseCog creates "get" and
412
 		Manually sets a setting for the given guild. BaseCog creates "get" and
367
 		"set" commands for guild administrators to configure values themselves,
413
 		"set" commands for guild administrators to configure values themselves,
368
-		but this method can be used for hidden settings from code.
414
+		but this method can be used for hidden settings from code. A ValueError
415
+		will be raised if the new value does not pass validation specified in
416
+		the CogSetting.
369
 		"""
417
 		"""
418
+		if setting.min_value is not None and new_value < setting.min_value:
419
+			raise ValueError(f'{setting.name} must be at least {setting.min_value}')
420
+		if setting.max_value is not None and new_value > setting.max_value:
421
+			raise ValueError(f'{setting.name} must be no more than {setting.max_value}')
422
+		if setting.enum_values and new_value not in setting.enum_values:
423
+			raise ValueError(f'{setting.name} must be one of {setting.enum_values}')
370
 		key = f'{cls.__name__}.{setting.name}'
424
 		key = f'{cls.__name__}.{setting.name}'
371
 		Storage.set_config_value(guild, key, new_value)
425
 		Storage.set_config_value(guild, key, new_value)
372
 
426
 
373
 	@commands.Cog.listener()
427
 	@commands.Cog.listener()
374
 	async def on_ready(self):
428
 	async def on_ready(self):
429
+		'Event listener'
375
 		self.__set_up_setting_commands()
430
 		self.__set_up_setting_commands()
376
 
431
 
377
 	def __set_up_setting_commands(self):
432
 	def __set_up_setting_commands(self):
406
 		under the bot if `None`.
461
 		under the bot if `None`.
407
 		"""
462
 		"""
408
 		# Manually constructing equivalent of:
463
 		# Manually constructing equivalent of:
409
-		# 	@commands.command(
410
-		# 		brief='Posts a test warning in the configured warning channel.'
411
-		# 	)
464
+		# 	@commands.command()
412
 		# 	@commands.has_permissions(ban_members=True)
465
 		# 	@commands.has_permissions(ban_members=True)
413
 		# 	@commands.guild_only()
466
 		# 	@commands.guild_only()
414
 		# 	async def getvar(self, context):
467
 		# 	async def getvar(self, context):
573
 		"""
626
 		"""
574
 		Subclass override point for being notified when a CogSetting is edited.
627
 		Subclass override point for being notified when a CogSetting is edited.
575
 		"""
628
 		"""
576
-		pass
577
 
629
 
578
 	# Bot message handling
630
 	# Bot message handling
579
 
631
 
588
 		return bm
640
 		return bm
589
 
641
 
590
 	async def post_message(self, message: BotMessage) -> bool:
642
 	async def post_message(self, message: BotMessage) -> bool:
643
+		"""
644
+		Posts a BotMessage to a guild. Returns whether it was successful. If
645
+		the caller wants to listen to reactions they should be added before
646
+		calling this method. Listen to reactions by overriding `on_mod_react`.
647
+		"""
591
 		message.source_cog = self
648
 		message.source_cog = self
592
-		await message._update()
593
-		guild_messages = self.__bot_messages(message.guild)
594
-		if message.is_sent():
649
+		await message.update()
650
+		if message.has_reactions() and message.is_sent():
651
+			guild_messages = self.__bot_messages(message.guild)
595
 			guild_messages[message.message_id()] = message
652
 			guild_messages[message.message_id()] = message
596
-			return True
597
-		return False
653
+		return message.is_sent()
598
 
654
 
599
 	@commands.Cog.listener()
655
 	@commands.Cog.listener()
600
 	async def on_raw_reaction_add(self, payload: RawReactionActionEvent):
656
 	async def on_raw_reaction_add(self, payload: RawReactionActionEvent):
645
 		Subclass override point for receiving mod reactions to bot messages sent
701
 		Subclass override point for receiving mod reactions to bot messages sent
646
 		via `post_message()`.
702
 		via `post_message()`.
647
 		"""
703
 		"""
648
-		pass
649
 
704
 
650
 	# Helpers
705
 	# Helpers
651
 
706
 
652
 	@classmethod
707
 	@classmethod
653
-	async def validate_param(cls, context: commands.Context, param_name: str, value,
654
-		allowed_types: tuple = None,
655
-		min_value = None,
656
-		max_value = None) -> bool:
657
-		"""
658
-		Convenience method for validating a command parameter is of the expected
659
-		type and in the expected range. Bad values will cause a reply to be sent
660
-		to the original message and a False will be returned. If all checks
661
-		succeed, True will be returned.
662
-		"""
663
-		# TODO: Rework this to use BotMessage
664
-		if allowed_types is not None and not isinstance(value, allowed_types):
665
-			if len(allowed_types) == 1:
666
-				await context.message.reply(f'⚠️ `{param_name}` must be of type ' +
667
-					f'{allowed_types[0]}.', mention_author=False)
668
-			else:
669
-				await context.message.reply(f'⚠️ `{param_name}` must be of types ' +
670
-					f'{allowed_types}.', mention_author=False)
671
-			return False
672
-		if min_value is not None and value < min_value:
673
-			await context.message.reply(f'⚠️ `{param_name}` must be >= {min_value}.',
674
-				mention_author=False)
675
-			return False
676
-		if max_value is not None and value > max_value:
677
-			await context.message.reply(f'⚠️ `{param_name}` must be <= {max_value}.',
678
-				mention_author=False)
679
-		return True
680
-
681
-	@classmethod
682
-	async def warn(cls, guild: Guild, message: str) -> Message:
683
-		"""
684
-		DEPRECATED. Use post_message.
685
-
686
-		Sends a warning message to the configured warning channel for the
687
-		given guild. If no warning channel is configured no action is taken.
688
-		Returns the Message if successful or None if not.
689
-		"""
690
-		channel_id = Storage.get_config_value(guild, ConfigKey.WARNING_CHANNEL_ID)
691
-		if channel_id is None:
692
-			cls.log(guild, '\u0007No warning channel set! No warning issued.')
693
-			return None
694
-		channel: TextChannel = guild.get_channel(channel_id)
695
-		if channel is None:
696
-			cls.log(guild, '\u0007Configured warning channel does not exist!')
697
-			return None
698
-		mention: str = Storage.get_config_value(guild, ConfigKey.WARNING_MENTION)
699
-		text: str = message
700
-		if mention is not None:
701
-			text = f'{mention} {text}'
702
-		msg: Message = await channel.send(text)
703
-		return msg
704
-
705
-	@classmethod
706
-	async def update_warn(cls, warn_message: Message, new_text: str) -> None:
708
+	def log(cls, guild: Guild, message) -> None:
707
 		"""
709
 		"""
708
-		DEPRECATED. Use post_message.
709
-
710
-		Updates the text of a previously posted `warn`. Includes configured
711
-		mentions if necessary.
710
+		Writes a message to the console. Intended for significant events only.
712
 		"""
711
 		"""
713
-		text: str = new_text
714
-		mention: str = Storage.get_config_value(
715
-			warn_message.guild,
716
-			ConfigKey.WARNING_MENTION)
717
-		if mention is not None:
718
-			text = f'{mention} {text}'
719
-		await warn_message.edit(content=text)
720
-
721
-	@classmethod
722
-	def log(cls, guild: Guild, message) -> None:
723
 		now = datetime.now()
712
 		now = datetime.now()
724
 		print(f'[{now.strftime("%Y-%m-%dT%H:%M:%S")}|{cls.__name__}|{guild.name}] {message}')
713
 		print(f'[{now.strftime("%Y-%m-%dT%H:%M:%S")}|{cls.__name__}|{guild.name}] {message}')

+ 14
- 15
cogs/configcog.py Parādīt failu

5
 from storage import ConfigKey, Storage
5
 from storage import ConfigKey, Storage
6
 from cogs.basecog import BaseCog
6
 from cogs.basecog import BaseCog
7
 
7
 
8
-class ConfigCog(BaseCog):
8
+class ConfigCog(BaseCog, name='Configuration'):
9
 	"""
9
 	"""
10
 	Cog for handling general bot configuration.
10
 	Cog for handling general bot configuration.
11
 	"""
11
 	"""
12
-	def __init__(self, bot):
13
-		super().__init__(bot)
14
-
15
 	@commands.group(
12
 	@commands.group(
16
 		brief='Manages general bot configuration'
13
 		brief='Manages general bot configuration'
17
 	)
14
 	)
18
 	@commands.has_permissions(ban_members=True)
15
 	@commands.has_permissions(ban_members=True)
19
 	@commands.guild_only()
16
 	@commands.guild_only()
20
 	async def config(self, context: commands.Context):
17
 	async def config(self, context: commands.Context):
21
-		'Command group'
18
+		'General guild configuration command group'
22
 		if context.invoked_subcommand is None:
19
 		if context.invoked_subcommand is None:
23
 			await context.send_help()
20
 			await context.send_help()
24
 
21
 
25
 	@config.command(
22
 	@config.command(
26
-		name='setwarningchannel',
27
-		brief='Sets where to post mod warnings',
23
+		brief='Sets the mod warning channel',
28
 		description='Run this command in the channel where bot messages ' +
24
 		description='Run this command in the channel where bot messages ' +
29
 			'intended for server moderators should be sent. Other bot ' +
25
 			'intended for server moderators should be sent. Other bot ' +
30
 			'messages may still be posted in the channel a command was ' +
26
 			'messages may still be posted in the channel a command was ' +
31
 			'invoked in. If no output channel is set, mod-related messages ' +
27
 			'invoked in. If no output channel is set, mod-related messages ' +
32
 			'will not be posted!',
28
 			'will not be posted!',
33
 	)
29
 	)
34
-	async def config_setwarningchannel(self, context: commands.Context) -> None:
30
+	async def setwarningchannel(self, context: commands.Context) -> None:
31
+		'Command handler'
35
 		guild: Guild = context.guild
32
 		guild: Guild = context.guild
36
 		channel: TextChannel = context.channel
33
 		channel: TextChannel = context.channel
37
 		Storage.set_config_value(guild, ConfigKey.WARNING_CHANNEL_ID,
34
 		Storage.set_config_value(guild, ConfigKey.WARNING_CHANNEL_ID,
41
 			mention_author=False)
38
 			mention_author=False)
42
 
39
 
43
 	@config.command(
40
 	@config.command(
44
-		name='getwarningchannel',
45
-		brief='Shows the channel where mod warnings are posted',
41
+		brief='Shows the mod warning channel',
42
+		description='Shows the configured channel (if any) where mod ' + \
43
+			'warnings will be posted.',
46
 	)
44
 	)
47
-	async def config_getwarningchannel(self, context: commands.Context) -> None:
45
+	async def getwarningchannel(self, context: commands.Context) -> None:
46
+		'Command handler'
48
 		guild: Guild = context.guild
47
 		guild: Guild = context.guild
49
 		channel_id = Storage.get_config_value(guild, ConfigKey.WARNING_CHANNEL_ID)
48
 		channel_id = Storage.get_config_value(guild, ConfigKey.WARNING_CHANNEL_ID)
50
 		if channel_id is None:
49
 		if channel_id is None:
58
 				mention_author=False)
57
 				mention_author=False)
59
 
58
 
60
 	@config.command(
59
 	@config.command(
61
-		name='setwarningmention',
62
 		brief='Sets a user/role to mention in warning messages',
60
 		brief='Sets a user/role to mention in warning messages',
63
 		usage='<@user|@role>',
61
 		usage='<@user|@role>',
64
 		description='Configures an role or other prefix to include at the ' +
62
 		description='Configures an role or other prefix to include at the ' +
66
 			'attention of certain users, be sure to specify a properly ' +
64
 			'attention of certain users, be sure to specify a properly ' +
67
 			'formed @ tag, not just the name of the user/role.'
65
 			'formed @ tag, not just the name of the user/role.'
68
 	)
66
 	)
69
-	async def config_setwarningmention(self,
67
+	async def setwarningmention(self,
70
 			context: commands.Context,
68
 			context: commands.Context,
71
 			mention: str = None) -> None:
69
 			mention: str = None) -> None:
70
+		'Command handler'
72
 		guild: Guild = context.guild
71
 		guild: Guild = context.guild
73
 		Storage.set_config_value(guild, ConfigKey.WARNING_MENTION, mention)
72
 		Storage.set_config_value(guild, ConfigKey.WARNING_MENTION, mention)
74
 		if mention is None:
73
 		if mention is None:
81
 				mention_author=False)
80
 				mention_author=False)
82
 
81
 
83
 	@config.command(
82
 	@config.command(
84
-		name='getwarningmention',
85
 		brief='Shows the user/role to mention in warning messages',
83
 		brief='Shows the user/role to mention in warning messages',
86
 		description='Shows the text, if any, that will be prefixed on any ' +
84
 		description='Shows the text, if any, that will be prefixed on any ' +
87
 			'warning messages.'
85
 			'warning messages.'
88
 	)
86
 	)
89
-	async def config_getwarningmention(self, context: commands.Context) -> None:
87
+	async def getwarningmention(self, context: commands.Context) -> None:
88
+		'Command handler'
90
 		guild: Guild = context.guild
89
 		guild: Guild = context.guild
91
 		mention: str = Storage.get_config_value(guild, ConfigKey.WARNING_MENTION)
90
 		mention: str = Storage.get_config_value(guild, ConfigKey.WARNING_MENTION)
92
 		if mention is None:
91
 		if mention is None:

+ 6
- 4
cogs/crosspostcog.py Parādīt failu

1
 from datetime import datetime, timedelta
1
 from datetime import datetime, timedelta
2
-from discord import Guild, Member, Message
2
+from discord import Member, Message
3
 from discord.ext import commands
3
 from discord.ext import commands
4
 
4
 
5
 from cogs.basecog import BaseCog, BotMessage, BotMessageReaction, CogSetting
5
 from cogs.basecog import BaseCog, BotMessage, BotMessageReaction, CogSetting
8
 from storage import Storage
8
 from storage import Storage
9
 
9
 
10
 class SpamContext:
10
 class SpamContext:
11
+	"""
12
+	Data about a set of duplicate messages from a user.
13
+	"""
11
 	def __init__(self, member, message_hash):
14
 	def __init__(self, member, message_hash):
12
 		self.member = member
15
 		self.member = member
13
 		self.message_hash = message_hash
16
 		self.message_hash = message_hash
20
 		self.deleted_messages = set()  # of Message
23
 		self.deleted_messages = set()  # of Message
21
 		self.unique_channels = set()  # of TextChannel
24
 		self.unique_channels = set()  # of TextChannel
22
 
25
 
23
-class CrossPostCog(BaseCog):
26
+class CrossPostCog(BaseCog, name='Crosspost Detection'):
24
 	"""
27
 	"""
25
 	Detects a user posting the same text in multiple channels in a short period
28
 	Detects a user posting the same text in multiple channels in a short period
26
 	of time: a common pattern for spammers. Repeated posts in the same channel
29
 	of time: a common pattern for spammers. Repeated posts in the same channel
120
 				continue
123
 				continue
121
 			key = f'{message.author.id}|{message_hash}'
124
 			key = f'{message.author.id}|{message_hash}'
122
 			context = spam_lookup.get(key)
125
 			context = spam_lookup.get(key)
123
-			is_new = context is None
124
 			if context is None:
126
 			if context is None:
125
 				context = SpamContext(message.author, message_hash)
127
 				context = SpamContext(message.author, message_hash)
126
 				spam_lookup[key] = context
128
 				spam_lookup[key] = context
139
 		channel_count = len(context.unique_channels)
141
 		channel_count = len(context.unique_channels)
140
 		if channel_count >= ban_count:
142
 		if channel_count >= ban_count:
141
 			if not context.is_banned:
143
 			if not context.is_banned:
142
-				count = len(context.spam_messages)
143
 				await context.member.ban(
144
 				await context.member.ban(
144
 					reason='Rocketbot: Posted same message in ' + \
145
 					reason='Rocketbot: Posted same message in ' + \
145
 						f'{channel_count} channels. Banned by ' + \
146
 						f'{channel_count} channels. Banned by ' + \
227
 
228
 
228
 	@commands.Cog.listener()
229
 	@commands.Cog.listener()
229
 	async def on_message(self, message: Message):
230
 	async def on_message(self, message: Message):
231
+		'Event handler'
230
 		if message.author is None or \
232
 		if message.author is None or \
231
 				message.author.bot or \
233
 				message.author.bot or \
232
 				message.channel is None or \
234
 				message.channel is None or \

+ 29
- 10
cogs/generalcog.py Parādīt failu

1
+import re
1
 from datetime import datetime, timedelta
2
 from datetime import datetime, timedelta
2
-from discord import Member, Message
3
+from discord import Message
3
 from discord.ext import commands
4
 from discord.ext import commands
4
-import re
5
 
5
 
6
 from cogs.basecog import BaseCog, BotMessage
6
 from cogs.basecog import BaseCog, BotMessage
7
 from config import CONFIG
7
 from config import CONFIG
8
 from rbutils import parse_timedelta, describe_timedelta
8
 from rbutils import parse_timedelta, describe_timedelta
9
 from storage import ConfigKey, Storage
9
 from storage import ConfigKey, Storage
10
 
10
 
11
-class GeneralCog(BaseCog):
11
+class GeneralCog(BaseCog, name='General'):
12
+	"""
13
+	Cog for handling high-level bot functionality and commands. Should be the
14
+	first cog added to the bot.
15
+	"""
12
 	def __init__(self, bot: commands.Bot):
16
 	def __init__(self, bot: commands.Bot):
13
 		super().__init__(bot)
17
 		super().__init__(bot)
14
 		self.is_connected = False
18
 		self.is_connected = False
16
 
20
 
17
 	@commands.Cog.listener()
21
 	@commands.Cog.listener()
18
 	async def on_connect(self):
22
 	async def on_connect(self):
23
+		'Event handler'
19
 		print('on_connect')
24
 		print('on_connect')
20
 		self.is_connected = True
25
 		self.is_connected = True
21
 
26
 
22
 	@commands.Cog.listener()
27
 	@commands.Cog.listener()
23
 	async def on_ready(self):
28
 	async def on_ready(self):
29
+		'Event handler'
24
 		print('on_ready')
30
 		print('on_ready')
25
 		self.is_ready = True
31
 		self.is_ready = True
26
 
32
 
27
 	@commands.command(
33
 	@commands.command(
28
-		brief='Posts a test warning in the configured warning channel.'
34
+		brief='Posts a test warning',
35
+		description='Tests whether a warning channel is configured for this ' + \
36
+			'guild by posting a test warning. If a mod mention is ' + \
37
+			'configured, that user/role will be tagged in the test warning.',
29
 	)
38
 	)
30
 	@commands.has_permissions(ban_members=True)
39
 	@commands.has_permissions(ban_members=True)
31
 	@commands.guild_only()
40
 	@commands.guild_only()
32
 	async def testwarn(self, context):
41
 	async def testwarn(self, context):
42
+		'Command handler'
33
 		if Storage.get_config_value(context.guild, ConfigKey.WARNING_CHANNEL_ID) is None:
43
 		if Storage.get_config_value(context.guild, ConfigKey.WARNING_CHANNEL_ID) is None:
34
 			await context.message.reply(
44
 			await context.message.reply(
35
 				f'{CONFIG["warning_emoji"]} No warning channel set!',
45
 				f'{CONFIG["warning_emoji"]} No warning channel set!',
43
 
53
 
44
 	@commands.command(
54
 	@commands.command(
45
 		brief='Simple test reply',
55
 		brief='Simple test reply',
56
+		description='Replies to the command message. Useful to ensure the ' + \
57
+			'bot is working properly.',
46
 	)
58
 	)
47
 	async def hello(self, context):
59
 	async def hello(self, context):
60
+		'Command handler'
48
 		await context.message.reply(
61
 		await context.message.reply(
49
 			f'Hey, {context.author.name}!',
62
 			f'Hey, {context.author.name}!',
50
 		 	mention_author=False)
63
 		 	mention_author=False)
51
 
64
 
52
 	@commands.command(
65
 	@commands.command(
53
-		brief='Shuts down the bot (admin only)',
66
+		brief='Shuts down the bot',
67
+		description='Causes the bot script to terminate. Only usable by a ' + \
68
+			'user with server admin permissions.',
54
 	)
69
 	)
55
 	@commands.has_permissions(administrator=True)
70
 	@commands.has_permissions(administrator=True)
56
 	@commands.guild_only()
71
 	@commands.guild_only()
57
-	async def shutdown(self, context):
72
+	async def shutdown(self, context: commands.Context):
73
+		'Command handler'
74
+		await context.message.add_reaction('👋')
58
 		await self.bot.close()
75
 		await self.bot.close()
59
 
76
 
60
 	@commands.command(
77
 	@commands.command(
61
 		brief='Mass deletes messages',
78
 		brief='Mass deletes messages',
62
-		description='Deletes recent messages by the given user. The age is ' +
63
-			'a duration, such as "30s", "5m", "1h30m". Messages far back in ' +
64
-			'the scrollback might not be deleted by this command.',
65
-		usage='<user> <age>'
79
+		description='Deletes recent messages by the given user. The user ' +
80
+			'can be either an @ mention or a numeric user ID. The age is ' +
81
+			'a duration, such as "30s", "5m", "1h30m". Only the most ' +
82
+			'recent 100 messages in each channel are searched.',
83
+		usage='<user:id|mention> <age:timespan>'
66
 	)
84
 	)
67
 	@commands.has_permissions(manage_messages=True)
85
 	@commands.has_permissions(manage_messages=True)
68
 	@commands.guild_only()
86
 	@commands.guild_only()
69
 	async def deletemessages(self, context, user: str, age: str) -> None:
87
 	async def deletemessages(self, context, user: str, age: str) -> None:
88
+		'Command handler'
70
 		member_id = self.__parse_member_id(user)
89
 		member_id = self.__parse_member_id(user)
71
 		if member_id is None:
90
 		if member_id is None:
72
 			await context.message.reply(
91
 			await context.message.reply(

+ 9
- 5
cogs/joinraidcog.py Parādīt failu

1
+import weakref
1
 from datetime import datetime, timedelta
2
 from datetime import datetime, timedelta
2
 from discord import Guild, Member
3
 from discord import Guild, Member
3
 from discord.ext import commands
4
 from discord.ext import commands
4
-import weakref
5
 
5
 
6
 from cogs.basecog import BaseCog, BotMessage, BotMessageReaction, CogSetting
6
 from cogs.basecog import BaseCog, BotMessage, BotMessageReaction, CogSetting
7
 from config import CONFIG
7
 from config import CONFIG
9
 from storage import Storage
9
 from storage import Storage
10
 
10
 
11
 class JoinRaidContext:
11
 class JoinRaidContext:
12
+	"""
13
+	Data about a join raid.
14
+	"""
12
 	def __init__(self, join_members: list):
15
 	def __init__(self, join_members: list):
13
 		self.join_members = list(join_members)
16
 		self.join_members = list(join_members)
14
 		self.kicked_members = set()
17
 		self.kicked_members = set()
16
 		self.warning_message_ref = None
19
 		self.warning_message_ref = None
17
 
20
 
18
 	def last_join_time(self) -> datetime:
21
 	def last_join_time(self) -> datetime:
22
+		'Returns when the most recent member join was, in UTC'
19
 		return self.join_members[-1].joined_at
23
 		return self.join_members[-1].joined_at
20
 
24
 
21
-class JoinRaidCog(BaseCog):
25
+class JoinRaidCog(BaseCog, name='Join Raids'):
22
 	"""
26
 	"""
23
 	Cog for monitoring member joins and detecting potential bot raids.
27
 	Cog for monitoring member joins and detecting potential bot raids.
24
 	"""
28
 	"""
40
 			min_value=1.0,
44
 			min_value=1.0,
41
 			max_value=900.0)
45
 			max_value=900.0)
42
 
46
 
43
-	STATE_KEY_RECENT_JOINS = "JoinRaidCog.recent_joins"	
47
+	STATE_KEY_RECENT_JOINS = "JoinRaidCog.recent_joins"
44
 	STATE_KEY_LAST_RAID = "JoinRaidCog.last_raid"
48
 	STATE_KEY_LAST_RAID = "JoinRaidCog.last_raid"
45
 
49
 
46
 	def __init__(self, bot):
50
 	def __init__(self, bot):
55
 	@commands.has_permissions(ban_members=True)
59
 	@commands.has_permissions(ban_members=True)
56
 	@commands.guild_only()
60
 	@commands.guild_only()
57
 	async def joinraid(self, context: commands.Context):
61
 	async def joinraid(self, context: commands.Context):
58
-		'Command group'
62
+		'Join raid detection command group'
59
 		if context.invoked_subcommand is None:
63
 		if context.invoked_subcommand is None:
60
 			await context.send_help()
64
 			await context.send_help()
61
 
65
 
64
 			reaction: BotMessageReaction,
68
 			reaction: BotMessageReaction,
65
 			reacted_by: Member) -> None:
69
 			reacted_by: Member) -> None:
66
 		guild: Guild = bot_message.guild
70
 		guild: Guild = bot_message.guild
67
-		raid: JoinRaidRecord = bot_message.context
71
+		raid: JoinRaidContext = bot_message.context
68
 		if reaction.emoji == CONFIG['kick_emoji']:
72
 		if reaction.emoji == CONFIG['kick_emoji']:
69
 			to_kick = set(raid.join_members) - raid.kicked_members
73
 			to_kick = set(raid.join_members) - raid.kicked_members
70
 			for member in to_kick:
74
 			for member in to_kick:

+ 133
- 81
cogs/patterncog.py Parādīt failu

1
+import re
1
 from abc import ABC, abstractmethod
2
 from abc import ABC, abstractmethod
2
 from discord import Guild, Member, Message
3
 from discord import Guild, Member, Message
3
 from discord.ext import commands
4
 from discord.ext import commands
4
-from datetime import timedelta
5
-import re
6
 
5
 
7
 from cogs.basecog import BaseCog, BotMessage, BotMessageReaction
6
 from cogs.basecog import BaseCog, BotMessage, BotMessageReaction
8
 from config import CONFIG
7
 from config import CONFIG
10
 from storage import Storage
9
 from storage import Storage
11
 
10
 
12
 class PatternAction:
11
 class PatternAction:
13
-	def __init__(self, type: str, args: list):
14
-		self.type = type
12
+	"""
13
+	Describes one action to take on a matched message or its author.
14
+	"""
15
+	def __init__(self, action: str, args: list):
16
+		self.action = action
15
 		self.arguments = list(args)
17
 		self.arguments = list(args)
16
 
18
 
17
 	def __str__(self) -> str:
19
 	def __str__(self) -> str:
18
 		arg_str = ', '.join(self.arguments)
20
 		arg_str = ', '.join(self.arguments)
19
-		return f'{self.type}({arg_str})'
21
+		return f'{self.action}({arg_str})'
20
 
22
 
21
 class PatternExpression(ABC):
23
 class PatternExpression(ABC):
24
+	"""
25
+	Abstract message matching expression.
26
+	"""
22
 	def __init__(self):
27
 	def __init__(self):
23
 		pass
28
 		pass
24
 
29
 
25
 	@abstractmethod
30
 	@abstractmethod
26
 	def matches(self, message: Message) -> bool:
31
 	def matches(self, message: Message) -> bool:
32
+		"""
33
+		Whether a message matches this expression.
34
+		"""
27
 		return False
35
 		return False
28
 
36
 
29
 class PatternSimpleExpression(PatternExpression):
37
 class PatternSimpleExpression(PatternExpression):
38
+	"""
39
+	Message matching expression with a simple "<field> <operator> <value>"
40
+	structure.
41
+	"""
30
 	def __init__(self, field: str, operator: str, value):
42
 	def __init__(self, field: str, operator: str, value):
43
+		super().__init__()
31
 		self.field = field
44
 		self.field = field
32
 		self.operator = operator
45
 		self.operator = operator
33
 		self.value = value
46
 		self.value = value
34
 
47
 
35
-	def matches(self, message: Message) -> bool:
36
-		field_value = None
48
+	def __field_value(self, message: Message):
37
 		if self.field == 'content':
49
 		if self.field == 'content':
38
-			field_value = message.content
39
-		elif self.field == 'author':
40
-			field_value = str(message.author.id)
41
-		elif self.field == 'author.id':
42
-			field_value = str(message.author.id)
43
-		elif self.field == 'author.joinage':
44
-			field_value = message.created_at - message.author.joined_at
45
-		elif self.field == 'author.name':
46
-			field_value = message.author.name
50
+			return message.content
51
+		if self.field == 'author':
52
+			return str(message.author.id)
53
+		if self.field == 'author.id':
54
+			return str(message.author.id)
55
+		if self.field == 'author.joinage':
56
+			return message.created_at - message.author.joined_at
57
+		if self.field == 'author.name':
58
+			return message.author.name
47
 		else:
59
 		else:
48
 			raise ValueError(f'Bad field name {self.field}')
60
 			raise ValueError(f'Bad field name {self.field}')
61
+
62
+	def matches(self, message: Message) -> bool:
63
+		field_value = self.__field_value(message)
49
 		if self.operator == '==':
64
 		if self.operator == '==':
50
 			if isinstance(field_value, str) and isinstance(self.value, str):
65
 			if isinstance(field_value, str) and isinstance(self.value, str):
51
 				return field_value.lower() == self.value.lower()
66
 				return field_value.lower() == self.value.lower()
78
 		return f'({self.field} {self.operator} {self.value})'
93
 		return f'({self.field} {self.operator} {self.value})'
79
 
94
 
80
 class PatternCompoundExpression(PatternExpression):
95
 class PatternCompoundExpression(PatternExpression):
96
+	"""
97
+	Message matching expression that combines several child expressions with
98
+	a boolean operator.
99
+	"""
81
 	def __init__(self, operator: str, operands: list):
100
 	def __init__(self, operator: str, operands: list):
101
+		super().__init__()
82
 		self.operator = operator
102
 		self.operator = operator
83
 		self.operands = list(operands)
103
 		self.operands = list(operands)
84
 
104
 
85
 	def matches(self, message: Message) -> bool:
105
 	def matches(self, message: Message) -> bool:
86
 		if self.operator == '!':
106
 		if self.operator == '!':
87
 			return not self.operands[0].matches(message)
107
 			return not self.operands[0].matches(message)
88
-		elif self.operator == 'and':
108
+		if self.operator == 'and':
89
 			for op in self.operands:
109
 			for op in self.operands:
90
 				if not op.matches(message):
110
 				if not op.matches(message):
91
 					return False
111
 					return False
92
 			return True
112
 			return True
93
-		elif self.operator == 'or':
113
+		if self.operator == 'or':
94
 			for op in self.operands:
114
 			for op in self.operands:
95
 				if op.matches(message):
115
 				if op.matches(message):
96
 					return True
116
 					return True
97
 			return False
117
 			return False
98
-		else:
99
-			raise RuntimeError(f'Bad operator "{self.operator}"')
118
+		raise RuntimeError(f'Bad operator "{self.operator}"')
100
 
119
 
101
 	def __str__(self) -> str:
120
 	def __str__(self) -> str:
102
 		if self.operator == '!':
121
 		if self.operator == '!':
103
 			return f'(!( {self.operands[0]} ))'
122
 			return f'(!( {self.operands[0]} ))'
104
-		else:
105
-			strs = map(str, self.operands)
106
-			joined = f' {self.operator} '.join(strs)
107
-			return f'( {joined} )'
123
+		strs = map(str, self.operands)
124
+		joined = f' {self.operator} '.join(strs)
125
+		return f'( {joined} )'
108
 
126
 
109
 class PatternStatement:
127
 class PatternStatement:
128
+	"""
129
+	A full message match statement. If a message matches the given expression,
130
+	the given actions should be performed.
131
+	"""
110
 	def __init__(self, name: str, actions: list, expression: PatternExpression, original: str):
132
 	def __init__(self, name: str, actions: list, expression: PatternExpression, original: str):
111
 		self.name = name
133
 		self.name = name
112
 		self.actions = list(actions)  # PatternAction[]
134
 		self.actions = list(actions)  # PatternAction[]
114
 		self.original = original
136
 		self.original = original
115
 
137
 
116
 class PatternContext:
138
 class PatternContext:
139
+	"""
140
+	Data about a message that has matched a configured statement and what
141
+	actions have been carried out.
142
+	"""
117
 	def __init__(self, message: Message, statement: PatternStatement):
143
 	def __init__(self, message: Message, statement: PatternStatement):
118
 		self.message = message
144
 		self.message = message
119
 		self.statement = statement
145
 		self.statement = statement
121
 		self.is_kicked = False
147
 		self.is_kicked = False
122
 		self.is_banned = False
148
 		self.is_banned = False
123
 
149
 
124
-class PatternCog(BaseCog):
125
-	def __init__(self, bot):
126
-		super().__init__(bot)
127
-
128
-	# def __patterns(self, guild: Guild) -> list:
129
-	# 	patterns = Storage.get_state_value(guild, 'pattern_patterns')
130
-	# 	if patterns is None:
131
-	# 		patterns_encoded = Storage.get_config_value(guild, 'pattern_patterns')
132
-	# 		if patterns_encoded:
133
-	# 			patterns = []
134
-	# 			for pe in patterns_encoded:
135
-	# 				patterns.append(Pattern.decode(pe))
136
-	# 			Storage.set_state_value(guild, 'pattern_patterns', patterns)
137
-	# 	return patterns
150
+class PatternCog(BaseCog, name='Pattern Matching'):
151
+	"""
152
+	Highly flexible cog for performing various actions on messages that match
153
+	various critera. Patterns can be defined by mods for each guild.
154
+	"""
138
 
155
 
139
 	def __get_patterns(self, guild: Guild) -> dict:
156
 	def __get_patterns(self, guild: Guild) -> dict:
140
 		patterns = Storage.get_state_value(guild, 'PatternCog.patterns')
157
 		patterns = Storage.get_state_value(guild, 'PatternCog.patterns')
148
 					try:
165
 					try:
149
 						ps = PatternCompiler.parse_statement(name, statement)
166
 						ps = PatternCompiler.parse_statement(name, statement)
150
 						patterns[name] = ps
167
 						patterns[name] = ps
151
-					except RuntimeError as e:
152
-						self.log(guild, f'Error parsing saved statement "{name}". Skipping: {statement}')
168
+					except Exception as e:
169
+						self.log(guild, f'Error parsing saved statement "{name}". Skipping: {statement}. Error: {e}')
153
 			Storage.set_state_value(guild, 'PatternCog.patterns', patterns)
170
 			Storage.set_state_value(guild, 'PatternCog.patterns', patterns)
154
 		return patterns
171
 		return patterns
155
 
172
 
156
-	def __save_patterns(self, guild: Guild, patterns: dict) -> None:
173
+	@classmethod
174
+	def __save_patterns(cls, guild: Guild, patterns: dict) -> None:
157
 		to_save = []
175
 		to_save = []
158
 		for name, statement in patterns.items():
176
 		for name, statement in patterns.items():
159
 			to_save.append({
177
 			to_save.append({
164
 
182
 
165
 	@commands.Cog.listener()
183
 	@commands.Cog.listener()
166
 	async def on_message(self, message: Message) -> None:
184
 	async def on_message(self, message: Message) -> None:
185
+		'Event listener'
167
 		if message.author is None or \
186
 		if message.author is None or \
168
 				message.author.bot or \
187
 				message.author.bot or \
169
 				message.channel is None or \
188
 				message.channel is None or \
176
 			return
195
 			return
177
 
196
 
178
 		patterns = self.__get_patterns(message.guild)
197
 		patterns = self.__get_patterns(message.guild)
179
-		for name, statement in patterns.items():
198
+		for _, statement in patterns.items():
180
 			if statement.expression.matches(message):
199
 			if statement.expression.matches(message):
181
 				await self.__trigger_actions(message, statement)
200
 				await self.__trigger_actions(message, statement)
182
 				break
201
 				break
188
 		action_descriptions = []
207
 		action_descriptions = []
189
 		self.log(message.guild, f'Message from {message.author.name} matched pattern "{statement.name}"')
208
 		self.log(message.guild, f'Message from {message.author.name} matched pattern "{statement.name}"')
190
 		for action in statement.actions:
209
 		for action in statement.actions:
191
-			if action.type == 'ban':
210
+			if action.action == 'ban':
192
 				await message.author.ban(
211
 				await message.author.ban(
193
 					reason=f'Rocketbot: Message matched custom pattern named "{statement.name}"',
212
 					reason=f'Rocketbot: Message matched custom pattern named "{statement.name}"',
194
 					delete_message_days=0)
213
 					delete_message_days=0)
196
 				context.is_kicked = True
215
 				context.is_kicked = True
197
 				action_descriptions.append('Author banned')
216
 				action_descriptions.append('Author banned')
198
 				self.log(message.guild, f'{message.author.name} banned')
217
 				self.log(message.guild, f'{message.author.name} banned')
199
-			elif action.type == 'delete':
218
+			elif action.action == 'delete':
200
 				await message.delete()
219
 				await message.delete()
201
 				context.is_deleted = True
220
 				context.is_deleted = True
202
 				action_descriptions.append('Message deleted')
221
 				action_descriptions.append('Message deleted')
203
 				self.log(message.guild, f'{message.author.name}\'s message deleted')
222
 				self.log(message.guild, f'{message.author.name}\'s message deleted')
204
-			elif action.type == 'kick':
223
+			elif action.action == 'kick':
205
 				await message.author.kick(
224
 				await message.author.kick(
206
 					reason=f'Rocketbot: Message matched custom pattern named "{statement.name}"')
225
 					reason=f'Rocketbot: Message matched custom pattern named "{statement.name}"')
207
 				context.is_kicked = True
226
 				context.is_kicked = True
208
 				action_descriptions.append('Author kicked')
227
 				action_descriptions.append('Author kicked')
209
 				self.log(message.guild, f'{message.author.name} kicked')
228
 				self.log(message.guild, f'{message.author.name} kicked')
210
-			elif action.type == 'modwarn':
229
+			elif action.action == 'modwarn':
211
 				should_alert_mods = True
230
 				should_alert_mods = True
212
 				action_descriptions.append('Mods alerted')
231
 				action_descriptions.append('Mods alerted')
213
-			elif action.type == 'reply':
232
+			elif action.action == 'reply':
214
 				await message.reply(
233
 				await message.reply(
215
 					f'{action.arguments[0]}',
234
 					f'{action.arguments[0]}',
216
 					mention_author=False)
235
 					mention_author=False)
240
 			context.is_deleted = True
259
 			context.is_deleted = True
241
 		elif reaction.emoji == CONFIG['kick_emoji']:
260
 		elif reaction.emoji == CONFIG['kick_emoji']:
242
 			await context.message.author.kick(
261
 			await context.message.author.kick(
243
-				reason=f'Rocketbot: Message matched custom pattern named ' + \
244
-					'"{statement.name}". Kicked by {reacted_by.name}.')
262
+				reason='Rocketbot: Message matched custom pattern named ' + \
263
+					f'"{context.statement.name}". Kicked by {reacted_by.name}.')
245
 			context.is_kicked = True
264
 			context.is_kicked = True
246
 		elif reaction.emoji == CONFIG['ban_emoji']:
265
 		elif reaction.emoji == CONFIG['ban_emoji']:
247
 			await context.message.author.ban(
266
 			await context.message.author.ban(
248
-				reason=f'Rocketbot: Message matched custom pattern named ' + \
249
-					'"{statement.name}". Banned by {reacted_by.name}.',
267
+				reason='Rocketbot: Message matched custom pattern named ' + \
268
+					f'"{context.statement.name}". Banned by {reacted_by.name}.',
250
 					delete_message_days=1)
269
 					delete_message_days=1)
251
 			context.is_banned = True
270
 			context.is_banned = True
252
 		await bot_message.set_reactions(BotMessageReaction.standard_set(
271
 		await bot_message.set_reactions(BotMessageReaction.standard_set(
260
 	@commands.has_permissions(ban_members=True)
279
 	@commands.has_permissions(ban_members=True)
261
 	@commands.guild_only()
280
 	@commands.guild_only()
262
 	async def pattern(self, context: commands.Context):
281
 	async def pattern(self, context: commands.Context):
263
-		'Message pattern matching'
282
+		'Message pattern matching command group'
264
 		if context.invoked_subcommand is None:
283
 		if context.invoked_subcommand is None:
265
 			await context.send_help()
284
 			await context.send_help()
266
 
285
 
273
 		ignore_extra=True
292
 		ignore_extra=True
274
 	)
293
 	)
275
 	async def add(self, context: commands.Context, name: str):
294
 	async def add(self, context: commands.Context, name: str):
295
+		'Command handler'
276
 		pattern_str = PatternCompiler.expression_str_from_context(context, name)
296
 		pattern_str = PatternCompiler.expression_str_from_context(context, name)
277
 		try:
297
 		try:
278
 			statement = PatternCompiler.parse_statement(name, pattern_str)
298
 			statement = PatternCompiler.parse_statement(name, pattern_str)
292
 		usage='<pattern_name>'
312
 		usage='<pattern_name>'
293
 	)
313
 	)
294
 	async def remove(self, context: commands.Context, name: str):
314
 	async def remove(self, context: commands.Context, name: str):
315
+		'Command handler'
295
 		patterns = self.__get_patterns(context.guild)
316
 		patterns = self.__get_patterns(context.guild)
296
 		if patterns.get(name) is not None:
317
 		if patterns.get(name) is not None:
297
 			del patterns[name]
318
 			del patterns[name]
308
 		brief='Lists all patterns'
329
 		brief='Lists all patterns'
309
 	)
330
 	)
310
 	async def list(self, context: commands.Context) -> None:
331
 	async def list(self, context: commands.Context) -> None:
332
+		'Command handler'
311
 		patterns = self.__get_patterns(context.guild)
333
 		patterns = self.__get_patterns(context.guild)
312
 		if len(patterns) == 0:
334
 		if len(patterns) == 0:
313
 			await context.message.reply('No patterns defined.', mention_author=False)
335
 			await context.message.reply('No patterns defined.', mention_author=False)
318
 		await context.message.reply(msg, mention_author=False)
340
 		await context.message.reply(msg, mention_author=False)
319
 
341
 
320
 class PatternCompiler:
342
 class PatternCompiler:
343
+	"""
344
+	Parses a user-provided message filter statement into a PatternStatement.
345
+	"""
321
 	TYPE_ID = 'id'
346
 	TYPE_ID = 'id'
322
 	TYPE_MEMBER = 'Member'
347
 	TYPE_MEMBER = 'Member'
323
 	TYPE_TEXT = 'text'
348
 	TYPE_TEXT = 'text'
364
 
389
 
365
 	@classmethod
390
 	@classmethod
366
 	def expression_str_from_context(cls, context: commands.Context, name: str) -> str:
391
 	def expression_str_from_context(cls, context: commands.Context, name: str) -> str:
392
+		"""
393
+		Extracts the statement string from an "add" command context.
394
+		"""
367
 		pattern_str = context.message.content
395
 		pattern_str = context.message.content
368
 		command_chain = [ name ]
396
 		command_chain = [ name ]
369
 		cmd = context.command
397
 		cmd = context.command
380
 
408
 
381
 	@classmethod
409
 	@classmethod
382
 	def parse_statement(cls, name: str, statement: str) -> PatternStatement:
410
 	def parse_statement(cls, name: str, statement: str) -> PatternStatement:
411
+		"""
412
+		Parses a user-provided message filter statement into a PatternStatement.
413
+		"""
383
 		tokens = cls.tokenize(statement)
414
 		tokens = cls.tokenize(statement)
384
 		token_index = 0
415
 		token_index = 0
385
 		actions, token_index = cls.read_actions(tokens, token_index)
416
 		actions, token_index = cls.read_actions(tokens, token_index)
388
 
419
 
389
 	@classmethod
420
 	@classmethod
390
 	def tokenize(cls, statement: str) -> list:
421
 	def tokenize(cls, statement: str) -> list:
422
+		"""
423
+		Converts a message filter statement into a list of tokens.
424
+		"""
391
 		tokens = []
425
 		tokens = []
392
 		in_quote = False
426
 		in_quote = False
393
 		in_escape = False
427
 		in_escape = False
476
 
510
 
477
 	@classmethod
511
 	@classmethod
478
 	def read_actions(cls, tokens: list, token_index: int) -> tuple:
512
 	def read_actions(cls, tokens: list, token_index: int) -> tuple:
513
+		"""
514
+		Reads the actions from a list of statement tokens. Returns a tuple
515
+		containing a list of PatternActions and the token index this method
516
+		left off at (the token after the "if").
517
+		"""
479
 		actions = []
518
 		actions = []
480
 		current_action_tokens = []
519
 		current_action_tokens = []
481
 		while token_index < len(tokens):
520
 		while token_index < len(tokens):
501
 
540
 
502
 	@classmethod
541
 	@classmethod
503
 	def __validate_action(cls, action: PatternAction) -> None:
542
 	def __validate_action(cls, action: PatternAction) -> None:
504
-		args = cls.ACTION_TO_ARGS.get(action.type)
543
+		args = cls.ACTION_TO_ARGS.get(action.action)
505
 		if args is None:
544
 		if args is None:
506
-			raise RuntimeError(f'Unknown action "{action.type}"')
545
+			raise RuntimeError(f'Unknown action "{action.action}"')
507
 		if len(action.arguments) != len(args):
546
 		if len(action.arguments) != len(args):
508
-			arg_list = ', '.join(args)
509
 			if len(args) == 0:
547
 			if len(args) == 0:
510
-				raise RuntimeError(f'Action "{action.type}" expects no arguments, got {len(action.arguments)}.')
548
+				raise RuntimeError(f'Action "{action.action}" expects no arguments, ' + \
549
+					f'got {len(action.arguments)}.')
511
 			else:
550
 			else:
512
-				raise RuntimeError(f'Action "{action.type}" expects {len(args)} arguments, got {len(action.arguments)}.')
513
-		for i in range(len(args)):
514
-			datatype = args[i]
551
+				raise RuntimeError(f'Action "{action.action}" expects {len(args)} ' + \
552
+					f'arguments, got {len(action.arguments)}.')
553
+		for i, datatype in enumerate(args):
515
 			action.arguments[i] = cls.parse_value(action.arguments[i], datatype)
554
 			action.arguments[i] = cls.parse_value(action.arguments[i], datatype)
516
 
555
 
517
 	@classmethod
556
 	@classmethod
518
-	def read_expression(cls, tokens: list, token_index: int, depth: int = 0, one_subexpression: bool = False) -> tuple:
519
-		# field op value
520
-		# (field op value)
521
-		# !(field op value)
522
-		# field op value and field op value
523
-		# (field op value and field op value) or field op value
524
-		indent = '\t' * depth
557
+	def read_expression(cls,
558
+			tokens: list,
559
+			token_index: int,
560
+			depth: int = 0,
561
+			one_subexpression: bool = False) -> tuple:
562
+		"""
563
+		Reads an expression from a list of statement tokens. Returns a tuple
564
+		containing the PatternExpression and the token index it left off at.
565
+		If one_subexpression is True then it will return after reading a
566
+		single expression instead of joining multiples (for readong the
567
+		subject of a NOT expression).
568
+		"""
525
 		subexpressions = []
569
 		subexpressions = []
526
 		last_compound_operator = None
570
 		last_compound_operator = None
527
 		while token_index < len(tokens):
571
 		while token_index < len(tokens):
528
 			if one_subexpression:
572
 			if one_subexpression:
529
 				if len(subexpressions) == 1:
573
 				if len(subexpressions) == 1:
530
 					return (subexpressions[0], token_index)
574
 					return (subexpressions[0], token_index)
531
-				elif len(subexpressions) > 1:
575
+				if len(subexpressions) > 1:
532
 					raise RuntimeError('Too many subexpressions')
576
 					raise RuntimeError('Too many subexpressions')
533
 			compound_operator = None
577
 			compound_operator = None
534
 			if tokens[token_index] == ')':
578
 			if tokens[token_index] == ')':
535
 				if len(subexpressions) == 0:
579
 				if len(subexpressions) == 0:
536
 					raise RuntimeError('No subexpressions')
580
 					raise RuntimeError('No subexpressions')
537
-				elif len(subexpressions) == 1:
581
+				if len(subexpressions) == 1:
538
 					return (subexpressions[0], token_index)
582
 					return (subexpressions[0], token_index)
539
-				else:
540
-					return (PatternCompoundExpression(last_compound_operator, subexpressions), token_index)
583
+				return (PatternCompoundExpression(last_compound_operator, subexpressions), token_index)
541
 			if tokens[token_index] in set(["and", "or"]):
584
 			if tokens[token_index] in set(["and", "or"]):
542
 				compound_operator = tokens[token_index]
585
 				compound_operator = tokens[token_index]
543
 				if last_compound_operator and compound_operator != last_compound_operator:
586
 				if last_compound_operator and compound_operator != last_compound_operator:
547
 					last_compound_operator = compound_operator
590
 					last_compound_operator = compound_operator
548
 				token_index += 1
591
 				token_index += 1
549
 			if tokens[token_index] == '!':
592
 			if tokens[token_index] == '!':
550
-				(exp, next_index) = cls.read_expression(tokens, token_index + 1, depth + 1, one_subexpression=True)
593
+				(exp, next_index) = cls.read_expression(tokens, token_index + 1, \
594
+						depth + 1, one_subexpression=True)
551
 				subexpressions.append(PatternCompoundExpression('!', [exp]))
595
 				subexpressions.append(PatternCompoundExpression('!', [exp]))
552
 				token_index = next_index
596
 				token_index = next_index
553
 			elif tokens[token_index] == '(':
597
 			elif tokens[token_index] == '(':
569
 
613
 
570
 	@classmethod
614
 	@classmethod
571
 	def read_simple_expression(cls, tokens: list, token_index: int, depth: int = 0) -> tuple:
615
 	def read_simple_expression(cls, tokens: list, token_index: int, depth: int = 0) -> tuple:
572
-		indent = '\t' * depth
573
-
616
+		"""
617
+		Reads a simple expression consisting of a field name, operator, and
618
+		comparison value. Returns a tuple of the PatternSimpleExpression and
619
+		the token index it left off at.
620
+		"""
621
+		if depth > 8:
622
+			raise RuntimeError('Expression nests too deeply')
574
 		if token_index >= len(tokens):
623
 		if token_index >= len(tokens):
575
 			raise RuntimeError('Expected field name, found EOL')
624
 			raise RuntimeError('Expected field name, found EOL')
576
 		field = tokens[token_index]
625
 		field = tokens[token_index]
609
 		return (exp, token_index)
658
 		return (exp, token_index)
610
 
659
 
611
 	@classmethod
660
 	@classmethod
612
-	def parse_value(cls, value: str, type: str):
613
-		if type == cls.TYPE_ID:
661
+	def parse_value(cls, value: str, datatype: str):
662
+		"""
663
+		Converts a value token to its Python value.
664
+		"""
665
+		if datatype == cls.TYPE_ID:
614
 			p = re.compile('^[0-9]+$')
666
 			p = re.compile('^[0-9]+$')
615
 			if p.match(value) is None:
667
 			if p.match(value) is None:
616
 				raise ValueError(f'Illegal id value "{value}"')
668
 				raise ValueError(f'Illegal id value "{value}"')
617
 			# Store it as a str so it can be larger than an int
669
 			# Store it as a str so it can be larger than an int
618
 			return value
670
 			return value
619
-		if type == cls.TYPE_MEMBER:
671
+		if datatype == cls.TYPE_MEMBER:
620
 			p = re.compile('^<@!?([0-9]+)>$')
672
 			p = re.compile('^<@!?([0-9]+)>$')
621
 			m = p.match(value)
673
 			m = p.match(value)
622
 			if m is None:
674
 			if m is None:
623
-				raise ValueError(f'Illegal member value. Must be an @ mention.')
675
+				raise ValueError('Illegal member value. Must be an @ mention.')
624
 			return m.group(1)
676
 			return m.group(1)
625
-		if type == cls.TYPE_TEXT:
677
+		if datatype == cls.TYPE_TEXT:
626
 			# Must be quoted.
678
 			# Must be quoted.
627
 			if len(value) < 2 or \
679
 			if len(value) < 2 or \
628
 					value[0:1] not in cls.STRING_QUOTE_CHARS or \
680
 					value[0:1] not in cls.STRING_QUOTE_CHARS or \
630
 					value[0:1] != value[-1:]:
682
 					value[0:1] != value[-1:]:
631
 				raise ValueError(f'Not a quoted string value: {value}')
683
 				raise ValueError(f'Not a quoted string value: {value}')
632
 			return value[1:-1]
684
 			return value[1:-1]
633
-		if type == cls.TYPE_INT:
685
+		if datatype == cls.TYPE_INT:
634
 			return int(value)
686
 			return int(value)
635
-		if type == cls.TYPE_FLOAT:
687
+		if datatype == cls.TYPE_FLOAT:
636
 			return float(value)
688
 			return float(value)
637
-		if type == cls.TYPE_TIMESPAN:
689
+		if datatype == cls.TYPE_TIMESPAN:
638
 			return parse_timedelta(value)
690
 			return parse_timedelta(value)
639
 		raise ValueError(f'Unhandled datatype {datatype}')
691
 		raise ValueError(f'Unhandled datatype {datatype}')

+ 10
- 25
cogs/urlspamcog.py Parādīt failu

1
-from discord import Guild, Member, Message
2
-from discord.ext import commands
3
 import re
1
 import re
4
 from datetime import timedelta
2
 from datetime import timedelta
3
+from discord import Member, Message
4
+from discord.ext import commands
5
 
5
 
6
 from cogs.basecog import BaseCog, BotMessage, BotMessageReaction, CogSetting
6
 from cogs.basecog import BaseCog, BotMessage, BotMessageReaction, CogSetting
7
 from config import CONFIG
7
 from config import CONFIG
8
-from storage import Storage
8
+from rbutils import describe_timedelta
9
 
9
 
10
 class URLSpamContext:
10
 class URLSpamContext:
11
+	"""
12
+	Data about a suspected spam message containing a URL.
13
+	"""
11
 	def __init__(self, spam_message: Message):
14
 	def __init__(self, spam_message: Message):
12
 		self.spam_message = spam_message
15
 		self.spam_message = spam_message
13
 		self.is_deleted = False
16
 		self.is_deleted = False
14
 		self.is_kicked = False
17
 		self.is_kicked = False
15
 		self.is_banned = False
18
 		self.is_banned = False
16
 
19
 
17
-class URLSpamCog(BaseCog):
20
+class URLSpamCog(BaseCog, name='URL Spam'):
18
 	"""
21
 	"""
19
 	Detects users posting URLs who just joined recently: a common spam pattern.
22
 	Detects users posting URLs who just joined recently: a common spam pattern.
20
 	Can be configured to take immediate action or just warn the mods.
23
 	Can be configured to take immediate action or just warn the mods.
50
 	@commands.has_permissions(ban_members=True)
53
 	@commands.has_permissions(ban_members=True)
51
 	@commands.guild_only()
54
 	@commands.guild_only()
52
 	async def urlspam(self, context: commands.Context):
55
 	async def urlspam(self, context: commands.Context):
53
-		'Command group'
56
+		'URL spam command group'
54
 		if context.invoked_subcommand is None:
57
 		if context.invoked_subcommand is None:
55
 			await context.send_help()
58
 			await context.send_help()
56
 
59
 
57
 	@commands.Cog.listener()
60
 	@commands.Cog.listener()
58
 	async def on_message(self, message: Message):
61
 	async def on_message(self, message: Message):
62
+		'Event listener'
59
 		if message.author is None or \
63
 		if message.author is None or \
60
 				message.author.bot or \
64
 				message.author.bot or \
61
 				message.guild is None or \
65
 				message.guild is None or \
73
 		if not self.__contains_url(message.content):
77
 		if not self.__contains_url(message.content):
74
 			return
78
 			return
75
 		join_age = message.created_at - message.author.joined_at
79
 		join_age = message.created_at - message.author.joined_at
76
-		join_age_str = self.__format_timedelta(join_age)
80
+		join_age_str = describe_timedelta(join_age)
77
 		if join_age < min_join_age:
81
 		if join_age < min_join_age:
78
 			context = URLSpamContext(message)
82
 			context = URLSpamContext(message)
79
 			needs_attention = False
83
 			needs_attention = False
165
 	def __contains_url(cls, text: str) -> bool:
169
 	def __contains_url(cls, text: str) -> bool:
166
 		p = re.compile(r'http(?:s)?://[^\s]+')
170
 		p = re.compile(r'http(?:s)?://[^\s]+')
167
 		return p.search(text) is not None
171
 		return p.search(text) is not None
168
-
169
-	@classmethod
170
-	def __format_timedelta(cls, timespan: timedelta) -> str:
171
-		parts = []
172
-		d = timespan.days
173
-		h = timespan.seconds // 3600
174
-		m = (timespan.seconds // 60) % 60
175
-		s = timespan.seconds % 60
176
-		if d > 0:
177
-			parts.append(f'{d} days')
178
-		if d > 0 or h > 0:
179
-			parts.append(f'{h} hours')
180
-		if d > 0 or h > 0 or m > 0:
181
-			parts.append(f'{m} minutes')
182
-		parts.append(f'{s} seconds')
183
-		# Limit the precision to the two most significant elements
184
-		while len(parts) > 2:
185
-			parts.pop(-1)
186
-		return ' '.join(parts)

+ 20
- 26
rbutils.py Parādīt failu

1
-from datetime import timedelta
2
 import re
1
 import re
2
+from datetime import timedelta
3
 
3
 
4
 def parse_timedelta(s: str) -> timedelta:
4
 def parse_timedelta(s: str) -> timedelta:
5
 	"""
5
 	"""
32
 	return timedelta(days=days, hours=hours, minutes=minutes, seconds=seconds)
32
 	return timedelta(days=days, hours=hours, minutes=minutes, seconds=seconds)
33
 
33
 
34
 def describe_timedelta(td: timedelta, max_components: int = 2) -> str:
34
 def describe_timedelta(td: timedelta, max_components: int = 2) -> str:
35
-	values = [
36
-		td.days,
37
-		td.seconds // 3600,
38
-		(td.seconds // 60) % 60,
39
-		td.seconds % 60,
40
-	]
41
-	units = [
42
-		'day' if values[0] == 1 else 'days',
43
-		'hour' if values[1] == 1 else 'hours',
44
-		'minute' if values[2] == 1 else 'minutes',
45
-		'second' if values[3] == 1 else 'seconds',
46
-	]
47
-	while len(values) > 1 and values[0] == 0:
48
-		values.pop(0)
49
-		units.pop(0)
50
-	if len(values) > max_components:
51
-		values = values[0:max_components]
52
-		units = units[0:max_components]
53
-	while len(values) > 1 and values[-1] == 0:
54
-		values.pop(-1)
55
-		units.pop(-1)
56
-	tokens = []
57
-	for i in range(len(values)):
58
-		tokens.append(f'{values[i]} {units[i]}')
59
-	return ' '.join(tokens)
35
+	"""
36
+	Formats a human-readable description of a time span. E.g. "3 days 2 hours".
37
+	"""
38
+	d = td.days
39
+	h = td.seconds // 3600
40
+	m = (td.seconds // 60) % 60
41
+	s = td.seconds % 60
42
+	components = []
43
+	if d != 0:
44
+		components.append('1 day' if d == 1 else f'{d} days')
45
+	if h != 0:
46
+		components.append('1 hour' if h == 1 else f'{h} hours')
47
+	if m != 0:
48
+		components.append('1 minute' if m == 1 else f'{m} minutes')
49
+	if s != 0 or len(components) == 0:
50
+		components.append('1 second' if s == 1 else f'{s} seconds')
51
+	if len(components) > max_components:
52
+		components = components[0:max_components]
53
+	return ' '.join(components)

+ 12
- 9
rocketbot.py Parādīt failu

18
 
18
 
19
 CURRENT_CONFIG_VERSION = 3
19
 CURRENT_CONFIG_VERSION = 3
20
 if (CONFIG.get('__config_version') or 0) < CURRENT_CONFIG_VERSION:
20
 if (CONFIG.get('__config_version') or 0) < CURRENT_CONFIG_VERSION:
21
-    # If you're getting this error, it means something changed in config.py's
22
-    # format. Consult config.py.sample and compare it to your own config.py.
23
-    # Rename/move any values as needed. When satisfied, update "__config_version"
24
-    # to the value in config.py.sample.
25
-    raise RuntimeError('config.py format may be outdated. Review ' +
26
-        'config.py.sample, update the "__config_version" field to ' +
27
-        f'{CURRENT_CONFIG_VERSION}, and try again.')
21
+	# If you're getting this error, it means something changed in config.py's
22
+	# format. Consult config.py.sample and compare it to your own config.py.
23
+	# Rename/move any values as needed. When satisfied, update "__config_version"
24
+	# to the value in config.py.sample.
25
+	raise RuntimeError('config.py format may be outdated. Review ' +
26
+		'config.py.sample, update the "__config_version" field to ' +
27
+		f'{CURRENT_CONFIG_VERSION}, and try again.')
28
 
28
 
29
 class Rocketbot(commands.Bot):
29
 class Rocketbot(commands.Bot):
30
+	"""
31
+	Bot subclass
32
+	"""
30
 	def __init__(self, command_prefix, **kwargs):
33
 	def __init__(self, command_prefix, **kwargs):
31
 		super().__init__(command_prefix, **kwargs)
34
 		super().__init__(command_prefix, **kwargs)
32
 
35
 
33
 intents = Intents.default()
36
 intents = Intents.default()
34
-intents.messages = True
35
-intents.members = True # To get join/leave events
37
+intents.messages = True  # pylint: disable=assigning-non-slot
38
+intents.members = True  # pylint: disable=assigning-non-slot
36
 bot = Rocketbot(command_prefix=CONFIG['command_prefix'], intents=intents)
39
 bot = Rocketbot(command_prefix=CONFIG['command_prefix'], intents=intents)
37
 
40
 
38
 # Core
41
 # Core

+ 31
- 15
rscollections.py Parādīt failu

1
-from abc import ABC, abstractmethod
2
-
3
 """
1
 """
4
 Subclasses of list, set, and dict with special behaviors.
2
 Subclasses of list, set, and dict with special behaviors.
5
 """
3
 """
6
 
4
 
5
+from abc import ABC, abstractmethod
6
+
7
 # Abstract collections
7
 # Abstract collections
8
 
8
 
9
 class AbstractMutableList(list, ABC):
9
 class AbstractMutableList(list, ABC):
247
 	however elements will only be discarded following the next mutating
247
 	however elements will only be discarded following the next mutating
248
 	operation. Call `self.purge_old_elements()` to force resizing.
248
 	operation. Call `self.purge_old_elements()` to force resizing.
249
 	"""
249
 	"""
250
-	def __init__(self, max_element_count: int, element_age, *args, **kwargs):
250
+	def __init__(self,
251
+			max_element_count: int,
252
+			element_age,
253
+			*args, **kwargs):
251
 		super().__init__(*args, **kwargs)
254
 		super().__init__(*args, **kwargs)
252
 		self.element_age = element_age
255
 		self.element_age = element_age
253
 		self.max_element_count = max_element_count
256
 		self.max_element_count = max_element_count
272
 		while len(self) > self.max_element_count:
275
 		while len(self) > self.max_element_count:
273
 			oldest_age = None
276
 			oldest_age = None
274
 			oldest_index = -1
277
 			oldest_index = -1
275
-			for i in range(len(self)):
276
-				elem = self[i]
278
+			for i, elem in enumerate(self):
277
 				age = self.element_age(i, elem)
279
 				age = self.element_age(i, elem)
278
 				if oldest_age is None or age < oldest_age:
280
 				if oldest_age is None or age < oldest_age:
279
 					oldest_age = age
281
 					oldest_age = age
281
 			self.pop(oldest_index)
283
 			self.pop(oldest_index)
282
 		self.is_culling = False
284
 		self.is_culling = False
283
 
285
 
284
-	def copy():
285
-		return SizeBoundList(max_element_count, element_age, super())
286
+	def copy(self):
287
+		return SizeBoundList(self.max_element_count, self.element_age, super())
286
 
288
 
287
 class SizeBoundSet(AbstractMutableSet):
289
 class SizeBoundSet(AbstractMutableSet):
288
 	"""
290
 	"""
303
 	however elements will only be discarded following the next mutating
305
 	however elements will only be discarded following the next mutating
304
 	operation. Call `self.purge_old_elements()` to force resizing.
306
 	operation. Call `self.purge_old_elements()` to force resizing.
305
 	"""
307
 	"""
306
-	def __init__(self, max_element_count: int, element_age, *args, **kwargs):
308
+	def __init__(self,
309
+			max_element_count: int,
310
+			element_age,
311
+			*args, **kwargs):
307
 		super().__init__(*args, **kwargs)
312
 		super().__init__(*args, **kwargs)
308
 		self.element_age = element_age
313
 		self.element_age = element_age
309
 		self.max_element_count = max_element_count
314
 		self.max_element_count = max_element_count
335
 			self.remove(oldest_elem)
340
 			self.remove(oldest_elem)
336
 		self.is_culling = False
341
 		self.is_culling = False
337
 
342
 
338
-	def copy():
339
-		return SizeBoundSet(max_element_count, element_age, super())
343
+	def copy(self):
344
+		return SizeBoundSet(self.max_element_count, self.element_age, super())
340
 
345
 
341
 class SizeBoundDict(AbstractMutableDict):
346
 class SizeBoundDict(AbstractMutableDict):
342
 	"""
347
 	"""
357
 	however elements will only be discarded following the next mutating
362
 	however elements will only be discarded following the next mutating
358
 	operation. Call `self.purge_old_elements()` to force resizing.
363
 	operation. Call `self.purge_old_elements()` to force resizing.
359
 	"""
364
 	"""
360
-	def __init__(self, max_element_count = 10, element_age = (lambda key, value: float(key)), *args, **kwargs):
365
+	def __init__(self,
366
+			max_element_count: int,
367
+			element_age,
368
+			*args, **kwargs):
361
 		super().__init__(*args, **kwargs)
369
 		super().__init__(*args, **kwargs)
362
 		self.element_age = element_age
370
 		self.element_age = element_age
363
 		self.max_element_count = max_element_count
371
 		self.max_element_count = max_element_count
389
 			del self[oldest_key]
397
 			del self[oldest_key]
390
 		self.is_culling = False
398
 		self.is_culling = False
391
 
399
 
392
-	def copy():
393
-		return SizeBoundDict(max_element_count, element_age, super())
400
+	def copy(self):
401
+		return SizeBoundDict(self.max_element_count, self.element_age, super())
394
 
402
 
395
 # Collections with limited age of elements
403
 # Collections with limited age of elements
396
 
404
 
436
 		min_age = None
444
 		min_age = None
437
 		max_age = None
445
 		max_age = None
438
 		ages = {}
446
 		ages = {}
439
-		for i in range(len(self)):
440
-			elem = self[i]
447
+		for i, elem in enumerate(self):
441
 			age = self.element_age(i, elem)
448
 			age = self.element_age(i, elem)
442
 			ages[i] = age
449
 			ages[i] = age
443
 			if min_age is None or age < min_age:
450
 			if min_age is None or age < min_age:
453
 				del self[i]
460
 				del self[i]
454
 		self.is_culling = False
461
 		self.is_culling = False
455
 
462
 
463
+	def copy(self):
464
+		return AgeBoundList(self.max_age, self.element_age, super())
465
+
456
 class AgeBoundSet(AbstractMutableSet):
466
 class AgeBoundSet(AbstractMutableSet):
457
 	"""
467
 	"""
458
 	Subclass of `set` that enforces a maximum "age" of elements.
468
 	Subclass of `set` that enforces a maximum "age" of elements.
511
 				self.remove(elem)
521
 				self.remove(elem)
512
 		self.is_culling = False
522
 		self.is_culling = False
513
 
523
 
524
+	def copy(self):
525
+		return AgeBoundSet(self.max_age, self.element_age, super())
526
+
514
 class AgeBoundDict(AbstractMutableDict):
527
 class AgeBoundDict(AbstractMutableDict):
515
 	"""
528
 	"""
516
 	Subclass of `dict` that enforces a maximum "age" of elements.
529
 	Subclass of `dict` that enforces a maximum "age" of elements.
568
 			if age < cutoff:
581
 			if age < cutoff:
569
 				del self[key]
582
 				del self[key]
570
 		self.is_culling = False
583
 		self.is_culling = False
584
+
585
+	def copy(self):
586
+		return AgeBoundDict(self.max_age, self.element_age, super())

+ 19
- 14
storage.py Parādīt failu

1
+"""
2
+Handles storage of persisted and non-persisted data for the bot.
3
+"""
1
 import json
4
 import json
2
 from os.path import exists
5
 from os.path import exists
3
-
4
 from discord import Guild
6
 from discord import Guild
5
 
7
 
6
 from config import CONFIG
8
 from config import CONFIG
7
 
9
 
8
 class ConfigKey:
10
 class ConfigKey:
11
+	"""
12
+	Common keys in persisted guild storage.
13
+	"""
9
 	WARNING_CHANNEL_ID = 'warning_channel_id'
14
 	WARNING_CHANNEL_ID = 'warning_channel_id'
10
 	WARNING_MENTION = 'warning_mention'
15
 	WARNING_MENTION = 'warning_mention'
11
 
16
 
47
 		cls.set_state_values(guild, { key: value })
52
 		cls.set_state_values(guild, { key: value })
48
 
53
 
49
 	@classmethod
54
 	@classmethod
50
-	def set_state_values(cls, guild: Guild, vars: dict) -> None:
55
+	def set_state_values(cls, guild: Guild, values: dict) -> None:
51
 		"""
56
 		"""
52
 		Merges in a set of key-value pairs into the transient state for the
57
 		Merges in a set of key-value pairs into the transient state for the
53
 		given guild. Any pairs with a value of `None` will be removed from the
58
 		given guild. Any pairs with a value of `None` will be removed from the
54
 		transient state.
59
 		transient state.
55
 		"""
60
 		"""
56
-		if vars is None or len(vars) == 0:
61
+		if values is None or len(values) == 0:
57
 			return
62
 			return
58
 		state: dict = cls.get_state(guild)
63
 		state: dict = cls.get_state(guild)
59
-		for key, value in vars.items():
64
+		for key, value in values.items():
60
 			if value is None:
65
 			if value is None:
61
 				del state[key]
66
 				del state[key]
62
 			else:
67
 			else:
105
 		cls.set_config_values(guild, { key: value })
110
 		cls.set_config_values(guild, { key: value })
106
 
111
 
107
 	@classmethod
112
 	@classmethod
108
-	def set_config_values(cls, guild: Guild, vars: dict) -> None:
113
+	def set_config_values(cls, guild: Guild, values: dict) -> None:
109
 		"""
114
 		"""
110
-		Merges the given `vars` dict with the saved config for the given guild
111
-		and writes it to disk. `vars` must be JSON-encodable or a `ValueError`
115
+		Merges the given `values` dict with the saved config for the given guild
116
+		and writes it to disk. `values` must be JSON-encodable or a `ValueError`
112
 		will be raised. Keys with associated values of `None` will be removed
117
 		will be raised. Keys with associated values of `None` will be removed
113
 		from the persisted config.
118
 		from the persisted config.
114
 		"""
119
 		"""
115
-		if vars is None or len(vars) == 0:
120
+		if values is None or len(values) == 0:
116
 			return
121
 			return
117
 		config: dict = cls.get_config(guild)
122
 		config: dict = cls.get_config(guild)
118
 		try:
123
 		try:
119
-			json.dumps(vars)
120
-		except:
121
-			raise ValueError(f'vars not JSON encodable - {vars}')
122
-		for key, value in vars.items():
124
+			json.dumps(values)
125
+		except Exception as e:
126
+			raise ValueError(f'values not JSON encodable - {values}') from e
127
+		for key, value in values.items():
123
 			if value is None:
128
 			if value is None:
124
 				del config[key]
129
 				del config[key]
125
 			else:
130
 			else:
134
 		path: str = cls.__guild_config_path(guild)
139
 		path: str = cls.__guild_config_path(guild)
135
 		cls.__trace(f'Saving config for guild {guild.id} to {path}')
140
 		cls.__trace(f'Saving config for guild {guild.id} to {path}')
136
 		cls.__trace(f'config = {config}')
141
 		cls.__trace(f'config = {config}')
137
-		with open(path, 'w') as file:
142
+		with open(path, 'w', encoding='utf8') as file:
138
 			# Pretty printing to make more legible for debugging
143
 			# Pretty printing to make more legible for debugging
139
 			# Sorting keys to help with diffs
144
 			# Sorting keys to help with diffs
140
 			json.dump(config, file, indent='\t', sort_keys=True)
145
 			json.dump(config, file, indent='\t', sort_keys=True)
151
 			cls.__trace(f'No config on disk for guild {guild.id}. Returning None.')
156
 			cls.__trace(f'No config on disk for guild {guild.id}. Returning None.')
152
 			return None
157
 			return None
153
 		cls.__trace(f'Loading config from disk for guild {guild.id}')
158
 		cls.__trace(f'Loading config from disk for guild {guild.id}')
154
-		with open(path, 'r') as file:
159
+		with open(path, 'r', encoding='utf8') as file:
155
 			config = json.load(file)
160
 			config = json.load(file)
156
 		cls.__trace('State loaded')
161
 		cls.__trace('State loaded')
157
 		return config
162
 		return config

Notiek ielāde…
Atcelt
Saglabāt