Experimental Discord bot written in Python
Nelze vybrat více než 25 témat Téma musí začínat písmenem nebo číslem, může obsahovat pomlčky („-“) a může být dlouhé až 35 znaků.

rocketbot.py 14KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398
  1. from config import config
  2. from datetime import datetime, timedelta
  3. import discord
  4. from discord.ext import commands
  5. from discord.ext.commands.context import Context
  6. import sqlite3
  7. import sys
  8. # -- Database ---------------------------------------------------------------
  9. def runSQLBatch(batchFunction):
  10. dbConnection = sqlite3.connect('rocketbot.db')
  11. dbCursor = dbConnection.cursor()
  12. batchFunction(dbConnection, dbCursor)
  13. dbConnection.commit()
  14. dbConnection.close()
  15. def loadGuildSettings():
  16. def load(con, cur):
  17. global config
  18. for row in cur.execute("""SELECT * FROM guilds"""):
  19. id = row[0]
  20. g = getOrCreateGuildContext(id, save=False)
  21. g.warningChannelId = row[1]
  22. g.warningMentionId = row[2]
  23. g.joinWarningCount = row[3] or config['joinWarningCount']
  24. g.joinWarningSeconds = row[4] or config['joinWarningSeconds']
  25. print('Guild {0} channel id is {1}'.format(id, g.warningChannelId))
  26. runSQLBatch(load)
  27. def createTables():
  28. def makeTables(con, cur):
  29. cur.execute("""CREATE TABLE guilds (
  30. id INTEGER,
  31. warningChannelId INTEGER,
  32. warningUserId INTEGER,
  33. joinWarningCount INTEGER,
  34. joinWarningSeconds INTEGER,
  35. PRIMARY KEY(id ASC))""")
  36. runSQLBatch(makeTables)
  37. def saveGuildContext(gc):
  38. def save(con, cur):
  39. print('Saving guild context with id {0}'.format(gc.id))
  40. cur.execute('SELECT id FROM guilds WHERE id=?', (gc.id,))
  41. channelId = gc.warningChannel.id if gc.warningChannel != None else gc.warningChannelId
  42. userId = gc.warningMention.id if gc.warningMention != None else gc.warningMentionId
  43. exists = cur.fetchone() != None
  44. print('Record exists' if exists else 'Record does not exist')
  45. if exists:
  46. cur.execute("""UPDATE guilds SET
  47. warningChannelId=?,
  48. warningUserId=?,
  49. joinWarningCount=?,
  50. joinWarningSeconds=?
  51. WHERE id=?""", (
  52. channelId,
  53. userId,
  54. gc.joinWarningCount,
  55. gc.joinWarningSeconds,
  56. gc.id
  57. ))
  58. else:
  59. cur.execute("""INSERT INTO guilds (
  60. id,
  61. warningChannelId,
  62. warningUserId,
  63. joinWarningCount,
  64. joinWarningSeconds)
  65. VALUES (?, ?, ?, ?, ?)""", (
  66. gc.id,
  67. channelId,
  68. userId,
  69. gc.joinWarningCount,
  70. gc.joinWarningSeconds
  71. ))
  72. runSQLBatch(save)
  73. # -- Classes ----------------------------------------------------------------
  74. class JoinRaid:
  75. def __init__(self, joins):
  76. self.joins = joins[:]
  77. async def addJoin(self, member):
  78. for join in self.joins:
  79. if join.member.id == member.id:
  80. # TODO: Move to front of list
  81. return # Already added
  82. self.joins.append(JoinRecord(member))
  83. async def kickAll(self):
  84. for join in self.joins:
  85. if join.isKicked or join.isBanned:
  86. continue
  87. await join.member.kick()
  88. join.isKicked = True
  89. async def banAll(self):
  90. for join in self.joins:
  91. if join.isBanned:
  92. continue
  93. await join.member.ban()
  94. join.isBanned = True
  95. class JoinRecord:
  96. def __init__(self, member):
  97. self.member = member
  98. self.joinTime = member.joined_at or datetime.now()
  99. self.isKicked = False
  100. self.isBanned = False
  101. def ageSeconds(self, referenceTime=datetime.now()):
  102. a = referenceTime - self.joinTime
  103. return a.total_seconds()
  104. class GuildContext:
  105. def __init__(self, id):
  106. global config
  107. self.id = id
  108. self.guild = None
  109. self.warningChannelId = None
  110. self.warningChannel = None
  111. self.warningMentionId = None
  112. self.warningMention = None
  113. self.joins = []
  114. self.joinWarningCount = config['joinWarningCount']
  115. self.joinWarningSeconds = config['joinWarningSeconds']
  116. self.isJoinRaidInProgress = False
  117. self.lastWarningMessage = None
  118. async def handleJoin(self, member):
  119. print('{0.guild.name}: {0.name} joined'.format(member))
  120. self.joins.append(JoinRecord(member))
  121. await self.__checkJoins()
  122. async def handleReactionAdd(self, message, member, emoji):
  123. global config
  124. if member.bot:
  125. return
  126. print('User {0} added emoji {1}'.format(member, emoji))
  127. if not member.permissions_in(message.channel).ban_members:
  128. print('Reactor does not have ban permissions')
  129. return
  130. if self.lastWarningMessage == None or message.id != self.lastWarningMessage.id:
  131. print('Reacted to a non-warning message')
  132. return
  133. if emoji.name == config['kickEmoji']:
  134. print('Kicking these users:')
  135. for join in self.joins:
  136. await join.member.kick(reason='Kicked by rocketbot for join raiding')
  137. join.isKicked = True
  138. print(' ' + join.member.name)
  139. elif emoji.name == config['banEmoji']:
  140. print('Banning these users:')
  141. for join in self.joins:
  142. await join.member.ban(reason='Banned by rocketbot for join raiding', delete_message_days=0)
  143. join.isBanned = True
  144. print(' ' + join.member.name)
  145. else:
  146. print('Unhandled emoji. Doing nothing.')
  147. return
  148. await self.__updateLastJoinRaidMessage()
  149. async def handleSetWarningChannel(self, context):
  150. print('{0.guild.name}: Warning channel set to {0.channel.name}'.format(context))
  151. self.warningChannel = context.channel
  152. self.warningChannelId = context.channel.id
  153. saveGuildContext(self)
  154. async def handleSetWarningMention(self, context):
  155. if len(context.args) < 1:
  156. return
  157. roleArg = context.args[0]
  158. print('{0.guild.name}: {1}'.format(context, roleArg))
  159. async def handleSetRaidWarningRate(self, context, count, seconds):
  160. self.joinWarningCount = count
  161. self.joinWarningSeconds = seconds
  162. print('{0.name}: Set rate as {1} joins per {2} seconds'.format(self.guild, self.joinWarningCount, self.joinWarningSeconds))
  163. saveGuildContext(self)
  164. async def __checkJoins(self):
  165. now = datetime.now()
  166. recentJoinCount = self.__countRecentJoins(now)
  167. print('{0} join(s) in the past {1} seconds'.format(recentJoinCount, self.joinWarningSeconds))
  168. if recentJoinCount >= self.joinWarningCount:
  169. # In join raid
  170. if not self.isJoinRaidInProgress:
  171. # Raid just started
  172. self.isJoinRaidInProgress = True
  173. await self.__onJoinRaidBegin()
  174. else:
  175. # Continuing existing join raid
  176. await self.__onJoinRaidUpdated()
  177. else:
  178. # No join raid
  179. if self.isJoinRaidInProgress:
  180. # Join raid just ended
  181. self.isJoinRaidInProgress = False
  182. await self.__onJoinRaidEnd()
  183. self.__cullOldJoins(now)
  184. def __countRecentJoins(self, now=datetime.now()):
  185. recentJoinCount = 0
  186. for joinRecord in self.joins:
  187. if joinRecord.ageSeconds(now) <= self.joinWarningSeconds:
  188. recentJoinCount += 1
  189. return recentJoinCount
  190. def __cullOldJoins(self, now=datetime.now()):
  191. i = 0
  192. while i < len(self.joins):
  193. if self.joins[i].ageSeconds(now) > self.joinWarningSeconds:
  194. self.joins.pop(i)
  195. else:
  196. i += 1
  197. def __joinRaidMessage(self):
  198. global config
  199. # TODO: Mention mod role
  200. message = 'A join raid has been detected! It includes these users:'
  201. canKick = False
  202. canBan = False
  203. for join in self.joins:
  204. message += '\n• '
  205. if join.isBanned:
  206. message += '~~' + join.member.mention + '~~ - banned'
  207. elif join.isKicked:
  208. message += '~~' + join.member.mention + '~~ - kicked'
  209. canBan = True
  210. else:
  211. message += join.member.mention
  212. canKick = True
  213. canBan = True
  214. if canKick:
  215. message += '\nTo kick all these users, react with :' + config['kickEmojiName'] + ':'
  216. if canBan:
  217. message += '\nTo ban all these users, react with :' + config['banEmojiName'] + ':'
  218. return (message, canKick, canBan)
  219. async def __updateLastJoinRaidMessage(self):
  220. if self.lastWarningMessage == None:
  221. print('No previous warning message to update')
  222. return
  223. (message, canKick, canBan) = self.__joinRaidMessage()
  224. await self.lastWarningMessage.edit(content=message)
  225. async def __onJoinRaidBegin(self):
  226. global config
  227. print('A join raid has begun!')
  228. if self.warningChannel == None:
  229. print('No warning channel set')
  230. return
  231. (message, canKick, canBan) = self.__joinRaidMessage()
  232. self.lastWarningMessage = await self.warningChannel.send(message)
  233. if canKick:
  234. await self.lastWarningMessage.add_reaction(config['kickEmoji'])
  235. if canBan:
  236. await self.lastWarningMessage.add_reaction(config['banEmoji'])
  237. async def __onJoinRaidUpdated(self):
  238. print('Join raid still occurring')
  239. await self.__updateLastJoinRaidMessage()
  240. async def __onJoinRaidEnd(self):
  241. print('Join raid has ended')
  242. pass
  243. guildIdToGuildContext = {}
  244. def getOrCreateGuildContext(obj, save=True):
  245. gid = None
  246. guild = None
  247. if obj == None:
  248. return None
  249. if isinstance(obj, int):
  250. gid = obj
  251. elif isinstance(obj, GuildContext):
  252. return obj
  253. elif isinstance(obj, discord.Guild):
  254. gid = obj.id
  255. guild = obj
  256. elif isinstance(obj, (discord.Message, discord.Member, discord.TextChannel, discord.ext.commands.context.Context)):
  257. gid = obj.guild.id
  258. guild = obj.guild
  259. if gid == None:
  260. print('Unhandled datatype', type(obj))
  261. return None
  262. lookedUp = guildIdToGuildContext.get(gid)
  263. if lookedUp != None:
  264. return lookedUp
  265. g = GuildContext(gid)
  266. g.guild = guild or g.guild
  267. guildIdToGuildContext[gid] = g
  268. if save:
  269. saveGuildContext(g)
  270. return g
  271. loadGuildSettings()
  272. intents = discord.Intents.default()
  273. intents.members = True # To get join/leave events
  274. bot = commands.Bot(command_prefix='$', intents=intents)
  275. # -- Bot commands -----------------------------------------------------------
  276. @commands.command()
  277. async def hello(ctx):
  278. message = ctx.message
  279. print('Got message from {0.author.name} in {0.channel.id}'.format(message))
  280. m = await message.channel.send('Hello, {0.author.mention}!'.format(message))
  281. print('Replied "Hello!"')
  282. bot.add_command(hello)
  283. @commands.command()
  284. @commands.has_permissions(manage_messages=True)
  285. async def setraidwarningrate(ctx, count: int, seconds: int):
  286. g = getOrCreateGuildContext(ctx)
  287. if g == None:
  288. return
  289. await g.handleSetRaidWarningRate(ctx, count, seconds)
  290. bot.add_command(setraidwarningrate)
  291. @commands.command()
  292. @commands.has_permissions(manage_messages=True)
  293. async def setwarningchannel(ctx):
  294. g = getOrCreateGuildContext(ctx)
  295. if g == None:
  296. return
  297. await g.handleSetWarningChannel(ctx)
  298. bot.add_command(setwarningchannel)
  299. @commands.command()
  300. @commands.has_permissions(manage_messages=True)
  301. async def setwarningrole(ctx, role: str):
  302. g = getOrCreateGuildContext(ctx)
  303. if g == None:
  304. return
  305. await g.handleSetWarningMention(ctx)
  306. bot.add_command(setwarningrole)
  307. # -- Bot events -------------------------------------------------------------
  308. @bot.listen()
  309. async def on_connect():
  310. global bot
  311. global guildIdToGuildContext
  312. for guild in bot.guilds:
  313. g = guildIdToGuildContext.get(guild.id)
  314. if g == None:
  315. print('No record for', guild.id)
  316. continue
  317. g.guild = guild
  318. if g.warningChannelId != None:
  319. g.warningChannel = guild.get_channel(g.warningChannelId)
  320. print('Recovered warning channel', g.warningChannel)
  321. if g.warningMentionId != None:
  322. g.warningMention = guild.get_role(g.warningMentionId)
  323. print('Recovered warning mention', g.warningMention)
  324. @bot.listen()
  325. async def on_ready():
  326. pass
  327. @bot.listen()
  328. async def on_member_join(member):
  329. print('User {0.name} joined'.format(member))
  330. g = getOrCreateGuildContext(member)
  331. if g == None:
  332. print('No GuildContext for guild {0.guild.name}'.format(member))
  333. return
  334. await g.handleJoin(member)
  335. @bot.listen()
  336. async def on_member_remove(member):
  337. print('User {0.name} left'.format(member))
  338. @bot.listen()
  339. async def on_raw_reaction_add(payload: discord.RawReactionActionEvent):
  340. # payload.message_id: int
  341. # payload.user_id: int
  342. # payload.channel_id: int
  343. # payload.guild_id: int
  344. # payload.emoji: PartialEmoji
  345. # payload.member: Member
  346. # payload.event_type: str
  347. global bot
  348. guild = bot.get_guild(payload.guild_id)
  349. channel = guild.get_channel(payload.channel_id)
  350. message = await channel.fetch_message(payload.message_id)
  351. member = payload.member
  352. emoji = payload.emoji
  353. gc = getOrCreateGuildContext(guild)
  354. await gc.handleReactionAdd(message, member, emoji)
  355. print('Starting bot')
  356. bot.run(config['clientToken'])
  357. print('Bot done')