from config import config from datetime import datetime, timedelta import discord from discord.ext import commands from discord.ext.commands.context import Context import sqlite3 import sys # -- Database --------------------------------------------------------------- def runSQLBatch(batchFunction): dbConnection = sqlite3.connect('rocketbot.db') dbCursor = dbConnection.cursor() batchFunction(dbConnection, dbCursor) dbConnection.commit() dbConnection.close() def loadGuildSettings(): def load(con, cur): global config for row in cur.execute("""SELECT * FROM guilds"""): id = row[0] g = getOrCreateGuildContext(id, save=False) g.warningChannelId = row[1] g.warningMentionId = row[2] g.joinWarningCount = row[3] or config['joinWarningCount'] g.joinWarningSeconds = row[4] or config['joinWarningSeconds'] print('Guild {0} channel id is {1}'.format(id, g.warningChannelId)) runSQLBatch(load) def createTables(): def makeTables(con, cur): cur.execute("""CREATE TABLE guilds ( id INTEGER, warningChannelId INTEGER, warningUserId INTEGER, joinWarningCount INTEGER, joinWarningSeconds INTEGER, PRIMARY KEY(id ASC))""") runSQLBatch(makeTables) def saveGuildContext(gc): def save(con, cur): print('Saving guild context with id {0}'.format(gc.id)) cur.execute('SELECT id FROM guilds WHERE id=?', (gc.id,)) channelId = gc.warningChannel.id if gc.warningChannel != None else gc.warningChannelId userId = gc.warningMention.id if gc.warningMention != None else gc.warningMentionId exists = cur.fetchone() != None print('Record exists' if exists else 'Record does not exist') if exists: cur.execute("""UPDATE guilds SET warningChannelId=?, warningUserId=?, joinWarningCount=?, joinWarningSeconds=? WHERE id=?""", ( channelId, userId, gc.joinWarningCount, gc.joinWarningSeconds, gc.id )) else: cur.execute("""INSERT INTO guilds ( id, warningChannelId, warningUserId, joinWarningCount, joinWarningSeconds) VALUES (?, ?, ?, ?, ?)""", ( gc.id, channelId, userId, gc.joinWarningCount, gc.joinWarningSeconds )) runSQLBatch(save) # -- Classes ---------------------------------------------------------------- class JoinRecord: def __init__(self, member): self.member = member self.joinTime = member.joined_at or datetime.now() self.isKicked = False self.isBanned = False def ageSeconds(self, referenceTime=datetime.now()): a = referenceTime - self.joinTime return a.total_seconds() class GuildContext: def __init__(self, id): global config self.id = id self.guild = None self.warningChannelId = None self.warningChannel = None self.warningMentionId = None self.warningMention = None self.joins = [] self.joinWarningCount = config['joinWarningCount'] self.joinWarningSeconds = config['joinWarningSeconds'] self.isJoinRaidInProgress = False self.lastWarningMessage = None async def handleJoin(self, member): print('{0.guild.name}: {0.name} joined'.format(member)) self.joins.append(JoinRecord(member)) await self.__checkJoins() async def handleSetWarningChannel(self, context): print('{0.guild.name}: Warning channel set to {0.channel.name}'.format(context)) self.warningChannel = context.channel self.warningChannelId = context.channel.id saveGuildContext(self) async def handleSetWarningMention(self, context): if len(context.args) < 1: return roleArg = context.args[0] print('{0.guild.name}: {1}'.format(context, roleArg)) async def handleSetRaidWarningRate(self, context, count, seconds): self.joinWarningCount = count self.joinWarningSeconds = seconds print('{0.name}: Set rate as {1} joins per {2} seconds'.format(self.guild, self.joinWarningCount, self.joinWarningSeconds)) saveGuildContext(self) async def __checkJoins(self): now = datetime.now() recentJoinCount = self.__countRecentJoins(now) print('{0} join(s) in the past {1} seconds'.format(recentJoinCount, self.joinWarningSeconds)) if recentJoinCount >= self.joinWarningCount: # In join raid if not self.isJoinRaidInProgress: # Raid just started self.isJoinRaidInProgress = True await self.__onJoinRaidBegin() else: # Continuing existing join raid await self.__onJoinRaidUpdated() else: # No join raid if self.isJoinRaidInProgress: # Join raid just ended self.isJoinRaidInProgress = False await self.__onJoinRaidEnd() self.__cullOldJoins(now) def __countRecentJoins(self, now=datetime.now()): recentJoinCount = 0 for joinRecord in self.joins: if joinRecord.ageSeconds(now) <= self.joinWarningSeconds: recentJoinCount += 1 return recentJoinCount def __cullOldJoins(self, now=datetime.now()): i = 0 while i < len(self.joins): if self.joins[i].ageSeconds(now) > self.joinWarningSeconds: self.joins.pop(i) else: i += 1 async def __onJoinRaidBegin(self): print('A join raid has begun!') if self.warningChannel == None: print('No warning channel set') return # TODO: Mention mod role message = 'A join raid has been detected! It includes these users:' for join in self.joins: message += '\n• ' + join.member.mention self.lastWarningMessage = await self.warningChannel.send(message) async def __onJoinRaidUpdated(self): print('Join raid still occurring') if self.lastWarningMessage == None: return message = 'A join raid has been detected! It includes these users:' for join in self.joins: message += '\n• ' + join.member.mention await self.lastWarningMessage.edit(content=message) async def __onJoinRaidEnd(self): print('Join raid has ended') pass guildIdToGuildContext = {} def getOrCreateGuildContext(obj, save=True): gid = None if obj == None: return None if isinstance(obj, int): gid = obj elif isinstance(obj, GuildContext): return obj elif isinstance(obj, discord.Guild): gid = obj.id elif isinstance(obj, (discord.Message, discord.Member, discord.TextChannel, discord.ext.commands.context.Context)): gid = obj.guild.id if gid == None: print('Unhandled datatype', type(obj)) return None lookedUp = guildIdToGuildContext.get(gid) if lookedUp != None: return lookedUp g = GuildContext(gid) guildIdToGuildContext[gid] = g if save: saveGuildContext(g) return g loadGuildSettings() intents = discord.Intents.default() intents.members = True # To get join/leave events bot = commands.Bot(command_prefix='$', intents=intents) # -- Bot commands ----------------------------------------------------------- @commands.command() async def hello(ctx): message = ctx.message print('Got message from {0.author.name} in {0.channel.id}'.format(message)) await message.channel.send('Hello, {0.author.mention}!'.format(message)) print('Replied "Hello!"') bot.add_command(hello) @commands.command() @commands.has_permissions(manage_messages=True) async def setraidwarningrate(ctx, count: int, seconds: int): g = getOrCreateGuildContext(ctx) if g == None: return await g.handleSetRaidWarningRate(ctx, count, seconds) bot.add_command(setraidwarningrate) @commands.command() @commands.has_permissions(manage_messages=True) async def setwarningchannel(ctx): g = getOrCreateGuildContext(ctx) if g == None: return await g.handleSetWarningChannel(ctx) bot.add_command(setwarningchannel) @commands.command() @commands.has_permissions(manage_messages=True) async def setwarningrole(ctx): g = getOrCreateGuildContext(ctx) if g == None: return await g.handleSetWarningMention(ctx) bot.add_command(setwarningrole) # -- Bot events ------------------------------------------------------------- @bot.listen() async def on_connect(): global bot global guildIdToGuildContext for guild in bot.guilds: g = guildIdToGuildContext.get(guild.id) if g == None: print('No record for', guild.id) continue g.guild = guild if g.warningChannelId != None: g.warningChannel = guild.get_channel(g.warningChannelId) print('Recovered warning channel', g.warningChannel) if g.warningMentionId != None: g.warningMention = guild.get_role(g.warningMentionId) print('Recovered warning mention', g.warningMention) @bot.listen() async def on_ready(): pass @bot.listen() async def on_member_join(member): print('User {0.name} joined'.format(member)) g = getOrCreateGuildContext(member) if g == None: print('No GuildContext for guild {0.guild.name}'.format(member)) return await g.handleJoin(member) @bot.listen() async def on_member_remove(member): print('User {0.name} left'.format(member)) print('Starting bot') bot.run(config['clientToken']) print('Bot done')