""" Rocketbot Discord bot. Relies on a configured config.py (copy config.py.sample for a template) and the sqlite database rocketbot.db (copy rocketbot.db.sample for a blank database). Author: Ian Albert (@rocketsoup) Date: 2021-11-11 """ from datetime import datetime import sqlite3 import sys from discord import Guild, Intents, Member, Message, PartialEmoji, RawReactionActionEvent from discord.abc import GuildChannel from discord.ext import commands from discord.ext.commands.context import Context from config import CONFIG if sys.version_info.major < 3: raise Exception('Requires Python 3+') # -- Classes ---------------------------------------------------------------- class RaidPhase: """ Enum of phases in a JoinRaid. Phases progress monotonically. """ NONE = 0 JUST_STARTED = 1 CONTINUING = 2 ENDED = 3 class JoinRaid: """ Tracks recent joins to a guild to detect join raids, where a large number of automated users all join at the same time. """ def __init__(self): self.joins = [] self.phase = RaidPhase.NONE # datetime when the raid started, or None. self.raid_start_time = None # Message posted to Discord to warn of the raid. Convenience property managed # by caller. Ignored by this class. self.warning_message = None def handle_join(self, member: Member, now: datetime, max_age_seconds: float, max_join_count: int) -> None: """ Processes a new member join to a guild and detects join raids. Updates self.phase and self.raid_start_time properties. """ # Check for existing record for this user print(f'handle_join({member.name}) start') join: JoinRecord = None i: int = 0 while i < len(self.joins): elem = self.joins[i] if elem.member.id == member.id: print(f'Member {member.name} already in join list at index {i}. Removing.') join = self.joins.pop(i) join.join_time = now break i += 1 # Add new record to end self.joins.append(join or JoinRecord(member)) # Check raid status and do upkeep self.__process_joins(now, max_age_seconds, max_join_count) print(f'handle_join({member.name}) end') def __process_joins(self, now: datetime, max_age_seconds: float, max_join_count: int) -> None: """ Processes self.joins after each addition, detects raids, updates self.phase, and throws out unneeded records. """ print('__process_joins {') i: int = 0 recent_count: int = 0 should_cull: bool = self.phase == RaidPhase.NONE while i < len(self.joins): join: JoinRecord = self.joins[i] age: float = join.age_seconds(now) is_old: bool = age > max_age_seconds if not is_old: recent_count += 1 print(f'- {i}. {join.member.name} is {age}s old - recent_count={recent_count}') if is_old and should_cull: self.joins.pop(i) print(f'- {i}. {join.member.name} is {age}s old - too old, removing') else: print(f'- {i}. {join.member.name} is {age}s old - moving on to next') i += 1 is_raid = recent_count > max_join_count print(f'- is_raid {is_raid}') if is_raid: if self.phase == RaidPhase.NONE: self.phase = RaidPhase.JUST_STARTED self.raid_start_time = now print('- Phase moved to JUST_STARTED. Recording raid start time.') elif self.phase == RaidPhase.JUST_STARTED: self.phase = RaidPhase.CONTINUING print('- Phase moved to CONTINUING.') elif self.phase == self.phase in (RaidPhase.JUST_STARTED, RaidPhase.CONTINUING): self.phase = RaidPhase.ENDED print('- Phase moved to ENDED.') # Undo join add if the raid is over if self.phase == RaidPhase.ENDED and len(self.joins) > 0: last = self.joins.pop(-1) print(f'- Popping last join for {last.member.name}') print('} __process_joins') async def kick_all(self, reason: str = "Part of join raid") -> list[Member]: """ Kicks all users in this join raid. Skips users who have already been flagged as having been kicked or banned. Returns a List of Members who were newly kicked. """ kicks = [] for join in self.joins: if join.is_kicked or join.is_banned: continue await join.member.kick(reason=reason) join.is_kicked = True kicks.append(join.member) self.phase = RaidPhase.ENDED return kicks async def ban_all(self, reason: str = "Part of join raid", delete_message_days: int = 0) -> list[Member]: """ Bans all users in this join raid. Skips users who have already been flagged as having been banned. Users who were previously kicked can still be banned. Returns a List of Members who were newly banned. """ bans = [] for join in self.joins: if join.is_banned: continue await join.member.ban(reason=reason, delete_message_days=delete_message_days) join.is_banned = True bans.append(join.member) self.phase = RaidPhase.ENDED return bans class JoinRecord: """ Data object containing details about a guild join event. """ def __init__(self, member: Member): self.member = member self.join_time = member.joined_at or datetime.now() self.is_kicked = False self.is_banned = False def age_seconds(self, now: datetime) -> float: """ Returns the age of this join in seconds from the given "now" time. """ a = now - self.join_time return float(a.total_seconds()) class GuildContext: """ Logic and state for a single guild serviced by the bot. """ def __init__(self, guild_id: int): self.guild_id = guild_id self.guild = None # Resolved later # Config populated during load self.warning_channel_id = None self.warning_channel = None self.warning_mention = None self.join_warning_count = CONFIG['joinWarningCount'] self.join_warning_seconds = CONFIG['joinWarningSeconds'] # Non-persisted runtime state self.current_raid = JoinRaid() self.all_raids = [ self.current_raid ] # periodically culled of old ones # Commands async def command_hello(self, message: Message) -> None: """ Command handler """ await message.channel.send(f'Hey there, {message.author.mention}!') async def command_testwarn(self, context: Context) -> None: """ Command handler """ if self.warning_channel is None: self.__trace('No warning channel set!') await context.message.channel.send('No warning channel set on this guild! Type ' + f'`{bot.command_prefix}{setwarningchannel.__name__}` in the channel you ' + 'want warnings to be posted.') return await self.__warn('Test warning. This is only a test.') async def command_setwarningchannel(self, context: Context): """ Command handler """ self.__trace(f'Warning channel set to {context.channel.name}') self.warning_channel = context.channel self.warning_channel_id = context.channel.id save_guild_context(self) await self.__warn('Warning messages will now be sent to ' + self.warning_channel.mention) async def command_setwarningmention(self, _context: Context, mention: str): """ Command handler """ self.__trace('set warning mention') m = mention if mention is not None and len(mention) > 0 else None self.warning_mention = m save_guild_context(self) if m is None: await self.__warn('Warning messages will not mention anyone') else: await self.__warn('Warning messages will now mention ' + m) async def command_setraidwarningrate(self, _context: Context, count: int, seconds: int): """ Command handler """ self.join_warning_count = count self.join_warning_seconds = seconds save_guild_context(self) await self.__warn(f'Maximum join rate set to {count} joins per {seconds} seconds') # Events async def handle_join(self, member: Member) -> None: """ Event handler for all joins to this guild. """ print(f'{member.guild.name}: {member.name} joined') now = member.joined_at raid = self.current_raid raid.handle_join( member, now=now, max_age_seconds = self.join_warning_seconds, max_join_count = self.join_warning_count) self.__trace(f'raid phase: {raid.phase}') if raid.phase == RaidPhase.JUST_STARTED: await self.__on_join_raid_begin(raid) elif raid.phase == RaidPhase.CONTINUING: await self.__on_join_raid_updated(raid) elif raid.phase == RaidPhase.ENDED: self.__start_new_raid(member) await self.__on_join_raid_end(raid) self.__cull_old_raids(now) def __start_new_raid(self, member: Member = None): """ Retires self.current_raid and creates a new empty one. If `member` is passed, it will be added to the new self.current_raid after it is created. """ self.current_raid = JoinRaid() self.all_raids.append(self.current_raid) if member is not None: self.current_raid.handle_join( member, member.joined_at, max_age_seconds = self.join_warning_seconds, max_join_count = self.join_warning_count) async def handle_reaction_add(self, message, member, emoji): """ Handles all message reaction events to see if they need to be acted on. """ if member.id == bot.user.id: # It's-a me, Rocketbot! return if message.author.id != bot.user.id: # The message the user is reacting to wasn't authored by me. Ignore. return self.__trace(f'User {member} added emoji {emoji}') if not member.permissions_in(message.channel).ban_members: self.__trace('Reactor does not have ban permissions. Ignoring.') return if emoji.name == CONFIG['kickEmoji']: await self.__kick_all_in_raid_message(message) elif emoji.name == CONFIG['banEmoji']: await self.__ban_all_in_raid_message(message) else: print('Unhandled emoji. Ignoring.') return async def __kick_all_in_raid_message(self, message: Message): """ Kicks all the users mentioned in the given raid warning message. Users who were already kicked or banned will be skipped. """ raid = self.__find_raid_for_message(message) if raid is None: await message.reply("This is either not a raid warning or it's too old and I don't " + "have a record for it anymore. Sorry!") return self.__trace('Kicking...') members = await raid.kick_all() msg = 'Kicked these members:' for member in members: msg += f'\n\t{member.name}' if len(members) == 0: msg += '\n\t-none-' self.__trace(msg) self.__start_new_raid() await self.__update_join_raid_message(raid) async def __ban_all_in_raid_message(self, message: Message): """ Bans all the users mentioned in the given raid warning message. Users who were already banned will be skipped. """ raid = self.__find_raid_for_message(message) if raid is None: await message.reply("This is either not a raid warning or it's too old and I don't " + "have a record for it anymore. Sorry!") return self.__trace('Banning...') members = await raid.ban_all() msg = 'Banned these members:' for member in members: msg += f'\n\t{member.name}' if len(members) == 0: msg += '\n\t-none-' self.__trace(msg) self.__start_new_raid() await self.__update_join_raid_message(raid) def __find_raid_for_message(self, message: Message) -> JoinRaid: """ Retrieves a JoinRaid instance for the given raid warning message. Returns None if not found. """ for raid in self.all_raids: if raid.warning_message.id == message.id: return raid return None def __cull_old_raids(self, now: datetime): """ Gets rid of old JoinRaid records from self.all_raids that are too old to still be useful. """ i: int = 0 while i < len(self.all_raids): raid = self.all_raids[i] if raid == self.current_raid: i += 1 continue age_seconds = float((raid.raid_start_time - now).total_seconds()) if age_seconds > 86400.0: self.__trace('Culling old raid') self.all_raids.pop(i) else: i += 1 def __join_raid_message(self, raid: JoinRaid): """ Returns a 3-element tuple containing a text message appropriate for posting in Discord, a flag of whether any of the mentioned users can be kicked, and a flag of whether any of the mentioned users can be banned. """ message = '' if self.warning_mention is not None: message = self.warning_mention + ' ' message += '**RAID JOIN DETECTED!** It includes these users:\n' can_kick = False can_ban = False for join in raid.joins: message += '\n• ' if join.is_banned: message += '~~' + join.member.mention + '~~ - banned' elif join.is_kicked: message += '~~' + join.member.mention + '~~ - kicked' can_ban = True else: message += join.member.mention can_kick = True can_ban = True message += '\n' if can_kick: message += '\nTo kick all these users, react with :' + CONFIG['kickEmojiName'] + ':' else: message += '\nNo kickable users remain' if can_ban: message += '\nTo ban all these users, react with :' + CONFIG['banEmojiName'] + ':' else: message += '\nNo bannable users remain' return (message, can_kick, can_ban) async def __update_join_raid_message(self, raid: JoinRaid): """ Updates an existing join raid warning message with updated data. """ if raid.warning_message is None: self.__trace('No raid warning message to update') return (message, can_kick, can_ban) = self.__join_raid_message(raid) await raid.warning_message.edit(content=message) if not can_kick: await raid.warning_message.clear_reaction(CONFIG['kickEmoji']) if not can_ban: await raid.warning_message.clear_reaction(CONFIG['banEmoji']) async def __on_join_raid_begin(self, raid): """ Event triggered when the first member joins that triggers the raid detection. """ self.__trace('A join raid has begun!') if self.warning_channel is None: self.__trace('NO WARNING CHANNEL SET') return (message, can_kick, can_ban) = self.__join_raid_message(raid) raid.warning_message = await self.warning_channel.send(message) if can_kick: await raid.warning_message.add_reaction(CONFIG['kickEmoji']) if can_ban: await raid.warning_message.add_reaction(CONFIG['banEmoji']) async def __on_join_raid_updated(self, raid): """ Event triggered for each subsequent member join after the first one that triggered the raid detection. """ self.__trace('Join raid still occurring') await self.__update_join_raid_message(raid) async def __on_join_raid_end(self, _raid): """ Event triggered when the first member joins who is not part of the most recent raid. """ self.__trace('Join raid has ended') async def __warn(self, message): """ Posts a warning message in the configured warning channel. """ if self.warning_channel is None: self.__trace('NO WARNING CHANNEL SET. Warning message not posted.\n' + message) return None m = message if self.warning_mention is not None: m = self.warning_mention + ' ' + m return await self.warning_channel.send(m) def __trace(self, message): """ Debugging trace. """ print(f'{self.guild.name}: {message}') # lookup for int(Guild.guild_id) --> GuildContext guild_id_to_guild_context = {} def get_or_create_guild_context(val, save=True): """ Retrieves a cached GuildContext instance by its Guild id or Guild object itself. If no GuildContext record exists for the Guild, one is created and cached (and saved to the database unless `save=False`). """ gid = None guild = None if val is None: return None if isinstance(val, int): gid = val elif isinstance(val, Guild): gid = val.id guild = val if gid is None: print('Unhandled datatype', type(val)) return None looked_up = guild_id_to_guild_context.get(gid) if looked_up is not None: return looked_up gc = GuildContext(gid) gc.guild = guild or gc.guild guild_id_to_guild_context[gid] = gc if save: save_guild_context(gc) return gc # -- Database --------------------------------------------------------------- def run_sql_batch(batch_function): """ Performs an SQL transaction. After a connection is opened, the passed function is invoked with the sqlite3.Connection and sqlite3.Cursor passed as arguments. Once the passed function finishes, the connection is closed. """ db_connection: sqlite3.Connection = sqlite3.connect('rocketbot.db') db_cursor: sqlite3.Cursor = db_connection.cursor() batch_function(db_connection, db_cursor) db_connection.commit() db_connection.close() def load_guild_settings(): """ Populates the GuildContext cache with records from the database. """ def load(_con, cur): """ SQL """ for row in cur.execute("""SELECT * FROM guilds"""): guild_id = row[0] gc = get_or_create_guild_context(guild_id, save=False) gc.warning_channel_id = row[1] gc.warning_mention = row[2] gc.join_warning_count = row[3] or CONFIG['joinWarningCount'] gc.join_warning_seconds = row[4] or CONFIG['joinWarningSeconds'] print(f'Guild {guild_id} channel id is {gc.warning_channel_id}') run_sql_batch(load) def create_tables(): """ Creates all database tables. """ def make_tables(_con, cur): """ SQL """ cur.execute("""CREATE TABLE guilds ( guildId INTEGER, warningChannelId INTEGER, warningMention TEXT, joinWarningCount INTEGER, joinWarningSeconds INTEGER, PRIMARY KEY(guildId ASC))""") run_sql_batch(make_tables) def save_guild_context(gc: GuildContext): """ Saves the state of a GuildContext record to the database. """ def save(_con, cur): """ SQL """ print(f'Saving guild context with id {gc.guild_id}') cur.execute(""" SELECT guildId FROM guilds WHERE guildId=? """, ( gc.guild_id, )) channel_id = gc.warning_channel.id if gc.warning_channel is not None \ else gc.warning_channel_id exists = cur.fetchone() is not None if exists: print('Updating existing guild record in db') cur.execute(""" UPDATE guilds SET warningChannelId=?, warningMention=?, joinWarningCount=?, joinWarningSeconds=? WHERE guildId=? """, ( channel_id, gc.warning_mention, gc.join_warning_count, gc.join_warning_seconds, gc.guild_id, )) else: print('Creating new guild record in db') cur.execute(""" INSERT INTO guilds ( guildId, warningChannelId, warningMention, joinWarningCount, joinWarningSeconds) VALUES (?, ?, ?, ?, ?) """, ( gc.guild_id, channel_id, gc.warning_mention, gc.join_warning_count, gc.join_warning_seconds, )) run_sql_batch(save) # -- Main (1) --------------------------------------------------------------- load_guild_settings() intents = Intents.default() intents.members = True # To get join/leave events bot = commands.Bot(command_prefix=CONFIG['commandPrefix'], intents=intents) # -- Bot commands ----------------------------------------------------------- @bot.command( brief='Simply replies to the invoker with a hello message in the same channel.' ) async def hello(ctx: Context): """ Command handler """ gc: GuildContext = get_or_create_guild_context(ctx.guild) if gc is None: return message = ctx.message await gc.command_hello(message) @bot.command( brief='Posts a test warning message in the configured warning channel.', help="""If no warning channel is configured, the bot will reply in the channel the command was issued to notify no warning channel is set. If a warning mention is configured, the test warning will tag the configured person/role.""" ) @commands.has_permissions(manage_messages=True) async def testwarn(ctx: Context): """ Command handler """ gc: GuildContext = get_or_create_guild_context(ctx.guild) if gc is None: return await gc.command_testwarn(ctx) @bot.command( brief='Sets the threshold for detecting a join raid.', usage=' ', help="""The raid threshold is expressed as number of joins within a given number of seconds. Each time a member joins, the number of joins in the previous _x_ seconds is counted, and if that count, _y_, equals or exceeds the count configured by this command, a raid is detected.""" ) @commands.has_permissions(manage_messages=True) async def setraidwarningrate(ctx: Context, count: int, seconds: int): """ Command handler """ gc: GuildContext = get_or_create_guild_context(ctx.guild) if gc is None: return await gc.command_setraidwarningrate(ctx, count, seconds) @bot.command( brief='Sets the current channel as the destination for bot warning messages.' ) @commands.has_permissions(manage_messages=True) async def setwarningchannel(ctx: Context): """ Command handler """ gc: GuildContext = get_or_create_guild_context(ctx.guild) if gc is None: return await gc.command_setwarningchannel(ctx) @bot.command( brief='Sets an optional mention to include in every warning message.', usage='', help="""The argument provided to this command will be included verbatim, so if the intent is to tag a user or role, the argument must be a tag, not merely the name of the user/role.""" ) @commands.has_permissions(manage_messages=True) async def setwarningmention(ctx: Context, mention: str): """ Command handler """ gc: GuildContext = get_or_create_guild_context(ctx.guild) if gc is None: return await gc.command_setwarningmention(ctx, mention) # -- Bot events ------------------------------------------------------------- is_connected = False @bot.listen() async def on_connect(): """ Discord event handler """ global is_connected print('Connected') is_connected = True if is_connected and is_ready: await populate_guilds() is_ready = False @bot.listen() async def on_ready(): """ Discord event handler """ global is_ready print('Ready') is_ready = True if is_connected and is_ready: await populate_guilds() async def populate_guilds(): """ Called after both on_ready and on_connect are done. May be called more than once! """ for guild in bot.guilds: gc = guild_id_to_guild_context.get(guild.id) if gc is None: print(f'No GuildContext for {guild.id}') continue gc.guild = guild if gc.warning_channel_id is not None: gc.warning_channel = guild.get_channel(gc.warning_channel_id) if gc.warning_channel is not None: print(f'Recovered warning channel {gc.warning_channel}') else: print(f'Could not find channel with id {gc.warning_channel_id} in ' + f'guild {guild.name}') for channel in await guild.fetch_channels(): print(f'\t{channel.name} ({channel.id})') @bot.listen() async def on_member_join(member: Member) -> None: """ Discord event handler """ print(f'User {member.name} joined {member.guild.name}') gc: GuildContext = get_or_create_guild_context(member.guild) if gc is None: print(f'No GuildContext for guild {member.guild.name}') return await gc.handle_join(member) @bot.listen() async def on_member_remove(member: Member) -> None: """ Discord event handler """ print(f'User {member.name} left {member.guild.name}') @bot.listen() async def on_raw_reaction_add(payload: RawReactionActionEvent) -> None: """ Discord event handler """ guild: Guild = bot.get_guild(payload.guild_id) channel: GuildChannel = guild.get_channel(payload.channel_id) message: Message = await channel.fetch_message(payload.message_id) member: Member = payload.member emoji: PartialEmoji = payload.emoji gc: GuildContext = get_or_create_guild_context(guild) await gc.handle_reaction_add(message, member, emoji) # -- Main ------------------------------------------------------------------- print('Starting bot') bot.run(CONFIG['clientToken']) print('Bot done')