ソースを参照

Refactored to use cogs. Joinraid seems to be working.

pull/1/head
Rocketsoup 4年前
コミット
5c12900ea5
11個のファイルの変更760行の追加960行の削除
  1. 0
    47
      cogs/base.py
  2. 49
    0
      cogs/basecog.py
  3. 0
    36
      cogs/config.py
  4. 96
    0
      cogs/configcog.py
  5. 0
    24
      cogs/general.py
  6. 39
    0
      cogs/generalcog.py
  7. 0
    232
      cogs/joinraid.py
  8. 523
    0
      cogs/joinraidcog.py
  9. バイナリ
      rocketbot.db.sample
  10. 8
    608
      rocketbot.py
  11. 45
    13
      storage.py

+ 0
- 47
cogs/base.py ファイルの表示

@@ -1,47 +0,0 @@
1
-from discord import Guild, Message, TextChannel
2
-from discord.ext import commands
3
-
4
-from storage import StateKey, Storage
5
-
6
-class BaseCog(commands.Cog):
7
-	def __init__(self, bot):
8
-		self.bot = bot
9
-
10
-	@classmethod
11
-	def save_setting(cls, guild: Guild, name: str, value):
12
-		"""
13
-		Saves one value to a guild's persisted state. The given value must
14
-		be a JSON-encodable type.
15
-		"""
16
-		if value is not None and not isinstance(value, (bool, int, float, str, list, dict)):
17
-			raise Exception(f'value for key {name} is not supported JSON type! {type(value)}')
18
-		state: dict = Storage.state_for_guild(guild)
19
-		if value is None:
20
-			del state[name]
21
-		else:
22
-			state[name] = value
23
-		Storage.save_guild_state(guild)
24
-
25
-	@classmethod
26
-	async def warn(cls, guild: Guild, message: str) -> bool:
27
-		"""
28
-		Sends a warning message to the configured warning channel for the
29
-		given guild. If no warning channel is configured no action is taken.
30
-		Returns True if the message was sent successfully or False if the
31
-		warning channel could not be found.
32
-		"""
33
-		state: dict = Storage.state_for_guild(guild)
34
-		channel_id = state.get(StateKey.WARNING_CHANNEL_ID)
35
-		if channel_id is None:
36
-			cls.guild_trace(guild, 'No warning channel set! No warning issued.')
37
-			return False
38
-		channel: TextChannel = guild.get_channel(channel_id)
39
-		if channel is None:
40
-			cls.guild_trace(guild, 'Configured warning channel no longer exists!')
41
-			return False
42
-		message: Message = await channel.send(message)
43
-		return message is not None
44
-
45
-	@classmethod
46
-	def guild_trace(cls, guild: Guild, message: str) -> None:
47
-		print(f'[guild {guild.id}|{guild.name}] {message}')

+ 49
- 0
cogs/basecog.py ファイルの表示

@@ -0,0 +1,49 @@
1
+from discord import Guild, Message, TextChannel
2
+from discord.ext import commands
3
+
4
+from storage import StateKey, Storage
5
+import json
6
+
7
+class BaseCog(commands.Cog):
8
+	def __init__(self, bot):
9
+		self.bot = bot
10
+
11
+	@classmethod
12
+	async def warn(cls, guild: Guild, message: str) -> Message:
13
+		"""
14
+		Sends a warning message to the configured warning channel for the
15
+		given guild. If no warning channel is configured no action is taken.
16
+		Returns the Message if successful or None if not.
17
+		"""
18
+		channel_id = Storage.get_state_value(guild, StateKey.WARNING_CHANNEL_ID)
19
+		if channel_id is None:
20
+			cls.guild_trace(guild, 'No warning channel set! No warning issued.')
21
+			return None
22
+		channel: TextChannel = guild.get_channel(channel_id)
23
+		if channel is None:
24
+			cls.guild_trace(guild, 'Configured warning channel does not exist!')
25
+			return None
26
+		mention: str = Storage.get_state_value(guild, StateKey.WARNING_MENTION)
27
+		text: str = message
28
+		if mention is not None:
29
+			text = f'{mention} {text}'
30
+		msg: Message = await channel.send(text)
31
+		return msg
32
+
33
+	@classmethod
34
+	async def update_warn(cls, warn_message: Message, new_text: str) -> None:
35
+		"""
36
+		Updates the text of a previously posted `warn`. Includes configured
37
+		mentions if necessary.
38
+		"""
39
+		text: str = new_text
40
+		mention: str = Storage.get_state_value(
41
+			warn_message.guild,
42
+			StateKey.WARNING_MENTION)
43
+		if mention is not None:
44
+			text = f'{mention} {text}'
45
+		await warn_message.edit(content=text)
46
+
47
+	@classmethod
48
+	def guild_trace(cls, guild: Guild, message: str) -> None:
49
+		print(f'[guild {guild.id}|{guild.name}] {message}')

+ 0
- 36
cogs/config.py ファイルの表示

@@ -1,36 +0,0 @@
1
-from discord import Guild, TextChannel
2
-from discord.ext import commands
3
-from storage import StateKey, Storage
4
-from cogs.base import BaseCog
5
-
6
-class ConfigCog(BaseCog):
7
-	"""
8
-	Cog for handling general bot configuration.
9
-	"""
10
-	def __init__(self, bot):
11
-		self.bot = bot
12
-
13
-	@commands.group(
14
-		brief='Manages general bot configuration'
15
-	)
16
-	@commands.has_permissions(ban_members=True)
17
-	@commands.guild_only()
18
-	async def config(self, context: commands.Context):
19
-		'Command group'
20
-		if context.invoked_subcommand is None:
21
-			await context.send_help()
22
-
23
-	@config.command(
24
-		name='setwarningchannel',
25
-		brief='Sets the channel where mod warnings are posted',
26
-		description='Run this command in the channel where bot messages ' +
27
-			'intended for server moderators should be sent. Other bot messages ' +
28
-			'may still be posted in the channel a command was invoked in. If ' +
29
-			'no output channel is set, mod-related messages will not be posted!',
30
-	)
31
-	async def config_setwarningchannel(self, context: commands.Context):
32
-		'Command handler'
33
-		guild: Guild = context.guild
34
-		channel: TextChannel = context.channel
35
-		self.save_setting(guild, StateKey.WARNING_CHANNEL_ID, context.channel.id)
36
-		await context.message.reply(f'Warning channel set to {channel.name}', mention_author=False)

+ 96
- 0
cogs/configcog.py ファイルの表示

@@ -0,0 +1,96 @@
1
+from discord import Guild, TextChannel
2
+from discord.ext import commands
3
+from storage import StateKey, Storage
4
+from cogs.basecog import BaseCog
5
+
6
+class ConfigCog(BaseCog):
7
+	"""
8
+	Cog for handling general bot configuration.
9
+	"""
10
+	def __init__(self, bot):
11
+		self.bot = bot
12
+
13
+	@commands.group(
14
+		brief='Manages general bot configuration'
15
+	)
16
+	@commands.has_permissions(ban_members=True)
17
+	@commands.guild_only()
18
+	async def config(self, context: commands.Context):
19
+		'Command group'
20
+		if context.invoked_subcommand is None:
21
+			await context.send_help()
22
+
23
+	@config.command(
24
+		name='setwarningchannel',
25
+		brief='Sets the channel where mod warnings are posted',
26
+		description='Run this command in the channel where bot messages ' +
27
+			'intended for server moderators should be sent. Other bot messages ' +
28
+			'may still be posted in the channel a command was invoked in. If ' +
29
+			'no output channel is set, mod-related messages will not be posted!',
30
+	)
31
+	async def config_setwarningchannel(self, context: commands.Context) -> None:
32
+		'Command handler'
33
+		guild: Guild = context.guild
34
+		channel: TextChannel = context.channel
35
+		Storage.set_state_value(guild, StateKey.WARNING_CHANNEL_ID,
36
+			context.channel.id)
37
+		await context.message.reply(
38
+			f'Warning channel updated to {channel.mention}.',
39
+			mention_author=False)
40
+
41
+	@config.command(
42
+		name='getwarningchannel',
43
+		brief='Shows the channel where mod warnings are posted',
44
+	)
45
+	async def config_getwarningchannel(self, context: commands.Context) -> None:
46
+		'Command handler'
47
+		guild: Guild = context.guild
48
+		channel_id = Storage.get_state_value(guild, StateKey.WARNING_CHANNEL_ID)
49
+		if channel_id is None:
50
+			await context.message.reply(
51
+				'No warning channel is configured.',
52
+				mention_author=False)
53
+		else:
54
+			channel = guild.get_channel(channel_id)
55
+			await context.message.reply(
56
+				f'Warning channel is configured as {channel.mention}.',
57
+				mention_author=False)
58
+
59
+	@config.command(
60
+		name='setwarningmention',
61
+		brief='Sets a user/role to mention in warning messages',
62
+		usage='<@user|@role>',
63
+		description='Configures an role or other prefix to include at the ' +
64
+			'beginning of warning messages. If the intent is to get the ' +
65
+			'attention of certain users, be sure to specify a properly ' +
66
+			'formed @ tag, not just the name of the user/role.'
67
+	)
68
+	async def config_setwarningmention(self, context: commands.Context, mention: str = None) -> None:
69
+		guild: Guild = context.guild
70
+		Storage.set_state_value(guild, StateKey.WARNING_MENTION, mention)
71
+		if mention is None:
72
+			await context.message.reply(
73
+				'Warning messages will not tag anyone.',
74
+				mention_author=False)
75
+		else:
76
+			await context.message.reply(
77
+				f'Warning messages will now tag {mention}.',
78
+				mention_author=False)
79
+
80
+	@config.command(
81
+		name='getwarningmention',
82
+		brief='Shows the user/role to mention in warning messages',
83
+		description='Shows the text, if any, that will be prefixed on any ' +
84
+			'warning messages.'
85
+	)
86
+	async def config_getwarningmention(self, context: commands.Context) -> None:
87
+		guild: Guild = context.guild
88
+		mention: str = Storage.get_state_value(guild, StateKey.WARNING_MENTION)
89
+		if mention is None:
90
+			await context.message.reply(
91
+				'No warning mention configured.',
92
+				mention_author=False)
93
+		else:
94
+			await context.message.reply(
95
+				f'Warning messages will tag {mention}',
96
+				mention_author=False)

+ 0
- 24
cogs/general.py ファイルの表示

@@ -1,24 +0,0 @@
1
-from discord.ext import commands
2
-from cogs.base import BaseCog
3
-
4
-class GeneralCog(BaseCog):
5
-	def __init__(self, bot: commands.Bot):
6
-		self.bot = bot
7
-		self.is_connected = False
8
-		self.is_ready = False
9
-
10
-	@commands.Cog.listener()
11
-	async def on_connect(self):
12
-		print('on_connect')
13
-		self.is_connected = True
14
-
15
-	@commands.Cog.listener()
16
-	async def on_ready(self):
17
-		print('on_ready')
18
-		self.is_ready = True
19
-
20
-	@commands.command()
21
-	@commands.has_permissions(ban_members=True)
22
-	@commands.guild_only()
23
-	async def testwarn(self, context):
24
-		await self.warn(context.guild, 'Test warning')

+ 39
- 0
cogs/generalcog.py ファイルの表示

@@ -0,0 +1,39 @@
1
+from discord.ext import commands
2
+from cogs.basecog import BaseCog
3
+from storage import StateKey, Storage
4
+
5
+class GeneralCog(BaseCog):
6
+	def __init__(self, bot: commands.Bot):
7
+		self.bot = bot
8
+		self.is_connected = False
9
+		self.is_ready = False
10
+
11
+	@commands.Cog.listener()
12
+	async def on_connect(self):
13
+		print('on_connect')
14
+		self.is_connected = True
15
+
16
+	@commands.Cog.listener()
17
+	async def on_ready(self):
18
+		print('on_ready')
19
+		self.is_ready = True
20
+
21
+	@commands.command(
22
+		brief='Posts a test warning in the configured warning channel.'
23
+	)
24
+	@commands.has_permissions(ban_members=True)
25
+	@commands.guild_only()
26
+	async def testwarn(self, context):
27
+		if Storage.get_state_value(context.guild, StateKey.WARNING_CHANNEL_ID) is None:
28
+			await context.message.reply(
29
+				'No warning channel set!',
30
+				mention_author=False)
31
+		else:
32
+			await self.warn(context.guild,
33
+				f'Test warning message (requested by {context.author.name})')
34
+
35
+	@commands.command()
36
+	async def hello(self, context):
37
+		await context.message.reply(
38
+			f'Hey, {context.author.name}!',
39
+		 	mention_author=False)

+ 0
- 232
cogs/joinraid.py ファイルの表示

@@ -1,232 +0,0 @@
1
-from discord import Guild, Intents, Member, Message, PartialEmoji, RawReactionActionEvent
2
-from discord.ext import commands
3
-from storage import Storage
4
-from cogs.base import BaseCog
5
-
6
-class JoinRecord:
7
-    """
8
-    Data object containing details about a guild join event.
9
-    """
10
-    def __init__(self, member: Member):
11
-        self.member = member
12
-        self.join_time = member.joined_at or datetime.now()
13
-        self.is_kicked = False
14
-        self.is_banned = False
15
-
16
-    def age_seconds(self, now: datetime) -> float:
17
-        """
18
-        Returns the age of this join in seconds from the given "now" time.
19
-        """
20
-        a = now - self.join_time
21
-        return float(a.total_seconds())
22
-
23
-class RaidPhase:
24
-    """
25
-    Enum of phases in a JoinRaid. Phases progress monotonically.
26
-    """
27
-    NONE = 0
28
-    JUST_STARTED = 1
29
-    CONTINUING = 2
30
-    ENDED = 3
31
-
32
-class JoinRaid:
33
-    """
34
-    Tracks recent joins to a guild to detect join raids, where a large number of automated users
35
-    all join at the same time.
36
-    """
37
-    def __init__(self):
38
-        self.joins = []
39
-        self.phase = RaidPhase.NONE
40
-        # datetime when the raid started, or None.
41
-        self.raid_start_time = None
42
-        # Message posted to Discord to warn of the raid. Convenience property managed
43
-        # by caller. Ignored by this class.
44
-        self.warning_message = None
45
-
46
-    def handle_join(self,
47
-            member: Member,
48
-            now: datetime,
49
-            max_age_seconds: float,
50
-            max_join_count: int) -> None:
51
-        """
52
-        Processes a new member join to a guild and detects join raids. Updates
53
-        self.phase and self.raid_start_time properties.
54
-        """
55
-        # Check for existing record for this user
56
-        print(f'handle_join({member.name}) start')
57
-        join: JoinRecord = None
58
-        i: int = 0
59
-        while i < len(self.joins):
60
-            elem = self.joins[i]
61
-            if elem.member.id == member.id:
62
-                print(f'Member {member.name} already in join list at index {i}. Removing.')
63
-                join = self.joins.pop(i)
64
-                join.join_time = now
65
-                break
66
-            i += 1
67
-        # Add new record to end
68
-        self.joins.append(join or JoinRecord(member))
69
-        # Check raid status and do upkeep
70
-        self.__process_joins(now, max_age_seconds, max_join_count)
71
-        print(f'handle_join({member.name}) end')
72
-
73
-    def __process_joins(self,
74
-            now: datetime,
75
-            max_age_seconds: float,
76
-            max_join_count: int) -> None:
77
-        """
78
-        Processes self.joins after each addition, detects raids, updates self.phase,
79
-        and throws out unneeded records.
80
-        """
81
-        print('__process_joins {')
82
-        i: int = 0
83
-        recent_count: int = 0
84
-        should_cull: bool = self.phase == RaidPhase.NONE
85
-        while i < len(self.joins):
86
-            join: JoinRecord = self.joins[i]
87
-            age: float = join.age_seconds(now)
88
-            is_old: bool = age > max_age_seconds
89
-            if not is_old:
90
-                recent_count += 1
91
-                print(f'- {i}. {join.member.name} is {age}s old - recent_count={recent_count}')
92
-            if is_old and should_cull:
93
-                self.joins.pop(i)
94
-                print(f'- {i}. {join.member.name} is {age}s old - too old, removing')
95
-            else:
96
-                print(f'- {i}. {join.member.name} is {age}s old - moving on to next')
97
-                i += 1
98
-        is_raid = recent_count > max_join_count
99
-        print(f'- is_raid {is_raid}')
100
-        if is_raid:
101
-            if self.phase == RaidPhase.NONE:
102
-                self.phase = RaidPhase.JUST_STARTED
103
-                self.raid_start_time = now
104
-                print('- Phase moved to JUST_STARTED. Recording raid start time.')
105
-            elif self.phase == RaidPhase.JUST_STARTED:
106
-                self.phase = RaidPhase.CONTINUING
107
-                print('- Phase moved to CONTINUING.')
108
-        elif self.phase == self.phase in (RaidPhase.JUST_STARTED, RaidPhase.CONTINUING):
109
-            self.phase = RaidPhase.ENDED
110
-            print('- Phase moved to ENDED.')
111
-
112
-        # Undo join add if the raid is over
113
-        if self.phase == RaidPhase.ENDED and len(self.joins) > 0:
114
-            last = self.joins.pop(-1)
115
-            print(f'- Popping last join for {last.member.name}')
116
-        print('} __process_joins')
117
-
118
-    async def kick_all(self,
119
-            reason: str = "Part of join raid") -> list[Member]:
120
-        """
121
-        Kicks all users in this join raid. Skips users who have already been
122
-        flagged as having been kicked or banned. Returns a List of Members
123
-        who were newly kicked.
124
-        """
125
-        kicks = []
126
-        for join in self.joins:
127
-            if join.is_kicked or join.is_banned:
128
-                continue
129
-            await join.member.kick(reason=reason)
130
-            join.is_kicked = True
131
-            kicks.append(join.member)
132
-        self.phase = RaidPhase.ENDED
133
-        return kicks
134
-
135
-    async def ban_all(self,
136
-            reason: str = "Part of join raid",
137
-            delete_message_days: int = 0) -> list[Member]:
138
-        """
139
-        Bans all users in this join raid. Skips users who have already been
140
-        flagged as having been banned. Users who were previously kicked can
141
-        still be banned. Returns a List of Members who were newly banned.
142
-        """
143
-        bans = []
144
-        for join in self.joins:
145
-            if join.is_banned:
146
-                continue
147
-            await join.member.ban(reason=reason, delete_message_days=delete_message_days)
148
-            join.is_banned = True
149
-            bans.append(join.member)
150
-        self.phase = RaidPhase.ENDED
151
-        return bans
152
-
153
-class JoinRaidCog(BaseCog):
154
-	"""
155
-	Cog for monitoring member joins and detecting potential bot raids.
156
-	"""
157
-	def __init__(self, bot):
158
-		self.bot = bot
159
-
160
-	@commands.group(
161
-		brief='Manages join raid detection and handling',
162
-	)
163
-	@commands.has_permissions(ban_members=True)
164
-	@commands.guild_only()
165
-	async def joinraid(self, context: commands.Context):
166
-		'Command group'
167
-		if context.invoked_subcommand is None:
168
-			await context.send_help()
169
-
170
-	@joinraid.command(
171
-		name='enable',
172
-		brief='Enables join raid detection',
173
-		description='Join raid detection is off by default.',
174
-	)
175
-	async def joinraid_enable(self, context: commands.Context):
176
-		'Command handler'
177
-		# TODO
178
-		pass
179
-
180
-	@joinraid.command(
181
-		name='disable',
182
-		brief='Disables join raid detection',
183
-		description='Join raid detection is off by default.',
184
-	)
185
-	async def joinraid_disable(self, context: commands.Context):
186
-		'Command handler'
187
-		# TODO
188
-		pass
189
-
190
-	@joinraid.command(
191
-		name='setrate',
192
-		brief='Sets the rate of joins which triggers a warning to mods',
193
-		description='Each time a member joins, the join records from the ' +
194
-			'previous _x_ seconds are counted up, where _x_ is the number of ' +
195
-			'seconds configured by this command. If that count meets or ' +
196
-			'exceeds the maximum join count configured by this command then ' +
197
-			'a raid is detected and a warning is issued to the mods.',
198
-		usage='<join_count> <seconds>',
199
-	)
200
-	async def joinraid_setrate(self, context: commands.Context,
201
-			joinCount: int,
202
-			seconds: int):
203
-		'Command handler'
204
-		# TODO
205
-		pass
206
-
207
-	@joinraid.command(
208
-		name='getrate',
209
-		brief='Shows the rate of joins which triggers a warning to mods',
210
-	)
211
-	async def joinraid_getrate(self, context: commands.Context):
212
-		'Command handler'
213
-		# TODO
214
-		pass
215
-
216
-	@commands.Cog.listener()
217
-	async def on_raw_reaction_add(self, payload: RawReactionActionEvent):
218
-		'Event handler'
219
-		# TODO
220
-		pass
221
-
222
-	@commands.Cog.listener()
223
-	async def on_member_join(self, member: Member) -> None:
224
-		'Event handler'
225
-		# TODO
226
-		pass
227
-
228
-	async def __send_warning(self) -> None:
229
-		config = self.bot.get_cog('ConfigCog')
230
-		if config is None:
231
-			return
232
-		config.

+ 523
- 0
cogs/joinraidcog.py ファイルの表示

@@ -0,0 +1,523 @@
1
+from discord import Guild, Intents, Member, Message, PartialEmoji, RawReactionActionEvent
2
+from discord.ext import commands
3
+from storage import Storage
4
+from cogs.basecog import BaseCog
5
+from config import CONFIG
6
+from datetime import datetime
7
+
8
+class JoinRecord:
9
+    """
10
+    Data object containing details about a single guild join event.
11
+    """
12
+    def __init__(self, member: Member):
13
+        self.member = member
14
+        self.join_time = member.joined_at or datetime.now()
15
+        # These flags only track whether this bot has kicked/banned
16
+        self.is_kicked = False
17
+        self.is_banned = False
18
+
19
+    def age_seconds(self, now: datetime) -> float:
20
+        """
21
+        Returns the age of this join in seconds from the given "now" time.
22
+        """
23
+        a = now - self.join_time
24
+        return float(a.total_seconds())
25
+
26
+class RaidPhase:
27
+    """
28
+    Enum of phases in a JoinRaidRecord. Phases progress monotonically.
29
+    """
30
+    NONE = 0
31
+    JUST_STARTED = 1
32
+    CONTINUING = 2
33
+    ENDED = 3
34
+
35
+class JoinRaidRecord:
36
+    """
37
+    Tracks recent joins to a guild to detect join raids, where a large number
38
+    of automated users all join at the same time. Manages list of joins to not
39
+    grow unbounded.
40
+    """
41
+    def __init__(self):
42
+        self.joins = []
43
+        self.phase = RaidPhase.NONE
44
+        # datetime when the raid started, or None.
45
+        self.raid_start_time = None
46
+        # Message posted to Discord to warn of the raid. Convenience property
47
+        # managed by caller. Ignored by this class.
48
+        self.warning_message = None
49
+
50
+    def handle_join(self,
51
+            member: Member,
52
+            now: datetime,
53
+            max_join_count: int,
54
+            max_age_seconds: float) -> None:
55
+        """
56
+        Processes a new member join to a guild and detects join raids. Updates
57
+        self.phase and self.raid_start_time properties.
58
+        """
59
+        # Check for existing record for this user
60
+        print(f'handle_join({member.name}) start')
61
+        join: JoinRecord = None
62
+        i: int = 0
63
+        while i < len(self.joins):
64
+            elem = self.joins[i]
65
+            if elem.member.id == member.id:
66
+                print(f'Member {member.name} already in join list at index {i}. Removing.')
67
+                join = self.joins.pop(i)
68
+                join.join_time = now
69
+                break
70
+            i += 1
71
+        # Add new record to end
72
+        self.joins.append(join or JoinRecord(member))
73
+        # Check raid status and do upkeep
74
+        self.__process_joins(now, max_age_seconds, max_join_count)
75
+        print(f'handle_join({member.name}) end')
76
+
77
+    def __process_joins(self,
78
+            now: datetime,
79
+            max_age_seconds: float,
80
+            max_join_count: int) -> None:
81
+        """
82
+        Processes self.joins after each addition, detects raids, updates
83
+        self.phase, and throws out unneeded records.
84
+        """
85
+        print('__process_joins {')
86
+        i: int = 0
87
+        recent_count: int = 0
88
+        should_cull: bool = self.phase == RaidPhase.NONE
89
+        while i < len(self.joins):
90
+            join: JoinRecord = self.joins[i]
91
+            age: float = join.age_seconds(now)
92
+            is_old: bool = age > max_age_seconds
93
+            if not is_old:
94
+                recent_count += 1
95
+                print(f'- {i}. {join.member.name} is {age}s old - recent_count={recent_count}')
96
+            if is_old and should_cull:
97
+                self.joins.pop(i)
98
+                print(f'- {i}. {join.member.name} is {age}s old - too old, removing')
99
+            else:
100
+                print(f'- {i}. {join.member.name} is {age}s old - moving on to next')
101
+                i += 1
102
+        is_raid = recent_count > max_join_count
103
+        print(f'- is_raid {is_raid}')
104
+        if is_raid:
105
+            if self.phase == RaidPhase.NONE:
106
+                self.phase = RaidPhase.JUST_STARTED
107
+                self.raid_start_time = now
108
+                print('- Phase moved to JUST_STARTED. Recording raid start time.')
109
+            elif self.phase == RaidPhase.JUST_STARTED:
110
+                self.phase = RaidPhase.CONTINUING
111
+                print('- Phase moved to CONTINUING.')
112
+        elif self.phase == self.phase in (RaidPhase.JUST_STARTED, RaidPhase.CONTINUING):
113
+            self.phase = RaidPhase.ENDED
114
+            print('- Phase moved to ENDED.')
115
+
116
+        # Undo join add if the raid is over
117
+        if self.phase == RaidPhase.ENDED and len(self.joins) > 0:
118
+            last = self.joins.pop(-1)
119
+            print(f'- Popping last join for {last.member.name}')
120
+        print('} __process_joins')
121
+
122
+    async def kick_all(self,
123
+            reason: str = "Part of join raid") -> list[Member]:
124
+        """
125
+        Kicks all users in this join raid. Skips users who have already been
126
+        flagged as having been kicked or banned. Returns a List of Members
127
+        who were newly kicked.
128
+        """
129
+        kicks = []
130
+        for join in self.joins:
131
+            if join.is_kicked or join.is_banned:
132
+                continue
133
+            await join.member.kick(reason=reason)
134
+            join.is_kicked = True
135
+            kicks.append(join.member)
136
+        self.phase = RaidPhase.ENDED
137
+        return kicks
138
+
139
+    async def ban_all(self,
140
+            reason: str = "Part of join raid",
141
+            delete_message_days: int = 0) -> list[Member]:
142
+        """
143
+        Bans all users in this join raid. Skips users who have already been
144
+        flagged as having been banned. Users who were previously kicked can
145
+        still be banned. Returns a List of Members who were newly banned.
146
+        """
147
+        bans = []
148
+        for join in self.joins:
149
+            if join.is_banned:
150
+                continue
151
+            await join.member.ban(
152
+                reason=reason,
153
+                delete_message_days=delete_message_days)
154
+            join.is_banned = True
155
+            bans.append(join.member)
156
+        self.phase = RaidPhase.ENDED
157
+        return bans
158
+
159
+class GuildContext:
160
+    """
161
+    Logic and state for a single guild serviced by the bot.
162
+    """
163
+    def __init__(self, guild_id: int):
164
+        self.guild_id = guild_id
165
+        self.join_warning_count = CONFIG['joinWarningCount']
166
+        self.join_warning_seconds = CONFIG['joinWarningSeconds']
167
+        # Non-persisted runtime state
168
+        self.current_raid = JoinRaidRecord()
169
+        self.all_raids = [ self.current_raid ] # periodically culled of old ones
170
+
171
+    # Events
172
+
173
+    async def handle_join(self, member: Member) -> None:
174
+        """
175
+        Event handler for all joins to this guild.
176
+        """
177
+        now = member.joined_at
178
+        raid = self.current_raid
179
+        raid.handle_join(
180
+            member,
181
+            now=now,
182
+            max_age_seconds = self.join_warning_seconds,
183
+            max_join_count = self.join_warning_count)
184
+        self.__trace(f'raid phase: {raid.phase}')
185
+        if raid.phase == RaidPhase.JUST_STARTED:
186
+            await self.__on_join_raid_begin(raid)
187
+        elif raid.phase == RaidPhase.CONTINUING:
188
+            await self.__on_join_raid_updated(raid)
189
+        elif raid.phase == RaidPhase.ENDED:
190
+            self.__start_new_raid(member)
191
+            await self.__on_join_raid_end(raid)
192
+        self.__cull_old_raids(now)
193
+
194
+    def reset_raid(self, now: datetime):
195
+        """
196
+        Retires self.current_raid and creates a new empty one.
197
+        """
198
+        self.current_raid = JoinRaidRecord()
199
+        self.all_raids.append(self.current_raid)
200
+        self.__cull_old_raids(now)
201
+
202
+    def find_raid_for_message_id(self, message_id: int) -> JoinRaidRecord:
203
+        """
204
+        Retrieves a JoinRaidRecord instance for the given raid warning message.
205
+        Returns None if not found.
206
+        """
207
+        for raid in self.all_raids:
208
+            if raid.warning_message.id == message_id:
209
+                return raid
210
+        return None
211
+
212
+    def __cull_old_raids(self, now: datetime):
213
+        """
214
+        Gets rid of old JoinRaidRecord records from self.all_raids that are too
215
+        old to still be useful.
216
+        """
217
+        i: int = 0
218
+        while i < len(self.all_raids):
219
+            raid = self.all_raids[i]
220
+            if raid == self.current_raid:
221
+                i += 1
222
+                continue
223
+            age_seconds = float((raid.raid_start_time - now).total_seconds())
224
+            if age_seconds > 86400.0:
225
+                self.__trace('Culling old raid')
226
+                self.all_raids.pop(i)
227
+            else:
228
+                i += 1
229
+
230
+    def __trace(self, message):
231
+        """
232
+        Debugging trace.
233
+        """
234
+        print(f'{self.guild_id}: {message}')
235
+
236
+class JoinRaidCog(BaseCog):
237
+    """
238
+    Cog for monitoring member joins and detecting potential bot raids.
239
+    """
240
+    MIN_JOIN_COUNT = 2
241
+
242
+    STATE_KEY_RAID_COUNT = 'joinraid_count'
243
+    STATE_KEY_RAID_SECONDS = 'joinraid_seconds'
244
+    STATE_KEY_ENABLED = 'joinraid_enabled'
245
+
246
+    def __init__(self, bot):
247
+        self.bot = bot
248
+        self.guild_id_to_context = {}  # Guild.id -> GuildContext
249
+
250
+    # -- Config -------------------------------------------------------------
251
+
252
+    def __get_raid_rate(self, guild: Guild) -> tuple:
253
+        """
254
+        Returns the join rate configured for this guild.
255
+        """
256
+        count: int = Storage.get_state_value(guild, self.STATE_KEY_RAID_COUNT) \
257
+            or CONFIG['joinWarningCount']
258
+        seconds: float = Storage.get_state_value(guild, self.STATE_KEY_RAID_SECONDS) \
259
+            or CONFIG['joinWarningSeconds']
260
+        return (count, seconds)
261
+
262
+    def __is_enabled(self, guild: Guild) -> bool:
263
+        """
264
+        Returns whether join raid detection is enabled in this guild.
265
+        """
266
+        return Storage.get_state_value(guild, self.STATE_KEY_ENABLED) or False
267
+
268
+    # -- Commands -----------------------------------------------------------
269
+
270
+    @commands.group(
271
+        brief='Manages join raid detection and handling',
272
+    )
273
+    @commands.has_permissions(ban_members=True)
274
+    @commands.guild_only()
275
+    async def joinraid(self, context: commands.Context):
276
+        'Command group'
277
+        if context.invoked_subcommand is None:
278
+            await context.send_help()
279
+
280
+    @joinraid.command(
281
+        name='enable',
282
+        brief='Enables join raid detection',
283
+        description='Join raid detection is off by default.',
284
+    )
285
+    async def joinraid_enable(self, context: commands.Context):
286
+        'Command handler'
287
+        guild = context.guild
288
+        Storage.set_state_value(guild, self.STATE_KEY_ENABLED, True)
289
+        # TODO: Startup tracking if necessary
290
+        await context.message.reply(
291
+            '✅ ' + self.__describe_raid_settings(guild, force_enabled_status=True),
292
+            mention_author=False)
293
+
294
+    @joinraid.command(
295
+        name='disable',
296
+        brief='Disables join raid detection',
297
+        description='Join raid detection is off by default.',
298
+    )
299
+    async def joinraid_disable(self, context: commands.Context):
300
+        'Command handler'
301
+        guild = context.guild
302
+        Storage.set_state_value(guild, self.STATE_KEY_ENABLED, False)
303
+        # TODO: Tear down tracking if necessary
304
+        await context.message.reply(
305
+            '✅ ' + self.__describe_raid_settings(guild, force_enabled_status=True),
306
+            mention_author=False)
307
+
308
+    @joinraid.command(
309
+        name='setrate',
310
+        brief='Sets the rate of joins which triggers a warning to mods',
311
+        description='Each time a member joins, the join records from the ' +
312
+            'previous _x_ seconds are counted up, where _x_ is the number of ' +
313
+            'seconds configured by this command. If that count meets or ' +
314
+            'exceeds the maximum join count configured by this command then ' +
315
+            'a raid is detected and a warning is issued to the mods.',
316
+        usage='<join_count:int> <seconds:float>',
317
+    )
318
+    async def joinraid_setrate(self, context: commands.Context,
319
+            join_count: int,
320
+            seconds: float):
321
+        'Command handler'
322
+        guild = context.guild
323
+        if join_count < self.MIN_JOIN_COUNT:
324
+            await context.message.reply(
325
+                f'⚠️ `join_count` must be >= {self.MIN_JOIN_COUNT}',
326
+                mention_author=False)
327
+            return
328
+        if seconds <= 0:
329
+            await context.message.reply(
330
+                f'⚠️ `seconds` must be > 0',
331
+                mention_author=False)
332
+            return
333
+        Storage.set_state_values(guild, {
334
+            self.STATE_KEY_RAID_COUNT: join_count,
335
+            self.STATE_KEY_RAID_SECONDS: seconds,
336
+        })
337
+
338
+        await context.message.reply(
339
+            '✅ ' + self.__describe_raid_settings(guild, force_rate_status=True),
340
+            mention_author=False)
341
+
342
+    @joinraid.command(
343
+        name='getrate',
344
+        brief='Shows the rate of joins which triggers a warning to mods',
345
+    )
346
+    async def joinraid_getrate(self, context: commands.Context):
347
+        'Command handler'
348
+        await context.message.reply(
349
+            'ℹ️ ' + self.__describe_raid_settings(context.guild, force_rate_status=True),
350
+            mention_author=False)
351
+
352
+    # -- Listeners ----------------------------------------------------------
353
+
354
+    @commands.Cog.listener()
355
+    async def on_raw_reaction_add(self, payload: RawReactionActionEvent):
356
+        'Event handler'
357
+        if payload.user_id == self.bot.user.id:
358
+            # Ignore bot's own reactions
359
+            return
360
+        member: Member = payload.member
361
+        if member is None:
362
+            return
363
+        guild: Guild = self.bot.get_guild(payload.guild_id)
364
+        if guild is None:
365
+            # Possibly a DM
366
+            return
367
+        channel: GuildChannel = guild.get_channel(payload.channel_id)
368
+        if channel is None:
369
+            # Possibly a DM
370
+            return
371
+        message: Message = await channel.fetch_message(payload.message_id)
372
+        if message is None:
373
+            # Message deleted?
374
+            return
375
+        if message.author.id != self.bot.user.id:
376
+            # Bot didn't author this
377
+            return
378
+        if not member.permissions_in(channel).ban_members:
379
+            # Not a mod
380
+            # TODO: Remove reaction?
381
+            return
382
+        gc: GuildContext = self.__get_guild_context(guild)
383
+        raid: JoinRaidRecord = gc.find_raid_for_message_id(payload.message_id)
384
+        if raid is None:
385
+            # Either not a warning message or one we stopped tracking
386
+            return
387
+        emoji: PartialEmoji = payload.emoji
388
+        if emoji.name == CONFIG['kickEmoji']:
389
+            await raid.kick_all()
390
+            gc.reset_raid(message.created_at)
391
+            await self.__update_raid_warning(guild, raid)
392
+        elif emoji.name == CONFIG['banEmoji']:
393
+            await raid.ban_all()
394
+            gc.reset_raid(message.created_at)
395
+            await self.__update_raid_warning(guild, raid)
396
+
397
+    @commands.Cog.listener()
398
+    async def on_member_join(self, member: Member) -> None:
399
+        'Event handler'
400
+        guild: Guild = member.guild
401
+        if not self.__is_enabled(guild):
402
+            return
403
+        (count, seconds) = self.__get_raid_rate(guild)
404
+        now = member.joined_at
405
+        gc: GuildContext = self.__get_guild_context(guild)
406
+        raid: JoinRaidRecord = gc.current_raid
407
+        raid.handle_join(member, now, count, seconds)
408
+        if raid.phase == RaidPhase.JUST_STARTED:
409
+            await self.__post_raid_warning(guild, raid)
410
+        elif raid.phase == RaidPhase.CONTINUING:
411
+            await self.__update_raid_warning(guild, raid)
412
+        elif raid.phase == RaidPhase.ENDED:
413
+            # First join that occurred too late to be part of last raid. Join
414
+            # not added. Start a new raid record and add it there.
415
+            gc.reset_raid(now)
416
+            gc.current_raid.handle_join(member, now, count, seconds)
417
+
418
+    # -- Misc ---------------------------------------------------------------
419
+
420
+    def __describe_raid_settings(self,
421
+            guild: Guild,
422
+            force_enabled_status=False,
423
+            force_rate_status=False) -> str:
424
+        """
425
+        Creates a Discord message describing the current join raid settings.
426
+        """
427
+        enabled = self.__is_enabled(guild)
428
+        (count, seconds) = self.__get_raid_rate(guild)
429
+
430
+        sentences = []
431
+
432
+        if enabled or force_rate_status:
433
+            sentences.append(f'Join raids will be detected at {count} or more joins per {seconds} seconds.')
434
+
435
+        if enabled and force_enabled_status:
436
+                sentences.append('Raid detection enabled.')
437
+        elif not enabled:
438
+            sentences.append('Raid detection disabled.')
439
+
440
+        tips = []
441
+        if enabled or force_rate_status:
442
+            tips.append('• Use `setrate` subcommand to change detection threshold')
443
+        if enabled:
444
+            tips.append('• Use `disable` subcommand to disable detection.')
445
+        else:
446
+            tips.append('• Use `enable` subcommand to enable detection.')
447
+
448
+        message = ''
449
+        message += ' '.join(sentences)
450
+        if len(tips) > 0:
451
+            message += '\n\n' + ('\n'.join(tips))
452
+        return message
453
+
454
+    def __get_guild_context(self, guild: Guild) -> GuildContext:
455
+        """
456
+        Looks up the GuildContext for the given Guild or creates a new one if
457
+        one does not yet exist.
458
+        """
459
+        gc: GuildContext = self.guild_id_to_context.get(guild.id)
460
+        if gc is not None:
461
+            return gc
462
+        gc = GuildContext(guild.id)
463
+        self.guild_id_to_context[guild.id] = gc
464
+        return gc
465
+
466
+    async def __post_raid_warning(self, guild: Guild, raid: JoinRaidRecord) -> None:
467
+        """
468
+        Posts a warning message about the given raid.
469
+        """
470
+        (message, can_kick, can_ban) = self.__describe_raid(raid)
471
+        raid.warning_message = await self.warn(guild, message)
472
+        if can_kick:
473
+            await raid.warning_message.add_reaction(CONFIG['kickEmoji'])
474
+        if can_ban:
475
+            await raid.warning_message.add_reaction(CONFIG['banEmoji'])
476
+
477
+    async def __update_raid_warning(self, guild: Guild, raid: JoinRaidRecord) -> None:
478
+        """
479
+        Updates the existing warning message for a raid.
480
+        """
481
+        if raid.warning_message is None:
482
+            return
483
+        (message, can_kick, can_ban) = self.__describe_raid(raid)
484
+        await self.update_warn(raid.warning_message, message)
485
+        if not can_kick:
486
+            await raid.warning_message.clear_reaction(CONFIG['kickEmoji'])
487
+        if not can_ban:
488
+            await raid.warning_message.clear_reaction(CONFIG['banEmoji'])
489
+
490
+    def __describe_raid(self, raid: JoinRaidRecord) -> tuple:
491
+        """
492
+        Creates a Discord warning message with details about the given raid.
493
+        Returns a tuple containing the message text, a flag if any users can
494
+        still be kicked, and a flag if anyone can still be banned.
495
+        """
496
+        message = '🚨 **JOIN RAID DETECTED** 🚨'
497
+        message += '\nThe following members joined in close succession:\n'
498
+        any_kickable = False
499
+        any_bannable = False
500
+        for join in raid.joins:
501
+            message += '\n• '
502
+            if join.is_banned:
503
+                message += '~~' + join.member.mention + '~~ - banned'
504
+            elif join.is_kicked:
505
+                message += '~~' + join.member.mention + '~~ - kicked'
506
+                any_bannable = True
507
+            else:
508
+                message += join.member.mention
509
+                any_bannable = True
510
+                any_kickable = True
511
+        message += '\n_(list updates automatically)_'
512
+
513
+        message += '\n'
514
+        if any_kickable:
515
+            message += f'\nReact to this message with {CONFIG["kickEmoji"]} to kick all these users.'
516
+        else:
517
+            message += '\nNo users left to kick.'
518
+        if any_bannable:
519
+            message += f'\nReact to this message with {CONFIG["banEmoji"]} to ban all these users.'
520
+        else:
521
+            message += '\nNo users left to ban.'
522
+
523
+        return (message, any_kickable, any_bannable)

バイナリ
rocketbot.db.sample ファイルの表示


+ 8
- 608
rocketbot.py ファイルの表示

@@ -5,623 +5,23 @@ the sqlite database rocketbot.db (copy rocketbot.db.sample for a blank database)
5 5
 Author: Ian Albert (@rocketsoup)
6 6
 Date: 2021-11-11
7 7
 """
8
-from datetime import datetime
9
-import sqlite3
10
-import sys
11
-
12
-from discord import Guild, Intents, Member, Message, PartialEmoji, RawReactionActionEvent
13
-from discord.abc import GuildChannel
8
+from discord import Intents
14 9
 from discord.ext import commands
15
-from discord.ext.commands.context import Context
16 10
 
17 11
 from config import CONFIG
18
-from cogs.config import ConfigCog
19
-from cogs.general import GeneralCog
12
+from cogs.configcog import ConfigCog
13
+from cogs.generalcog import GeneralCog
14
+from cogs.joinraidcog import JoinRaidCog
20 15
 
21 16
 class Rocketbot(commands.Bot):
22 17
     def __init__(self, command_prefix, **kwargs):
23 18
         super().__init__(command_prefix, **kwargs)
24 19
 
25
-bot = Rocketbot(CONFIG['commandPrefix'])
20
+intents = Intents.default()
21
+intents.members = True # To get join/leave events
22
+bot = Rocketbot(command_prefix=CONFIG['commandPrefix'], intents=intents)
26 23
 bot.add_cog(GeneralCog(bot))
27 24
 bot.add_cog(ConfigCog(bot))
25
+bot.add_cog(JoinRaidCog(bot))
28 26
 bot.run(CONFIG['clientToken'], bot=True, reconnect=True)
29 27
 print('\nBot aborted')
30
-
31
-# -- Classes ----------------------------------------------------------------
32
-
33
-# class GuildContext:
34
-#     """
35
-#     Logic and state for a single guild serviced by the bot.
36
-#     """
37
-#     def __init__(self, guild_id: int):
38
-#         self.guild_id = guild_id
39
-#         self.guild = None # Resolved later
40
-#         # Config populated during load
41
-#         self.warning_channel_id = None
42
-#         self.warning_channel = None
43
-#         self.warning_mention = None
44
-#         self.join_warning_count = CONFIG['joinWarningCount']
45
-#         self.join_warning_seconds = CONFIG['joinWarningSeconds']
46
-#         # Non-persisted runtime state
47
-#         self.current_raid = JoinRaid()
48
-#         self.all_raids = [ self.current_raid ] # periodically culled of old ones
49
-
50
-#     # Commands
51
-
52
-#     async def command_hello(self, message: Message) -> None:
53
-#         """
54
-#         Command handler
55
-#         """
56
-#         await message.channel.send(f'Hey there, {message.author.mention}!')
57
-
58
-#     async def command_testwarn(self, context: Context) -> None:
59
-#         """
60
-#         Command handler
61
-#         """
62
-#         if self.warning_channel is None:
63
-#             self.__trace('No warning channel set!')
64
-#             await context.message.channel.send('No warning channel set on this guild! Type ' +
65
-#                 f'`{bot.command_prefix}{setwarningchannel.__name__}` in the channel you ' +
66
-#                 'want warnings to be posted.')
67
-#             return
68
-#         await self.__warn('Test warning. This is only a test.')
69
-
70
-#     async def command_setwarningchannel(self, context: Context):
71
-#         """
72
-#         Command handler
73
-#         """
74
-#         self.__trace(f'Warning channel set to {context.channel.name}')
75
-#         self.warning_channel = context.channel
76
-#         self.warning_channel_id = context.channel.id
77
-#         save_guild_context(self)
78
-#         await self.__warn('Warning messages will now be sent to ' + self.warning_channel.mention)
79
-
80
-#     async def command_setwarningmention(self, _context: Context, mention: str):
81
-#         """
82
-#         Command handler
83
-#         """
84
-#         self.__trace('set warning mention')
85
-#         m = mention if mention is not None and len(mention) > 0 else None
86
-#         self.warning_mention = m
87
-#         save_guild_context(self)
88
-#         if m is None:
89
-#             await self.__warn('Warning messages will not mention anyone')
90
-#         else:
91
-#             await self.__warn('Warning messages will now mention ' + m)
92
-
93
-#     async def command_setraidwarningrate(self, _context: Context, count: int, seconds: int):
94
-#         """
95
-#         Command handler
96
-#         """
97
-#         self.join_warning_count = count
98
-#         self.join_warning_seconds = seconds
99
-#         save_guild_context(self)
100
-#         await self.__warn(f'Maximum join rate set to {count} joins per {seconds} seconds')
101
-
102
-#     # Events
103
-
104
-#     async def handle_join(self, member: Member) -> None:
105
-#         """
106
-#         Event handler for all joins to this guild.
107
-#         """
108
-#         print(f'{member.guild.name}: {member.name} joined')
109
-#         now = member.joined_at
110
-#         raid = self.current_raid
111
-#         raid.handle_join(
112
-#             member,
113
-#             now=now,
114
-#             max_age_seconds = self.join_warning_seconds,
115
-#             max_join_count = self.join_warning_count)
116
-#         self.__trace(f'raid phase: {raid.phase}')
117
-#         if raid.phase == RaidPhase.JUST_STARTED:
118
-#             await self.__on_join_raid_begin(raid)
119
-#         elif raid.phase == RaidPhase.CONTINUING:
120
-#             await self.__on_join_raid_updated(raid)
121
-#         elif raid.phase == RaidPhase.ENDED:
122
-#             self.__start_new_raid(member)
123
-#             await self.__on_join_raid_end(raid)
124
-#         self.__cull_old_raids(now)
125
-
126
-#     def __start_new_raid(self, member: Member = None):
127
-#         """
128
-#         Retires self.current_raid and creates a new empty one. If `member` is passed, it will be
129
-#         added to the new self.current_raid after it is created.
130
-#         """
131
-#         self.current_raid = JoinRaid()
132
-#         self.all_raids.append(self.current_raid)
133
-#         if member is not None:
134
-#             self.current_raid.handle_join(
135
-#                 member,
136
-#                 member.joined_at,
137
-#                 max_age_seconds = self.join_warning_seconds,
138
-#                 max_join_count = self.join_warning_count)
139
-
140
-#     async def handle_reaction_add(self, message, member, emoji):
141
-#         """
142
-#         Handles all message reaction events to see if they need to be acted on.
143
-#         """
144
-#         if member.id == bot.user.id:
145
-#             # It's-a me, Rocketbot!
146
-#             return
147
-#         if message.author.id != bot.user.id:
148
-#             # The message the user is reacting to wasn't authored by me. Ignore.
149
-#             return
150
-#         self.__trace(f'User {member} added emoji {emoji}')
151
-#         if not member.permissions_in(message.channel).ban_members:
152
-#             self.__trace('Reactor does not have ban permissions. Ignoring.')
153
-#             return
154
-#         if emoji.name == CONFIG['kickEmoji']:
155
-#             await self.__kick_all_in_raid_message(message)
156
-#         elif emoji.name == CONFIG['banEmoji']:
157
-#             await self.__ban_all_in_raid_message(message)
158
-#         else:
159
-#             print('Unhandled emoji. Ignoring.')
160
-#             return
161
-
162
-#     async def __kick_all_in_raid_message(self, message: Message):
163
-#         """
164
-#         Kicks all the users mentioned in the given raid warning message. Users who were already
165
-#         kicked or banned will be skipped.
166
-#         """
167
-#         raid = self.__find_raid_for_message(message)
168
-#         if raid is None:
169
-#             await message.reply("This is either not a raid warning or it's too old and I don't " +
170
-#                 "have a record for it anymore. Sorry!")
171
-#             return
172
-#         self.__trace('Kicking...')
173
-#         members = await raid.kick_all()
174
-#         msg = 'Kicked these members:'
175
-#         for member in members:
176
-#             msg += f'\n\t{member.name}'
177
-#         if len(members) == 0:
178
-#             msg += '\n\t-none-'
179
-#         self.__trace(msg)
180
-#         self.__start_new_raid()
181
-#         await self.__update_join_raid_message(raid)
182
-
183
-#     async def __ban_all_in_raid_message(self, message: Message):
184
-#         """
185
-#         Bans all the users mentioned in the given raid warning message. Users who were already
186
-#         banned will be skipped.
187
-#         """
188
-#         raid = self.__find_raid_for_message(message)
189
-#         if raid is None:
190
-#             await message.reply("This is either not a raid warning or it's too old and I don't " +
191
-#                 "have a record for it anymore. Sorry!")
192
-#             return
193
-#         self.__trace('Banning...')
194
-#         members = await raid.ban_all()
195
-#         msg = 'Banned these members:'
196
-#         for member in members:
197
-#             msg += f'\n\t{member.name}'
198
-#         if len(members) == 0:
199
-#             msg += '\n\t-none-'
200
-#         self.__trace(msg)
201
-#         self.__start_new_raid()
202
-#         await self.__update_join_raid_message(raid)
203
-
204
-#     def __find_raid_for_message(self, message: Message) -> JoinRaid:
205
-#         """
206
-#         Retrieves a JoinRaid instance for the given raid warning message. Returns None if not found.
207
-#         """
208
-#         for raid in self.all_raids:
209
-#             if raid.warning_message.id == message.id:
210
-#                 return raid
211
-#         return None
212
-
213
-#     def __cull_old_raids(self, now: datetime):
214
-#         """
215
-#         Gets rid of old JoinRaid records from self.all_raids that are too old to still be useful.
216
-#         """
217
-#         i: int = 0
218
-#         while i < len(self.all_raids):
219
-#             raid = self.all_raids[i]
220
-#             if raid == self.current_raid:
221
-#                 i += 1
222
-#                 continue
223
-#             age_seconds = float((raid.raid_start_time - now).total_seconds())
224
-#             if age_seconds > 86400.0:
225
-#                 self.__trace('Culling old raid')
226
-#                 self.all_raids.pop(i)
227
-#             else:
228
-#                 i += 1
229
-
230
-#     def __join_raid_message(self, raid: JoinRaid):
231
-#         """
232
-#         Returns a 3-element tuple containing a text message appropriate for posting in
233
-#         Discord, a flag of whether any of the mentioned users can be kicked, and a flag
234
-#         of whether any of the mentioned users can be banned.
235
-#         """
236
-#         message = ''
237
-#         if self.warning_mention is not None:
238
-#             message = self.warning_mention + ' '
239
-#         message += '**RAID JOIN DETECTED!** It includes these users:\n'
240
-#         can_kick = False
241
-#         can_ban = False
242
-#         for join in raid.joins:
243
-#             message += '\n• '
244
-#             if join.is_banned:
245
-#                 message += '~~' + join.member.mention + '~~ - banned'
246
-#             elif join.is_kicked:
247
-#                 message += '~~' + join.member.mention + '~~ - kicked'
248
-#                 can_ban = True
249
-#             else:
250
-#                 message += join.member.mention
251
-#                 can_kick = True
252
-#                 can_ban = True
253
-#         message += '\n'
254
-#         if can_kick:
255
-#             message += '\nTo kick all these users, react with :' + CONFIG['kickEmojiName'] + ':'
256
-#         else:
257
-#             message += '\nNo kickable users remain'
258
-#         if can_ban:
259
-#             message += '\nTo ban all these users, react with :' + CONFIG['banEmojiName'] + ':'
260
-#         else:
261
-#             message += '\nNo bannable users remain'
262
-#         return (message, can_kick, can_ban)
263
-
264
-#     async def __update_join_raid_message(self, raid: JoinRaid):
265
-#         """
266
-#         Updates an existing join raid warning message with updated data.
267
-#         """
268
-#         if raid.warning_message is None:
269
-#             self.__trace('No raid warning message to update')
270
-#             return
271
-#         (message, can_kick, can_ban) = self.__join_raid_message(raid)
272
-#         await raid.warning_message.edit(content=message)
273
-#         if not can_kick:
274
-#             await raid.warning_message.clear_reaction(CONFIG['kickEmoji'])
275
-#         if not can_ban:
276
-#             await raid.warning_message.clear_reaction(CONFIG['banEmoji'])
277
-
278
-#     async def __on_join_raid_begin(self, raid):
279
-#         """
280
-#         Event triggered when the first member joins that triggers the raid detection.
281
-#         """
282
-#         self.__trace('A join raid has begun!')
283
-#         if self.warning_channel is None:
284
-#             self.__trace('NO WARNING CHANNEL SET')
285
-#             return
286
-#         (message, can_kick, can_ban) = self.__join_raid_message(raid)
287
-#         raid.warning_message = await self.warning_channel.send(message)
288
-#         if can_kick:
289
-#             await raid.warning_message.add_reaction(CONFIG['kickEmoji'])
290
-#         if can_ban:
291
-#             await raid.warning_message.add_reaction(CONFIG['banEmoji'])
292
-
293
-#     async def __on_join_raid_updated(self, raid):
294
-#         """
295
-#         Event triggered for each subsequent member join after the first one that triggered the
296
-#         raid detection.
297
-#         """
298
-#         self.__trace('Join raid still occurring')
299
-#         await self.__update_join_raid_message(raid)
300
-
301
-#     async def __on_join_raid_end(self, _raid):
302
-#         """
303
-#         Event triggered when the first member joins who is not part of the most recent raid.
304
-#         """
305
-#         self.__trace('Join raid has ended')
306
-
307
-#     async def __warn(self, message):
308
-#         """
309
-#         Posts a warning message in the configured warning channel.
310
-#         """
311
-#         if self.warning_channel is None:
312
-#             self.__trace('NO WARNING CHANNEL SET. Warning message not posted.\n' + message)
313
-#             return None
314
-#         m = message
315
-#         if self.warning_mention is not None:
316
-#             m = self.warning_mention + ' ' + m
317
-#         return await self.warning_channel.send(m)
318
-
319
-#     def __trace(self, message):
320
-#         """
321
-#         Debugging trace.
322
-#         """
323
-#         print(f'{self.guild.name}: {message}')
324
-
325
-# # lookup for int(Guild.guild_id) --> GuildContext
326
-# guild_id_to_guild_context = {}
327
-
328
-# def get_or_create_guild_context(val, save=True):
329
-#     """
330
-#     Retrieves a cached GuildContext instance by its Guild id or Guild object
331
-#     itself. If no GuildContext record exists for the Guild, one is created
332
-#     and cached (and saved to the database unless `save=False`).
333
-#     """
334
-#     gid = None
335
-#     guild = None
336
-#     if val is None:
337
-#         return None
338
-#     if isinstance(val, int):
339
-#         gid = val
340
-#     elif isinstance(val, Guild):
341
-#         gid = val.id
342
-#         guild = val
343
-#     if gid is None:
344
-#         print('Unhandled datatype', type(val))
345
-#         return None
346
-#     looked_up = guild_id_to_guild_context.get(gid)
347
-#     if looked_up is not None:
348
-#         return looked_up
349
-#     gc = GuildContext(gid)
350
-#     gc.guild = guild or gc.guild
351
-#     guild_id_to_guild_context[gid] = gc
352
-#     if save:
353
-#         save_guild_context(gc)
354
-#     return gc
355
-
356
-# # -- Database ---------------------------------------------------------------
357
-
358
-# def run_sql_batch(batch_function):
359
-#     """
360
-#     Performs an SQL transaction. After a connection is opened, the passed
361
-#     function is invoked with the sqlite3.Connection and sqlite3.Cursor
362
-#     passed as arguments. Once the passed function finishes, the connection
363
-#     is closed.
364
-#     """
365
-#     db_connection: sqlite3.Connection = sqlite3.connect('rocketbot.db')
366
-#     db_cursor: sqlite3.Cursor = db_connection.cursor()
367
-#     batch_function(db_connection, db_cursor)
368
-#     db_connection.commit()
369
-#     db_connection.close()
370
-
371
-# def load_guild_settings():
372
-#     """
373
-#     Populates the GuildContext cache with records from the database.
374
-#     """
375
-#     def load(_con, cur):
376
-#         """
377
-#         SQL
378
-#         """
379
-#         for row in cur.execute("""SELECT * FROM guilds"""):
380
-#             guild_id = row[0]
381
-#             gc = get_or_create_guild_context(guild_id, save=False)
382
-#             gc.warning_channel_id = row[1]
383
-#             gc.warning_mention = row[2]
384
-#             gc.join_warning_count = row[3] or CONFIG['joinWarningCount']
385
-#             gc.join_warning_seconds = row[4] or CONFIG['joinWarningSeconds']
386
-#             print(f'Guild {guild_id} channel id is {gc.warning_channel_id}')
387
-#     run_sql_batch(load)
388
-
389
-# def create_tables():
390
-#     """
391
-#     Creates all database tables.
392
-#     """
393
-#     def make_tables(_con, cur):
394
-#         """
395
-#         SQL
396
-#         """
397
-#         cur.execute("""CREATE TABLE guilds (
398
-#             guildId INTEGER,
399
-#             warningChannelId INTEGER,
400
-#             warningMention TEXT,
401
-#             joinWarningCount INTEGER,
402
-#             joinWarningSeconds INTEGER,
403
-#             PRIMARY KEY(guildId ASC))""")
404
-#     run_sql_batch(make_tables)
405
-
406
-# def save_guild_context(gc: GuildContext):
407
-#     """
408
-#     Saves the state of a GuildContext record to the database.
409
-#     """
410
-#     def save(_con, cur):
411
-#         """
412
-#         SQL
413
-#         """
414
-#         print(f'Saving guild context with id {gc.guild_id}')
415
-#         cur.execute("""
416
-#             SELECT guildId
417
-#             FROM guilds
418
-#             WHERE guildId=?
419
-#             """, (
420
-#                 gc.guild_id,
421
-#             ))
422
-#         channel_id = gc.warning_channel.id if gc.warning_channel is not None \
423
-#                 else gc.warning_channel_id
424
-#         exists = cur.fetchone() is not None
425
-#         if exists:
426
-#             print('Updating existing guild record in db')
427
-#             cur.execute("""
428
-#                 UPDATE guilds
429
-#                 SET warningChannelId=?,
430
-#                     warningMention=?,
431
-#                     joinWarningCount=?,
432
-#                     joinWarningSeconds=?
433
-#                 WHERE guildId=?
434
-#                 """, (
435
-#                     channel_id,
436
-#                     gc.warning_mention,
437
-#                     gc.join_warning_count,
438
-#                     gc.join_warning_seconds,
439
-#                     gc.guild_id,
440
-#                 ))
441
-#         else:
442
-#             print('Creating new guild record in db')
443
-#             cur.execute("""
444
-#                 INSERT INTO guilds (
445
-#                     guildId,
446
-#                     warningChannelId,
447
-#                     warningMention,
448
-#                     joinWarningCount,
449
-#                     joinWarningSeconds)
450
-#                 VALUES (?, ?, ?, ?, ?)
451
-#                 """, (
452
-#                     gc.guild_id,
453
-#                     channel_id,
454
-#                     gc.warning_mention,
455
-#                     gc.join_warning_count,
456
-#                     gc.join_warning_seconds,
457
-#                 ))
458
-#     run_sql_batch(save)
459
-
460
-# # -- Main (1) ---------------------------------------------------------------
461
-
462
-# load_guild_settings()
463
-
464
-# intents = Intents.default()
465
-# intents.members = True # To get join/leave events
466
-# bot = commands.Bot(command_prefix=CONFIG['commandPrefix'], intents=intents)
467
-
468
-# # -- Bot commands -----------------------------------------------------------
469
-
470
-# @bot.command(
471
-#     brief='Simply replies to the invoker with a hello message in the same channel.'
472
-# )
473
-# async def hello(ctx: Context):
474
-#     """
475
-#     Command handler
476
-#     """
477
-#     gc: GuildContext = get_or_create_guild_context(ctx.guild)
478
-#     if gc is None:
479
-#         return
480
-#     message = ctx.message
481
-#     await gc.command_hello(message)
482
-
483
-# @bot.command(
484
-#     brief='Posts a test warning message in the configured warning channel.',
485
-#     help="""If no warning channel is configured, the bot will reply in the channel the command was
486
-#     issued to notify no warning channel is set. If a warning mention is configured, the test 
487
-#     warning will tag the configured person/role."""
488
-# )
489
-# @commands.has_permissions(manage_messages=True)
490
-# async def testwarn(ctx: Context):
491
-#     """
492
-#     Command handler
493
-#     """
494
-#     gc: GuildContext = get_or_create_guild_context(ctx.guild)
495
-#     if gc is None:
496
-#         return
497
-#     await gc.command_testwarn(ctx)
498
-
499
-# @bot.command(
500
-#     brief='Sets the threshold for detecting a join raid.',
501
-#     usage='<count> <seconds>',
502
-#     help="""The raid threshold is expressed as number of joins within a given number of seconds.
503
-#     Each time a member joins, the number of joins in the previous _x_ seconds is counted, and if
504
-#     that count, _y_, equals or exceeds the count configured by this command, a raid is detected."""
505
-# )
506
-# @commands.has_permissions(manage_messages=True)
507
-# async def setraidwarningrate(ctx: Context, count: int, seconds: int):
508
-#     """
509
-#     Command handler
510
-#     """
511
-#     gc: GuildContext = get_or_create_guild_context(ctx.guild)
512
-#     if gc is None:
513
-#         return
514
-#     await gc.command_setraidwarningrate(ctx, count, seconds)
515
-
516
-# @bot.command(
517
-#     brief='Sets the current channel as the destination for bot warning messages.'
518
-# )
519
-# @commands.has_permissions(manage_messages=True)
520
-# async def setwarningchannel(ctx: Context):
521
-#     """
522
-#     Command handler
523
-#     """
524
-#     gc: GuildContext = get_or_create_guild_context(ctx.guild)
525
-#     if gc is None:
526
-#         return
527
-#     await gc.command_setwarningchannel(ctx)
528
-
529
-# @bot.command(
530
-#     brief='Sets an optional mention to include in every warning message.',
531
-#     usage='<mention>',
532
-#     help="""The argument provided to this command will be included verbatim, so if the intent is
533
-#     to tag a user or role, the argument must be a tag, not merely the name of the user/role."""
534
-# )
535
-# @commands.has_permissions(manage_messages=True)
536
-# async def setwarningmention(ctx: Context, mention: str):
537
-#     """
538
-#     Command handler
539
-#     """
540
-#     gc: GuildContext = get_or_create_guild_context(ctx.guild)
541
-#     if gc is None:
542
-#         return
543
-#     await gc.command_setwarningmention(ctx, mention)
544
-
545
-# # -- Bot events -------------------------------------------------------------
546
-
547
-# is_connected = False
548
-# @bot.listen()
549
-# async def on_connect():
550
-#     """
551
-#     Discord event handler
552
-#     """
553
-#     global is_connected
554
-#     print('Connected')
555
-#     is_connected = True
556
-#     if is_connected and is_ready:
557
-#         await populate_guilds()
558
-
559
-# is_ready = False
560
-# @bot.listen()
561
-# async def on_ready():
562
-#     """
563
-#     Discord event handler
564
-#     """
565
-#     global is_ready
566
-#     print('Ready')
567
-#     is_ready = True
568
-#     if is_connected and is_ready:
569
-#         await populate_guilds()
570
-
571
-# async def populate_guilds():
572
-#     """
573
-#     Called after both on_ready and on_connect are done. May be called more than once!
574
-#     """
575
-#     for guild in bot.guilds:
576
-#         gc = guild_id_to_guild_context.get(guild.id)
577
-#         if gc is None:
578
-#             print(f'No GuildContext for {guild.id}')
579
-#             continue
580
-#         gc.guild = guild
581
-#         if gc.warning_channel_id is not None:
582
-#             gc.warning_channel = guild.get_channel(gc.warning_channel_id)
583
-#             if gc.warning_channel is not None:
584
-#                 print(f'Recovered warning channel {gc.warning_channel}')
585
-#             else:
586
-#                 print(f'Could not find channel with id {gc.warning_channel_id} in ' +
587
-#                     f'guild {guild.name}')
588
-#                 for channel in await guild.fetch_channels():
589
-#                     print(f'\t{channel.name} ({channel.id})')
590
-
591
-# @bot.listen()
592
-# async def on_member_join(member: Member) -> None:
593
-#     """
594
-#     Discord event handler
595
-#     """
596
-#     print(f'User {member.name} joined {member.guild.name}')
597
-#     gc: GuildContext = get_or_create_guild_context(member.guild)
598
-#     if gc is None:
599
-#         print(f'No GuildContext for guild {member.guild.name}')
600
-#         return
601
-#     await gc.handle_join(member)
602
-
603
-# @bot.listen()
604
-# async def on_member_remove(member: Member) -> None:
605
-#     """
606
-#     Discord event handler
607
-#     """
608
-#     print(f'User {member.name} left {member.guild.name}')
609
-
610
-# @bot.listen()
611
-# async def on_raw_reaction_add(payload: RawReactionActionEvent) -> None:
612
-#     """
613
-#     Discord event handler
614
-#     """
615
-#     guild: Guild = bot.get_guild(payload.guild_id)
616
-#     channel: GuildChannel = guild.get_channel(payload.channel_id)
617
-#     message: Message = await channel.fetch_message(payload.message_id)
618
-#     member: Member = payload.member
619
-#     emoji: PartialEmoji = payload.emoji
620
-#     gc: GuildContext = get_or_create_guild_context(guild)
621
-#     await gc.handle_reaction_add(message, member, emoji)
622
-
623
-# # -- Main -------------------------------------------------------------------
624
-
625
-# print('Starting bot')
626
-# bot.run(CONFIG['clientToken'])
627
-# print('Bot done')

+ 45
- 13
storage.py ファイルの表示

@@ -7,6 +7,7 @@ from config import CONFIG
7 7
 
8 8
 class StateKey:
9 9
 	WARNING_CHANNEL_ID = 'warning_channel_id'
10
+	WARNING_MENTION = 'warning_mention'
10 11
 
11 12
 class Storage:
12 13
 	"""
@@ -17,30 +18,60 @@ class Storage:
17 18
 	__guild_id_to_state = {}
18 19
 
19 20
 	@classmethod
20
-	def state_for_guild(cls, guild: Guild) -> dict:
21
+	def get_state(cls, guild: Guild) -> dict:
21 22
 		"""
22
-		Returns the state for the given guild, loading from disk if necessary.
23
-		Always returns a dict.
23
+		Returns all persisted state for the given guild.
24 24
 		"""
25 25
 		state: dict = cls.__guild_id_to_state.get(guild.id)
26 26
 		if state is not None:
27 27
 			# Already in memory
28 28
 			return state
29 29
 		# Load from disk if possible
30
-		cls.__trace(f'No loaded state for guild {guild.id}. Attempting to load from disk.')
30
+		cls.__trace(f'No loaded state for guild {guild.id}. Attempting to ' +
31
+			'load from disk.')
31 32
 		state = cls.__load_guild_state(guild)
33
+		if state is None:
34
+			return {}
32 35
 		cls.__guild_id_to_state[guild.id] = state
33 36
 		return state
34 37
 
35 38
 	@classmethod
36
-	def save_guild_state(cls, guild: Guild) -> None:
39
+	def get_state_value(cls, guild: Guild, key: str):
37 40
 		"""
38
-		Saves state for the given guild to disk, if any exists.
41
+		Returns a persisted state value stored under the given key. Returns
42
+		None if not present.
39 43
 		"""
40
-		state: dict = cls.__guild_id_to_state.get(guild.id)
41
-		if state is None:
42
-			# Nothing to save
44
+		return cls.get_state(guild).get(key)
45
+
46
+	@classmethod
47
+	def set_state_value(cls, guild: Guild, key: str, value) -> None:
48
+		"""
49
+		Adds the given key-value pair to the persisted state for the given
50
+		Guild. If `value` is `None` the key will be removed from persisted
51
+		state.
52
+		"""
53
+		cls.set_state_values(guild, { key: value })
54
+
55
+	@classmethod
56
+	def set_state_values(cls, guild: Guild, vars: dict) -> None:
57
+		"""
58
+		Merges the given `vars` dict with the saved state for the given guild
59
+		and saves it to disk. `vars` must be JSON-encodable or a ValueError will
60
+		be raised. Keys with associated values of `None` will be removed from the
61
+		state.
62
+		"""
63
+		if vars is None or len(vars) == 0:
43 64
 			return
65
+		state: dict = cls.get_state(guild)
66
+		try:
67
+			json.dumps(vars)
68
+		except:
69
+			raise ValueError(f'vars not JSON encodable - {vars}')
70
+		for key, value in vars.items():
71
+			if value is None:
72
+				del state[key]
73
+			else:
74
+				state[key] = value
44 75
 		cls.__save_guild_state(guild, state)
45 76
 
46 77
 	@classmethod
@@ -49,7 +80,8 @@ class Storage:
49 80
 		Saves state for a guild to a JSON file on disk.
50 81
 		"""
51 82
 		path: str = cls.__guild_path(guild)
52
-		cls.__trace('Saving state for guild {guild.id} to {path}')
83
+		cls.__trace(f'Saving state for guild {guild.id} to {path}')
84
+		cls.__trace(f'state = {state}')
53 85
 		with open(path, 'w') as file:
54 86
 			# Pretty printing to make more legible for debugging
55 87
 			# Sorting keys to help with diffs
@@ -59,12 +91,12 @@ class Storage:
59 91
 	@classmethod
60 92
 	def __load_guild_state(cls, guild: Guild) -> dict:
61 93
 		"""
62
-		Loads state for a guild from a JSON file on disk.
94
+		Loads state for a guild from a JSON file on disk, or None if not found.
63 95
 		"""
64 96
 		path: str = cls.__guild_path(guild)
65 97
 		if not exists(path):
66
-			cls.__trace(f'No state on disk for guild {guild.id}. Returning {{}}.')
67
-			return {}
98
+			cls.__trace(f'No state on disk for guild {guild.id}. Returning None.')
99
+			return None
68 100
 		cls.__trace(f'Loading state from disk for guild {guild.id}')
69 101
 		with open(path, 'r') as file:
70 102
 			state = json.load(file)

読み込み中…
キャンセル
保存