Experimental Discord bot written in Python
Vous ne pouvez pas sélectionner plus de 25 sujets Les noms de sujets doivent commencer par une lettre ou un nombre, peuvent contenir des tirets ('-') et peuvent comporter jusqu'à 35 caractères.

storage.py 3.2KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. import json
  2. from os.path import exists
  3. from discord import Guild
  4. from config import CONFIG
  5. class StateKey:
  6. WARNING_CHANNEL_ID = 'warning_channel_id'
  7. WARNING_MENTION = 'warning_mention'
  8. class Storage:
  9. """
  10. Static class for managing persisted bot state.
  11. """
  12. # discord.Guild.id -> dict
  13. __guild_id_to_state = {}
  14. @classmethod
  15. def get_state(cls, guild: Guild) -> dict:
  16. """
  17. Returns all persisted state for the given guild.
  18. """
  19. state: dict = cls.__guild_id_to_state.get(guild.id)
  20. if state is not None:
  21. # Already in memory
  22. return state
  23. # Load from disk if possible
  24. cls.__trace(f'No loaded state for guild {guild.id}. Attempting to ' +
  25. 'load from disk.')
  26. state = cls.__load_guild_state(guild)
  27. if state is None:
  28. return {}
  29. cls.__guild_id_to_state[guild.id] = state
  30. return state
  31. @classmethod
  32. def get_state_value(cls, guild: Guild, key: str):
  33. """
  34. Returns a persisted state value stored under the given key. Returns
  35. None if not present.
  36. """
  37. return cls.get_state(guild).get(key)
  38. @classmethod
  39. def set_state_value(cls, guild: Guild, key: str, value) -> None:
  40. """
  41. Adds the given key-value pair to the persisted state for the given
  42. Guild. If `value` is `None` the key will be removed from persisted
  43. state.
  44. """
  45. cls.set_state_values(guild, { key: value })
  46. @classmethod
  47. def set_state_values(cls, guild: Guild, vars: dict) -> None:
  48. """
  49. Merges the given `vars` dict with the saved state for the given guild
  50. and saves it to disk. `vars` must be JSON-encodable or a ValueError will
  51. be raised. Keys with associated values of `None` will be removed from the
  52. state.
  53. """
  54. if vars is None or len(vars) == 0:
  55. return
  56. state: dict = cls.get_state(guild)
  57. try:
  58. json.dumps(vars)
  59. except:
  60. raise ValueError(f'vars not JSON encodable - {vars}')
  61. for key, value in vars.items():
  62. if value is None:
  63. del state[key]
  64. else:
  65. state[key] = value
  66. cls.__save_guild_state(guild, state)
  67. @classmethod
  68. def __save_guild_state(cls, guild: Guild, state: dict) -> None:
  69. """
  70. Saves state for a guild to a JSON file on disk.
  71. """
  72. path: str = cls.__guild_path(guild)
  73. cls.__trace(f'Saving state for guild {guild.id} to {path}')
  74. cls.__trace(f'state = {state}')
  75. with open(path, 'w') as file:
  76. # Pretty printing to make more legible for debugging
  77. # Sorting keys to help with diffs
  78. json.dump(state, file, indent='\t', sort_keys=True)
  79. cls.__trace('State saved')
  80. @classmethod
  81. def __load_guild_state(cls, guild: Guild) -> dict:
  82. """
  83. Loads state for a guild from a JSON file on disk, or None if not found.
  84. """
  85. path: str = cls.__guild_path(guild)
  86. if not exists(path):
  87. cls.__trace(f'No state on disk for guild {guild.id}. Returning None.')
  88. return None
  89. cls.__trace(f'Loading state from disk for guild {guild.id}')
  90. with open(path, 'r') as file:
  91. state = json.load(file)
  92. cls.__trace('State loaded')
  93. return state
  94. @classmethod
  95. def __guild_path(cls, guild: Guild) -> str:
  96. """
  97. Returns the JSON file path where guild state should be written.
  98. """
  99. config_value: str = CONFIG['statePath']
  100. path: str = config_value if config_value.endswith('/') else f'{config_value}/'
  101. return f'{path}guild_{guild.id}.json'
  102. @classmethod
  103. def __trace(cls, message: str) -> None:
  104. print(f'{Storage.__name__}: {message}')