from discord import Guild, Intents, Member, Message, PartialEmoji, RawReactionActionEvent from discord.ext import commands from storage import Storage from cogs.basecog import BaseCog from config import CONFIG from datetime import datetime class JoinRecord: """ Data object containing details about a single guild join event. """ def __init__(self, member: Member): self.member = member self.join_time = member.joined_at or datetime.now() # These flags only track whether this bot has kicked/banned self.is_kicked = False self.is_banned = False def age_seconds(self, now: datetime) -> float: """ Returns the age of this join in seconds from the given "now" time. """ a = now - self.join_time return float(a.total_seconds()) class RaidPhase: """ Enum of phases in a JoinRaidRecord. Phases progress monotonically. """ NONE = 0 JUST_STARTED = 1 CONTINUING = 2 ENDED = 3 class JoinRaidRecord: """ Tracks recent joins to a guild to detect join raids, where a large number of automated users all join at the same time. Manages list of joins to not grow unbounded. """ def __init__(self): self.joins = [] self.phase = RaidPhase.NONE # datetime when the raid started, or None. self.raid_start_time = None # Message posted to Discord to warn of the raid. Convenience property # managed by caller. Ignored by this class. self.warning_message = None def handle_join(self, member: Member, now: datetime, max_join_count: int, max_age_seconds: float) -> None: """ Processes a new member join to a guild and detects join raids. Updates self.phase and self.raid_start_time properties. """ # Check for existing record for this user print(f'handle_join({member.name}) start') join: JoinRecord = None i: int = 0 while i < len(self.joins): elem = self.joins[i] if elem.member.id == member.id: print(f'Member {member.name} already in join list at index {i}. Removing.') join = self.joins.pop(i) join.join_time = now break i += 1 # Add new record to end self.joins.append(join or JoinRecord(member)) # Check raid status and do upkeep self.__process_joins(now, max_age_seconds, max_join_count) print(f'handle_join({member.name}) end') def __process_joins(self, now: datetime, max_age_seconds: float, max_join_count: int) -> None: """ Processes self.joins after each addition, detects raids, updates self.phase, and throws out unneeded records. """ print('__process_joins {') i: int = 0 recent_count: int = 0 should_cull: bool = self.phase == RaidPhase.NONE while i < len(self.joins): join: JoinRecord = self.joins[i] age: float = join.age_seconds(now) is_old: bool = age > max_age_seconds if not is_old: recent_count += 1 print(f'- {i}. {join.member.name} is {age}s old - recent_count={recent_count}') if is_old and should_cull: self.joins.pop(i) print(f'- {i}. {join.member.name} is {age}s old - too old, removing') else: print(f'- {i}. {join.member.name} is {age}s old - moving on to next') i += 1 is_raid = recent_count > max_join_count print(f'- is_raid {is_raid}') if is_raid: if self.phase == RaidPhase.NONE: self.phase = RaidPhase.JUST_STARTED self.raid_start_time = now print('- Phase moved to JUST_STARTED. Recording raid start time.') elif self.phase == RaidPhase.JUST_STARTED: self.phase = RaidPhase.CONTINUING print('- Phase moved to CONTINUING.') elif self.phase == self.phase in (RaidPhase.JUST_STARTED, RaidPhase.CONTINUING): self.phase = RaidPhase.ENDED print('- Phase moved to ENDED.') # Undo join add if the raid is over if self.phase == RaidPhase.ENDED and len(self.joins) > 0: last = self.joins.pop(-1) print(f'- Popping last join for {last.member.name}') print('} __process_joins') async def kick_all(self, reason: str = "Part of join raid") -> list[Member]: """ Kicks all users in this join raid. Skips users who have already been flagged as having been kicked or banned. Returns a List of Members who were newly kicked. """ kicks = [] for join in self.joins: if join.is_kicked or join.is_banned: continue await join.member.kick(reason=reason) join.is_kicked = True kicks.append(join.member) self.phase = RaidPhase.ENDED return kicks async def ban_all(self, reason: str = "Part of join raid", delete_message_days: int = 0) -> list[Member]: """ Bans all users in this join raid. Skips users who have already been flagged as having been banned. Users who were previously kicked can still be banned. Returns a List of Members who were newly banned. """ bans = [] for join in self.joins: if join.is_banned: continue await join.member.ban( reason=reason, delete_message_days=delete_message_days) join.is_banned = True bans.append(join.member) self.phase = RaidPhase.ENDED return bans class GuildContext: """ Logic and state for a single guild serviced by the bot. """ def __init__(self, guild_id: int): self.guild_id = guild_id self.join_warning_count = CONFIG['joinWarningCount'] self.join_warning_seconds = CONFIG['joinWarningSeconds'] # Non-persisted runtime state self.current_raid = JoinRaidRecord() self.all_raids = [ self.current_raid ] # periodically culled of old ones # Events async def handle_join(self, member: Member) -> None: """ Event handler for all joins to this guild. """ now = member.joined_at raid = self.current_raid raid.handle_join( member, now=now, max_age_seconds = self.join_warning_seconds, max_join_count = self.join_warning_count) self.__trace(f'raid phase: {raid.phase}') if raid.phase == RaidPhase.JUST_STARTED: await self.__on_join_raid_begin(raid) elif raid.phase == RaidPhase.CONTINUING: await self.__on_join_raid_updated(raid) elif raid.phase == RaidPhase.ENDED: self.__start_new_raid(member) await self.__on_join_raid_end(raid) self.__cull_old_raids(now) def reset_raid(self, now: datetime): """ Retires self.current_raid and creates a new empty one. """ self.current_raid = JoinRaidRecord() self.all_raids.append(self.current_raid) self.__cull_old_raids(now) def find_raid_for_message_id(self, message_id: int) -> JoinRaidRecord: """ Retrieves a JoinRaidRecord instance for the given raid warning message. Returns None if not found. """ for raid in self.all_raids: if raid.warning_message.id == message_id: return raid return None def __cull_old_raids(self, now: datetime): """ Gets rid of old JoinRaidRecord records from self.all_raids that are too old to still be useful. """ i: int = 0 while i < len(self.all_raids): raid = self.all_raids[i] if raid == self.current_raid: i += 1 continue age_seconds = float((raid.raid_start_time - now).total_seconds()) if age_seconds > 86400.0: self.__trace('Culling old raid') self.all_raids.pop(i) else: i += 1 def __trace(self, message): """ Debugging trace. """ print(f'{self.guild_id}: {message}') class JoinRaidCog(BaseCog): """ Cog for monitoring member joins and detecting potential bot raids. """ MIN_JOIN_COUNT = 2 STATE_KEY_RAID_COUNT = 'joinraid_count' STATE_KEY_RAID_SECONDS = 'joinraid_seconds' STATE_KEY_ENABLED = 'joinraid_enabled' def __init__(self, bot): super().__init__(bot) self.guild_id_to_context = {} # Guild.id -> GuildContext # -- Config ------------------------------------------------------------- def __get_raid_rate(self, guild: Guild) -> tuple: """ Returns the join rate configured for this guild. """ count: int = Storage.get_state_value(guild, self.STATE_KEY_RAID_COUNT) \ or CONFIG['joinWarningCount'] seconds: float = Storage.get_state_value(guild, self.STATE_KEY_RAID_SECONDS) \ or CONFIG['joinWarningSeconds'] return (count, seconds) def __is_enabled(self, guild: Guild) -> bool: """ Returns whether join raid detection is enabled in this guild. """ return Storage.get_state_value(guild, self.STATE_KEY_ENABLED) or False # -- Commands ----------------------------------------------------------- @commands.group( brief='Manages join raid detection and handling', ) @commands.has_permissions(ban_members=True) @commands.guild_only() async def joinraid(self, context: commands.Context): 'Command group' if context.invoked_subcommand is None: await context.send_help() @joinraid.command( name='enable', brief='Enables join raid detection', description='Join raid detection is off by default.', ) async def joinraid_enable(self, context: commands.Context): 'Command handler' guild = context.guild Storage.set_state_value(guild, self.STATE_KEY_ENABLED, True) # TODO: Startup tracking if necessary await context.message.reply( '✅ ' + self.__describe_raid_settings(guild, force_enabled_status=True), mention_author=False) @joinraid.command( name='disable', brief='Disables join raid detection', description='Join raid detection is off by default.', ) async def joinraid_disable(self, context: commands.Context): 'Command handler' guild = context.guild Storage.set_state_value(guild, self.STATE_KEY_ENABLED, False) # TODO: Tear down tracking if necessary await context.message.reply( '✅ ' + self.__describe_raid_settings(guild, force_enabled_status=True), mention_author=False) @joinraid.command( name='setrate', brief='Sets the rate of joins which triggers a warning to mods', description='Each time a member joins, the join records from the ' + 'previous _x_ seconds are counted up, where _x_ is the number of ' + 'seconds configured by this command. If that count meets or ' + 'exceeds the maximum join count configured by this command then ' + 'a raid is detected and a warning is issued to the mods.', usage=' ', ) async def joinraid_setrate(self, context: commands.Context, join_count: int, seconds: float): 'Command handler' guild = context.guild if join_count < self.MIN_JOIN_COUNT: await context.message.reply( f'⚠️ `join_count` must be >= {self.MIN_JOIN_COUNT}', mention_author=False) return if seconds <= 0: await context.message.reply( f'⚠️ `seconds` must be > 0', mention_author=False) return Storage.set_state_values(guild, { self.STATE_KEY_RAID_COUNT: join_count, self.STATE_KEY_RAID_SECONDS: seconds, }) await context.message.reply( '✅ ' + self.__describe_raid_settings(guild, force_rate_status=True), mention_author=False) @joinraid.command( name='getrate', brief='Shows the rate of joins which triggers a warning to mods', ) async def joinraid_getrate(self, context: commands.Context): 'Command handler' await context.message.reply( 'ℹ️ ' + self.__describe_raid_settings(context.guild, force_rate_status=True), mention_author=False) # -- Listeners ---------------------------------------------------------- @commands.Cog.listener() async def on_raw_reaction_add(self, payload: RawReactionActionEvent): 'Event handler' if payload.user_id == self.bot.user.id: # Ignore bot's own reactions return member: Member = payload.member if member is None: return guild: Guild = self.bot.get_guild(payload.guild_id) if guild is None: # Possibly a DM return channel: GuildChannel = guild.get_channel(payload.channel_id) if channel is None: # Possibly a DM return message: Message = await channel.fetch_message(payload.message_id) if message is None: # Message deleted? return if message.author.id != self.bot.user.id: # Bot didn't author this return if not member.permissions_in(channel).ban_members: # Not a mod # TODO: Remove reaction? return gc: GuildContext = self.__get_guild_context(guild) raid: JoinRaidRecord = gc.find_raid_for_message_id(payload.message_id) if raid is None: # Either not a warning message or one we stopped tracking return emoji: PartialEmoji = payload.emoji if emoji.name == CONFIG['kickEmoji']: await raid.kick_all() gc.reset_raid(message.created_at) await self.__update_raid_warning(guild, raid) elif emoji.name == CONFIG['banEmoji']: await raid.ban_all() gc.reset_raid(message.created_at) await self.__update_raid_warning(guild, raid) @commands.Cog.listener() async def on_member_join(self, member: Member) -> None: 'Event handler' guild: Guild = member.guild if not self.__is_enabled(guild): return (count, seconds) = self.__get_raid_rate(guild) now = member.joined_at gc: GuildContext = self.__get_guild_context(guild) raid: JoinRaidRecord = gc.current_raid raid.handle_join(member, now, count, seconds) if raid.phase == RaidPhase.JUST_STARTED: await self.__post_raid_warning(guild, raid) elif raid.phase == RaidPhase.CONTINUING: await self.__update_raid_warning(guild, raid) elif raid.phase == RaidPhase.ENDED: # First join that occurred too late to be part of last raid. Join # not added. Start a new raid record and add it there. gc.reset_raid(now) gc.current_raid.handle_join(member, now, count, seconds) # -- Misc --------------------------------------------------------------- def __describe_raid_settings(self, guild: Guild, force_enabled_status=False, force_rate_status=False) -> str: """ Creates a Discord message describing the current join raid settings. """ enabled = self.__is_enabled(guild) (count, seconds) = self.__get_raid_rate(guild) sentences = [] if enabled or force_rate_status: sentences.append(f'Join raids will be detected at {count} or more joins per {seconds} seconds.') if enabled and force_enabled_status: sentences.append('Raid detection enabled.') elif not enabled: sentences.append('Raid detection disabled.') tips = [] if enabled or force_rate_status: tips.append('• Use `setrate` subcommand to change detection threshold') if enabled: tips.append('• Use `disable` subcommand to disable detection.') else: tips.append('• Use `enable` subcommand to enable detection.') message = '' message += ' '.join(sentences) if len(tips) > 0: message += '\n\n' + ('\n'.join(tips)) return message def __get_guild_context(self, guild: Guild) -> GuildContext: """ Looks up the GuildContext for the given Guild or creates a new one if one does not yet exist. """ gc: GuildContext = self.guild_id_to_context.get(guild.id) if gc is not None: return gc gc = GuildContext(guild.id) self.guild_id_to_context[guild.id] = gc return gc async def __post_raid_warning(self, guild: Guild, raid: JoinRaidRecord) -> None: """ Posts a warning message about the given raid. """ (message, can_kick, can_ban) = self.__describe_raid(raid) raid.warning_message = await self.warn(guild, message) if can_kick: await raid.warning_message.add_reaction(CONFIG['kickEmoji']) if can_ban: await raid.warning_message.add_reaction(CONFIG['banEmoji']) async def __update_raid_warning(self, guild: Guild, raid: JoinRaidRecord) -> None: """ Updates the existing warning message for a raid. """ if raid.warning_message is None: return (message, can_kick, can_ban) = self.__describe_raid(raid) await self.update_warn(raid.warning_message, message) if not can_kick: await raid.warning_message.clear_reaction(CONFIG['kickEmoji']) if not can_ban: await raid.warning_message.clear_reaction(CONFIG['banEmoji']) def __describe_raid(self, raid: JoinRaidRecord) -> tuple: """ Creates a Discord warning message with details about the given raid. Returns a tuple containing the message text, a flag if any users can still be kicked, and a flag if anyone can still be banned. """ message = '🚨 **JOIN RAID DETECTED** 🚨' message += '\nThe following members joined in close succession:\n' any_kickable = False any_bannable = False for join in raid.joins: message += '\n• ' if join.is_banned: message += '~~' + join.member.mention + '~~ - banned' elif join.is_kicked: message += '~~' + join.member.mention + '~~ - kicked' any_bannable = True else: message += join.member.mention any_bannable = True any_kickable = True message += '\n_(list updates automatically)_' message += '\n' if any_kickable: message += f'\nReact to this message with {CONFIG["kickEmoji"]} to kick all these users.' else: message += '\nNo users left to kick.' if any_bannable: message += f'\nReact to this message with {CONFIG["banEmoji"]} to ban all these users.' else: message += '\nNo users left to ban.' return (message, any_kickable, any_bannable)