Sfoglia il codice sorgente

Breaking things out into cogs and helper classes

pull/1/head
Rocketsoup 4 anni fa
parent
commit
94670930d1
10 ha cambiato i file con 997 aggiunte e 745 eliminazioni
  1. 1
    1
      .gitignore
  2. 0
    0
      cogs/__init__.py
  3. 17
    0
      cogs/base.py
  4. 37
    0
      cogs/config.py
  5. 17
    0
      cogs/general.py
  6. 225
    0
      cogs/joinraid.py
  7. 1
    0
      config.py.sample
  8. 606
    744
      rocketbot.py
  9. 17
    0
      state/test.json
  10. 76
    0
      storage.py

+ 1
- 1
.gitignore Vedi File

@@ -118,5 +118,5 @@ dmypy.json
118 118
 .DS_Store
119 119
 
120 120
 # Contains secrets so exclude
121
-config.py
121
+/config.py
122 122
 rocketbot.db

+ 0
- 0
cogs/__init__.py Vedi File


+ 17
- 0
cogs/base.py Vedi File

@@ -0,0 +1,17 @@
1
+from discord import Guild
2
+from discord.ext import commands
3
+
4
+from storage import Storage
5
+
6
+class BaseCog(commands.Cog):
7
+	def __init__(self, bot):
8
+		self.bot = bot
9
+
10
+	def save_setting(guild: Guild, name: str, value):
11
+		state: dict = Storage.state_for_guild(guild)
12
+		state[name] = value
13
+		Storage.save_guild_state(guild)
14
+
15
+	def warn(guild: Guild):
16
+		state: dict = Storage.state_for_guild(guild)
17
+		

+ 37
- 0
cogs/config.py Vedi File

@@ -0,0 +1,37 @@
1
+from discord import Guild, TextChannel
2
+from discord.ext import commands
3
+from storage import Storage
4
+
5
+class ConfigCog(commands.Cog):
6
+	"""
7
+	Cog for handling general bot configuration.
8
+	"""
9
+	def __init__(self, bot):
10
+		self.bot = bot
11
+
12
+	@commands.group(
13
+		brief='Manages general bot configuration'
14
+	)
15
+	@commands.has_permissions(ban_members=True)
16
+	@commands.guild_only()
17
+	async def config(self, context: commands.Context):
18
+		'Command group'
19
+		if context.invoked_subcommand is None:
20
+			await context.send_help()
21
+
22
+	@config.command(
23
+		name='setoutputchannel',
24
+		brief='Sets the channel where mod communications occur',
25
+		description='Run this command in the channel where bot messages ' +
26
+			'intended for server moderators should be sent. Other bot messages ' +
27
+			'may still be posted in the channel a command was invoked in. If ' +
28
+			'no output channel is set, mod-related messages will not be posted!',
29
+	)
30
+	async def config_setoutputchannel(self, context: commands.Context):
31
+		'Command handler'
32
+		guild: Guild = context.guild
33
+		channel: TextChannel = context.channel
34
+		state: dict = Storage.state_for_guild(context.guild)
35
+		state['output_channel_id'] = context.channel.id
36
+		Storage.save_guild_state(context.guild)
37
+		await context.message.reply(f'Output channel set to {channel.name}', mention_author=False)

+ 17
- 0
cogs/general.py Vedi File

@@ -0,0 +1,17 @@
1
+from discord.ext import commands
2
+
3
+class GeneralCog(commands.Cog):
4
+	def __init__(self, bot: commands.Bot):
5
+		self.bot = bot
6
+		self.is_connected = False
7
+		self.is_ready = False
8
+
9
+	@commands.Cog.listener()
10
+	async def on_connect(self):
11
+		print('on_connect')
12
+		self.is_connected = True
13
+
14
+	@commands.Cog.listener()
15
+	async def on_ready(self):
16
+		print('on_ready')
17
+		self.is_ready = True

+ 225
- 0
cogs/joinraid.py Vedi File

@@ -0,0 +1,225 @@
1
+from discord import Guild, Intents, Member, Message, PartialEmoji, RawReactionActionEvent
2
+from discord.ext import commands
3
+from storage import Storage
4
+
5
+class JoinRecord:
6
+    """
7
+    Data object containing details about a guild join event.
8
+    """
9
+    def __init__(self, member: Member):
10
+        self.member = member
11
+        self.join_time = member.joined_at or datetime.now()
12
+        self.is_kicked = False
13
+        self.is_banned = False
14
+
15
+    def age_seconds(self, now: datetime) -> float:
16
+        """
17
+        Returns the age of this join in seconds from the given "now" time.
18
+        """
19
+        a = now - self.join_time
20
+        return float(a.total_seconds())
21
+
22
+class RaidPhase:
23
+    """
24
+    Enum of phases in a JoinRaid. Phases progress monotonically.
25
+    """
26
+    NONE = 0
27
+    JUST_STARTED = 1
28
+    CONTINUING = 2
29
+    ENDED = 3
30
+
31
+class JoinRaid:
32
+    """
33
+    Tracks recent joins to a guild to detect join raids, where a large number of automated users
34
+    all join at the same time.
35
+    """
36
+    def __init__(self):
37
+        self.joins = []
38
+        self.phase = RaidPhase.NONE
39
+        # datetime when the raid started, or None.
40
+        self.raid_start_time = None
41
+        # Message posted to Discord to warn of the raid. Convenience property managed
42
+        # by caller. Ignored by this class.
43
+        self.warning_message = None
44
+
45
+    def handle_join(self,
46
+            member: Member,
47
+            now: datetime,
48
+            max_age_seconds: float,
49
+            max_join_count: int) -> None:
50
+        """
51
+        Processes a new member join to a guild and detects join raids. Updates
52
+        self.phase and self.raid_start_time properties.
53
+        """
54
+        # Check for existing record for this user
55
+        print(f'handle_join({member.name}) start')
56
+        join: JoinRecord = None
57
+        i: int = 0
58
+        while i < len(self.joins):
59
+            elem = self.joins[i]
60
+            if elem.member.id == member.id:
61
+                print(f'Member {member.name} already in join list at index {i}. Removing.')
62
+                join = self.joins.pop(i)
63
+                join.join_time = now
64
+                break
65
+            i += 1
66
+        # Add new record to end
67
+        self.joins.append(join or JoinRecord(member))
68
+        # Check raid status and do upkeep
69
+        self.__process_joins(now, max_age_seconds, max_join_count)
70
+        print(f'handle_join({member.name}) end')
71
+
72
+    def __process_joins(self,
73
+            now: datetime,
74
+            max_age_seconds: float,
75
+            max_join_count: int) -> None:
76
+        """
77
+        Processes self.joins after each addition, detects raids, updates self.phase,
78
+        and throws out unneeded records.
79
+        """
80
+        print('__process_joins {')
81
+        i: int = 0
82
+        recent_count: int = 0
83
+        should_cull: bool = self.phase == RaidPhase.NONE
84
+        while i < len(self.joins):
85
+            join: JoinRecord = self.joins[i]
86
+            age: float = join.age_seconds(now)
87
+            is_old: bool = age > max_age_seconds
88
+            if not is_old:
89
+                recent_count += 1
90
+                print(f'- {i}. {join.member.name} is {age}s old - recent_count={recent_count}')
91
+            if is_old and should_cull:
92
+                self.joins.pop(i)
93
+                print(f'- {i}. {join.member.name} is {age}s old - too old, removing')
94
+            else:
95
+                print(f'- {i}. {join.member.name} is {age}s old - moving on to next')
96
+                i += 1
97
+        is_raid = recent_count > max_join_count
98
+        print(f'- is_raid {is_raid}')
99
+        if is_raid:
100
+            if self.phase == RaidPhase.NONE:
101
+                self.phase = RaidPhase.JUST_STARTED
102
+                self.raid_start_time = now
103
+                print('- Phase moved to JUST_STARTED. Recording raid start time.')
104
+            elif self.phase == RaidPhase.JUST_STARTED:
105
+                self.phase = RaidPhase.CONTINUING
106
+                print('- Phase moved to CONTINUING.')
107
+        elif self.phase == self.phase in (RaidPhase.JUST_STARTED, RaidPhase.CONTINUING):
108
+            self.phase = RaidPhase.ENDED
109
+            print('- Phase moved to ENDED.')
110
+
111
+        # Undo join add if the raid is over
112
+        if self.phase == RaidPhase.ENDED and len(self.joins) > 0:
113
+            last = self.joins.pop(-1)
114
+            print(f'- Popping last join for {last.member.name}')
115
+        print('} __process_joins')
116
+
117
+    async def kick_all(self,
118
+            reason: str = "Part of join raid") -> list[Member]:
119
+        """
120
+        Kicks all users in this join raid. Skips users who have already been
121
+        flagged as having been kicked or banned. Returns a List of Members
122
+        who were newly kicked.
123
+        """
124
+        kicks = []
125
+        for join in self.joins:
126
+            if join.is_kicked or join.is_banned:
127
+                continue
128
+            await join.member.kick(reason=reason)
129
+            join.is_kicked = True
130
+            kicks.append(join.member)
131
+        self.phase = RaidPhase.ENDED
132
+        return kicks
133
+
134
+    async def ban_all(self,
135
+            reason: str = "Part of join raid",
136
+            delete_message_days: int = 0) -> list[Member]:
137
+        """
138
+        Bans all users in this join raid. Skips users who have already been
139
+        flagged as having been banned. Users who were previously kicked can
140
+        still be banned. Returns a List of Members who were newly banned.
141
+        """
142
+        bans = []
143
+        for join in self.joins:
144
+            if join.is_banned:
145
+                continue
146
+            await join.member.ban(reason=reason, delete_message_days=delete_message_days)
147
+            join.is_banned = True
148
+            bans.append(join.member)
149
+        self.phase = RaidPhase.ENDED
150
+        return bans
151
+
152
+class JoinRaidCog(commands.Cog):
153
+	"""
154
+	Cog for monitoring member joins and detecting potential bot raids.
155
+	"""
156
+	def __init__(self, bot):
157
+		self.bot = bot
158
+
159
+	@commands.group(
160
+		brief='Manages join raid detection and handling',
161
+	)
162
+	@commands.has_permissions(ban_members=True)
163
+	@commands.guild_only()
164
+	async def joinraid(self, context: commands.Context):
165
+		'Command group'
166
+		if context.invoked_subcommand is None:
167
+			await context.send_help()
168
+
169
+	@joinraid.command(
170
+		name='enable',
171
+		brief='Enables join raid detection',
172
+		description='Join raid detection is off by default.',
173
+	)
174
+	async def joinraid_enable(self, context: commands.Context):
175
+		'Command handler'
176
+		# TODO
177
+		pass
178
+
179
+	@joinraid.command(
180
+		name='disable',
181
+		brief='Disables join raid detection',
182
+		description='Join raid detection is off by default.',
183
+	)
184
+	async def joinraid_disable(self, context: commands.Context):
185
+		'Command handler'
186
+		# TODO
187
+		pass
188
+
189
+	@joinraid.command(
190
+		name='setrate',
191
+		brief='Sets the rate of joins which triggers a warning to mods',
192
+		description='Each time a member joins, the join records from the ' +
193
+			'previous _x_ seconds are counted up, where _x_ is the number of ' +
194
+			'seconds configured by this command. If that count meets or ' +
195
+			'exceeds the maximum join count configured by this command then ' +
196
+			'a raid is detected and a warning is issued to the mods.',
197
+		usage='<join_count> <seconds>',
198
+	)
199
+	async def joinraid_setrate(self, context: commands.Context,
200
+			joinCount: int,
201
+			seconds: int):
202
+		'Command handler'
203
+		# TODO
204
+		pass
205
+
206
+	@joinraid.command(
207
+		name='getrate',
208
+		brief='Shows the rate of joins which triggers a warning to mods',
209
+	)
210
+	async def joinraid_getrate(self, context: commands.Context):
211
+		'Command handler'
212
+		# TODO
213
+		pass
214
+
215
+	@commands.Cog.listener()
216
+	async def on_raw_reaction_add(self, payload: RawReactionActionEvent):
217
+		'Event handler'
218
+		# TODO
219
+		pass
220
+
221
+	@commands.Cog.listener()
222
+	async def on_member_join(self, member: Member) -> None:
223
+		'Event handler'
224
+		# TODO
225
+		pass

+ 1
- 0
config.py.sample Vedi File

@@ -12,4 +12,5 @@ CONFIG = {
12 12
 	'kickEmoji': '👢',
13 13
 	'banEmojiName': 'no_entry_sign',
14 14
 	'banEmoji': '🚫',
15
+	'statePath': 'state/',
15 16
 }

+ 606
- 744
rocketbot.py
File diff soppresso perché troppo grande
Vedi File


+ 17
- 0
state/test.json Vedi File

@@ -0,0 +1,17 @@
1
+{
2
+	"alpha": 123,
3
+	"bravo": 3.1415926535,
4
+	"charlie": true,
5
+	"delta": "Some text",
6
+	"echo": null,
7
+	"foxtrot": [
8
+		1,
9
+		2,
10
+		3
11
+	],
12
+	"golf": {
13
+		"a": 1,
14
+		"b": 2,
15
+		"c": 3
16
+	}
17
+}

+ 76
- 0
storage.py Vedi File

@@ -0,0 +1,76 @@
1
+import json
2
+from os.path import exists
3
+
4
+from discord import Guild
5
+
6
+from config import CONFIG
7
+
8
+class Storage:
9
+	"""
10
+	Static class for managing persisted bot state.
11
+	"""
12
+
13
+	# discord.Guild.id -> dict
14
+	__guild_id_to_state = {}
15
+
16
+	def state_for_guild(guild: Guild) -> dict:
17
+		"""
18
+		Returns the state for the given guild, loading from disk if necessary.
19
+		Always returns a dict.
20
+		"""
21
+		state: dict = __guild_id_to_state[guild.id]
22
+		if state is not None:
23
+			# Already in memory
24
+			return state
25
+		# Load from disk if possible
26
+		__trace(f'No loaded state for guild {guild.id}. Attempting to load from disk.')
27
+		state = __load_guild_state(guild)
28
+		__guild_id_to_state[guild.id] = state
29
+		return state
30
+
31
+	def save_guild_state(guild: Guild) -> NoneType:
32
+		"""
33
+		Saves state for the given guild to disk, if any exists.
34
+		"""
35
+		state: dict = __guild_id_to_state[guild.id]
36
+		if state is None:
37
+			# Nothing to save
38
+			return
39
+		__save_guild_state(guild, state)
40
+
41
+	def __save_guild_state(guild: Guild, state: dict) -> NoneType:
42
+		"""
43
+		Saves state for a guild to a JSON file on disk.
44
+		"""
45
+		path: str = __guild_path(guild)
46
+		__trace('Saving state for guild {guild.id} to {path}')
47
+		with open(path, 'w') as file:
48
+			# Pretty printing to make more legible for debugging
49
+			# Sorting keys to help with diffs
50
+			json.dump(state, file, indent='\t', sort_keys=True)
51
+		__trace('State saved')
52
+
53
+	def __load_guild_state(guild: Guild) -> dict:
54
+		"""
55
+		Loads state for a guild from a JSON file on disk.
56
+		"""
57
+		path: str = __guild_path(guild)
58
+		if not exists(path):
59
+			__trace(f'No state on disk for guild {guild.id}. Returning {{}}.')
60
+			return {}
61
+		__trace('Loading state from disk for guild {guild.id}')
62
+		with open(path, 'r') as file:
63
+			state = json.load(file)
64
+		__trace('State loaded')
65
+		return state
66
+
67
+	def __guild_path(guild: Guild) -> str:
68
+		"""
69
+		Returns the JSON file path where guild state should be written.
70
+		"""
71
+		config_value: str = CONFIG['statePath']
72
+		path: str = config_value if config_value.endswith('/') else f'{config_value}/'
73
+		return f'{path}guild_{guild.id}.json'
74
+
75
+	def __trace(message: str) -> NoneType:
76
+		print(f'{Storage.__name__}: {message}')

Loading…
Annulla
Salva