Переглянути джерело

PyCharm linter stuff. URLs now handled better in edit log message diffs

master
Rocketsoup 2 місяці тому
джерело
коміт
f064428bd2

+ 3
- 0
.gitignore Переглянути файл

@@ -120,3 +120,6 @@ dmypy.json
120 120
 # Rocketbot stuff
121 121
 /config.py
122 122
 /config/**
123
+
124
+# PyCharm
125
+.idea

+ 5
- 0
README.md Переглянути файл

@@ -2,6 +2,11 @@
2 2
 
3 3
 Experimental Discord bot written in Python.
4 4
 
5
+## Requirements
6
+
7
+* Written for Python 3.9
8
+* Install dependencies with `pip3.9 install -r requirements.txt`
9
+
5 10
 ## Usage
6 11
 
7 12
 * To see the list of commands, type `$rb_help`.

+ 1
- 0
requirements.txt Переглянути файл

@@ -0,0 +1 @@
1
+discord.py == 2.3.2

+ 2
- 1
rocketbot/botmessage.py Переглянути файл

@@ -317,7 +317,8 @@ class BotMessage:
317 317
 		s += self.text
318 318
 
319 319
 		if self.quote:
320
-			s += f'\n\n> {self.quote}'
320
+			quoted = '\n> '.join(self.quote.splitlines())
321
+			s += f'\n\n> {quoted}'
321 322
 
322 323
 		if len(self.__reactions) > 0:
323 324
 			s += '\n\nAvailable actions:'

+ 8
- 8
rocketbot/cogs/autokickcog.py Переглянути файл

@@ -9,7 +9,6 @@ from config import CONFIG
9 9
 from rocketbot.cogs.basecog import BaseCog, BotMessage, CogSetting
10 10
 from rocketbot.collections import AgeBoundDict
11 11
 from rocketbot.storage import Storage
12
-from rocketbot.utils import bot_log
13 12
 
14 13
 class AutoKickContext:
15 14
 	"""
@@ -83,7 +82,7 @@ class AutoKickCog(BaseCog, name='Auto Kick'):
83 82
 			self.log(guild, f'New member {member.name} status is {member.status}')
84 83
 			self.status_check_members.append(StatusCheckContext(member))
85 84
 			return
86
-		self.__kick_or_ban_if_needed(member)
85
+		await self.__kick_or_ban_if_needed(member)
87 86
 
88 87
 	@tasks.loop(seconds=5.0)
89 88
 	async def status_check_timer(self):
@@ -148,10 +147,11 @@ class AutoKickCog(BaseCog, name='Auto Kick'):
148 147
 	@staticmethod
149 148
 	def ordinal(val: int):
150 149
 		'Formats an integer with an ordinal suffix (English only)'
151
-		if val % 10 == 1:
152
-			return f'{val}st'
153
-		if val % 10 == 2:
154
-			return f'{val}nd'
155
-		if val % 10 == 3:
156
-			return f'{val}rd'
150
+		if val % 100 < 10 or val % 100 > 20:
151
+			if val % 10 == 1:
152
+				return f'{val}st'
153
+			if val % 10 == 2:
154
+				return f'{val}nd'
155
+			if val % 10 == 3:
156
+				return f'{val}rd'
157 157
 		return f'{val}th'

+ 11
- 6
rocketbot/cogs/basecog.py Переглянути файл

@@ -1,8 +1,10 @@
1 1
 """
2 2
 Base cog class and helper classes.
3 3
 """
4
-from datetime import datetime, timedelta
5
-from discord import Guild, Member, Message, RawReactionActionEvent
4
+from datetime import datetime, timedelta, timezone
5
+from typing import Optional
6
+
7
+from discord import Guild, Member, Message, RawReactionActionEvent, TextChannel
6 8
 from discord.abc import GuildChannel
7 9
 from discord.ext import commands
8 10
 
@@ -158,7 +160,7 @@ class BaseCog(commands.Cog):
158 160
 	def __bot_messages(cls, guild: Guild) -> AgeBoundDict:
159 161
 		bm = Storage.get_state_value(guild, 'bot_messages')
160 162
 		if bm is None:
161
-			far_future = datetime.utcnow() + timedelta(days=1000)
163
+			far_future = datetime.now(timezone.utc) + timedelta(days=1000)
162 164
 			bm = AgeBoundDict(timedelta(seconds=600),
163 165
 				lambda k, v : v.message_sent_at() or far_future)
164 166
 			Storage.set_state_value(guild, 'bot_messages', bm)
@@ -204,10 +206,13 @@ class BaseCog(commands.Cog):
204 206
 			# Can't use this reaction with this message
205 207
 			return
206 208
 
207
-		channel: GuildChannel = guild.get_channel(payload.channel_id) or await guild.fetch_channel(payload.channel_id)
208
-		if channel is None:
209
+		g_channel: GuildChannel = guild.get_channel(payload.channel_id) or await guild.fetch_channel(payload.channel_id)
210
+		if g_channel is None:
209 211
 			# Possibly a DM
210 212
 			return
213
+		if not isinstance(g_channel, TextChannel):
214
+			return
215
+		channel: TextChannel = g_channel
211 216
 		member: Member = payload.member
212 217
 		if member is None:
213 218
 			return
@@ -237,7 +242,7 @@ class BaseCog(commands.Cog):
237 242
 	# Helpers
238 243
 
239 244
 	@classmethod
240
-	def log(cls, guild: Guild, message) -> None:
245
+	def log(cls, guild: Optional[Guild], message) -> None:
241 246
 		"""
242 247
 		Writes a message to the console. Intended for significant events only.
243 248
 		"""

+ 5
- 5
rocketbot/cogs/configcog.py Переглянути файл

@@ -18,7 +18,7 @@ class ConfigCog(BaseCog, name='Configuration'):
18 18
 	@commands.has_permissions(ban_members=True)
19 19
 	@commands.guild_only()
20 20
 	async def config(self, context: commands.Context):
21
-		'General guild configuration command group'
21
+		"""General guild configuration command group"""
22 22
 		if context.invoked_subcommand is None:
23 23
 			await context.send_help()
24 24
 
@@ -31,7 +31,7 @@ class ConfigCog(BaseCog, name='Configuration'):
31 31
 			'will not be posted!',
32 32
 	)
33 33
 	async def setwarningchannel(self, context: commands.Context) -> None:
34
-		'Command handler'
34
+		"""Command handler"""
35 35
 		guild: Guild = context.guild
36 36
 		channel: TextChannel = context.channel
37 37
 		Storage.set_config_value(guild, ConfigKey.WARNING_CHANNEL_ID,
@@ -46,7 +46,7 @@ class ConfigCog(BaseCog, name='Configuration'):
46 46
 			'warnings will be posted.',
47 47
 	)
48 48
 	async def getwarningchannel(self, context: commands.Context) -> None:
49
-		'Command handler'
49
+		"""Command handler"""
50 50
 		guild: Guild = context.guild
51 51
 		channel_id = Storage.get_config_value(guild, ConfigKey.WARNING_CHANNEL_ID)
52 52
 		if channel_id is None:
@@ -70,7 +70,7 @@ class ConfigCog(BaseCog, name='Configuration'):
70 70
 	async def setwarningmention(self,
71 71
 			context: commands.Context,
72 72
 			mention: str = None) -> None:
73
-		'Command handler'
73
+		"""Command handler"""
74 74
 		guild: Guild = context.guild
75 75
 		Storage.set_config_value(guild, ConfigKey.WARNING_MENTION, mention)
76 76
 		if mention is None:
@@ -88,7 +88,7 @@ class ConfigCog(BaseCog, name='Configuration'):
88 88
 			'warning messages.'
89 89
 	)
90 90
 	async def getwarningmention(self, context: commands.Context) -> None:
91
-		'Command handler'
91
+		"""Command handler"""
92 92
 		guild: Guild = context.guild
93 93
 		mention: str = Storage.get_config_value(guild, ConfigKey.WARNING_MENTION)
94 94
 		if mention is None:

+ 1
- 1
rocketbot/cogs/crosspostcog.py Переглянути файл

@@ -172,7 +172,7 @@ class CrossPostCog(BaseCog, name='Crosspost Detection'):
172 172
 			message_type: int = BotMessage.TYPE_INFO if self.was_warned_recently(context.member) \
173 173
 				else BotMessage.TYPE_MOD_WARNING
174 174
 			message = BotMessage(context.member.guild, '', message_type, context)
175
-			message.quote = discordutils.remove_markdown(first_spam_message.clean_content)
175
+			message.quote = discordutils.remove_markdown(first_spam_message.clean_content())
176 176
 			self.record_warning(context.member)
177 177
 		if context.is_autobanned:
178 178
 			text = f'User {context.member.mention} auto banned for ' + \

+ 5
- 3
rocketbot/cogs/generalcog.py Переглянути файл

@@ -2,7 +2,9 @@
2 2
 Cog for handling most ungrouped commands and basic behaviors.
3 3
 """
4 4
 import re
5
-from datetime import datetime, timedelta
5
+from datetime import datetime, timedelta, timezone
6
+from typing import Optional
7
+
6 8
 from discord import Message
7 9
 from discord.errors import DiscordException
8 10
 from discord.ext import commands
@@ -122,7 +124,7 @@ class GeneralCog(BaseCog, name='General'):
122 124
 				f'{CONFIG["failure_emoji"]} age must be a timespan, like "30s", "10m", "1h30m"',
123 125
 				mention_author=False)
124 126
 			return
125
-		cutoff: datetime = datetime.utcnow() - age_delta
127
+		cutoff: datetime = datetime.now(timezone.utc) - age_delta
126 128
 		def predicate(message: Message) -> bool:
127 129
 			return str(message.author.id) == member_id and message.created_at >= cutoff
128 130
 		deleted_messages = []
@@ -137,7 +139,7 @@ class GeneralCog(BaseCog, name='General'):
137 139
 			f'messages by <@!{member_id}> from the past {describe_timedelta(age_delta)}.',
138 140
 			mention_author=False)
139 141
 
140
-	def __parse_member_id(self, arg: str) -> str:
142
+	def __parse_member_id(self, arg: str) -> Optional[str]:
141 143
 		p = re.compile('^<@!?([0-9]+)>$')
142 144
 		m = p.match(arg)
143 145
 		if m:

+ 6
- 6
rocketbot/cogs/joinagecog.py Переглянути файл

@@ -1,6 +1,6 @@
1 1
 import weakref
2 2
 
3
-from datetime import datetime, timedelta
3
+from datetime import datetime, timedelta, timezone
4 4
 from discord import Guild, Member
5 5
 from discord.ext import commands
6 6
 
@@ -47,7 +47,7 @@ class JoinAgeCog(BaseCog, name='Join Age'):
47 47
 	@commands.has_permissions(ban_members=True)
48 48
 	@commands.guild_only()
49 49
 	async def joinage(self, context: commands.Context):
50
-		'Join age tracking'
50
+		"""Join age tracking"""
51 51
 		if context.invoked_subcommand is None:
52 52
 			await context.send_help()
53 53
 
@@ -58,16 +58,16 @@ class JoinAgeCog(BaseCog, name='Join Age'):
58 58
 		usage='<time_period>'
59 59
 	)
60 60
 	async def search(self, context: commands.Context, timespan: str):
61
-		'Command handler'
61
+		"""Command handler"""
62 62
 		guild: Guild = context.guild
63 63
 		recent_joins: AgeBoundList = Storage.get_state_value(guild, self.STATE_KEY_RECENT_JOINS)
64 64
 		if recent_joins is None:
65 65
 			max_age: timedelta = timedelta(seconds=self.get_guild_setting(guild, self.SETTING_JOIN_TIME))
66
-			recent_joins = AgeBoundList(max_age, lambda i, member : member.joined_at)
66
+			recent_joins = AgeBoundList(max_age, lambda i, member0 : member0.joined_at)
67 67
 			Storage.set_state_value(guild, self.STATE_KEY_RECENT_JOINS, recent_joins)
68 68
 		results: list = []
69 69
 		ts: timedelta = timedelta_from_str(timespan)
70
-		cutoff: datetime = datetime.utcnow() - ts
70
+		cutoff: datetime = datetime.now(timezone.utc) - ts
71 71
 		for member in recent_joins:
72 72
 			if member.joined_at > cutoff:
73 73
 				results.append(member)
@@ -106,7 +106,7 @@ class JoinAgeCog(BaseCog, name='Join Age'):
106 106
 
107 107
 	@commands.Cog.listener()
108 108
 	async def on_member_join(self, member: Member) -> None:
109
-		'Event handler'
109
+		"""Event handler"""
110 110
 		guild: Guild = member.guild
111 111
 		if not self.get_guild_setting(guild, self.SETTING_ENABLED):
112 112
 			return

+ 3
- 3
rocketbot/cogs/joinraidcog.py Переглянути файл

@@ -22,7 +22,7 @@ class JoinRaidContext:
22 22
 		self.warning_message_ref = None
23 23
 
24 24
 	def last_join_time(self) -> datetime:
25
-		'Returns when the most recent member join was, in UTC'
25
+		"""Returns when the most recent member join was, in UTC"""
26 26
 		return self.join_members[-1].joined_at
27 27
 
28 28
 class JoinRaidCog(BaseCog, name='Join Raids'):
@@ -62,7 +62,7 @@ class JoinRaidCog(BaseCog, name='Join Raids'):
62 62
 	@commands.has_permissions(ban_members=True)
63 63
 	@commands.guild_only()
64 64
 	async def joinraid(self, context: commands.Context):
65
-		'Join raid detection command group'
65
+		"""Join raid detection command group"""
66 66
 		if context.invoked_subcommand is None:
67 67
 			await context.send_help()
68 68
 
@@ -92,7 +92,7 @@ class JoinRaidCog(BaseCog, name='Join Raids'):
92 92
 
93 93
 	@commands.Cog.listener()
94 94
 	async def on_member_join(self, member: Member) -> None:
95
-		'Event handler'
95
+		"""Event handler"""
96 96
 		guild: Guild = member.guild
97 97
 		if not self.get_guild_setting(guild, self.SETTING_ENABLED):
98 98
 			return

+ 30
- 9
rocketbot/cogs/logcog.py Переглянути файл

@@ -1,21 +1,17 @@
1 1
 """
2 2
 Cog for detecting large numbers of guild joins in a short period of time.
3 3
 """
4
-import weakref
5 4
 from collections.abc import Sequence
6
-from datetime import datetime, timedelta
5
+from datetime import datetime
7 6
 from discord import AuditLogAction, AuditLogEntry, Emoji, Guild, GuildSticker, Invite, Member, Message, RawBulkMessageDeleteEvent, RawMessageDeleteEvent, RawMessageUpdateEvent, Role, Thread, User
8 7
 from discord.abc import GuildChannel
9 8
 from discord.ext import commands
10 9
 from discord.utils import escape_markdown
11
-from typing import List, Optional, Tuple, Union
10
+from typing import Optional, Tuple, Union
12 11
 import difflib
13
-import traceback
12
+import re
14 13
 
15
-from config import CONFIG
16
-from rocketbot.cogs.basecog import BaseCog, BotMessage, BotMessageReaction, CogSetting
17
-from rocketbot.collections import AgeBoundList
18
-from rocketbot.storage import Storage
14
+from rocketbot.cogs.basecog import BaseCog, BotMessage, CogSetting
19 15
 
20 16
 class LoggingCog(BaseCog, name='Logging'):
21 17
 	"""
@@ -457,7 +453,7 @@ class LoggingCog(BaseCog, name='Logging'):
457 453
 			# Most likely an embed being asynchronously populated by server
458 454
 			return
459 455
 		if content_changed:
460
-			(before_markdown, after_markdown) = self.__diff(self.__quote_markdown(before.content), \
456
+			(before_markdown, after_markdown) = self.__diff(self.__quote_markdown(before.content),
461 457
 													  self.__quote_markdown(after.content))
462 458
 		else:
463 459
 			before_markdown = self.__quote_markdown(before.content)
@@ -722,6 +718,20 @@ class LoggingCog(BaseCog, name='Logging'):
722 718
 		return f'**{user.name}** ({user.display_name} {user.id})'
723 719
 
724 720
 	def __diff(self, a: str, b: str) -> Tuple[str, str]:
721
+		# URLs don't work well in the diffs. Replace them with private use characters, one per unique URL.
722
+		preserved_sequences = []
723
+		def sub_token(match: re.Match) -> str:
724
+			seq = match.group(0)
725
+			sequence_index = len(preserved_sequences)
726
+			if seq in preserved_sequences:
727
+				sequence_index = preserved_sequences.index(seq)
728
+			else:
729
+				preserved_sequences.append(seq)
730
+			return chr(0xe000 + sequence_index)
731
+		url_regex = r'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\(\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+'
732
+		a = re.sub(url_regex, sub_token, a)
733
+		b = re.sub(url_regex, sub_token, b)
734
+
725 735
 		deletion_start = '~~'
726 736
 		deletion_end = '~~'
727 737
 		addition_start = '**'
@@ -758,4 +768,15 @@ class LoggingCog(BaseCog, name='Logging'):
758 768
 			markdown_a += deletion_end
759 769
 		if b_open:
760 770
 			markdown_b += addition_end
771
+
772
+		# Sub URLs back in
773
+		def unsub_token(match: re.Match) -> str:
774
+			char = match.group(0)
775
+			index = ord(char) - 0xe000
776
+			if 0 <= index < len(preserved_sequences):
777
+				return preserved_sequences[index]
778
+			return char
779
+		markdown_a = re.sub(r'[\ue000-\uefff]', unsub_token, markdown_a)
780
+		markdown_b = re.sub(r'[\ue000-\uefff]', unsub_token, markdown_b)
781
+
761 782
 		return (markdown_a, markdown_b)

+ 7
- 5
rocketbot/cogs/patterncog.py Переглянути файл

@@ -3,6 +3,8 @@ Cog for matching messages against guild-configurable criteria and taking
3 3
 automated actions on them.
4 4
 """
5 5
 from datetime import datetime
6
+from typing import Optional
7
+
6 8
 from discord import Guild, Member, Message, utils as discordutils
7 9
 from discord.ext import commands
8 10
 
@@ -60,19 +62,19 @@ class PatternCog(BaseCog, name='Pattern Matching'):
60 62
 	def __save_patterns(cls,
61 63
 			guild: Guild,
62 64
 			patterns: dict[str, PatternStatement]) -> None:
63
-		to_save: list[dict] = list(map(PatternStatement.to_json, patterns.values()))
65
+		to_save: list[dict] = list(map(lambda ps: ps.to_json(), patterns.values()))
64 66
 		cls.set_guild_setting(guild, cls.SETTING_PATTERNS, to_save)
65 67
 
66 68
 	@classmethod
67
-	def __get_last_matched(cls, guild: Guild, name: str) -> datetime:
68
-		last_matched: dict[name, datetime] = Storage.get_state_value(guild, 'PatternCog.last_matched')
69
+	def __get_last_matched(cls, guild: Guild, name: str) -> Optional[datetime]:
70
+		last_matched: dict[str, datetime] = Storage.get_state_value(guild, 'PatternCog.last_matched')
69 71
 		if last_matched:
70 72
 			return last_matched.get(name)
71 73
 		return None
72 74
 
73 75
 	@classmethod
74 76
 	def __set_last_matched(cls, guild: Guild, name: str, time: datetime) -> None:
75
-		last_matched: dict[name, datetime] = Storage.get_state_value(guild, 'PatternCog.last_matched')
77
+		last_matched: dict[str, datetime] = Storage.get_state_value(guild, 'PatternCog.last_matched')
76 78
 		if last_matched is None:
77 79
 			last_matched = {}
78 80
 			Storage.set_state_value(guild, 'PatternCog.last_matched', last_matched)
@@ -156,7 +158,7 @@ class PatternCog(BaseCog, name='Pattern Matching'):
156 158
 				type=message_type,
157 159
 				context=context)
158 160
 			self.record_warning(message.author)
159
-			bm.quote = discordutils.remove_markdown(message.clean_content)
161
+			bm.quote = discordutils.remove_markdown(message.clean_content())
160 162
 			await bm.set_reactions(BotMessageReaction.standard_set(
161 163
 				did_delete=context.is_deleted,
162 164
 				did_kick=context.is_kicked,

+ 9
- 9
rocketbot/cogs/urlspamcog.py Переглянути файл

@@ -33,7 +33,7 @@ class URLSpamCog(BaseCog, name='URL Spam'):
33 33
 	SETTING_ACTION = CogSetting('action', str,
34 34
 			brief='action to take on spam',
35 35
 			description='The action to take on detected URL spam.',
36
-			enum_values=set(['nothing', 'modwarn', 'delete', 'kick', 'ban']))
36
+			enum_values={'nothing', 'modwarn', 'delete', 'kick', 'ban'})
37 37
 	SETTING_JOIN_AGE = CogSetting('joinage', float,
38 38
 			brief='seconds since member joined',
39 39
 			description='The minimum seconds since the user joined the ' + \
@@ -48,8 +48,8 @@ class URLSpamCog(BaseCog, name='URL Spam'):
48 48
 			brief='action to take on deceptive link markdown',
49 49
 			description='The action to take on chat messages with links ' + \
50 50
 				'where the text looks like a different URL than the actual link.',
51
-			enum_values=set(['nothing', 'modwarn', 'modwarndelete', \
52
-				'chatwarn', 'chatwarndelete', 'delete', 'kick', 'ban']))
51
+			enum_values={'nothing', 'modwarn', 'modwarndelete',
52
+				'chatwarn', 'chatwarndelete', 'delete', 'kick', 'ban'})
53 53
 
54 54
 	def __init__(self, bot):
55 55
 		super().__init__(bot)
@@ -79,8 +79,8 @@ class URLSpamCog(BaseCog, name='URL Spam'):
79 79
 			return
80 80
 		if not self.get_guild_setting(message.guild, self.SETTING_ENABLED):
81 81
 			return
82
-		await self.check_message_recency(message);
83
-		await self.check_deceptive_links(message);
82
+		await self.check_message_recency(message)
83
+		await self.check_deceptive_links(message)
84 84
 
85 85
 	async def check_message_recency(self, message: Message):
86 86
 		'Checks if the message was sent too recently by a new user'
@@ -132,7 +132,7 @@ class URLSpamCog(BaseCog, name='URL Spam'):
132 132
 					f'{join_age_str} after joining.',
133 133
 					type = BotMessage.TYPE_MOD_WARNING if needs_attention else BotMessage.TYPE_INFO,
134 134
 					context = context)
135
-			bm.quote = discordutils.remove_markdown(message.clean_content)
135
+			bm.quote = discordutils.remove_markdown(message.clean_content())
136 136
 			await bm.set_reactions(BotMessageReaction.standard_set(
137 137
 				did_delete=context.is_deleted,
138 138
 				did_kick=context.is_kicked,
@@ -197,7 +197,7 @@ class URLSpamCog(BaseCog, name='URL Spam'):
197 197
 		# Strip markdown that can safely contain URL sequences
198 198
 		content = re.sub(r'`[^`]+`', '', content)  # `inline code`
199 199
 		content = re.sub(r'```.+?```', '', content, re.DOTALL)  # ``` code block ```
200
-		matches = re.findall(r'\[([^\]]+)\]\(([^\)]+)\)', content)
200
+		matches = re.findall(r'\[([^]]+)]\(([^)]+)\)', content)
201 201
 		for match in matches:
202 202
 			original_label: str = match[0].strip()
203 203
 			original_link: str = match[1].strip()
@@ -246,7 +246,7 @@ class URLSpamCog(BaseCog, name='URL Spam'):
246 246
 		port_pattern = '(?::[0-9]+)?'
247 247
 		path_pattern = r'(?:/[^ \]\)]*)?'
248 248
 		pattern = r'^' + host_pattern + port_pattern + path_pattern + '$'
249
-		return re.match(pattern, s, re.IGNORECASE) != None
249
+		return re.match(pattern, s, re.IGNORECASE) is not None
250 250
 
251 251
 	async def on_mod_react(self,
252 252
 			bot_message: BotMessage,
@@ -291,5 +291,5 @@ class URLSpamCog(BaseCog, name='URL Spam'):
291 291
 
292 292
 	@classmethod
293 293
 	def __contains_url(cls, text: str) -> bool:
294
-		p = re.compile(r'http(?:s)?://[^\s]+')
294
+		p = re.compile(r'https?://\S+')
295 295
 		return p.search(text) is not None

+ 5
- 3
rocketbot/cogs/usernamecog.py Переглянути файл

@@ -1,6 +1,8 @@
1 1
 """
2 2
 Cog for detecting username patterns.
3 3
 """
4
+from typing import Optional
5
+
4 6
 from discord import Guild, Member
5 7
 from discord.ext import commands
6 8
 
@@ -14,9 +16,9 @@ class UsernamePatternContext:
14 16
 	"""
15 17
 	def __init__(self, member: Member) -> None:
16 18
 		self.member: Member = member
17
-		self.kicked_by: Member = None
18
-		self.banned_by: Member = None
19
-		self.ignored_by: Member = None
19
+		self.kicked_by: Optional[Member] = None
20
+		self.banned_by: Optional[Member] = None
21
+		self.ignored_by: Optional[Member] = None
20 22
 
21 23
 	def reactions(self) -> list[BotMessageReaction]:
22 24
 		"""

+ 44
- 45
rocketbot/cogsetting.py Переглянути файл

@@ -1,16 +1,24 @@
1 1
 """
2 2
 A guild configuration setting available for editing via bot commands.
3 3
 """
4
-from types import coroutine
5
-from typing import Any, Optional, Type
4
+from typing import Any, Callable, Coroutine, Optional, Type
6 5
 
7 6
 from discord.ext import commands
8
-from discord.ext.commands import Bot, Cog, Command, Context, Group
7
+from discord.ext.commands import Bot, Command, Context, Group, Cog
9 8
 
10 9
 from config import CONFIG
11 10
 from rocketbot.storage import Storage
12 11
 from rocketbot.utils import first_command_group
13 12
 
13
+def _fix_command(command: Command) -> None:
14
+	"""
15
+	HACK: Fixes bug in discord.py 2.3.2 where it's requiring the user to
16
+	supply the context argument. This removes that argument from the list.
17
+	"""
18
+	params = command.params
19
+	del params['context']
20
+	command.params = params
21
+
14 22
 class CogSetting:
15 23
 	"""
16 24
 	Describes a configuration setting for a guild that can be edited by the
@@ -20,7 +28,7 @@ class CogSetting:
20 28
 	"""
21 29
 	def __init__(self,
22 30
 			name: str,
23
-			datatype: Type,
31
+			datatype: Optional[Type],
24 32
 			brief: Optional[str] = None,
25 33
 			description: Optional[str] = None,
26 34
 			usage: Optional[str] = None,
@@ -90,14 +98,14 @@ class CogSetting:
90 98
 
91 99
 	def __make_getter_command(self, cog: Cog) -> Command:
92 100
 		setting: CogSetting = self
93
-		async def getter(cog: Cog, context: Context) -> None:
101
+		async def getter(cog0: Cog, context: Context) -> None:
94 102
 			setting_name = setting.name
95
-			if context.command.parent:
103
+			if isinstance(context.command.parent, Group):
96 104
 				setting_name = f'{context.command.parent.name}.{setting_name}'
97
-			key = f'{cog.__class__.__name__}.{setting.name}'
105
+			key = f'{cog0.__class__.__name__}.{setting.name}'
98 106
 			value = Storage.get_config_value(context.guild, key)
99 107
 			if value is None:
100
-				value = cog.get_cog_default(setting.name)
108
+				value = cog0.get_cog_default(setting.name)
101 109
 				await context.message.reply(
102 110
 					f'{CONFIG["info_emoji"]} `{setting_name}` is using default of `{value}`',
103 111
 					mention_author=False)
@@ -115,12 +123,12 @@ class CogSetting:
115 123
 				commands.guild_only(),
116 124
 			])
117 125
 		command.cog = cog
118
-		self.__fix_command(command)
126
+		_fix_command(command)
119 127
 		return command
120 128
 
121 129
 	def __make_setter_command(self, cog: Cog) -> Command:
122 130
 		setting: CogSetting = self
123
-		async def setter_common(cog: Cog, context: Context, new_value) -> None:
131
+		async def setter_common(cog0: Cog, context: Context, new_value) -> None:
124 132
 			try:
125 133
 				setting.validate_value(new_value)
126 134
 			except ValueError as ve:
@@ -129,26 +137,26 @@ class CogSetting:
129 137
 					mention_author=False)
130 138
 				return
131 139
 			setting_name = setting.name
132
-			if context.command.parent:
140
+			if isinstance(context.command.parent, Group):
133 141
 				setting_name = f'{context.command.parent.name}.{setting_name}'
134
-			key = f'{cog.__class__.__name__}.{setting.name}'
142
+			key = f'{cog0.__class__.__name__}.{setting.name}'
135 143
 			Storage.set_config_value(context.guild, key, new_value)
136 144
 			await context.message.reply(
137 145
 				f'{CONFIG["success_emoji"]} `{setting_name}` is now set to `{new_value}`',
138 146
 				mention_author=False)
139
-			await cog.on_setting_updated(context.guild, setting)
140
-			cog.log(context.guild, f'{context.author.name} set {key} to {new_value}')
141
-
142
-		async def setter_int(cog, context, new_value: int):
143
-			await setter_common(cog, context, new_value)
144
-		async def setter_float(cog, context, new_value: float):
145
-			await setter_common(cog, context, new_value)
146
-		async def setter_str(cog, context, new_value: str):
147
-			await setter_common(cog, context, new_value)
148
-		async def setter_bool(cog, context, new_value: bool):
149
-			await setter_common(cog, context, new_value)
150
-
151
-		setter: coroutine = None
147
+			await cog0.on_setting_updated(context.guild, setting)
148
+			cog0.log(context.guild, f'{context.author.name} set {key} to {new_value}')
149
+
150
+		async def setter_int(cog1, context, new_value: int):
151
+			await setter_common(cog1, context, new_value)
152
+		async def setter_float(cog2, context, new_value: float):
153
+			await setter_common(cog2, context, new_value)
154
+		async def setter_str(cog3, context, new_value: str):
155
+			await setter_common(cog3, context, new_value)
156
+		async def setter_bool(cog4, context, new_value: bool):
157
+			await setter_common(cog4, context, new_value)
158
+
159
+		setter: Callable[[Cog, Context, Any], Coroutine]
152 160
 		if setting.datatype == int:
153 161
 			setter = setter_int
154 162
 		elif setting.datatype == float:
@@ -173,19 +181,19 @@ class CogSetting:
173 181
 		# HACK: Passing `cog` in init gets ignored and set to `None` so set after.
174 182
 		# This ensures the callback is passed the cog as `self` argument.
175 183
 		command.cog = cog
176
-		self.__fix_command(command)
184
+		_fix_command(command)
177 185
 		return command
178 186
 
179 187
 	def __make_enable_command(self, cog: Cog) -> Command:
180 188
 		setting: CogSetting = self
181
-		async def enabler(cog: Cog, context: Context) -> None:
182
-			key = f'{cog.__class__.__name__}.{setting.name}'
189
+		async def enabler(cog0: Cog, context: Context) -> None:
190
+			key = f'{cog0.__class__.__name__}.{setting.name}'
183 191
 			Storage.set_config_value(context.guild, key, True)
184 192
 			await context.message.reply(
185 193
 				f'{CONFIG["success_emoji"]} {setting.brief.capitalize()} enabled.',
186 194
 				mention_author=False)
187
-			await cog.on_setting_updated(context.guild, setting)
188
-			cog.log(context.guild, f'{context.author.name} enabled {cog.__class__.__name__}')
195
+			await cog0.on_setting_updated(context.guild, setting)
196
+			cog0.log(context.guild, f'{context.author.name} enabled {cog0.__class__.__name__}')
189 197
 
190 198
 		command = Command(
191 199
 			enabler,
@@ -197,19 +205,19 @@ class CogSetting:
197 205
 				commands.guild_only(),
198 206
 			])
199 207
 		command.cog = cog
200
-		self.__fix_command(command)
208
+		_fix_command(command)
201 209
 		return command
202 210
 
203 211
 	def __make_disable_command(self, cog: Cog) -> Command:
204 212
 		setting: CogSetting = self
205
-		async def disabler(cog: Cog, context: Context) -> None:
206
-			key = f'{cog.__class__.__name__}.{setting.name}'
213
+		async def disabler(cog0: Cog, context: Context) -> None:
214
+			key = f'{cog0.__class__.__name__}.{setting.name}'
207 215
 			Storage.set_config_value(context.guild, key, False)
208 216
 			await context.message.reply(
209 217
 				f'{CONFIG["success_emoji"]} {setting.brief.capitalize()} disabled.',
210 218
 				mention_author=False)
211
-			await cog.on_setting_updated(context.guild, setting)
212
-			cog.log(context.guild, f'{context.author.name} disabled {cog.__class__.__name__}')
219
+			await cog0.on_setting_updated(context.guild, setting)
220
+			cog0.log(context.guild, f'{context.author.name} disabled {cog0.__class__.__name__}')
213 221
 
214 222
 		command = Command(
215 223
 			disabler,
@@ -221,18 +229,9 @@ class CogSetting:
221 229
 				commands.guild_only(),
222 230
 			])
223 231
 		command.cog = cog
224
-		self.__fix_command(command)
232
+		_fix_command(command)
225 233
 		return command
226 234
 
227
-	def __fix_command(self, command: Command) -> None:
228
-		"""
229
-		HACK: Fixes bug in discord.py 2.3.2 where it's requiring the user to
230
-		supply the context argument. This removes that argument from the list.
231
-		"""
232
-		params = command.params
233
-		del params['context']
234
-		command.params = params
235
-
236 235
 	@classmethod
237 236
 	def set_up_all(cls, cog: Cog, bot: Bot, settings: list) -> None:
238 237
 		"""

+ 1
- 1
rocketbot/pattern.py Переглянути файл

@@ -4,7 +4,7 @@ to take on them.
4 4
 """
5 5
 import re
6 6
 from abc import ABCMeta, abstractmethod
7
-from datetime, timezone import datetime
7
+from datetime import datetime, timezone
8 8
 from typing import Any
9 9
 
10 10
 from discord import Message, utils as discordutils

+ 11
- 11
rocketbot/utils.py Переглянути файл

@@ -78,8 +78,8 @@ def describe_timedelta(td: timedelta, max_components: int = 2) -> str:
78 78
 		components = components[0:max_components]
79 79
 	return ' '.join(components)
80 80
 
81
-def first_command_group(cog: Cog) -> Group:
82
-	'Returns the first command Group found in a cog.'
81
+def first_command_group(cog: Cog) -> Optional[Group]:
82
+	"""Returns the first command Group found in a cog."""
83 83
 	for member_name in dir(cog):
84 84
 		member = getattr(cog, member_name)
85 85
 		if isinstance(member, Group):
@@ -87,7 +87,7 @@ def first_command_group(cog: Cog) -> Group:
87 87
 	return None
88 88
 
89 89
 def bot_log(guild: Optional[Guild], cog_class: Optional[Type], message: Any) -> None:
90
-	'Logs a message to stdout with time, cog, and guild info.'
90
+	"""Logs a message to stdout with time, cog, and guild info."""
91 91
 	now: datetime = datetime.now() # local
92 92
 	s = f'[{now.strftime("%Y-%m-%dT%H:%M:%S")}|'
93 93
 	s += f'{cog_class.__name__}|' if cog_class else '-|'
@@ -102,38 +102,38 @@ __USER_MENTION_REGEX: re.Pattern = re.compile('^<@!([0-9]{17,20})>$')
102 102
 __ROLE_MENTION_REGEX: re.Pattern = re.compile('^<@&([0-9]{17,20})>$')
103 103
 
104 104
 def is_user_id(val: str) -> bool:
105
-	'Tests if a string is in user/role ID format.'
105
+	"""Tests if a string is in user/role ID format."""
106 106
 	return __ID_REGEX.match(val) is not None
107 107
 
108 108
 def is_mention(val: str) -> bool:
109
-	'Tests if a string is a user or role mention.'
109
+	"""Tests if a string is a user or role mention."""
110 110
 	return __MENTION_REGEX.match(val) is not None
111 111
 
112 112
 def is_role_mention(val: str) -> bool:
113
-	'Tests if a string is a role mention.'
113
+	"""Tests if a string is a role mention."""
114 114
 	return __ROLE_MENTION_REGEX.match(val) is not None
115 115
 
116 116
 def is_user_mention(val: str) -> bool:
117
-	'Tests if a string is a user mention.'
117
+	"""Tests if a string is a user mention."""
118 118
 	return __USER_MENTION_REGEX.match(val) is not None
119 119
 
120 120
 def user_id_from_mention(mention: str) -> str:
121
-	'Extracts the user ID from a mention. Raises a ValueError if malformed.'
121
+	"""Extracts the user ID from a mention. Raises a ValueError if malformed."""
122 122
 	m = __USER_MENTION_REGEX.match(mention)
123 123
 	if m:
124 124
 		return m.group(1)
125 125
 	raise ValueError(f'"{mention}" is not an @ user mention')
126 126
 
127 127
 def mention_from_user_id(user_id: Union[str, int]) -> str:
128
-	'Returns a markdown user mention from a user id.'
128
+	"""Returns a markdown user mention from a user id."""
129 129
 	return f'<@!{user_id}>'
130 130
 
131 131
 def mention_from_role_id(role_id: Union[str, int]) -> str:
132
-	'Returns a markdown role mention from a role id.'
132
+	"""Returns a markdown role mention from a role id."""
133 133
 	return f'<@&{role_id}>'
134 134
 
135 135
 def str_from_quoted_str(val: str) -> str:
136
-	'Removes the leading and trailing quotes from a string.'
136
+	"""Removes the leading and trailing quotes from a string."""
137 137
 	if len(val) < 2 or val[0:1] not in __QUOTE_CHARS or val[-1:] not in __QUOTE_CHARS:
138 138
 		raise ValueError(f'Not a quoted string: {val}')
139 139
 	return val[1:-1]

Завантаження…
Відмінити
Зберегти