瀏覽代碼

cog_refactor (#1)

tags/1.0
Rocketsoup 4 年之前
父節點
當前提交
e746ae2f78
共有 10 個檔案被更改,包括 841 行新增754 行删除
  1. 3
    3
      .gitignore
  2. 0
    0
      cogs/__init__.py
  3. 49
    0
      cogs/basecog.py
  4. 96
    0
      cogs/configcog.py
  5. 39
    0
      cogs/generalcog.py
  6. 523
    0
      cogs/joinraidcog.py
  7. 1
    0
      config.py.sample
  8. 二進制
      rocketbot.db.sample
  9. 13
    751
      rocketbot.py
  10. 117
    0
      storage.py

+ 3
- 3
.gitignore 查看文件

@@ -117,6 +117,6 @@ dmypy.json
117 117
 # Mac!!
118 118
 .DS_Store
119 119
 
120
-# Contains secrets so exclude
121
-config.py
122
-rocketbot.db
120
+# Rocketbot stuff
121
+/config.py
122
+/state/**

+ 0
- 0
cogs/__init__.py 查看文件


+ 49
- 0
cogs/basecog.py 查看文件

@@ -0,0 +1,49 @@
1
+from discord import Guild, Message, TextChannel
2
+from discord.ext import commands
3
+
4
+from storage import StateKey, Storage
5
+import json
6
+
7
+class BaseCog(commands.Cog):
8
+	def __init__(self, bot):
9
+		self.bot = bot
10
+
11
+	@classmethod
12
+	async def warn(cls, guild: Guild, message: str) -> Message:
13
+		"""
14
+		Sends a warning message to the configured warning channel for the
15
+		given guild. If no warning channel is configured no action is taken.
16
+		Returns the Message if successful or None if not.
17
+		"""
18
+		channel_id = Storage.get_state_value(guild, StateKey.WARNING_CHANNEL_ID)
19
+		if channel_id is None:
20
+			cls.guild_trace(guild, 'No warning channel set! No warning issued.')
21
+			return None
22
+		channel: TextChannel = guild.get_channel(channel_id)
23
+		if channel is None:
24
+			cls.guild_trace(guild, 'Configured warning channel does not exist!')
25
+			return None
26
+		mention: str = Storage.get_state_value(guild, StateKey.WARNING_MENTION)
27
+		text: str = message
28
+		if mention is not None:
29
+			text = f'{mention} {text}'
30
+		msg: Message = await channel.send(text)
31
+		return msg
32
+
33
+	@classmethod
34
+	async def update_warn(cls, warn_message: Message, new_text: str) -> None:
35
+		"""
36
+		Updates the text of a previously posted `warn`. Includes configured
37
+		mentions if necessary.
38
+		"""
39
+		text: str = new_text
40
+		mention: str = Storage.get_state_value(
41
+			warn_message.guild,
42
+			StateKey.WARNING_MENTION)
43
+		if mention is not None:
44
+			text = f'{mention} {text}'
45
+		await warn_message.edit(content=text)
46
+
47
+	@classmethod
48
+	def guild_trace(cls, guild: Guild, message: str) -> None:
49
+		print(f'[guild {guild.id}|{guild.name}] {message}')

+ 96
- 0
cogs/configcog.py 查看文件

@@ -0,0 +1,96 @@
1
+from discord import Guild, TextChannel
2
+from discord.ext import commands
3
+from storage import StateKey, Storage
4
+from cogs.basecog import BaseCog
5
+
6
+class ConfigCog(BaseCog):
7
+	"""
8
+	Cog for handling general bot configuration.
9
+	"""
10
+	def __init__(self, bot):
11
+		self.bot = bot
12
+
13
+	@commands.group(
14
+		brief='Manages general bot configuration'
15
+	)
16
+	@commands.has_permissions(ban_members=True)
17
+	@commands.guild_only()
18
+	async def config(self, context: commands.Context):
19
+		'Command group'
20
+		if context.invoked_subcommand is None:
21
+			await context.send_help()
22
+
23
+	@config.command(
24
+		name='setwarningchannel',
25
+		brief='Sets the channel where mod warnings are posted',
26
+		description='Run this command in the channel where bot messages ' +
27
+			'intended for server moderators should be sent. Other bot messages ' +
28
+			'may still be posted in the channel a command was invoked in. If ' +
29
+			'no output channel is set, mod-related messages will not be posted!',
30
+	)
31
+	async def config_setwarningchannel(self, context: commands.Context) -> None:
32
+		'Command handler'
33
+		guild: Guild = context.guild
34
+		channel: TextChannel = context.channel
35
+		Storage.set_state_value(guild, StateKey.WARNING_CHANNEL_ID,
36
+			context.channel.id)
37
+		await context.message.reply(
38
+			f'Warning channel updated to {channel.mention}.',
39
+			mention_author=False)
40
+
41
+	@config.command(
42
+		name='getwarningchannel',
43
+		brief='Shows the channel where mod warnings are posted',
44
+	)
45
+	async def config_getwarningchannel(self, context: commands.Context) -> None:
46
+		'Command handler'
47
+		guild: Guild = context.guild
48
+		channel_id = Storage.get_state_value(guild, StateKey.WARNING_CHANNEL_ID)
49
+		if channel_id is None:
50
+			await context.message.reply(
51
+				'No warning channel is configured.',
52
+				mention_author=False)
53
+		else:
54
+			channel = guild.get_channel(channel_id)
55
+			await context.message.reply(
56
+				f'Warning channel is configured as {channel.mention}.',
57
+				mention_author=False)
58
+
59
+	@config.command(
60
+		name='setwarningmention',
61
+		brief='Sets a user/role to mention in warning messages',
62
+		usage='<@user|@role>',
63
+		description='Configures an role or other prefix to include at the ' +
64
+			'beginning of warning messages. If the intent is to get the ' +
65
+			'attention of certain users, be sure to specify a properly ' +
66
+			'formed @ tag, not just the name of the user/role.'
67
+	)
68
+	async def config_setwarningmention(self, context: commands.Context, mention: str = None) -> None:
69
+		guild: Guild = context.guild
70
+		Storage.set_state_value(guild, StateKey.WARNING_MENTION, mention)
71
+		if mention is None:
72
+			await context.message.reply(
73
+				'Warning messages will not tag anyone.',
74
+				mention_author=False)
75
+		else:
76
+			await context.message.reply(
77
+				f'Warning messages will now tag {mention}.',
78
+				mention_author=False)
79
+
80
+	@config.command(
81
+		name='getwarningmention',
82
+		brief='Shows the user/role to mention in warning messages',
83
+		description='Shows the text, if any, that will be prefixed on any ' +
84
+			'warning messages.'
85
+	)
86
+	async def config_getwarningmention(self, context: commands.Context) -> None:
87
+		guild: Guild = context.guild
88
+		mention: str = Storage.get_state_value(guild, StateKey.WARNING_MENTION)
89
+		if mention is None:
90
+			await context.message.reply(
91
+				'No warning mention configured.',
92
+				mention_author=False)
93
+		else:
94
+			await context.message.reply(
95
+				f'Warning messages will tag {mention}',
96
+				mention_author=False)

+ 39
- 0
cogs/generalcog.py 查看文件

@@ -0,0 +1,39 @@
1
+from discord.ext import commands
2
+from cogs.basecog import BaseCog
3
+from storage import StateKey, Storage
4
+
5
+class GeneralCog(BaseCog):
6
+	def __init__(self, bot: commands.Bot):
7
+		self.bot = bot
8
+		self.is_connected = False
9
+		self.is_ready = False
10
+
11
+	@commands.Cog.listener()
12
+	async def on_connect(self):
13
+		print('on_connect')
14
+		self.is_connected = True
15
+
16
+	@commands.Cog.listener()
17
+	async def on_ready(self):
18
+		print('on_ready')
19
+		self.is_ready = True
20
+
21
+	@commands.command(
22
+		brief='Posts a test warning in the configured warning channel.'
23
+	)
24
+	@commands.has_permissions(ban_members=True)
25
+	@commands.guild_only()
26
+	async def testwarn(self, context):
27
+		if Storage.get_state_value(context.guild, StateKey.WARNING_CHANNEL_ID) is None:
28
+			await context.message.reply(
29
+				'No warning channel set!',
30
+				mention_author=False)
31
+		else:
32
+			await self.warn(context.guild,
33
+				f'Test warning message (requested by {context.author.name})')
34
+
35
+	@commands.command()
36
+	async def hello(self, context):
37
+		await context.message.reply(
38
+			f'Hey, {context.author.name}!',
39
+		 	mention_author=False)

+ 523
- 0
cogs/joinraidcog.py 查看文件

@@ -0,0 +1,523 @@
1
+from discord import Guild, Intents, Member, Message, PartialEmoji, RawReactionActionEvent
2
+from discord.ext import commands
3
+from storage import Storage
4
+from cogs.basecog import BaseCog
5
+from config import CONFIG
6
+from datetime import datetime
7
+
8
+class JoinRecord:
9
+    """
10
+    Data object containing details about a single guild join event.
11
+    """
12
+    def __init__(self, member: Member):
13
+        self.member = member
14
+        self.join_time = member.joined_at or datetime.now()
15
+        # These flags only track whether this bot has kicked/banned
16
+        self.is_kicked = False
17
+        self.is_banned = False
18
+
19
+    def age_seconds(self, now: datetime) -> float:
20
+        """
21
+        Returns the age of this join in seconds from the given "now" time.
22
+        """
23
+        a = now - self.join_time
24
+        return float(a.total_seconds())
25
+
26
+class RaidPhase:
27
+    """
28
+    Enum of phases in a JoinRaidRecord. Phases progress monotonically.
29
+    """
30
+    NONE = 0
31
+    JUST_STARTED = 1
32
+    CONTINUING = 2
33
+    ENDED = 3
34
+
35
+class JoinRaidRecord:
36
+    """
37
+    Tracks recent joins to a guild to detect join raids, where a large number
38
+    of automated users all join at the same time. Manages list of joins to not
39
+    grow unbounded.
40
+    """
41
+    def __init__(self):
42
+        self.joins = []
43
+        self.phase = RaidPhase.NONE
44
+        # datetime when the raid started, or None.
45
+        self.raid_start_time = None
46
+        # Message posted to Discord to warn of the raid. Convenience property
47
+        # managed by caller. Ignored by this class.
48
+        self.warning_message = None
49
+
50
+    def handle_join(self,
51
+            member: Member,
52
+            now: datetime,
53
+            max_join_count: int,
54
+            max_age_seconds: float) -> None:
55
+        """
56
+        Processes a new member join to a guild and detects join raids. Updates
57
+        self.phase and self.raid_start_time properties.
58
+        """
59
+        # Check for existing record for this user
60
+        print(f'handle_join({member.name}) start')
61
+        join: JoinRecord = None
62
+        i: int = 0
63
+        while i < len(self.joins):
64
+            elem = self.joins[i]
65
+            if elem.member.id == member.id:
66
+                print(f'Member {member.name} already in join list at index {i}. Removing.')
67
+                join = self.joins.pop(i)
68
+                join.join_time = now
69
+                break
70
+            i += 1
71
+        # Add new record to end
72
+        self.joins.append(join or JoinRecord(member))
73
+        # Check raid status and do upkeep
74
+        self.__process_joins(now, max_age_seconds, max_join_count)
75
+        print(f'handle_join({member.name}) end')
76
+
77
+    def __process_joins(self,
78
+            now: datetime,
79
+            max_age_seconds: float,
80
+            max_join_count: int) -> None:
81
+        """
82
+        Processes self.joins after each addition, detects raids, updates
83
+        self.phase, and throws out unneeded records.
84
+        """
85
+        print('__process_joins {')
86
+        i: int = 0
87
+        recent_count: int = 0
88
+        should_cull: bool = self.phase == RaidPhase.NONE
89
+        while i < len(self.joins):
90
+            join: JoinRecord = self.joins[i]
91
+            age: float = join.age_seconds(now)
92
+            is_old: bool = age > max_age_seconds
93
+            if not is_old:
94
+                recent_count += 1
95
+                print(f'- {i}. {join.member.name} is {age}s old - recent_count={recent_count}')
96
+            if is_old and should_cull:
97
+                self.joins.pop(i)
98
+                print(f'- {i}. {join.member.name} is {age}s old - too old, removing')
99
+            else:
100
+                print(f'- {i}. {join.member.name} is {age}s old - moving on to next')
101
+                i += 1
102
+        is_raid = recent_count > max_join_count
103
+        print(f'- is_raid {is_raid}')
104
+        if is_raid:
105
+            if self.phase == RaidPhase.NONE:
106
+                self.phase = RaidPhase.JUST_STARTED
107
+                self.raid_start_time = now
108
+                print('- Phase moved to JUST_STARTED. Recording raid start time.')
109
+            elif self.phase == RaidPhase.JUST_STARTED:
110
+                self.phase = RaidPhase.CONTINUING
111
+                print('- Phase moved to CONTINUING.')
112
+        elif self.phase == self.phase in (RaidPhase.JUST_STARTED, RaidPhase.CONTINUING):
113
+            self.phase = RaidPhase.ENDED
114
+            print('- Phase moved to ENDED.')
115
+
116
+        # Undo join add if the raid is over
117
+        if self.phase == RaidPhase.ENDED and len(self.joins) > 0:
118
+            last = self.joins.pop(-1)
119
+            print(f'- Popping last join for {last.member.name}')
120
+        print('} __process_joins')
121
+
122
+    async def kick_all(self,
123
+            reason: str = "Part of join raid") -> list[Member]:
124
+        """
125
+        Kicks all users in this join raid. Skips users who have already been
126
+        flagged as having been kicked or banned. Returns a List of Members
127
+        who were newly kicked.
128
+        """
129
+        kicks = []
130
+        for join in self.joins:
131
+            if join.is_kicked or join.is_banned:
132
+                continue
133
+            await join.member.kick(reason=reason)
134
+            join.is_kicked = True
135
+            kicks.append(join.member)
136
+        self.phase = RaidPhase.ENDED
137
+        return kicks
138
+
139
+    async def ban_all(self,
140
+            reason: str = "Part of join raid",
141
+            delete_message_days: int = 0) -> list[Member]:
142
+        """
143
+        Bans all users in this join raid. Skips users who have already been
144
+        flagged as having been banned. Users who were previously kicked can
145
+        still be banned. Returns a List of Members who were newly banned.
146
+        """
147
+        bans = []
148
+        for join in self.joins:
149
+            if join.is_banned:
150
+                continue
151
+            await join.member.ban(
152
+                reason=reason,
153
+                delete_message_days=delete_message_days)
154
+            join.is_banned = True
155
+            bans.append(join.member)
156
+        self.phase = RaidPhase.ENDED
157
+        return bans
158
+
159
+class GuildContext:
160
+    """
161
+    Logic and state for a single guild serviced by the bot.
162
+    """
163
+    def __init__(self, guild_id: int):
164
+        self.guild_id = guild_id
165
+        self.join_warning_count = CONFIG['joinWarningCount']
166
+        self.join_warning_seconds = CONFIG['joinWarningSeconds']
167
+        # Non-persisted runtime state
168
+        self.current_raid = JoinRaidRecord()
169
+        self.all_raids = [ self.current_raid ] # periodically culled of old ones
170
+
171
+    # Events
172
+
173
+    async def handle_join(self, member: Member) -> None:
174
+        """
175
+        Event handler for all joins to this guild.
176
+        """
177
+        now = member.joined_at
178
+        raid = self.current_raid
179
+        raid.handle_join(
180
+            member,
181
+            now=now,
182
+            max_age_seconds = self.join_warning_seconds,
183
+            max_join_count = self.join_warning_count)
184
+        self.__trace(f'raid phase: {raid.phase}')
185
+        if raid.phase == RaidPhase.JUST_STARTED:
186
+            await self.__on_join_raid_begin(raid)
187
+        elif raid.phase == RaidPhase.CONTINUING:
188
+            await self.__on_join_raid_updated(raid)
189
+        elif raid.phase == RaidPhase.ENDED:
190
+            self.__start_new_raid(member)
191
+            await self.__on_join_raid_end(raid)
192
+        self.__cull_old_raids(now)
193
+
194
+    def reset_raid(self, now: datetime):
195
+        """
196
+        Retires self.current_raid and creates a new empty one.
197
+        """
198
+        self.current_raid = JoinRaidRecord()
199
+        self.all_raids.append(self.current_raid)
200
+        self.__cull_old_raids(now)
201
+
202
+    def find_raid_for_message_id(self, message_id: int) -> JoinRaidRecord:
203
+        """
204
+        Retrieves a JoinRaidRecord instance for the given raid warning message.
205
+        Returns None if not found.
206
+        """
207
+        for raid in self.all_raids:
208
+            if raid.warning_message.id == message_id:
209
+                return raid
210
+        return None
211
+
212
+    def __cull_old_raids(self, now: datetime):
213
+        """
214
+        Gets rid of old JoinRaidRecord records from self.all_raids that are too
215
+        old to still be useful.
216
+        """
217
+        i: int = 0
218
+        while i < len(self.all_raids):
219
+            raid = self.all_raids[i]
220
+            if raid == self.current_raid:
221
+                i += 1
222
+                continue
223
+            age_seconds = float((raid.raid_start_time - now).total_seconds())
224
+            if age_seconds > 86400.0:
225
+                self.__trace('Culling old raid')
226
+                self.all_raids.pop(i)
227
+            else:
228
+                i += 1
229
+
230
+    def __trace(self, message):
231
+        """
232
+        Debugging trace.
233
+        """
234
+        print(f'{self.guild_id}: {message}')
235
+
236
+class JoinRaidCog(BaseCog):
237
+    """
238
+    Cog for monitoring member joins and detecting potential bot raids.
239
+    """
240
+    MIN_JOIN_COUNT = 2
241
+
242
+    STATE_KEY_RAID_COUNT = 'joinraid_count'
243
+    STATE_KEY_RAID_SECONDS = 'joinraid_seconds'
244
+    STATE_KEY_ENABLED = 'joinraid_enabled'
245
+
246
+    def __init__(self, bot):
247
+        self.bot = bot
248
+        self.guild_id_to_context = {}  # Guild.id -> GuildContext
249
+
250
+    # -- Config -------------------------------------------------------------
251
+
252
+    def __get_raid_rate(self, guild: Guild) -> tuple:
253
+        """
254
+        Returns the join rate configured for this guild.
255
+        """
256
+        count: int = Storage.get_state_value(guild, self.STATE_KEY_RAID_COUNT) \
257
+            or CONFIG['joinWarningCount']
258
+        seconds: float = Storage.get_state_value(guild, self.STATE_KEY_RAID_SECONDS) \
259
+            or CONFIG['joinWarningSeconds']
260
+        return (count, seconds)
261
+
262
+    def __is_enabled(self, guild: Guild) -> bool:
263
+        """
264
+        Returns whether join raid detection is enabled in this guild.
265
+        """
266
+        return Storage.get_state_value(guild, self.STATE_KEY_ENABLED) or False
267
+
268
+    # -- Commands -----------------------------------------------------------
269
+
270
+    @commands.group(
271
+        brief='Manages join raid detection and handling',
272
+    )
273
+    @commands.has_permissions(ban_members=True)
274
+    @commands.guild_only()
275
+    async def joinraid(self, context: commands.Context):
276
+        'Command group'
277
+        if context.invoked_subcommand is None:
278
+            await context.send_help()
279
+
280
+    @joinraid.command(
281
+        name='enable',
282
+        brief='Enables join raid detection',
283
+        description='Join raid detection is off by default.',
284
+    )
285
+    async def joinraid_enable(self, context: commands.Context):
286
+        'Command handler'
287
+        guild = context.guild
288
+        Storage.set_state_value(guild, self.STATE_KEY_ENABLED, True)
289
+        # TODO: Startup tracking if necessary
290
+        await context.message.reply(
291
+            '✅ ' + self.__describe_raid_settings(guild, force_enabled_status=True),
292
+            mention_author=False)
293
+
294
+    @joinraid.command(
295
+        name='disable',
296
+        brief='Disables join raid detection',
297
+        description='Join raid detection is off by default.',
298
+    )
299
+    async def joinraid_disable(self, context: commands.Context):
300
+        'Command handler'
301
+        guild = context.guild
302
+        Storage.set_state_value(guild, self.STATE_KEY_ENABLED, False)
303
+        # TODO: Tear down tracking if necessary
304
+        await context.message.reply(
305
+            '✅ ' + self.__describe_raid_settings(guild, force_enabled_status=True),
306
+            mention_author=False)
307
+
308
+    @joinraid.command(
309
+        name='setrate',
310
+        brief='Sets the rate of joins which triggers a warning to mods',
311
+        description='Each time a member joins, the join records from the ' +
312
+            'previous _x_ seconds are counted up, where _x_ is the number of ' +
313
+            'seconds configured by this command. If that count meets or ' +
314
+            'exceeds the maximum join count configured by this command then ' +
315
+            'a raid is detected and a warning is issued to the mods.',
316
+        usage='<join_count:int> <seconds:float>',
317
+    )
318
+    async def joinraid_setrate(self, context: commands.Context,
319
+            join_count: int,
320
+            seconds: float):
321
+        'Command handler'
322
+        guild = context.guild
323
+        if join_count < self.MIN_JOIN_COUNT:
324
+            await context.message.reply(
325
+                f'⚠️ `join_count` must be >= {self.MIN_JOIN_COUNT}',
326
+                mention_author=False)
327
+            return
328
+        if seconds <= 0:
329
+            await context.message.reply(
330
+                f'⚠️ `seconds` must be > 0',
331
+                mention_author=False)
332
+            return
333
+        Storage.set_state_values(guild, {
334
+            self.STATE_KEY_RAID_COUNT: join_count,
335
+            self.STATE_KEY_RAID_SECONDS: seconds,
336
+        })
337
+
338
+        await context.message.reply(
339
+            '✅ ' + self.__describe_raid_settings(guild, force_rate_status=True),
340
+            mention_author=False)
341
+
342
+    @joinraid.command(
343
+        name='getrate',
344
+        brief='Shows the rate of joins which triggers a warning to mods',
345
+    )
346
+    async def joinraid_getrate(self, context: commands.Context):
347
+        'Command handler'
348
+        await context.message.reply(
349
+            'ℹ️ ' + self.__describe_raid_settings(context.guild, force_rate_status=True),
350
+            mention_author=False)
351
+
352
+    # -- Listeners ----------------------------------------------------------
353
+
354
+    @commands.Cog.listener()
355
+    async def on_raw_reaction_add(self, payload: RawReactionActionEvent):
356
+        'Event handler'
357
+        if payload.user_id == self.bot.user.id:
358
+            # Ignore bot's own reactions
359
+            return
360
+        member: Member = payload.member
361
+        if member is None:
362
+            return
363
+        guild: Guild = self.bot.get_guild(payload.guild_id)
364
+        if guild is None:
365
+            # Possibly a DM
366
+            return
367
+        channel: GuildChannel = guild.get_channel(payload.channel_id)
368
+        if channel is None:
369
+            # Possibly a DM
370
+            return
371
+        message: Message = await channel.fetch_message(payload.message_id)
372
+        if message is None:
373
+            # Message deleted?
374
+            return
375
+        if message.author.id != self.bot.user.id:
376
+            # Bot didn't author this
377
+            return
378
+        if not member.permissions_in(channel).ban_members:
379
+            # Not a mod
380
+            # TODO: Remove reaction?
381
+            return
382
+        gc: GuildContext = self.__get_guild_context(guild)
383
+        raid: JoinRaidRecord = gc.find_raid_for_message_id(payload.message_id)
384
+        if raid is None:
385
+            # Either not a warning message or one we stopped tracking
386
+            return
387
+        emoji: PartialEmoji = payload.emoji
388
+        if emoji.name == CONFIG['kickEmoji']:
389
+            await raid.kick_all()
390
+            gc.reset_raid(message.created_at)
391
+            await self.__update_raid_warning(guild, raid)
392
+        elif emoji.name == CONFIG['banEmoji']:
393
+            await raid.ban_all()
394
+            gc.reset_raid(message.created_at)
395
+            await self.__update_raid_warning(guild, raid)
396
+
397
+    @commands.Cog.listener()
398
+    async def on_member_join(self, member: Member) -> None:
399
+        'Event handler'
400
+        guild: Guild = member.guild
401
+        if not self.__is_enabled(guild):
402
+            return
403
+        (count, seconds) = self.__get_raid_rate(guild)
404
+        now = member.joined_at
405
+        gc: GuildContext = self.__get_guild_context(guild)
406
+        raid: JoinRaidRecord = gc.current_raid
407
+        raid.handle_join(member, now, count, seconds)
408
+        if raid.phase == RaidPhase.JUST_STARTED:
409
+            await self.__post_raid_warning(guild, raid)
410
+        elif raid.phase == RaidPhase.CONTINUING:
411
+            await self.__update_raid_warning(guild, raid)
412
+        elif raid.phase == RaidPhase.ENDED:
413
+            # First join that occurred too late to be part of last raid. Join
414
+            # not added. Start a new raid record and add it there.
415
+            gc.reset_raid(now)
416
+            gc.current_raid.handle_join(member, now, count, seconds)
417
+
418
+    # -- Misc ---------------------------------------------------------------
419
+
420
+    def __describe_raid_settings(self,
421
+            guild: Guild,
422
+            force_enabled_status=False,
423
+            force_rate_status=False) -> str:
424
+        """
425
+        Creates a Discord message describing the current join raid settings.
426
+        """
427
+        enabled = self.__is_enabled(guild)
428
+        (count, seconds) = self.__get_raid_rate(guild)
429
+
430
+        sentences = []
431
+
432
+        if enabled or force_rate_status:
433
+            sentences.append(f'Join raids will be detected at {count} or more joins per {seconds} seconds.')
434
+
435
+        if enabled and force_enabled_status:
436
+                sentences.append('Raid detection enabled.')
437
+        elif not enabled:
438
+            sentences.append('Raid detection disabled.')
439
+
440
+        tips = []
441
+        if enabled or force_rate_status:
442
+            tips.append('• Use `setrate` subcommand to change detection threshold')
443
+        if enabled:
444
+            tips.append('• Use `disable` subcommand to disable detection.')
445
+        else:
446
+            tips.append('• Use `enable` subcommand to enable detection.')
447
+
448
+        message = ''
449
+        message += ' '.join(sentences)
450
+        if len(tips) > 0:
451
+            message += '\n\n' + ('\n'.join(tips))
452
+        return message
453
+
454
+    def __get_guild_context(self, guild: Guild) -> GuildContext:
455
+        """
456
+        Looks up the GuildContext for the given Guild or creates a new one if
457
+        one does not yet exist.
458
+        """
459
+        gc: GuildContext = self.guild_id_to_context.get(guild.id)
460
+        if gc is not None:
461
+            return gc
462
+        gc = GuildContext(guild.id)
463
+        self.guild_id_to_context[guild.id] = gc
464
+        return gc
465
+
466
+    async def __post_raid_warning(self, guild: Guild, raid: JoinRaidRecord) -> None:
467
+        """
468
+        Posts a warning message about the given raid.
469
+        """
470
+        (message, can_kick, can_ban) = self.__describe_raid(raid)
471
+        raid.warning_message = await self.warn(guild, message)
472
+        if can_kick:
473
+            await raid.warning_message.add_reaction(CONFIG['kickEmoji'])
474
+        if can_ban:
475
+            await raid.warning_message.add_reaction(CONFIG['banEmoji'])
476
+
477
+    async def __update_raid_warning(self, guild: Guild, raid: JoinRaidRecord) -> None:
478
+        """
479
+        Updates the existing warning message for a raid.
480
+        """
481
+        if raid.warning_message is None:
482
+            return
483
+        (message, can_kick, can_ban) = self.__describe_raid(raid)
484
+        await self.update_warn(raid.warning_message, message)
485
+        if not can_kick:
486
+            await raid.warning_message.clear_reaction(CONFIG['kickEmoji'])
487
+        if not can_ban:
488
+            await raid.warning_message.clear_reaction(CONFIG['banEmoji'])
489
+
490
+    def __describe_raid(self, raid: JoinRaidRecord) -> tuple:
491
+        """
492
+        Creates a Discord warning message with details about the given raid.
493
+        Returns a tuple containing the message text, a flag if any users can
494
+        still be kicked, and a flag if anyone can still be banned.
495
+        """
496
+        message = '🚨 **JOIN RAID DETECTED** 🚨'
497
+        message += '\nThe following members joined in close succession:\n'
498
+        any_kickable = False
499
+        any_bannable = False
500
+        for join in raid.joins:
501
+            message += '\n• '
502
+            if join.is_banned:
503
+                message += '~~' + join.member.mention + '~~ - banned'
504
+            elif join.is_kicked:
505
+                message += '~~' + join.member.mention + '~~ - kicked'
506
+                any_bannable = True
507
+            else:
508
+                message += join.member.mention
509
+                any_bannable = True
510
+                any_kickable = True
511
+        message += '\n_(list updates automatically)_'
512
+
513
+        message += '\n'
514
+        if any_kickable:
515
+            message += f'\nReact to this message with {CONFIG["kickEmoji"]} to kick all these users.'
516
+        else:
517
+            message += '\nNo users left to kick.'
518
+        if any_bannable:
519
+            message += f'\nReact to this message with {CONFIG["banEmoji"]} to ban all these users.'
520
+        else:
521
+            message += '\nNo users left to ban.'
522
+
523
+        return (message, any_kickable, any_bannable)

+ 1
- 0
config.py.sample 查看文件

@@ -12,4 +12,5 @@ CONFIG = {
12 12
 	'kickEmoji': '👢',
13 13
 	'banEmojiName': 'no_entry_sign',
14 14
 	'banEmoji': '🚫',
15
+	'statePath': 'state/',
15 16
 }

二進制
rocketbot.db.sample 查看文件


+ 13
- 751
rocketbot.py 查看文件

@@ -5,761 +5,23 @@ the sqlite database rocketbot.db (copy rocketbot.db.sample for a blank database)
5 5
 Author: Ian Albert (@rocketsoup)
6 6
 Date: 2021-11-11
7 7
 """
8
-from datetime import datetime
9
-import sqlite3
10
-import sys
11
-
12
-from discord import Guild, Intents, Member, Message, PartialEmoji, RawReactionActionEvent
13
-from discord.abc import GuildChannel
8
+from discord import Intents
14 9
 from discord.ext import commands
15
-from discord.ext.commands.context import Context
16 10
 
17 11
 from config import CONFIG
12
+from cogs.configcog import ConfigCog
13
+from cogs.generalcog import GeneralCog
14
+from cogs.joinraidcog import JoinRaidCog
18 15
 
19
-if sys.version_info.major < 3:
20
-    raise Exception('Requires Python 3+')
21
-
22
-# -- Classes ----------------------------------------------------------------
23
-
24
-class RaidPhase:
25
-    """
26
-    Enum of phases in a JoinRaid. Phases progress monotonically.
27
-    """
28
-    NONE = 0
29
-    JUST_STARTED = 1
30
-    CONTINUING = 2
31
-    ENDED = 3
32
-
33
-class JoinRaid:
34
-    """
35
-    Tracks recent joins to a guild to detect join raids, where a large number of automated users
36
-    all join at the same time.
37
-    """
38
-    def __init__(self):
39
-        self.joins = []
40
-        self.phase = RaidPhase.NONE
41
-        # datetime when the raid started, or None.
42
-        self.raid_start_time = None
43
-        # Message posted to Discord to warn of the raid. Convenience property managed
44
-        # by caller. Ignored by this class.
45
-        self.warning_message = None
46
-
47
-    def handle_join(self,
48
-            member: Member,
49
-            now: datetime,
50
-            max_age_seconds: float,
51
-            max_join_count: int) -> None:
52
-        """
53
-        Processes a new member join to a guild and detects join raids. Updates
54
-        self.phase and self.raid_start_time properties.
55
-        """
56
-        # Check for existing record for this user
57
-        print(f'handle_join({member.name}) start')
58
-        join: JoinRecord = None
59
-        i: int = 0
60
-        while i < len(self.joins):
61
-            elem = self.joins[i]
62
-            if elem.member.id == member.id:
63
-                print(f'Member {member.name} already in join list at index {i}. Removing.')
64
-                join = self.joins.pop(i)
65
-                join.join_time = now
66
-                break
67
-            i += 1
68
-        # Add new record to end
69
-        self.joins.append(join or JoinRecord(member))
70
-        # Check raid status and do upkeep
71
-        self.__process_joins(now, max_age_seconds, max_join_count)
72
-        print(f'handle_join({member.name}) end')
73
-
74
-    def __process_joins(self,
75
-            now: datetime,
76
-            max_age_seconds: float,
77
-            max_join_count: int) -> None:
78
-        """
79
-        Processes self.joins after each addition, detects raids, updates self.phase,
80
-        and throws out unneeded records.
81
-        """
82
-        print('__process_joins {')
83
-        i: int = 0
84
-        recent_count: int = 0
85
-        should_cull: bool = self.phase == RaidPhase.NONE
86
-        while i < len(self.joins):
87
-            join: JoinRecord = self.joins[i]
88
-            age: float = join.age_seconds(now)
89
-            is_old: bool = age > max_age_seconds
90
-            if not is_old:
91
-                recent_count += 1
92
-                print(f'- {i}. {join.member.name} is {age}s old - recent_count={recent_count}')
93
-            if is_old and should_cull:
94
-                self.joins.pop(i)
95
-                print(f'- {i}. {join.member.name} is {age}s old - too old, removing')
96
-            else:
97
-                print(f'- {i}. {join.member.name} is {age}s old - moving on to next')
98
-                i += 1
99
-        is_raid = recent_count > max_join_count
100
-        print(f'- is_raid {is_raid}')
101
-        if is_raid:
102
-            if self.phase == RaidPhase.NONE:
103
-                self.phase = RaidPhase.JUST_STARTED
104
-                self.raid_start_time = now
105
-                print('- Phase moved to JUST_STARTED. Recording raid start time.')
106
-            elif self.phase == RaidPhase.JUST_STARTED:
107
-                self.phase = RaidPhase.CONTINUING
108
-                print('- Phase moved to CONTINUING.')
109
-        elif self.phase == self.phase in (RaidPhase.JUST_STARTED, RaidPhase.CONTINUING):
110
-            self.phase = RaidPhase.ENDED
111
-            print('- Phase moved to ENDED.')
112
-
113
-        # Undo join add if the raid is over
114
-        if self.phase == RaidPhase.ENDED and len(self.joins) > 0:
115
-            last = self.joins.pop(-1)
116
-            print(f'- Popping last join for {last.member.name}')
117
-        print('} __process_joins')
118
-
119
-    async def kick_all(self,
120
-            reason: str = "Part of join raid") -> list[Member]:
121
-        """
122
-        Kicks all users in this join raid. Skips users who have already been
123
-        flagged as having been kicked or banned. Returns a List of Members
124
-        who were newly kicked.
125
-        """
126
-        kicks = []
127
-        for join in self.joins:
128
-            if join.is_kicked or join.is_banned:
129
-                continue
130
-            await join.member.kick(reason=reason)
131
-            join.is_kicked = True
132
-            kicks.append(join.member)
133
-        self.phase = RaidPhase.ENDED
134
-        return kicks
135
-
136
-    async def ban_all(self,
137
-            reason: str = "Part of join raid",
138
-            delete_message_days: int = 0) -> list[Member]:
139
-        """
140
-        Bans all users in this join raid. Skips users who have already been
141
-        flagged as having been banned. Users who were previously kicked can
142
-        still be banned. Returns a List of Members who were newly banned.
143
-        """
144
-        bans = []
145
-        for join in self.joins:
146
-            if join.is_banned:
147
-                continue
148
-            await join.member.ban(reason=reason, delete_message_days=delete_message_days)
149
-            join.is_banned = True
150
-            bans.append(join.member)
151
-        self.phase = RaidPhase.ENDED
152
-        return bans
153
-
154
-class JoinRecord:
155
-    """
156
-    Data object containing details about a guild join event.
157
-    """
158
-    def __init__(self, member: Member):
159
-        self.member = member
160
-        self.join_time = member.joined_at or datetime.now()
161
-        self.is_kicked = False
162
-        self.is_banned = False
163
-
164
-    def age_seconds(self, now: datetime) -> float:
165
-        """
166
-        Returns the age of this join in seconds from the given "now" time.
167
-        """
168
-        a = now - self.join_time
169
-        return float(a.total_seconds())
170
-
171
-class GuildContext:
172
-    """
173
-    Logic and state for a single guild serviced by the bot.
174
-    """
175
-    def __init__(self, guild_id: int):
176
-        self.guild_id = guild_id
177
-        self.guild = None # Resolved later
178
-        # Config populated during load
179
-        self.warning_channel_id = None
180
-        self.warning_channel = None
181
-        self.warning_mention = None
182
-        self.join_warning_count = CONFIG['joinWarningCount']
183
-        self.join_warning_seconds = CONFIG['joinWarningSeconds']
184
-        # Non-persisted runtime state
185
-        self.current_raid = JoinRaid()
186
-        self.all_raids = [ self.current_raid ] # periodically culled of old ones
187
-
188
-    # Commands
189
-
190
-    async def command_hello(self, message: Message) -> None:
191
-        """
192
-        Command handler
193
-        """
194
-        await message.channel.send(f'Hey there, {message.author.mention}!')
195
-
196
-    async def command_testwarn(self, context: Context) -> None:
197
-        """
198
-        Command handler
199
-        """
200
-        if self.warning_channel is None:
201
-            self.__trace('No warning channel set!')
202
-            await context.message.channel.send('No warning channel set on this guild! Type ' +
203
-                f'`{bot.command_prefix}{setwarningchannel.__name__}` in the channel you ' +
204
-                'want warnings to be posted.')
205
-            return
206
-        await self.__warn('Test warning. This is only a test.')
207
-
208
-    async def command_setwarningchannel(self, context: Context):
209
-        """
210
-        Command handler
211
-        """
212
-        self.__trace(f'Warning channel set to {context.channel.name}')
213
-        self.warning_channel = context.channel
214
-        self.warning_channel_id = context.channel.id
215
-        save_guild_context(self)
216
-        await self.__warn('Warning messages will now be sent to ' + self.warning_channel.mention)
217
-
218
-    async def command_setwarningmention(self, _context: Context, mention: str):
219
-        """
220
-        Command handler
221
-        """
222
-        self.__trace('set warning mention')
223
-        m = mention if mention is not None and len(mention) > 0 else None
224
-        self.warning_mention = m
225
-        save_guild_context(self)
226
-        if m is None:
227
-            await self.__warn('Warning messages will not mention anyone')
228
-        else:
229
-            await self.__warn('Warning messages will now mention ' + m)
230
-
231
-    async def command_setraidwarningrate(self, _context: Context, count: int, seconds: int):
232
-        """
233
-        Command handler
234
-        """
235
-        self.join_warning_count = count
236
-        self.join_warning_seconds = seconds
237
-        save_guild_context(self)
238
-        await self.__warn(f'Maximum join rate set to {count} joins per {seconds} seconds')
239
-
240
-    # Events
241
-
242
-    async def handle_join(self, member: Member) -> None:
243
-        """
244
-        Event handler for all joins to this guild.
245
-        """
246
-        print(f'{member.guild.name}: {member.name} joined')
247
-        now = member.joined_at
248
-        raid = self.current_raid
249
-        raid.handle_join(
250
-            member,
251
-            now=now,
252
-            max_age_seconds = self.join_warning_seconds,
253
-            max_join_count = self.join_warning_count)
254
-        self.__trace(f'raid phase: {raid.phase}')
255
-        if raid.phase == RaidPhase.JUST_STARTED:
256
-            await self.__on_join_raid_begin(raid)
257
-        elif raid.phase == RaidPhase.CONTINUING:
258
-            await self.__on_join_raid_updated(raid)
259
-        elif raid.phase == RaidPhase.ENDED:
260
-            self.__start_new_raid(member)
261
-            await self.__on_join_raid_end(raid)
262
-        self.__cull_old_raids(now)
263
-
264
-    def __start_new_raid(self, member: Member = None):
265
-        """
266
-        Retires self.current_raid and creates a new empty one. If `member` is passed, it will be
267
-        added to the new self.current_raid after it is created.
268
-        """
269
-        self.current_raid = JoinRaid()
270
-        self.all_raids.append(self.current_raid)
271
-        if member is not None:
272
-            self.current_raid.handle_join(
273
-                member,
274
-                member.joined_at,
275
-                max_age_seconds = self.join_warning_seconds,
276
-                max_join_count = self.join_warning_count)
277
-
278
-    async def handle_reaction_add(self, message, member, emoji):
279
-        """
280
-        Handles all message reaction events to see if they need to be acted on.
281
-        """
282
-        if member.id == bot.user.id:
283
-            # It's-a me, Rocketbot!
284
-            return
285
-        if message.author.id != bot.user.id:
286
-            # The message the user is reacting to wasn't authored by me. Ignore.
287
-            return
288
-        self.__trace(f'User {member} added emoji {emoji}')
289
-        if not member.permissions_in(message.channel).ban_members:
290
-            self.__trace('Reactor does not have ban permissions. Ignoring.')
291
-            return
292
-        if emoji.name == CONFIG['kickEmoji']:
293
-            await self.__kick_all_in_raid_message(message)
294
-        elif emoji.name == CONFIG['banEmoji']:
295
-            await self.__ban_all_in_raid_message(message)
296
-        else:
297
-            print('Unhandled emoji. Ignoring.')
298
-            return
299
-
300
-    async def __kick_all_in_raid_message(self, message: Message):
301
-        """
302
-        Kicks all the users mentioned in the given raid warning message. Users who were already
303
-        kicked or banned will be skipped.
304
-        """
305
-        raid = self.__find_raid_for_message(message)
306
-        if raid is None:
307
-            await message.reply("This is either not a raid warning or it's too old and I don't " +
308
-                "have a record for it anymore. Sorry!")
309
-            return
310
-        self.__trace('Kicking...')
311
-        members = await raid.kick_all()
312
-        msg = 'Kicked these members:'
313
-        for member in members:
314
-            msg += f'\n\t{member.name}'
315
-        if len(members) == 0:
316
-            msg += '\n\t-none-'
317
-        self.__trace(msg)
318
-        self.__start_new_raid()
319
-        await self.__update_join_raid_message(raid)
320
-
321
-    async def __ban_all_in_raid_message(self, message: Message):
322
-        """
323
-        Bans all the users mentioned in the given raid warning message. Users who were already
324
-        banned will be skipped.
325
-        """
326
-        raid = self.__find_raid_for_message(message)
327
-        if raid is None:
328
-            await message.reply("This is either not a raid warning or it's too old and I don't " +
329
-                "have a record for it anymore. Sorry!")
330
-            return
331
-        self.__trace('Banning...')
332
-        members = await raid.ban_all()
333
-        msg = 'Banned these members:'
334
-        for member in members:
335
-            msg += f'\n\t{member.name}'
336
-        if len(members) == 0:
337
-            msg += '\n\t-none-'
338
-        self.__trace(msg)
339
-        self.__start_new_raid()
340
-        await self.__update_join_raid_message(raid)
341
-
342
-    def __find_raid_for_message(self, message: Message) -> JoinRaid:
343
-        """
344
-        Retrieves a JoinRaid instance for the given raid warning message. Returns None if not found.
345
-        """
346
-        for raid in self.all_raids:
347
-            if raid.warning_message.id == message.id:
348
-                return raid
349
-        return None
350
-
351
-    def __cull_old_raids(self, now: datetime):
352
-        """
353
-        Gets rid of old JoinRaid records from self.all_raids that are too old to still be useful.
354
-        """
355
-        i: int = 0
356
-        while i < len(self.all_raids):
357
-            raid = self.all_raids[i]
358
-            if raid == self.current_raid:
359
-                i += 1
360
-                continue
361
-            age_seconds = float((raid.raid_start_time - now).total_seconds())
362
-            if age_seconds > 86400.0:
363
-                self.__trace('Culling old raid')
364
-                self.all_raids.pop(i)
365
-            else:
366
-                i += 1
367
-
368
-    def __join_raid_message(self, raid: JoinRaid):
369
-        """
370
-        Returns a 3-element tuple containing a text message appropriate for posting in
371
-        Discord, a flag of whether any of the mentioned users can be kicked, and a flag
372
-        of whether any of the mentioned users can be banned.
373
-        """
374
-        message = ''
375
-        if self.warning_mention is not None:
376
-            message = self.warning_mention + ' '
377
-        message += '**RAID JOIN DETECTED!** It includes these users:\n'
378
-        can_kick = False
379
-        can_ban = False
380
-        for join in raid.joins:
381
-            message += '\n• '
382
-            if join.is_banned:
383
-                message += '~~' + join.member.mention + '~~ - banned'
384
-            elif join.is_kicked:
385
-                message += '~~' + join.member.mention + '~~ - kicked'
386
-                can_ban = True
387
-            else:
388
-                message += join.member.mention
389
-                can_kick = True
390
-                can_ban = True
391
-        message += '\n'
392
-        if can_kick:
393
-            message += '\nTo kick all these users, react with :' + CONFIG['kickEmojiName'] + ':'
394
-        else:
395
-            message += '\nNo kickable users remain'
396
-        if can_ban:
397
-            message += '\nTo ban all these users, react with :' + CONFIG['banEmojiName'] + ':'
398
-        else:
399
-            message += '\nNo bannable users remain'
400
-        return (message, can_kick, can_ban)
401
-
402
-    async def __update_join_raid_message(self, raid: JoinRaid):
403
-        """
404
-        Updates an existing join raid warning message with updated data.
405
-        """
406
-        if raid.warning_message is None:
407
-            self.__trace('No raid warning message to update')
408
-            return
409
-        (message, can_kick, can_ban) = self.__join_raid_message(raid)
410
-        await raid.warning_message.edit(content=message)
411
-        if not can_kick:
412
-            await raid.warning_message.clear_reaction(CONFIG['kickEmoji'])
413
-        if not can_ban:
414
-            await raid.warning_message.clear_reaction(CONFIG['banEmoji'])
415
-
416
-    async def __on_join_raid_begin(self, raid):
417
-        """
418
-        Event triggered when the first member joins that triggers the raid detection.
419
-        """
420
-        self.__trace('A join raid has begun!')
421
-        if self.warning_channel is None:
422
-            self.__trace('NO WARNING CHANNEL SET')
423
-            return
424
-        (message, can_kick, can_ban) = self.__join_raid_message(raid)
425
-        raid.warning_message = await self.warning_channel.send(message)
426
-        if can_kick:
427
-            await raid.warning_message.add_reaction(CONFIG['kickEmoji'])
428
-        if can_ban:
429
-            await raid.warning_message.add_reaction(CONFIG['banEmoji'])
430
-
431
-    async def __on_join_raid_updated(self, raid):
432
-        """
433
-        Event triggered for each subsequent member join after the first one that triggered the
434
-        raid detection.
435
-        """
436
-        self.__trace('Join raid still occurring')
437
-        await self.__update_join_raid_message(raid)
438
-
439
-    async def __on_join_raid_end(self, _raid):
440
-        """
441
-        Event triggered when the first member joins who is not part of the most recent raid.
442
-        """
443
-        self.__trace('Join raid has ended')
444
-
445
-    async def __warn(self, message):
446
-        """
447
-        Posts a warning message in the configured warning channel.
448
-        """
449
-        if self.warning_channel is None:
450
-            self.__trace('NO WARNING CHANNEL SET. Warning message not posted.\n' + message)
451
-            return None
452
-        m = message
453
-        if self.warning_mention is not None:
454
-            m = self.warning_mention + ' ' + m
455
-        return await self.warning_channel.send(m)
456
-
457
-    def __trace(self, message):
458
-        """
459
-        Debugging trace.
460
-        """
461
-        print(f'{self.guild.name}: {message}')
462
-
463
-# lookup for int(Guild.guild_id) --> GuildContext
464
-guild_id_to_guild_context = {}
465
-
466
-def get_or_create_guild_context(val, save=True):
467
-    """
468
-    Retrieves a cached GuildContext instance by its Guild id or Guild object
469
-    itself. If no GuildContext record exists for the Guild, one is created
470
-    and cached (and saved to the database unless `save=False`).
471
-    """
472
-    gid = None
473
-    guild = None
474
-    if val is None:
475
-        return None
476
-    if isinstance(val, int):
477
-        gid = val
478
-    elif isinstance(val, Guild):
479
-        gid = val.id
480
-        guild = val
481
-    if gid is None:
482
-        print('Unhandled datatype', type(val))
483
-        return None
484
-    looked_up = guild_id_to_guild_context.get(gid)
485
-    if looked_up is not None:
486
-        return looked_up
487
-    gc = GuildContext(gid)
488
-    gc.guild = guild or gc.guild
489
-    guild_id_to_guild_context[gid] = gc
490
-    if save:
491
-        save_guild_context(gc)
492
-    return gc
493
-
494
-# -- Database ---------------------------------------------------------------
495
-
496
-def run_sql_batch(batch_function):
497
-    """
498
-    Performs an SQL transaction. After a connection is opened, the passed
499
-    function is invoked with the sqlite3.Connection and sqlite3.Cursor
500
-    passed as arguments. Once the passed function finishes, the connection
501
-    is closed.
502
-    """
503
-    db_connection: sqlite3.Connection = sqlite3.connect('rocketbot.db')
504
-    db_cursor: sqlite3.Cursor = db_connection.cursor()
505
-    batch_function(db_connection, db_cursor)
506
-    db_connection.commit()
507
-    db_connection.close()
508
-
509
-def load_guild_settings():
510
-    """
511
-    Populates the GuildContext cache with records from the database.
512
-    """
513
-    def load(_con, cur):
514
-        """
515
-        SQL
516
-        """
517
-        for row in cur.execute("""SELECT * FROM guilds"""):
518
-            guild_id = row[0]
519
-            gc = get_or_create_guild_context(guild_id, save=False)
520
-            gc.warning_channel_id = row[1]
521
-            gc.warning_mention = row[2]
522
-            gc.join_warning_count = row[3] or CONFIG['joinWarningCount']
523
-            gc.join_warning_seconds = row[4] or CONFIG['joinWarningSeconds']
524
-            print(f'Guild {guild_id} channel id is {gc.warning_channel_id}')
525
-    run_sql_batch(load)
526
-
527
-def create_tables():
528
-    """
529
-    Creates all database tables.
530
-    """
531
-    def make_tables(_con, cur):
532
-        """
533
-        SQL
534
-        """
535
-        cur.execute("""CREATE TABLE guilds (
536
-            guildId INTEGER,
537
-            warningChannelId INTEGER,
538
-            warningMention TEXT,
539
-            joinWarningCount INTEGER,
540
-            joinWarningSeconds INTEGER,
541
-            PRIMARY KEY(guildId ASC))""")
542
-    run_sql_batch(make_tables)
543
-
544
-def save_guild_context(gc: GuildContext):
545
-    """
546
-    Saves the state of a GuildContext record to the database.
547
-    """
548
-    def save(_con, cur):
549
-        """
550
-        SQL
551
-        """
552
-        print(f'Saving guild context with id {gc.guild_id}')
553
-        cur.execute("""
554
-            SELECT guildId
555
-            FROM guilds
556
-            WHERE guildId=?
557
-            """, (
558
-                gc.guild_id,
559
-            ))
560
-        channel_id = gc.warning_channel.id if gc.warning_channel is not None \
561
-                else gc.warning_channel_id
562
-        exists = cur.fetchone() is not None
563
-        if exists:
564
-            print('Updating existing guild record in db')
565
-            cur.execute("""
566
-                UPDATE guilds
567
-                SET warningChannelId=?,
568
-                    warningMention=?,
569
-                    joinWarningCount=?,
570
-                    joinWarningSeconds=?
571
-                WHERE guildId=?
572
-                """, (
573
-                    channel_id,
574
-                    gc.warning_mention,
575
-                    gc.join_warning_count,
576
-                    gc.join_warning_seconds,
577
-                    gc.guild_id,
578
-                ))
579
-        else:
580
-            print('Creating new guild record in db')
581
-            cur.execute("""
582
-                INSERT INTO guilds (
583
-                    guildId,
584
-                    warningChannelId,
585
-                    warningMention,
586
-                    joinWarningCount,
587
-                    joinWarningSeconds)
588
-                VALUES (?, ?, ?, ?, ?)
589
-                """, (
590
-                    gc.guild_id,
591
-                    channel_id,
592
-                    gc.warning_mention,
593
-                    gc.join_warning_count,
594
-                    gc.join_warning_seconds,
595
-                ))
596
-    run_sql_batch(save)
597
-
598
-# -- Main (1) ---------------------------------------------------------------
599
-
600
-load_guild_settings()
16
+class Rocketbot(commands.Bot):
17
+    def __init__(self, command_prefix, **kwargs):
18
+        super().__init__(command_prefix, **kwargs)
601 19
 
602 20
 intents = Intents.default()
603 21
 intents.members = True # To get join/leave events
604
-bot = commands.Bot(command_prefix=CONFIG['commandPrefix'], intents=intents)
605
-
606
-# -- Bot commands -----------------------------------------------------------
607
-
608
-@bot.command(
609
-    brief='Simply replies to the invoker with a hello message in the same channel.'
610
-)
611
-async def hello(ctx: Context):
612
-    """
613
-    Command handler
614
-    """
615
-    gc: GuildContext = get_or_create_guild_context(ctx.guild)
616
-    if gc is None:
617
-        return
618
-    message = ctx.message
619
-    await gc.command_hello(message)
620
-
621
-@bot.command(
622
-    brief='Posts a test warning message in the configured warning channel.',
623
-    help="""If no warning channel is configured, the bot will reply in the channel the command was
624
-    issued to notify no warning channel is set. If a warning mention is configured, the test 
625
-    warning will tag the configured person/role."""
626
-)
627
-@commands.has_permissions(manage_messages=True)
628
-async def testwarn(ctx: Context):
629
-    """
630
-    Command handler
631
-    """
632
-    gc: GuildContext = get_or_create_guild_context(ctx.guild)
633
-    if gc is None:
634
-        return
635
-    await gc.command_testwarn(ctx)
636
-
637
-@bot.command(
638
-    brief='Sets the threshold for detecting a join raid.',
639
-    usage='<count> <seconds>',
640
-    help="""The raid threshold is expressed as number of joins within a given number of seconds.
641
-    Each time a member joins, the number of joins in the previous _x_ seconds is counted, and if
642
-    that count, _y_, equals or exceeds the count configured by this command, a raid is detected."""
643
-)
644
-@commands.has_permissions(manage_messages=True)
645
-async def setraidwarningrate(ctx: Context, count: int, seconds: int):
646
-    """
647
-    Command handler
648
-    """
649
-    gc: GuildContext = get_or_create_guild_context(ctx.guild)
650
-    if gc is None:
651
-        return
652
-    await gc.command_setraidwarningrate(ctx, count, seconds)
653
-
654
-@bot.command(
655
-    brief='Sets the current channel as the destination for bot warning messages.'
656
-)
657
-@commands.has_permissions(manage_messages=True)
658
-async def setwarningchannel(ctx: Context):
659
-    """
660
-    Command handler
661
-    """
662
-    gc: GuildContext = get_or_create_guild_context(ctx.guild)
663
-    if gc is None:
664
-        return
665
-    await gc.command_setwarningchannel(ctx)
666
-
667
-@bot.command(
668
-    brief='Sets an optional mention to include in every warning message.',
669
-    usage='<mention>',
670
-    help="""The argument provided to this command will be included verbatim, so if the intent is
671
-    to tag a user or role, the argument must be a tag, not merely the name of the user/role."""
672
-)
673
-@commands.has_permissions(manage_messages=True)
674
-async def setwarningmention(ctx: Context, mention: str):
675
-    """
676
-    Command handler
677
-    """
678
-    gc: GuildContext = get_or_create_guild_context(ctx.guild)
679
-    if gc is None:
680
-        return
681
-    await gc.command_setwarningmention(ctx, mention)
682
-
683
-# -- Bot events -------------------------------------------------------------
684
-
685
-is_connected = False
686
-@bot.listen()
687
-async def on_connect():
688
-    """
689
-    Discord event handler
690
-    """
691
-    global is_connected
692
-    print('Connected')
693
-    is_connected = True
694
-    if is_connected and is_ready:
695
-        await populate_guilds()
696
-
697
-is_ready = False
698
-@bot.listen()
699
-async def on_ready():
700
-    """
701
-    Discord event handler
702
-    """
703
-    global is_ready
704
-    print('Ready')
705
-    is_ready = True
706
-    if is_connected and is_ready:
707
-        await populate_guilds()
708
-
709
-async def populate_guilds():
710
-    """
711
-    Called after both on_ready and on_connect are done. May be called more than once!
712
-    """
713
-    for guild in bot.guilds:
714
-        gc = guild_id_to_guild_context.get(guild.id)
715
-        if gc is None:
716
-            print(f'No GuildContext for {guild.id}')
717
-            continue
718
-        gc.guild = guild
719
-        if gc.warning_channel_id is not None:
720
-            gc.warning_channel = guild.get_channel(gc.warning_channel_id)
721
-            if gc.warning_channel is not None:
722
-                print(f'Recovered warning channel {gc.warning_channel}')
723
-            else:
724
-                print(f'Could not find channel with id {gc.warning_channel_id} in ' +
725
-                    f'guild {guild.name}')
726
-                for channel in await guild.fetch_channels():
727
-                    print(f'\t{channel.name} ({channel.id})')
728
-
729
-@bot.listen()
730
-async def on_member_join(member: Member) -> None:
731
-    """
732
-    Discord event handler
733
-    """
734
-    print(f'User {member.name} joined {member.guild.name}')
735
-    gc: GuildContext = get_or_create_guild_context(member.guild)
736
-    if gc is None:
737
-        print(f'No GuildContext for guild {member.guild.name}')
738
-        return
739
-    await gc.handle_join(member)
740
-
741
-@bot.listen()
742
-async def on_member_remove(member: Member) -> None:
743
-    """
744
-    Discord event handler
745
-    """
746
-    print(f'User {member.name} left {member.guild.name}')
747
-
748
-@bot.listen()
749
-async def on_raw_reaction_add(payload: RawReactionActionEvent) -> None:
750
-    """
751
-    Discord event handler
752
-    """
753
-    guild: Guild = bot.get_guild(payload.guild_id)
754
-    channel: GuildChannel = guild.get_channel(payload.channel_id)
755
-    message: Message = await channel.fetch_message(payload.message_id)
756
-    member: Member = payload.member
757
-    emoji: PartialEmoji = payload.emoji
758
-    gc: GuildContext = get_or_create_guild_context(guild)
759
-    await gc.handle_reaction_add(message, member, emoji)
760
-
761
-# -- Main -------------------------------------------------------------------
762
-
763
-print('Starting bot')
764
-bot.run(CONFIG['clientToken'])
765
-print('Bot done')
22
+bot = Rocketbot(command_prefix=CONFIG['commandPrefix'], intents=intents)
23
+bot.add_cog(GeneralCog(bot))
24
+bot.add_cog(ConfigCog(bot))
25
+bot.add_cog(JoinRaidCog(bot))
26
+bot.run(CONFIG['clientToken'], bot=True, reconnect=True)
27
+print('\nBot aborted')

+ 117
- 0
storage.py 查看文件

@@ -0,0 +1,117 @@
1
+import json
2
+from os.path import exists
3
+
4
+from discord import Guild
5
+
6
+from config import CONFIG
7
+
8
+class StateKey:
9
+	WARNING_CHANNEL_ID = 'warning_channel_id'
10
+	WARNING_MENTION = 'warning_mention'
11
+
12
+class Storage:
13
+	"""
14
+	Static class for managing persisted bot state.
15
+	"""
16
+
17
+	# discord.Guild.id -> dict
18
+	__guild_id_to_state = {}
19
+
20
+	@classmethod
21
+	def get_state(cls, guild: Guild) -> dict:
22
+		"""
23
+		Returns all persisted state for the given guild.
24
+		"""
25
+		state: dict = cls.__guild_id_to_state.get(guild.id)
26
+		if state is not None:
27
+			# Already in memory
28
+			return state
29
+		# Load from disk if possible
30
+		cls.__trace(f'No loaded state for guild {guild.id}. Attempting to ' +
31
+			'load from disk.')
32
+		state = cls.__load_guild_state(guild)
33
+		if state is None:
34
+			return {}
35
+		cls.__guild_id_to_state[guild.id] = state
36
+		return state
37
+
38
+	@classmethod
39
+	def get_state_value(cls, guild: Guild, key: str):
40
+		"""
41
+		Returns a persisted state value stored under the given key. Returns
42
+		None if not present.
43
+		"""
44
+		return cls.get_state(guild).get(key)
45
+
46
+	@classmethod
47
+	def set_state_value(cls, guild: Guild, key: str, value) -> None:
48
+		"""
49
+		Adds the given key-value pair to the persisted state for the given
50
+		Guild. If `value` is `None` the key will be removed from persisted
51
+		state.
52
+		"""
53
+		cls.set_state_values(guild, { key: value })
54
+
55
+	@classmethod
56
+	def set_state_values(cls, guild: Guild, vars: dict) -> None:
57
+		"""
58
+		Merges the given `vars` dict with the saved state for the given guild
59
+		and saves it to disk. `vars` must be JSON-encodable or a ValueError will
60
+		be raised. Keys with associated values of `None` will be removed from the
61
+		state.
62
+		"""
63
+		if vars is None or len(vars) == 0:
64
+			return
65
+		state: dict = cls.get_state(guild)
66
+		try:
67
+			json.dumps(vars)
68
+		except:
69
+			raise ValueError(f'vars not JSON encodable - {vars}')
70
+		for key, value in vars.items():
71
+			if value is None:
72
+				del state[key]
73
+			else:
74
+				state[key] = value
75
+		cls.__save_guild_state(guild, state)
76
+
77
+	@classmethod
78
+	def __save_guild_state(cls, guild: Guild, state: dict) -> None:
79
+		"""
80
+		Saves state for a guild to a JSON file on disk.
81
+		"""
82
+		path: str = cls.__guild_path(guild)
83
+		cls.__trace(f'Saving state for guild {guild.id} to {path}')
84
+		cls.__trace(f'state = {state}')
85
+		with open(path, 'w') as file:
86
+			# Pretty printing to make more legible for debugging
87
+			# Sorting keys to help with diffs
88
+			json.dump(state, file, indent='\t', sort_keys=True)
89
+		cls.__trace('State saved')
90
+
91
+	@classmethod
92
+	def __load_guild_state(cls, guild: Guild) -> dict:
93
+		"""
94
+		Loads state for a guild from a JSON file on disk, or None if not found.
95
+		"""
96
+		path: str = cls.__guild_path(guild)
97
+		if not exists(path):
98
+			cls.__trace(f'No state on disk for guild {guild.id}. Returning None.')
99
+			return None
100
+		cls.__trace(f'Loading state from disk for guild {guild.id}')
101
+		with open(path, 'r') as file:
102
+			state = json.load(file)
103
+		cls.__trace('State loaded')
104
+		return state
105
+
106
+	@classmethod
107
+	def __guild_path(cls, guild: Guild) -> str:
108
+		"""
109
+		Returns the JSON file path where guild state should be written.
110
+		"""
111
+		config_value: str = CONFIG['statePath']
112
+		path: str = config_value if config_value.endswith('/') else f'{config_value}/'
113
+		return f'{path}guild_{guild.id}.json'
114
+
115
+	@classmethod
116
+	def __trace(cls, message: str) -> None:
117
+		print(f'{Storage.__name__}: {message}')

Loading…
取消
儲存