""" Rocketbot Discord bot. Relies on a configured config.py (copy config.py.sample for a template) and the sqlite database rocketbot.db (copy rocketbot.db.sample for a blank database). Author: Ian Albert (@rocketsoup) Date: 2021-11-11 """ from datetime import datetime import sqlite3 import sys from discord import Guild, Intents, Member, Message, PartialEmoji, RawReactionActionEvent from discord.abc import GuildChannel from discord.ext import commands from discord.ext.commands.context import Context from config import CONFIG from cogs.config import ConfigCog from cogs.general import GeneralCog class Rocketbot(commands.Bot): def __init__(self, command_prefix, **kwargs): super().__init__(command_prefix, **kwargs) bot = Rocketbot(CONFIG['commandPrefix']) bot.add_cog(GeneralCog(bot)) bot.add_cog(ConfigCog(bot)) bot.run(CONFIG['clientToken'], bot=True, reconnect=True) print('\nBot aborted') # -- Classes ---------------------------------------------------------------- # class GuildContext: # """ # Logic and state for a single guild serviced by the bot. # """ # def __init__(self, guild_id: int): # self.guild_id = guild_id # self.guild = None # Resolved later # # Config populated during load # self.warning_channel_id = None # self.warning_channel = None # self.warning_mention = None # self.join_warning_count = CONFIG['joinWarningCount'] # self.join_warning_seconds = CONFIG['joinWarningSeconds'] # # Non-persisted runtime state # self.current_raid = JoinRaid() # self.all_raids = [ self.current_raid ] # periodically culled of old ones # # Commands # async def command_hello(self, message: Message) -> None: # """ # Command handler # """ # await message.channel.send(f'Hey there, {message.author.mention}!') # async def command_testwarn(self, context: Context) -> None: # """ # Command handler # """ # if self.warning_channel is None: # self.__trace('No warning channel set!') # await context.message.channel.send('No warning channel set on this guild! Type ' + # f'`{bot.command_prefix}{setwarningchannel.__name__}` in the channel you ' + # 'want warnings to be posted.') # return # await self.__warn('Test warning. This is only a test.') # async def command_setwarningchannel(self, context: Context): # """ # Command handler # """ # self.__trace(f'Warning channel set to {context.channel.name}') # self.warning_channel = context.channel # self.warning_channel_id = context.channel.id # save_guild_context(self) # await self.__warn('Warning messages will now be sent to ' + self.warning_channel.mention) # async def command_setwarningmention(self, _context: Context, mention: str): # """ # Command handler # """ # self.__trace('set warning mention') # m = mention if mention is not None and len(mention) > 0 else None # self.warning_mention = m # save_guild_context(self) # if m is None: # await self.__warn('Warning messages will not mention anyone') # else: # await self.__warn('Warning messages will now mention ' + m) # async def command_setraidwarningrate(self, _context: Context, count: int, seconds: int): # """ # Command handler # """ # self.join_warning_count = count # self.join_warning_seconds = seconds # save_guild_context(self) # await self.__warn(f'Maximum join rate set to {count} joins per {seconds} seconds') # # Events # async def handle_join(self, member: Member) -> None: # """ # Event handler for all joins to this guild. # """ # print(f'{member.guild.name}: {member.name} joined') # now = member.joined_at # raid = self.current_raid # raid.handle_join( # member, # now=now, # max_age_seconds = self.join_warning_seconds, # max_join_count = self.join_warning_count) # self.__trace(f'raid phase: {raid.phase}') # if raid.phase == RaidPhase.JUST_STARTED: # await self.__on_join_raid_begin(raid) # elif raid.phase == RaidPhase.CONTINUING: # await self.__on_join_raid_updated(raid) # elif raid.phase == RaidPhase.ENDED: # self.__start_new_raid(member) # await self.__on_join_raid_end(raid) # self.__cull_old_raids(now) # def __start_new_raid(self, member: Member = None): # """ # Retires self.current_raid and creates a new empty one. If `member` is passed, it will be # added to the new self.current_raid after it is created. # """ # self.current_raid = JoinRaid() # self.all_raids.append(self.current_raid) # if member is not None: # self.current_raid.handle_join( # member, # member.joined_at, # max_age_seconds = self.join_warning_seconds, # max_join_count = self.join_warning_count) # async def handle_reaction_add(self, message, member, emoji): # """ # Handles all message reaction events to see if they need to be acted on. # """ # if member.id == bot.user.id: # # It's-a me, Rocketbot! # return # if message.author.id != bot.user.id: # # The message the user is reacting to wasn't authored by me. Ignore. # return # self.__trace(f'User {member} added emoji {emoji}') # if not member.permissions_in(message.channel).ban_members: # self.__trace('Reactor does not have ban permissions. Ignoring.') # return # if emoji.name == CONFIG['kickEmoji']: # await self.__kick_all_in_raid_message(message) # elif emoji.name == CONFIG['banEmoji']: # await self.__ban_all_in_raid_message(message) # else: # print('Unhandled emoji. Ignoring.') # return # async def __kick_all_in_raid_message(self, message: Message): # """ # Kicks all the users mentioned in the given raid warning message. Users who were already # kicked or banned will be skipped. # """ # raid = self.__find_raid_for_message(message) # if raid is None: # await message.reply("This is either not a raid warning or it's too old and I don't " + # "have a record for it anymore. Sorry!") # return # self.__trace('Kicking...') # members = await raid.kick_all() # msg = 'Kicked these members:' # for member in members: # msg += f'\n\t{member.name}' # if len(members) == 0: # msg += '\n\t-none-' # self.__trace(msg) # self.__start_new_raid() # await self.__update_join_raid_message(raid) # async def __ban_all_in_raid_message(self, message: Message): # """ # Bans all the users mentioned in the given raid warning message. Users who were already # banned will be skipped. # """ # raid = self.__find_raid_for_message(message) # if raid is None: # await message.reply("This is either not a raid warning or it's too old and I don't " + # "have a record for it anymore. Sorry!") # return # self.__trace('Banning...') # members = await raid.ban_all() # msg = 'Banned these members:' # for member in members: # msg += f'\n\t{member.name}' # if len(members) == 0: # msg += '\n\t-none-' # self.__trace(msg) # self.__start_new_raid() # await self.__update_join_raid_message(raid) # def __find_raid_for_message(self, message: Message) -> JoinRaid: # """ # Retrieves a JoinRaid instance for the given raid warning message. Returns None if not found. # """ # for raid in self.all_raids: # if raid.warning_message.id == message.id: # return raid # return None # def __cull_old_raids(self, now: datetime): # """ # Gets rid of old JoinRaid records from self.all_raids that are too old to still be useful. # """ # i: int = 0 # while i < len(self.all_raids): # raid = self.all_raids[i] # if raid == self.current_raid: # i += 1 # continue # age_seconds = float((raid.raid_start_time - now).total_seconds()) # if age_seconds > 86400.0: # self.__trace('Culling old raid') # self.all_raids.pop(i) # else: # i += 1 # def __join_raid_message(self, raid: JoinRaid): # """ # Returns a 3-element tuple containing a text message appropriate for posting in # Discord, a flag of whether any of the mentioned users can be kicked, and a flag # of whether any of the mentioned users can be banned. # """ # message = '' # if self.warning_mention is not None: # message = self.warning_mention + ' ' # message += '**RAID JOIN DETECTED!** It includes these users:\n' # can_kick = False # can_ban = False # for join in raid.joins: # message += '\n• ' # if join.is_banned: # message += '~~' + join.member.mention + '~~ - banned' # elif join.is_kicked: # message += '~~' + join.member.mention + '~~ - kicked' # can_ban = True # else: # message += join.member.mention # can_kick = True # can_ban = True # message += '\n' # if can_kick: # message += '\nTo kick all these users, react with :' + CONFIG['kickEmojiName'] + ':' # else: # message += '\nNo kickable users remain' # if can_ban: # message += '\nTo ban all these users, react with :' + CONFIG['banEmojiName'] + ':' # else: # message += '\nNo bannable users remain' # return (message, can_kick, can_ban) # async def __update_join_raid_message(self, raid: JoinRaid): # """ # Updates an existing join raid warning message with updated data. # """ # if raid.warning_message is None: # self.__trace('No raid warning message to update') # return # (message, can_kick, can_ban) = self.__join_raid_message(raid) # await raid.warning_message.edit(content=message) # if not can_kick: # await raid.warning_message.clear_reaction(CONFIG['kickEmoji']) # if not can_ban: # await raid.warning_message.clear_reaction(CONFIG['banEmoji']) # async def __on_join_raid_begin(self, raid): # """ # Event triggered when the first member joins that triggers the raid detection. # """ # self.__trace('A join raid has begun!') # if self.warning_channel is None: # self.__trace('NO WARNING CHANNEL SET') # return # (message, can_kick, can_ban) = self.__join_raid_message(raid) # raid.warning_message = await self.warning_channel.send(message) # if can_kick: # await raid.warning_message.add_reaction(CONFIG['kickEmoji']) # if can_ban: # await raid.warning_message.add_reaction(CONFIG['banEmoji']) # async def __on_join_raid_updated(self, raid): # """ # Event triggered for each subsequent member join after the first one that triggered the # raid detection. # """ # self.__trace('Join raid still occurring') # await self.__update_join_raid_message(raid) # async def __on_join_raid_end(self, _raid): # """ # Event triggered when the first member joins who is not part of the most recent raid. # """ # self.__trace('Join raid has ended') # async def __warn(self, message): # """ # Posts a warning message in the configured warning channel. # """ # if self.warning_channel is None: # self.__trace('NO WARNING CHANNEL SET. Warning message not posted.\n' + message) # return None # m = message # if self.warning_mention is not None: # m = self.warning_mention + ' ' + m # return await self.warning_channel.send(m) # def __trace(self, message): # """ # Debugging trace. # """ # print(f'{self.guild.name}: {message}') # # lookup for int(Guild.guild_id) --> GuildContext # guild_id_to_guild_context = {} # def get_or_create_guild_context(val, save=True): # """ # Retrieves a cached GuildContext instance by its Guild id or Guild object # itself. If no GuildContext record exists for the Guild, one is created # and cached (and saved to the database unless `save=False`). # """ # gid = None # guild = None # if val is None: # return None # if isinstance(val, int): # gid = val # elif isinstance(val, Guild): # gid = val.id # guild = val # if gid is None: # print('Unhandled datatype', type(val)) # return None # looked_up = guild_id_to_guild_context.get(gid) # if looked_up is not None: # return looked_up # gc = GuildContext(gid) # gc.guild = guild or gc.guild # guild_id_to_guild_context[gid] = gc # if save: # save_guild_context(gc) # return gc # # -- Database --------------------------------------------------------------- # def run_sql_batch(batch_function): # """ # Performs an SQL transaction. After a connection is opened, the passed # function is invoked with the sqlite3.Connection and sqlite3.Cursor # passed as arguments. Once the passed function finishes, the connection # is closed. # """ # db_connection: sqlite3.Connection = sqlite3.connect('rocketbot.db') # db_cursor: sqlite3.Cursor = db_connection.cursor() # batch_function(db_connection, db_cursor) # db_connection.commit() # db_connection.close() # def load_guild_settings(): # """ # Populates the GuildContext cache with records from the database. # """ # def load(_con, cur): # """ # SQL # """ # for row in cur.execute("""SELECT * FROM guilds"""): # guild_id = row[0] # gc = get_or_create_guild_context(guild_id, save=False) # gc.warning_channel_id = row[1] # gc.warning_mention = row[2] # gc.join_warning_count = row[3] or CONFIG['joinWarningCount'] # gc.join_warning_seconds = row[4] or CONFIG['joinWarningSeconds'] # print(f'Guild {guild_id} channel id is {gc.warning_channel_id}') # run_sql_batch(load) # def create_tables(): # """ # Creates all database tables. # """ # def make_tables(_con, cur): # """ # SQL # """ # cur.execute("""CREATE TABLE guilds ( # guildId INTEGER, # warningChannelId INTEGER, # warningMention TEXT, # joinWarningCount INTEGER, # joinWarningSeconds INTEGER, # PRIMARY KEY(guildId ASC))""") # run_sql_batch(make_tables) # def save_guild_context(gc: GuildContext): # """ # Saves the state of a GuildContext record to the database. # """ # def save(_con, cur): # """ # SQL # """ # print(f'Saving guild context with id {gc.guild_id}') # cur.execute(""" # SELECT guildId # FROM guilds # WHERE guildId=? # """, ( # gc.guild_id, # )) # channel_id = gc.warning_channel.id if gc.warning_channel is not None \ # else gc.warning_channel_id # exists = cur.fetchone() is not None # if exists: # print('Updating existing guild record in db') # cur.execute(""" # UPDATE guilds # SET warningChannelId=?, # warningMention=?, # joinWarningCount=?, # joinWarningSeconds=? # WHERE guildId=? # """, ( # channel_id, # gc.warning_mention, # gc.join_warning_count, # gc.join_warning_seconds, # gc.guild_id, # )) # else: # print('Creating new guild record in db') # cur.execute(""" # INSERT INTO guilds ( # guildId, # warningChannelId, # warningMention, # joinWarningCount, # joinWarningSeconds) # VALUES (?, ?, ?, ?, ?) # """, ( # gc.guild_id, # channel_id, # gc.warning_mention, # gc.join_warning_count, # gc.join_warning_seconds, # )) # run_sql_batch(save) # # -- Main (1) --------------------------------------------------------------- # load_guild_settings() # intents = Intents.default() # intents.members = True # To get join/leave events # bot = commands.Bot(command_prefix=CONFIG['commandPrefix'], intents=intents) # # -- Bot commands ----------------------------------------------------------- # @bot.command( # brief='Simply replies to the invoker with a hello message in the same channel.' # ) # async def hello(ctx: Context): # """ # Command handler # """ # gc: GuildContext = get_or_create_guild_context(ctx.guild) # if gc is None: # return # message = ctx.message # await gc.command_hello(message) # @bot.command( # brief='Posts a test warning message in the configured warning channel.', # help="""If no warning channel is configured, the bot will reply in the channel the command was # issued to notify no warning channel is set. If a warning mention is configured, the test # warning will tag the configured person/role.""" # ) # @commands.has_permissions(manage_messages=True) # async def testwarn(ctx: Context): # """ # Command handler # """ # gc: GuildContext = get_or_create_guild_context(ctx.guild) # if gc is None: # return # await gc.command_testwarn(ctx) # @bot.command( # brief='Sets the threshold for detecting a join raid.', # usage=' ', # help="""The raid threshold is expressed as number of joins within a given number of seconds. # Each time a member joins, the number of joins in the previous _x_ seconds is counted, and if # that count, _y_, equals or exceeds the count configured by this command, a raid is detected.""" # ) # @commands.has_permissions(manage_messages=True) # async def setraidwarningrate(ctx: Context, count: int, seconds: int): # """ # Command handler # """ # gc: GuildContext = get_or_create_guild_context(ctx.guild) # if gc is None: # return # await gc.command_setraidwarningrate(ctx, count, seconds) # @bot.command( # brief='Sets the current channel as the destination for bot warning messages.' # ) # @commands.has_permissions(manage_messages=True) # async def setwarningchannel(ctx: Context): # """ # Command handler # """ # gc: GuildContext = get_or_create_guild_context(ctx.guild) # if gc is None: # return # await gc.command_setwarningchannel(ctx) # @bot.command( # brief='Sets an optional mention to include in every warning message.', # usage='', # help="""The argument provided to this command will be included verbatim, so if the intent is # to tag a user or role, the argument must be a tag, not merely the name of the user/role.""" # ) # @commands.has_permissions(manage_messages=True) # async def setwarningmention(ctx: Context, mention: str): # """ # Command handler # """ # gc: GuildContext = get_or_create_guild_context(ctx.guild) # if gc is None: # return # await gc.command_setwarningmention(ctx, mention) # # -- Bot events ------------------------------------------------------------- # is_connected = False # @bot.listen() # async def on_connect(): # """ # Discord event handler # """ # global is_connected # print('Connected') # is_connected = True # if is_connected and is_ready: # await populate_guilds() # is_ready = False # @bot.listen() # async def on_ready(): # """ # Discord event handler # """ # global is_ready # print('Ready') # is_ready = True # if is_connected and is_ready: # await populate_guilds() # async def populate_guilds(): # """ # Called after both on_ready and on_connect are done. May be called more than once! # """ # for guild in bot.guilds: # gc = guild_id_to_guild_context.get(guild.id) # if gc is None: # print(f'No GuildContext for {guild.id}') # continue # gc.guild = guild # if gc.warning_channel_id is not None: # gc.warning_channel = guild.get_channel(gc.warning_channel_id) # if gc.warning_channel is not None: # print(f'Recovered warning channel {gc.warning_channel}') # else: # print(f'Could not find channel with id {gc.warning_channel_id} in ' + # f'guild {guild.name}') # for channel in await guild.fetch_channels(): # print(f'\t{channel.name} ({channel.id})') # @bot.listen() # async def on_member_join(member: Member) -> None: # """ # Discord event handler # """ # print(f'User {member.name} joined {member.guild.name}') # gc: GuildContext = get_or_create_guild_context(member.guild) # if gc is None: # print(f'No GuildContext for guild {member.guild.name}') # return # await gc.handle_join(member) # @bot.listen() # async def on_member_remove(member: Member) -> None: # """ # Discord event handler # """ # print(f'User {member.name} left {member.guild.name}') # @bot.listen() # async def on_raw_reaction_add(payload: RawReactionActionEvent) -> None: # """ # Discord event handler # """ # guild: Guild = bot.get_guild(payload.guild_id) # channel: GuildChannel = guild.get_channel(payload.channel_id) # message: Message = await channel.fetch_message(payload.message_id) # member: Member = payload.member # emoji: PartialEmoji = payload.emoji # gc: GuildContext = get_or_create_guild_context(guild) # await gc.handle_reaction_add(message, member, emoji) # # -- Main ------------------------------------------------------------------- # print('Starting bot') # bot.run(CONFIG['clientToken']) # print('Bot done')