from config import config from datetime import datetime, timedelta import discord from discord.ext import commands from discord.ext.commands.context import Context import sqlite3 import sys # -- Database --------------------------------------------------------------- def runSQLBatch(batchFunction): dbConnection = sqlite3.connect('rocketbot.db') dbCursor = dbConnection.cursor() batchFunction(dbConnection, dbCursor) dbConnection.commit() dbConnection.close() def loadGuildSettings(): def load(con, cur): global config for row in cur.execute("""SELECT * FROM guilds"""): id = row[0] g = getOrCreateGuildContext(id, save=False) g.warningChannelId = row[1] g.warningMentionId = row[2] g.joinWarningCount = row[3] or config['joinWarningCount'] g.joinWarningSeconds = row[4] or config['joinWarningSeconds'] print('Guild {0} channel id is {1}'.format(id, g.warningChannelId)) runSQLBatch(load) def createTables(): def makeTables(con, cur): cur.execute("""CREATE TABLE guilds ( id INTEGER, warningChannelId INTEGER, warningUserId INTEGER, joinWarningCount INTEGER, joinWarningSeconds INTEGER, PRIMARY KEY(id ASC))""") runSQLBatch(makeTables) def saveGuildContext(gc): def save(con, cur): print('Saving guild context with id {0}'.format(gc.id)) cur.execute('SELECT id FROM guilds WHERE id=?', (gc.id,)) channelId = gc.warningChannel.id if gc.warningChannel != None else gc.warningChannelId userId = gc.warningMention.id if gc.warningMention != None else gc.warningMentionId exists = cur.fetchone() != None print('Record exists' if exists else 'Record does not exist') if exists: cur.execute("""UPDATE guilds SET warningChannelId=?, warningUserId=?, joinWarningCount=?, joinWarningSeconds=? WHERE id=?""", ( channelId, userId, gc.joinWarningCount, gc.joinWarningSeconds, gc.id )) else: cur.execute("""INSERT INTO guilds ( id, warningChannelId, warningUserId, joinWarningCount, joinWarningSeconds) VALUES (?, ?, ?, ?, ?)""", ( gc.id, channelId, userId, gc.joinWarningCount, gc.joinWarningSeconds )) runSQLBatch(save) # -- Classes ---------------------------------------------------------------- class JoinRaid: def __init__(self, joins): self.joins = joins[:] async def addJoin(self, member): for join in self.joins: if join.member.id == member.id: # TODO: Move to front of list return # Already added self.joins.append(JoinRecord(member)) async def kickAll(self): for join in self.joins: if join.isKicked or join.isBanned: continue await join.member.kick() join.isKicked = True async def banAll(self): for join in self.joins: if join.isBanned: continue await join.member.ban() join.isBanned = True class JoinRecord: def __init__(self, member): self.member = member self.joinTime = member.joined_at or datetime.now() self.isKicked = False self.isBanned = False def ageSeconds(self, referenceTime=datetime.now()): a = referenceTime - self.joinTime return a.total_seconds() class GuildContext: def __init__(self, id): global config self.id = id self.guild = None self.warningChannelId = None self.warningChannel = None self.warningMentionId = None self.warningMention = None self.joins = [] self.joinWarningCount = config['joinWarningCount'] self.joinWarningSeconds = config['joinWarningSeconds'] self.isJoinRaidInProgress = False self.lastWarningMessage = None async def handleJoin(self, member): print('{0.guild.name}: {0.name} joined'.format(member)) self.joins.append(JoinRecord(member)) await self.__checkJoins() async def handleReactionAdd(self, message, member, emoji): global config if member.bot: return print('User {0} added emoji {1}'.format(member, emoji)) if not member.permissions_in(message.channel).ban_members: print('Reactor does not have ban permissions') return if self.lastWarningMessage == None or message.id != self.lastWarningMessage.id: print('Reacted to a non-warning message') return if emoji.name == config['kickEmoji']: print('Kicking these users:') for join in self.joins: await join.member.kick(reason='Kicked by rocketbot for join raiding') join.isKicked = True print(' ' + join.member.name) elif emoji.name == config['banEmoji']: print('Banning these users:') for join in self.joins: await join.member.ban(reason='Banned by rocketbot for join raiding', delete_message_days=0) join.isBanned = True print(' ' + join.member.name) else: print('Unhandled emoji. Doing nothing.') return await self.__updateLastJoinRaidMessage() async def handleSetWarningChannel(self, context): print('{0.guild.name}: Warning channel set to {0.channel.name}'.format(context)) self.warningChannel = context.channel self.warningChannelId = context.channel.id saveGuildContext(self) async def handleSetWarningMention(self, context): if len(context.args) < 1: return roleArg = context.args[0] print('{0.guild.name}: {1}'.format(context, roleArg)) async def handleSetRaidWarningRate(self, context, count, seconds): self.joinWarningCount = count self.joinWarningSeconds = seconds print('{0.name}: Set rate as {1} joins per {2} seconds'.format(self.guild, self.joinWarningCount, self.joinWarningSeconds)) saveGuildContext(self) async def __checkJoins(self): now = datetime.now() recentJoinCount = self.__countRecentJoins(now) print('{0} join(s) in the past {1} seconds'.format(recentJoinCount, self.joinWarningSeconds)) if recentJoinCount >= self.joinWarningCount: # In join raid if not self.isJoinRaidInProgress: # Raid just started self.isJoinRaidInProgress = True await self.__onJoinRaidBegin() else: # Continuing existing join raid await self.__onJoinRaidUpdated() else: # No join raid if self.isJoinRaidInProgress: # Join raid just ended self.isJoinRaidInProgress = False await self.__onJoinRaidEnd() self.__cullOldJoins(now) def __countRecentJoins(self, now=datetime.now()): recentJoinCount = 0 for joinRecord in self.joins: if joinRecord.ageSeconds(now) <= self.joinWarningSeconds: recentJoinCount += 1 return recentJoinCount def __cullOldJoins(self, now=datetime.now()): i = 0 while i < len(self.joins): if self.joins[i].ageSeconds(now) > self.joinWarningSeconds: self.joins.pop(i) else: i += 1 def __joinRaidMessage(self): global config # TODO: Mention mod role message = 'A join raid has been detected! It includes these users:' canKick = False canBan = False for join in self.joins: message += '\n• ' if join.isBanned: message += '~~' + join.member.mention + '~~ - banned' elif join.isKicked: message += '~~' + join.member.mention + '~~ - kicked' canBan = True else: message += join.member.mention canKick = True canBan = True if canKick: message += '\nTo kick all these users, react with :' + config['kickEmojiName'] + ':' if canBan: message += '\nTo ban all these users, react with :' + config['banEmojiName'] + ':' return (message, canKick, canBan) async def __updateLastJoinRaidMessage(self): if self.lastWarningMessage == None: print('No previous warning message to update') return (message, canKick, canBan) = self.__joinRaidMessage() await self.lastWarningMessage.edit(content=message) async def __onJoinRaidBegin(self): global config print('A join raid has begun!') if self.warningChannel == None: print('No warning channel set') return (message, canKick, canBan) = self.__joinRaidMessage() self.lastWarningMessage = await self.warningChannel.send(message) if canKick: await self.lastWarningMessage.add_reaction(config['kickEmoji']) if canBan: await self.lastWarningMessage.add_reaction(config['banEmoji']) async def __onJoinRaidUpdated(self): print('Join raid still occurring') await self.__updateLastJoinRaidMessage() async def __onJoinRaidEnd(self): print('Join raid has ended') pass guildIdToGuildContext = {} def getOrCreateGuildContext(obj, save=True): gid = None guild = None if obj == None: return None if isinstance(obj, int): gid = obj elif isinstance(obj, GuildContext): return obj elif isinstance(obj, discord.Guild): gid = obj.id guild = obj elif isinstance(obj, (discord.Message, discord.Member, discord.TextChannel, discord.ext.commands.context.Context)): gid = obj.guild.id guild = obj.guild if gid == None: print('Unhandled datatype', type(obj)) return None lookedUp = guildIdToGuildContext.get(gid) if lookedUp != None: return lookedUp g = GuildContext(gid) g.guild = guild or g.guild guildIdToGuildContext[gid] = g if save: saveGuildContext(g) return g loadGuildSettings() intents = discord.Intents.default() intents.members = True # To get join/leave events bot = commands.Bot(command_prefix='$', intents=intents) # -- Bot commands ----------------------------------------------------------- @commands.command() async def hello(ctx): message = ctx.message print('Got message from {0.author.name} in {0.channel.id}'.format(message)) m = await message.channel.send('Hello, {0.author.mention}!'.format(message)) print('Replied "Hello!"') bot.add_command(hello) @commands.command() @commands.has_permissions(manage_messages=True) async def setraidwarningrate(ctx, count: int, seconds: int): g = getOrCreateGuildContext(ctx) if g == None: return await g.handleSetRaidWarningRate(ctx, count, seconds) bot.add_command(setraidwarningrate) @commands.command() @commands.has_permissions(manage_messages=True) async def setwarningchannel(ctx): g = getOrCreateGuildContext(ctx) if g == None: return await g.handleSetWarningChannel(ctx) bot.add_command(setwarningchannel) @commands.command() @commands.has_permissions(manage_messages=True) async def setwarningrole(ctx, role: str): g = getOrCreateGuildContext(ctx) if g == None: return await g.handleSetWarningMention(ctx) bot.add_command(setwarningrole) # -- Bot events ------------------------------------------------------------- @bot.listen() async def on_connect(): global bot global guildIdToGuildContext for guild in bot.guilds: g = guildIdToGuildContext.get(guild.id) if g == None: print('No record for', guild.id) continue g.guild = guild if g.warningChannelId != None: g.warningChannel = guild.get_channel(g.warningChannelId) print('Recovered warning channel', g.warningChannel) if g.warningMentionId != None: g.warningMention = guild.get_role(g.warningMentionId) print('Recovered warning mention', g.warningMention) @bot.listen() async def on_ready(): pass @bot.listen() async def on_member_join(member): print('User {0.name} joined'.format(member)) g = getOrCreateGuildContext(member) if g == None: print('No GuildContext for guild {0.guild.name}'.format(member)) return await g.handleJoin(member) @bot.listen() async def on_member_remove(member): print('User {0.name} left'.format(member)) @bot.listen() async def on_raw_reaction_add(payload: discord.RawReactionActionEvent): # payload.message_id: int # payload.user_id: int # payload.channel_id: int # payload.guild_id: int # payload.emoji: PartialEmoji # payload.member: Member # payload.event_type: str global bot guild = bot.get_guild(payload.guild_id) channel = guild.get_channel(payload.channel_id) message = await channel.fetch_message(payload.message_id) member = payload.member emoji = payload.emoji gc = getOrCreateGuildContext(guild) await gc.handleReactionAdd(message, member, emoji) print('Starting bot') bot.run(config['clientToken']) print('Bot done')