Experimental Discord bot written in Python
Du kan inte välja fler än 25 ämnen Ämnen måste starta med en bokstav eller siffra, kan innehålla bindestreck ('-') och vara max 35 tecken långa.

joinraidcog.py 19KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523
  1. from discord import Guild, Intents, Member, Message, PartialEmoji, RawReactionActionEvent
  2. from discord.ext import commands
  3. from storage import Storage
  4. from cogs.basecog import BaseCog
  5. from config import CONFIG
  6. from datetime import datetime
  7. class JoinRecord:
  8. """
  9. Data object containing details about a single guild join event.
  10. """
  11. def __init__(self, member: Member):
  12. self.member = member
  13. self.join_time = member.joined_at or datetime.now()
  14. # These flags only track whether this bot has kicked/banned
  15. self.is_kicked = False
  16. self.is_banned = False
  17. def age_seconds(self, now: datetime) -> float:
  18. """
  19. Returns the age of this join in seconds from the given "now" time.
  20. """
  21. a = now - self.join_time
  22. return float(a.total_seconds())
  23. class RaidPhase:
  24. """
  25. Enum of phases in a JoinRaidRecord. Phases progress monotonically.
  26. """
  27. NONE = 0
  28. JUST_STARTED = 1
  29. CONTINUING = 2
  30. ENDED = 3
  31. class JoinRaidRecord:
  32. """
  33. Tracks recent joins to a guild to detect join raids, where a large number
  34. of automated users all join at the same time. Manages list of joins to not
  35. grow unbounded.
  36. """
  37. def __init__(self):
  38. self.joins = []
  39. self.phase = RaidPhase.NONE
  40. # datetime when the raid started, or None.
  41. self.raid_start_time = None
  42. # Message posted to Discord to warn of the raid. Convenience property
  43. # managed by caller. Ignored by this class.
  44. self.warning_message = None
  45. def handle_join(self,
  46. member: Member,
  47. now: datetime,
  48. max_join_count: int,
  49. max_age_seconds: float) -> None:
  50. """
  51. Processes a new member join to a guild and detects join raids. Updates
  52. self.phase and self.raid_start_time properties.
  53. """
  54. # Check for existing record for this user
  55. print(f'handle_join({member.name}) start')
  56. join: JoinRecord = None
  57. i: int = 0
  58. while i < len(self.joins):
  59. elem = self.joins[i]
  60. if elem.member.id == member.id:
  61. print(f'Member {member.name} already in join list at index {i}. Removing.')
  62. join = self.joins.pop(i)
  63. join.join_time = now
  64. break
  65. i += 1
  66. # Add new record to end
  67. self.joins.append(join or JoinRecord(member))
  68. # Check raid status and do upkeep
  69. self.__process_joins(now, max_age_seconds, max_join_count)
  70. print(f'handle_join({member.name}) end')
  71. def __process_joins(self,
  72. now: datetime,
  73. max_age_seconds: float,
  74. max_join_count: int) -> None:
  75. """
  76. Processes self.joins after each addition, detects raids, updates
  77. self.phase, and throws out unneeded records.
  78. """
  79. print('__process_joins {')
  80. i: int = 0
  81. recent_count: int = 0
  82. should_cull: bool = self.phase == RaidPhase.NONE
  83. while i < len(self.joins):
  84. join: JoinRecord = self.joins[i]
  85. age: float = join.age_seconds(now)
  86. is_old: bool = age > max_age_seconds
  87. if not is_old:
  88. recent_count += 1
  89. print(f'- {i}. {join.member.name} is {age}s old - recent_count={recent_count}')
  90. if is_old and should_cull:
  91. self.joins.pop(i)
  92. print(f'- {i}. {join.member.name} is {age}s old - too old, removing')
  93. else:
  94. print(f'- {i}. {join.member.name} is {age}s old - moving on to next')
  95. i += 1
  96. is_raid = recent_count > max_join_count
  97. print(f'- is_raid {is_raid}')
  98. if is_raid:
  99. if self.phase == RaidPhase.NONE:
  100. self.phase = RaidPhase.JUST_STARTED
  101. self.raid_start_time = now
  102. print('- Phase moved to JUST_STARTED. Recording raid start time.')
  103. elif self.phase == RaidPhase.JUST_STARTED:
  104. self.phase = RaidPhase.CONTINUING
  105. print('- Phase moved to CONTINUING.')
  106. elif self.phase == self.phase in (RaidPhase.JUST_STARTED, RaidPhase.CONTINUING):
  107. self.phase = RaidPhase.ENDED
  108. print('- Phase moved to ENDED.')
  109. # Undo join add if the raid is over
  110. if self.phase == RaidPhase.ENDED and len(self.joins) > 0:
  111. last = self.joins.pop(-1)
  112. print(f'- Popping last join for {last.member.name}')
  113. print('} __process_joins')
  114. async def kick_all(self,
  115. reason: str = "Part of join raid") -> list[Member]:
  116. """
  117. Kicks all users in this join raid. Skips users who have already been
  118. flagged as having been kicked or banned. Returns a List of Members
  119. who were newly kicked.
  120. """
  121. kicks = []
  122. for join in self.joins:
  123. if join.is_kicked or join.is_banned:
  124. continue
  125. await join.member.kick(reason=reason)
  126. join.is_kicked = True
  127. kicks.append(join.member)
  128. self.phase = RaidPhase.ENDED
  129. return kicks
  130. async def ban_all(self,
  131. reason: str = "Part of join raid",
  132. delete_message_days: int = 0) -> list[Member]:
  133. """
  134. Bans all users in this join raid. Skips users who have already been
  135. flagged as having been banned. Users who were previously kicked can
  136. still be banned. Returns a List of Members who were newly banned.
  137. """
  138. bans = []
  139. for join in self.joins:
  140. if join.is_banned:
  141. continue
  142. await join.member.ban(
  143. reason=reason,
  144. delete_message_days=delete_message_days)
  145. join.is_banned = True
  146. bans.append(join.member)
  147. self.phase = RaidPhase.ENDED
  148. return bans
  149. class GuildContext:
  150. """
  151. Logic and state for a single guild serviced by the bot.
  152. """
  153. def __init__(self, guild_id: int):
  154. self.guild_id = guild_id
  155. self.join_warning_count = CONFIG['joinWarningCount']
  156. self.join_warning_seconds = CONFIG['joinWarningSeconds']
  157. # Non-persisted runtime state
  158. self.current_raid = JoinRaidRecord()
  159. self.all_raids = [ self.current_raid ] # periodically culled of old ones
  160. # Events
  161. async def handle_join(self, member: Member) -> None:
  162. """
  163. Event handler for all joins to this guild.
  164. """
  165. now = member.joined_at
  166. raid = self.current_raid
  167. raid.handle_join(
  168. member,
  169. now=now,
  170. max_age_seconds = self.join_warning_seconds,
  171. max_join_count = self.join_warning_count)
  172. self.__trace(f'raid phase: {raid.phase}')
  173. if raid.phase == RaidPhase.JUST_STARTED:
  174. await self.__on_join_raid_begin(raid)
  175. elif raid.phase == RaidPhase.CONTINUING:
  176. await self.__on_join_raid_updated(raid)
  177. elif raid.phase == RaidPhase.ENDED:
  178. self.__start_new_raid(member)
  179. await self.__on_join_raid_end(raid)
  180. self.__cull_old_raids(now)
  181. def reset_raid(self, now: datetime):
  182. """
  183. Retires self.current_raid and creates a new empty one.
  184. """
  185. self.current_raid = JoinRaidRecord()
  186. self.all_raids.append(self.current_raid)
  187. self.__cull_old_raids(now)
  188. def find_raid_for_message_id(self, message_id: int) -> JoinRaidRecord:
  189. """
  190. Retrieves a JoinRaidRecord instance for the given raid warning message.
  191. Returns None if not found.
  192. """
  193. for raid in self.all_raids:
  194. if raid.warning_message.id == message_id:
  195. return raid
  196. return None
  197. def __cull_old_raids(self, now: datetime):
  198. """
  199. Gets rid of old JoinRaidRecord records from self.all_raids that are too
  200. old to still be useful.
  201. """
  202. i: int = 0
  203. while i < len(self.all_raids):
  204. raid = self.all_raids[i]
  205. if raid == self.current_raid:
  206. i += 1
  207. continue
  208. age_seconds = float((raid.raid_start_time - now).total_seconds())
  209. if age_seconds > 86400.0:
  210. self.__trace('Culling old raid')
  211. self.all_raids.pop(i)
  212. else:
  213. i += 1
  214. def __trace(self, message):
  215. """
  216. Debugging trace.
  217. """
  218. print(f'{self.guild_id}: {message}')
  219. class JoinRaidCog(BaseCog):
  220. """
  221. Cog for monitoring member joins and detecting potential bot raids.
  222. """
  223. MIN_JOIN_COUNT = 2
  224. STATE_KEY_RAID_COUNT = 'joinraid_count'
  225. STATE_KEY_RAID_SECONDS = 'joinraid_seconds'
  226. STATE_KEY_ENABLED = 'joinraid_enabled'
  227. def __init__(self, bot):
  228. self.bot = bot
  229. self.guild_id_to_context = {} # Guild.id -> GuildContext
  230. # -- Config -------------------------------------------------------------
  231. def __get_raid_rate(self, guild: Guild) -> tuple:
  232. """
  233. Returns the join rate configured for this guild.
  234. """
  235. count: int = Storage.get_state_value(guild, self.STATE_KEY_RAID_COUNT) \
  236. or CONFIG['joinWarningCount']
  237. seconds: float = Storage.get_state_value(guild, self.STATE_KEY_RAID_SECONDS) \
  238. or CONFIG['joinWarningSeconds']
  239. return (count, seconds)
  240. def __is_enabled(self, guild: Guild) -> bool:
  241. """
  242. Returns whether join raid detection is enabled in this guild.
  243. """
  244. return Storage.get_state_value(guild, self.STATE_KEY_ENABLED) or False
  245. # -- Commands -----------------------------------------------------------
  246. @commands.group(
  247. brief='Manages join raid detection and handling',
  248. )
  249. @commands.has_permissions(ban_members=True)
  250. @commands.guild_only()
  251. async def joinraid(self, context: commands.Context):
  252. 'Command group'
  253. if context.invoked_subcommand is None:
  254. await context.send_help()
  255. @joinraid.command(
  256. name='enable',
  257. brief='Enables join raid detection',
  258. description='Join raid detection is off by default.',
  259. )
  260. async def joinraid_enable(self, context: commands.Context):
  261. 'Command handler'
  262. guild = context.guild
  263. Storage.set_state_value(guild, self.STATE_KEY_ENABLED, True)
  264. # TODO: Startup tracking if necessary
  265. await context.message.reply(
  266. '✅ ' + self.__describe_raid_settings(guild, force_enabled_status=True),
  267. mention_author=False)
  268. @joinraid.command(
  269. name='disable',
  270. brief='Disables join raid detection',
  271. description='Join raid detection is off by default.',
  272. )
  273. async def joinraid_disable(self, context: commands.Context):
  274. 'Command handler'
  275. guild = context.guild
  276. Storage.set_state_value(guild, self.STATE_KEY_ENABLED, False)
  277. # TODO: Tear down tracking if necessary
  278. await context.message.reply(
  279. '✅ ' + self.__describe_raid_settings(guild, force_enabled_status=True),
  280. mention_author=False)
  281. @joinraid.command(
  282. name='setrate',
  283. brief='Sets the rate of joins which triggers a warning to mods',
  284. description='Each time a member joins, the join records from the ' +
  285. 'previous _x_ seconds are counted up, where _x_ is the number of ' +
  286. 'seconds configured by this command. If that count meets or ' +
  287. 'exceeds the maximum join count configured by this command then ' +
  288. 'a raid is detected and a warning is issued to the mods.',
  289. usage='<join_count:int> <seconds:float>',
  290. )
  291. async def joinraid_setrate(self, context: commands.Context,
  292. join_count: int,
  293. seconds: float):
  294. 'Command handler'
  295. guild = context.guild
  296. if join_count < self.MIN_JOIN_COUNT:
  297. await context.message.reply(
  298. f'⚠️ `join_count` must be >= {self.MIN_JOIN_COUNT}',
  299. mention_author=False)
  300. return
  301. if seconds <= 0:
  302. await context.message.reply(
  303. f'⚠️ `seconds` must be > 0',
  304. mention_author=False)
  305. return
  306. Storage.set_state_values(guild, {
  307. self.STATE_KEY_RAID_COUNT: join_count,
  308. self.STATE_KEY_RAID_SECONDS: seconds,
  309. })
  310. await context.message.reply(
  311. '✅ ' + self.__describe_raid_settings(guild, force_rate_status=True),
  312. mention_author=False)
  313. @joinraid.command(
  314. name='getrate',
  315. brief='Shows the rate of joins which triggers a warning to mods',
  316. )
  317. async def joinraid_getrate(self, context: commands.Context):
  318. 'Command handler'
  319. await context.message.reply(
  320. 'ℹ️ ' + self.__describe_raid_settings(context.guild, force_rate_status=True),
  321. mention_author=False)
  322. # -- Listeners ----------------------------------------------------------
  323. @commands.Cog.listener()
  324. async def on_raw_reaction_add(self, payload: RawReactionActionEvent):
  325. 'Event handler'
  326. if payload.user_id == self.bot.user.id:
  327. # Ignore bot's own reactions
  328. return
  329. member: Member = payload.member
  330. if member is None:
  331. return
  332. guild: Guild = self.bot.get_guild(payload.guild_id)
  333. if guild is None:
  334. # Possibly a DM
  335. return
  336. channel: GuildChannel = guild.get_channel(payload.channel_id)
  337. if channel is None:
  338. # Possibly a DM
  339. return
  340. message: Message = await channel.fetch_message(payload.message_id)
  341. if message is None:
  342. # Message deleted?
  343. return
  344. if message.author.id != self.bot.user.id:
  345. # Bot didn't author this
  346. return
  347. if not member.permissions_in(channel).ban_members:
  348. # Not a mod
  349. # TODO: Remove reaction?
  350. return
  351. gc: GuildContext = self.__get_guild_context(guild)
  352. raid: JoinRaidRecord = gc.find_raid_for_message_id(payload.message_id)
  353. if raid is None:
  354. # Either not a warning message or one we stopped tracking
  355. return
  356. emoji: PartialEmoji = payload.emoji
  357. if emoji.name == CONFIG['kickEmoji']:
  358. await raid.kick_all()
  359. gc.reset_raid(message.created_at)
  360. await self.__update_raid_warning(guild, raid)
  361. elif emoji.name == CONFIG['banEmoji']:
  362. await raid.ban_all()
  363. gc.reset_raid(message.created_at)
  364. await self.__update_raid_warning(guild, raid)
  365. @commands.Cog.listener()
  366. async def on_member_join(self, member: Member) -> None:
  367. 'Event handler'
  368. guild: Guild = member.guild
  369. if not self.__is_enabled(guild):
  370. return
  371. (count, seconds) = self.__get_raid_rate(guild)
  372. now = member.joined_at
  373. gc: GuildContext = self.__get_guild_context(guild)
  374. raid: JoinRaidRecord = gc.current_raid
  375. raid.handle_join(member, now, count, seconds)
  376. if raid.phase == RaidPhase.JUST_STARTED:
  377. await self.__post_raid_warning(guild, raid)
  378. elif raid.phase == RaidPhase.CONTINUING:
  379. await self.__update_raid_warning(guild, raid)
  380. elif raid.phase == RaidPhase.ENDED:
  381. # First join that occurred too late to be part of last raid. Join
  382. # not added. Start a new raid record and add it there.
  383. gc.reset_raid(now)
  384. gc.current_raid.handle_join(member, now, count, seconds)
  385. # -- Misc ---------------------------------------------------------------
  386. def __describe_raid_settings(self,
  387. guild: Guild,
  388. force_enabled_status=False,
  389. force_rate_status=False) -> str:
  390. """
  391. Creates a Discord message describing the current join raid settings.
  392. """
  393. enabled = self.__is_enabled(guild)
  394. (count, seconds) = self.__get_raid_rate(guild)
  395. sentences = []
  396. if enabled or force_rate_status:
  397. sentences.append(f'Join raids will be detected at {count} or more joins per {seconds} seconds.')
  398. if enabled and force_enabled_status:
  399. sentences.append('Raid detection enabled.')
  400. elif not enabled:
  401. sentences.append('Raid detection disabled.')
  402. tips = []
  403. if enabled or force_rate_status:
  404. tips.append('• Use `setrate` subcommand to change detection threshold')
  405. if enabled:
  406. tips.append('• Use `disable` subcommand to disable detection.')
  407. else:
  408. tips.append('• Use `enable` subcommand to enable detection.')
  409. message = ''
  410. message += ' '.join(sentences)
  411. if len(tips) > 0:
  412. message += '\n\n' + ('\n'.join(tips))
  413. return message
  414. def __get_guild_context(self, guild: Guild) -> GuildContext:
  415. """
  416. Looks up the GuildContext for the given Guild or creates a new one if
  417. one does not yet exist.
  418. """
  419. gc: GuildContext = self.guild_id_to_context.get(guild.id)
  420. if gc is not None:
  421. return gc
  422. gc = GuildContext(guild.id)
  423. self.guild_id_to_context[guild.id] = gc
  424. return gc
  425. async def __post_raid_warning(self, guild: Guild, raid: JoinRaidRecord) -> None:
  426. """
  427. Posts a warning message about the given raid.
  428. """
  429. (message, can_kick, can_ban) = self.__describe_raid(raid)
  430. raid.warning_message = await self.warn(guild, message)
  431. if can_kick:
  432. await raid.warning_message.add_reaction(CONFIG['kickEmoji'])
  433. if can_ban:
  434. await raid.warning_message.add_reaction(CONFIG['banEmoji'])
  435. async def __update_raid_warning(self, guild: Guild, raid: JoinRaidRecord) -> None:
  436. """
  437. Updates the existing warning message for a raid.
  438. """
  439. if raid.warning_message is None:
  440. return
  441. (message, can_kick, can_ban) = self.__describe_raid(raid)
  442. await self.update_warn(raid.warning_message, message)
  443. if not can_kick:
  444. await raid.warning_message.clear_reaction(CONFIG['kickEmoji'])
  445. if not can_ban:
  446. await raid.warning_message.clear_reaction(CONFIG['banEmoji'])
  447. def __describe_raid(self, raid: JoinRaidRecord) -> tuple:
  448. """
  449. Creates a Discord warning message with details about the given raid.
  450. Returns a tuple containing the message text, a flag if any users can
  451. still be kicked, and a flag if anyone can still be banned.
  452. """
  453. message = '🚨 **JOIN RAID DETECTED** 🚨'
  454. message += '\nThe following members joined in close succession:\n'
  455. any_kickable = False
  456. any_bannable = False
  457. for join in raid.joins:
  458. message += '\n• '
  459. if join.is_banned:
  460. message += '~~' + join.member.mention + '~~ - banned'
  461. elif join.is_kicked:
  462. message += '~~' + join.member.mention + '~~ - kicked'
  463. any_bannable = True
  464. else:
  465. message += join.member.mention
  466. any_bannable = True
  467. any_kickable = True
  468. message += '\n_(list updates automatically)_'
  469. message += '\n'
  470. if any_kickable:
  471. message += f'\nReact to this message with {CONFIG["kickEmoji"]} to kick all these users.'
  472. else:
  473. message += '\nNo users left to kick.'
  474. if any_bannable:
  475. message += f'\nReact to this message with {CONFIG["banEmoji"]} to ban all these users.'
  476. else:
  477. message += '\nNo users left to ban.'
  478. return (message, any_kickable, any_bannable)