| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296 |
- from config import config
- from datetime import datetime, timedelta
- import discord
- from discord.ext import commands
- from discord.ext.commands.context import Context
- import sqlite3
- import sys
-
- # -- Database ---------------------------------------------------------------
-
- def runSQLBatch(batchFunction):
- dbConnection = sqlite3.connect('rocketbot.db')
- dbCursor = dbConnection.cursor()
- batchFunction(dbConnection, dbCursor)
- dbConnection.commit()
- dbConnection.close()
-
- def loadGuildSettings():
- def load(con, cur):
- global config
- for row in cur.execute("""SELECT * FROM guilds"""):
- id = row[0]
- g = getOrCreateGuildContext(id, save=False)
- g.warningChannelId = row[1]
- g.warningMentionId = row[2]
- g.joinWarningCount = row[3] or config['joinWarningCount']
- g.joinWarningSeconds = row[4] or config['joinWarningSeconds']
- print('Guild {0} channel id is {1}'.format(id, g.warningChannelId))
- runSQLBatch(load)
-
- def createTables():
- def makeTables(con, cur):
- cur.execute("""CREATE TABLE guilds (
- id INTEGER,
- warningChannelId INTEGER,
- warningUserId INTEGER,
- joinWarningCount INTEGER,
- joinWarningSeconds INTEGER,
- PRIMARY KEY(id ASC))""")
- runSQLBatch(makeTables)
-
- def saveGuildContext(gc):
- def save(con, cur):
- print('Saving guild context with id {0}'.format(gc.id))
- cur.execute('SELECT id FROM guilds WHERE id=?', (gc.id,))
- channelId = gc.warningChannel.id if gc.warningChannel != None else gc.warningChannelId
- userId = gc.warningMention.id if gc.warningMention != None else gc.warningMentionId
- exists = cur.fetchone() != None
- print('Record exists' if exists else 'Record does not exist')
- if exists:
- cur.execute("""UPDATE guilds SET
- warningChannelId=?,
- warningUserId=?,
- joinWarningCount=?,
- joinWarningSeconds=?
- WHERE id=?""", (
- channelId,
- userId,
- gc.joinWarningCount,
- gc.joinWarningSeconds,
- gc.id
- ))
- else:
- cur.execute("""INSERT INTO guilds (
- id,
- warningChannelId,
- warningUserId,
- joinWarningCount,
- joinWarningSeconds)
- VALUES (?, ?, ?, ?, ?)""", (
- gc.id,
- channelId,
- userId,
- gc.joinWarningCount,
- gc.joinWarningSeconds
- ))
- runSQLBatch(save)
-
- # -- Classes ----------------------------------------------------------------
-
- class JoinRecord:
- def __init__(self, member):
- self.member = member
- self.joinTime = member.joined_at or datetime.now()
- self.isKicked = False
- self.isBanned = False
- def ageSeconds(self, referenceTime=datetime.now()):
- a = referenceTime - self.joinTime
- return a.total_seconds()
-
- class GuildContext:
- def __init__(self, id):
- global config
- self.id = id
- self.guild = None
- self.warningChannelId = None
- self.warningChannel = None
- self.warningMentionId = None
- self.warningMention = None
- self.joins = []
- self.joinWarningCount = config['joinWarningCount']
- self.joinWarningSeconds = config['joinWarningSeconds']
- self.isJoinRaidInProgress = False
- self.lastWarningMessage = None
-
- async def handleJoin(self, member):
- print('{0.guild.name}: {0.name} joined'.format(member))
- self.joins.append(JoinRecord(member))
- await self.__checkJoins()
-
- async def handleSetWarningChannel(self, context):
- print('{0.guild.name}: Warning channel set to {0.channel.name}'.format(context))
- self.warningChannel = context.channel
- self.warningChannelId = context.channel.id
- saveGuildContext(self)
-
- async def handleSetWarningMention(self, context):
- if len(context.args) < 1:
- return
- roleArg = context.args[0]
- print('{0.guild.name}: {1}'.format(context, roleArg))
-
- async def handleSetRaidWarningRate(self, context, count, seconds):
- self.joinWarningCount = count
- self.joinWarningSeconds = seconds
- print('{0.name}: Set rate as {1} joins per {2} seconds'.format(self.guild, self.joinWarningCount, self.joinWarningSeconds))
- saveGuildContext(self)
-
- async def __checkJoins(self):
- now = datetime.now()
- recentJoinCount = self.__countRecentJoins(now)
- print('{0} join(s) in the past {1} seconds'.format(recentJoinCount, self.joinWarningSeconds))
- if recentJoinCount >= self.joinWarningCount:
- # In join raid
- if not self.isJoinRaidInProgress:
- # Raid just started
- self.isJoinRaidInProgress = True
- await self.__onJoinRaidBegin()
- else:
- # Continuing existing join raid
- await self.__onJoinRaidUpdated()
- else:
- # No join raid
- if self.isJoinRaidInProgress:
- # Join raid just ended
- self.isJoinRaidInProgress = False
- await self.__onJoinRaidEnd()
- self.__cullOldJoins(now)
-
- def __countRecentJoins(self, now=datetime.now()):
- recentJoinCount = 0
- for joinRecord in self.joins:
- if joinRecord.ageSeconds(now) <= self.joinWarningSeconds:
- recentJoinCount += 1
- return recentJoinCount
-
- def __cullOldJoins(self, now=datetime.now()):
- i = 0
- while i < len(self.joins):
- if self.joins[i].ageSeconds(now) > self.joinWarningSeconds:
- self.joins.pop(i)
- else:
- i += 1
-
- async def __onJoinRaidBegin(self):
- print('A join raid has begun!')
- if self.warningChannel == None:
- print('No warning channel set')
- return
- # TODO: Mention mod role
- message = 'A join raid has been detected! It includes these users:'
- for join in self.joins:
- message += '\n• ' + join.member.mention
- self.lastWarningMessage = await self.warningChannel.send(message)
-
- async def __onJoinRaidUpdated(self):
- print('Join raid still occurring')
- if self.lastWarningMessage == None:
- return
- message = 'A join raid has been detected! It includes these users:'
- for join in self.joins:
- message += '\n• ' + join.member.mention
- await self.lastWarningMessage.edit(content=message)
-
- async def __onJoinRaidEnd(self):
- print('Join raid has ended')
- pass
-
- guildIdToGuildContext = {}
-
- def getOrCreateGuildContext(obj, save=True):
- gid = None
- if obj == None:
- return None
- if isinstance(obj, int):
- gid = obj
- elif isinstance(obj, GuildContext):
- return obj
- elif isinstance(obj, discord.Guild):
- gid = obj.id
- elif isinstance(obj, (discord.Message, discord.Member, discord.TextChannel, discord.ext.commands.context.Context)):
- gid = obj.guild.id
- if gid == None:
- print('Unhandled datatype', type(obj))
- return None
- lookedUp = guildIdToGuildContext.get(gid)
- if lookedUp != None:
- return lookedUp
- g = GuildContext(gid)
- guildIdToGuildContext[gid] = g
- if save:
- saveGuildContext(g)
- return g
-
- loadGuildSettings()
-
- intents = discord.Intents.default()
- intents.members = True # To get join/leave events
- bot = commands.Bot(command_prefix='$', intents=intents)
-
- # -- Bot commands -----------------------------------------------------------
-
- @commands.command()
- async def hello(ctx):
- message = ctx.message
- print('Got message from {0.author.name} in {0.channel.id}'.format(message))
- await message.channel.send('Hello, {0.author.mention}!'.format(message))
- print('Replied "Hello!"')
- bot.add_command(hello)
-
- @commands.command()
- @commands.has_permissions(manage_messages=True)
- async def setraidwarningrate(ctx, count: int, seconds: int):
- g = getOrCreateGuildContext(ctx)
- if g == None:
- return
- await g.handleSetRaidWarningRate(ctx, count, seconds)
- bot.add_command(setraidwarningrate)
-
- @commands.command()
- @commands.has_permissions(manage_messages=True)
- async def setwarningchannel(ctx):
- g = getOrCreateGuildContext(ctx)
- if g == None:
- return
- await g.handleSetWarningChannel(ctx)
- bot.add_command(setwarningchannel)
-
- @commands.command()
- @commands.has_permissions(manage_messages=True)
- async def setwarningrole(ctx):
- g = getOrCreateGuildContext(ctx)
- if g == None:
- return
- await g.handleSetWarningMention(ctx)
- bot.add_command(setwarningrole)
-
- # -- Bot events -------------------------------------------------------------
-
- @bot.listen()
- async def on_connect():
- global bot
- global guildIdToGuildContext
- for guild in bot.guilds:
- g = guildIdToGuildContext.get(guild.id)
- if g == None:
- print('No record for', guild.id)
- continue
- g.guild = guild
- if g.warningChannelId != None:
- g.warningChannel = guild.get_channel(g.warningChannelId)
- print('Recovered warning channel', g.warningChannel)
- if g.warningMentionId != None:
- g.warningMention = guild.get_role(g.warningMentionId)
- print('Recovered warning mention', g.warningMention)
-
- @bot.listen()
- async def on_ready():
- pass
-
- @bot.listen()
- async def on_member_join(member):
- print('User {0.name} joined'.format(member))
- g = getOrCreateGuildContext(member)
- if g == None:
- print('No GuildContext for guild {0.guild.name}'.format(member))
- return
- await g.handleJoin(member)
-
- @bot.listen()
- async def on_member_remove(member):
- print('User {0.name} left'.format(member))
-
- print('Starting bot')
- bot.run(config['clientToken'])
- print('Bot done')
|