| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398 |
- from config import config
- from datetime import datetime, timedelta
- import discord
- from discord.ext import commands
- from discord.ext.commands.context import Context
- import sqlite3
- import sys
-
- # -- Database ---------------------------------------------------------------
-
- def runSQLBatch(batchFunction):
- dbConnection = sqlite3.connect('rocketbot.db')
- dbCursor = dbConnection.cursor()
- batchFunction(dbConnection, dbCursor)
- dbConnection.commit()
- dbConnection.close()
-
- def loadGuildSettings():
- def load(con, cur):
- global config
- for row in cur.execute("""SELECT * FROM guilds"""):
- id = row[0]
- g = getOrCreateGuildContext(id, save=False)
- g.warningChannelId = row[1]
- g.warningMentionId = row[2]
- g.joinWarningCount = row[3] or config['joinWarningCount']
- g.joinWarningSeconds = row[4] or config['joinWarningSeconds']
- print('Guild {0} channel id is {1}'.format(id, g.warningChannelId))
- runSQLBatch(load)
-
- def createTables():
- def makeTables(con, cur):
- cur.execute("""CREATE TABLE guilds (
- id INTEGER,
- warningChannelId INTEGER,
- warningUserId INTEGER,
- joinWarningCount INTEGER,
- joinWarningSeconds INTEGER,
- PRIMARY KEY(id ASC))""")
- runSQLBatch(makeTables)
-
- def saveGuildContext(gc):
- def save(con, cur):
- print('Saving guild context with id {0}'.format(gc.id))
- cur.execute('SELECT id FROM guilds WHERE id=?', (gc.id,))
- channelId = gc.warningChannel.id if gc.warningChannel != None else gc.warningChannelId
- userId = gc.warningMention.id if gc.warningMention != None else gc.warningMentionId
- exists = cur.fetchone() != None
- print('Record exists' if exists else 'Record does not exist')
- if exists:
- cur.execute("""UPDATE guilds SET
- warningChannelId=?,
- warningUserId=?,
- joinWarningCount=?,
- joinWarningSeconds=?
- WHERE id=?""", (
- channelId,
- userId,
- gc.joinWarningCount,
- gc.joinWarningSeconds,
- gc.id
- ))
- else:
- cur.execute("""INSERT INTO guilds (
- id,
- warningChannelId,
- warningUserId,
- joinWarningCount,
- joinWarningSeconds)
- VALUES (?, ?, ?, ?, ?)""", (
- gc.id,
- channelId,
- userId,
- gc.joinWarningCount,
- gc.joinWarningSeconds
- ))
- runSQLBatch(save)
-
- # -- Classes ----------------------------------------------------------------
-
- class JoinRaid:
- def __init__(self, joins):
- self.joins = joins[:]
-
- async def addJoin(self, member):
- for join in self.joins:
- if join.member.id == member.id:
- # TODO: Move to front of list
- return # Already added
- self.joins.append(JoinRecord(member))
-
- async def kickAll(self):
- for join in self.joins:
- if join.isKicked or join.isBanned:
- continue
- await join.member.kick()
- join.isKicked = True
-
- async def banAll(self):
- for join in self.joins:
- if join.isBanned:
- continue
- await join.member.ban()
- join.isBanned = True
-
- class JoinRecord:
- def __init__(self, member):
- self.member = member
- self.joinTime = member.joined_at or datetime.now()
- self.isKicked = False
- self.isBanned = False
- def ageSeconds(self, referenceTime=datetime.now()):
- a = referenceTime - self.joinTime
- return a.total_seconds()
-
- class GuildContext:
- def __init__(self, id):
- global config
- self.id = id
- self.guild = None
- self.warningChannelId = None
- self.warningChannel = None
- self.warningMentionId = None
- self.warningMention = None
- self.joins = []
- self.joinWarningCount = config['joinWarningCount']
- self.joinWarningSeconds = config['joinWarningSeconds']
- self.isJoinRaidInProgress = False
- self.lastWarningMessage = None
-
- async def handleJoin(self, member):
- print('{0.guild.name}: {0.name} joined'.format(member))
- self.joins.append(JoinRecord(member))
- await self.__checkJoins()
-
- async def handleReactionAdd(self, message, member, emoji):
- global config
- if member.bot:
- return
- print('User {0} added emoji {1}'.format(member, emoji))
- if not member.permissions_in(message.channel).ban_members:
- print('Reactor does not have ban permissions')
- return
- if self.lastWarningMessage == None or message.id != self.lastWarningMessage.id:
- print('Reacted to a non-warning message')
- return
- if emoji.name == config['kickEmoji']:
- print('Kicking these users:')
- for join in self.joins:
- await join.member.kick(reason='Kicked by rocketbot for join raiding')
- join.isKicked = True
- print(' ' + join.member.name)
- elif emoji.name == config['banEmoji']:
- print('Banning these users:')
- for join in self.joins:
- await join.member.ban(reason='Banned by rocketbot for join raiding', delete_message_days=0)
- join.isBanned = True
- print(' ' + join.member.name)
- else:
- print('Unhandled emoji. Doing nothing.')
- return
- await self.__updateLastJoinRaidMessage()
-
- async def handleSetWarningChannel(self, context):
- print('{0.guild.name}: Warning channel set to {0.channel.name}'.format(context))
- self.warningChannel = context.channel
- self.warningChannelId = context.channel.id
- saveGuildContext(self)
-
- async def handleSetWarningMention(self, context):
- if len(context.args) < 1:
- return
- roleArg = context.args[0]
- print('{0.guild.name}: {1}'.format(context, roleArg))
-
- async def handleSetRaidWarningRate(self, context, count, seconds):
- self.joinWarningCount = count
- self.joinWarningSeconds = seconds
- print('{0.name}: Set rate as {1} joins per {2} seconds'.format(self.guild, self.joinWarningCount, self.joinWarningSeconds))
- saveGuildContext(self)
-
- async def __checkJoins(self):
- now = datetime.now()
- recentJoinCount = self.__countRecentJoins(now)
- print('{0} join(s) in the past {1} seconds'.format(recentJoinCount, self.joinWarningSeconds))
- if recentJoinCount >= self.joinWarningCount:
- # In join raid
- if not self.isJoinRaidInProgress:
- # Raid just started
- self.isJoinRaidInProgress = True
- await self.__onJoinRaidBegin()
- else:
- # Continuing existing join raid
- await self.__onJoinRaidUpdated()
- else:
- # No join raid
- if self.isJoinRaidInProgress:
- # Join raid just ended
- self.isJoinRaidInProgress = False
- await self.__onJoinRaidEnd()
- self.__cullOldJoins(now)
-
- def __countRecentJoins(self, now=datetime.now()):
- recentJoinCount = 0
- for joinRecord in self.joins:
- if joinRecord.ageSeconds(now) <= self.joinWarningSeconds:
- recentJoinCount += 1
- return recentJoinCount
-
- def __cullOldJoins(self, now=datetime.now()):
- i = 0
- while i < len(self.joins):
- if self.joins[i].ageSeconds(now) > self.joinWarningSeconds:
- self.joins.pop(i)
- else:
- i += 1
-
- def __joinRaidMessage(self):
- global config
- # TODO: Mention mod role
- message = 'A join raid has been detected! It includes these users:'
- canKick = False
- canBan = False
- for join in self.joins:
- message += '\n• '
- if join.isBanned:
- message += '~~' + join.member.mention + '~~ - banned'
- elif join.isKicked:
- message += '~~' + join.member.mention + '~~ - kicked'
- canBan = True
- else:
- message += join.member.mention
- canKick = True
- canBan = True
- if canKick:
- message += '\nTo kick all these users, react with :' + config['kickEmojiName'] + ':'
- if canBan:
- message += '\nTo ban all these users, react with :' + config['banEmojiName'] + ':'
- return (message, canKick, canBan)
-
- async def __updateLastJoinRaidMessage(self):
- if self.lastWarningMessage == None:
- print('No previous warning message to update')
- return
- (message, canKick, canBan) = self.__joinRaidMessage()
- await self.lastWarningMessage.edit(content=message)
-
- async def __onJoinRaidBegin(self):
- global config
- print('A join raid has begun!')
- if self.warningChannel == None:
- print('No warning channel set')
- return
- (message, canKick, canBan) = self.__joinRaidMessage()
- self.lastWarningMessage = await self.warningChannel.send(message)
- if canKick:
- await self.lastWarningMessage.add_reaction(config['kickEmoji'])
- if canBan:
- await self.lastWarningMessage.add_reaction(config['banEmoji'])
-
- async def __onJoinRaidUpdated(self):
- print('Join raid still occurring')
- await self.__updateLastJoinRaidMessage()
-
- async def __onJoinRaidEnd(self):
- print('Join raid has ended')
- pass
-
- guildIdToGuildContext = {}
-
- def getOrCreateGuildContext(obj, save=True):
- gid = None
- guild = None
- if obj == None:
- return None
- if isinstance(obj, int):
- gid = obj
- elif isinstance(obj, GuildContext):
- return obj
- elif isinstance(obj, discord.Guild):
- gid = obj.id
- guild = obj
- elif isinstance(obj, (discord.Message, discord.Member, discord.TextChannel, discord.ext.commands.context.Context)):
- gid = obj.guild.id
- guild = obj.guild
- if gid == None:
- print('Unhandled datatype', type(obj))
- return None
- lookedUp = guildIdToGuildContext.get(gid)
- if lookedUp != None:
- return lookedUp
- g = GuildContext(gid)
- g.guild = guild or g.guild
- guildIdToGuildContext[gid] = g
- if save:
- saveGuildContext(g)
- return g
-
- loadGuildSettings()
-
- intents = discord.Intents.default()
- intents.members = True # To get join/leave events
- bot = commands.Bot(command_prefix='$', intents=intents)
-
- # -- Bot commands -----------------------------------------------------------
-
- @commands.command()
- async def hello(ctx):
- message = ctx.message
- print('Got message from {0.author.name} in {0.channel.id}'.format(message))
- m = await message.channel.send('Hello, {0.author.mention}!'.format(message))
- print('Replied "Hello!"')
- bot.add_command(hello)
-
- @commands.command()
- @commands.has_permissions(manage_messages=True)
- async def setraidwarningrate(ctx, count: int, seconds: int):
- g = getOrCreateGuildContext(ctx)
- if g == None:
- return
- await g.handleSetRaidWarningRate(ctx, count, seconds)
- bot.add_command(setraidwarningrate)
-
- @commands.command()
- @commands.has_permissions(manage_messages=True)
- async def setwarningchannel(ctx):
- g = getOrCreateGuildContext(ctx)
- if g == None:
- return
- await g.handleSetWarningChannel(ctx)
- bot.add_command(setwarningchannel)
-
- @commands.command()
- @commands.has_permissions(manage_messages=True)
- async def setwarningrole(ctx, role: str):
- g = getOrCreateGuildContext(ctx)
- if g == None:
- return
- await g.handleSetWarningMention(ctx)
- bot.add_command(setwarningrole)
-
- # -- Bot events -------------------------------------------------------------
-
- @bot.listen()
- async def on_connect():
- global bot
- global guildIdToGuildContext
- for guild in bot.guilds:
- g = guildIdToGuildContext.get(guild.id)
- if g == None:
- print('No record for', guild.id)
- continue
- g.guild = guild
- if g.warningChannelId != None:
- g.warningChannel = guild.get_channel(g.warningChannelId)
- print('Recovered warning channel', g.warningChannel)
- if g.warningMentionId != None:
- g.warningMention = guild.get_role(g.warningMentionId)
- print('Recovered warning mention', g.warningMention)
-
- @bot.listen()
- async def on_ready():
- pass
-
- @bot.listen()
- async def on_member_join(member):
- print('User {0.name} joined'.format(member))
- g = getOrCreateGuildContext(member)
- if g == None:
- print('No GuildContext for guild {0.guild.name}'.format(member))
- return
- await g.handleJoin(member)
-
- @bot.listen()
- async def on_member_remove(member):
- print('User {0.name} left'.format(member))
-
- @bot.listen()
- async def on_raw_reaction_add(payload: discord.RawReactionActionEvent):
- # payload.message_id: int
- # payload.user_id: int
- # payload.channel_id: int
- # payload.guild_id: int
- # payload.emoji: PartialEmoji
- # payload.member: Member
- # payload.event_type: str
- global bot
- guild = bot.get_guild(payload.guild_id)
- channel = guild.get_channel(payload.channel_id)
- message = await channel.fetch_message(payload.message_id)
- member = payload.member
- emoji = payload.emoji
- gc = getOrCreateGuildContext(guild)
- await gc.handleReactionAdd(message, member, emoji)
-
- print('Starting bot')
- bot.run(config['clientToken'])
- print('Bot done')
|