Experimental Discord bot written in Python
您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

storage.py 5.4KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. """
  2. Handles storage of persisted and non-persisted data for the bot.
  3. """
  4. import json
  5. from os.path import exists
  6. from typing import Any, Optional
  7. from discord import Guild
  8. from config import CONFIG
  9. class ConfigKey:
  10. """
  11. Common keys in persisted guild storage.
  12. """
  13. WARNING_CHANNEL_ID = 'warning_channel_id'
  14. WARNING_MENTION = 'warning_mention'
  15. class Storage:
  16. """
  17. Static class for managing persisted bot configuration and transient state
  18. on a per-guild basis.
  19. """
  20. # -- Transient state management -----------------------------------------
  21. __guild_id_to_state: dict[int, dict[str, Any]] = {}
  22. @classmethod
  23. def get_state(cls, guild: Guild) -> dict[str, Any]:
  24. """
  25. Returns transient state for the given guild. This state is not preserved
  26. if the bot is restarted.
  27. """
  28. state: dict[str, Any] = cls.__guild_id_to_state.get(guild.id)
  29. if state is None:
  30. state = {}
  31. cls.__guild_id_to_state[guild.id] = state
  32. return state
  33. @classmethod
  34. def get_state_value(cls, guild: Guild, key: str) -> Optional[Any]:
  35. """
  36. Returns a state value for the given guild and key, or `None` if not set.
  37. """
  38. return cls.get_state(guild).get(key)
  39. @classmethod
  40. def set_state_value(cls, guild: Guild, key: str, value: Optional[Any]) -> None:
  41. """
  42. Updates a transient value associated with the given guild and key name.
  43. A value of `None` removes any previous value for that key.
  44. """
  45. cls.set_state_values(guild, { key: value })
  46. @classmethod
  47. def set_state_values(cls, guild: Guild, values: Optional[dict[str, Optional[Any]]]) -> None:
  48. """
  49. Merges in a set of key-value pairs into the transient state for the
  50. given guild. Any pairs with a value of `None` will be removed from the
  51. transient state.
  52. """
  53. if values is None or len(values) == 0:
  54. return
  55. state: dict[str, Any] = cls.get_state(guild)
  56. for key, value in values.items():
  57. if value is None:
  58. del state[key]
  59. else:
  60. state[key] = value
  61. # XXX: Superstitious. Should update by ref already but saw weirdness once.
  62. cls.__guild_id_to_state[guild.id] = state
  63. # -- Persisted configuration management ---------------------------------
  64. # discord.Guild.id -> dict
  65. __guild_id_to_config: dict[int, dict[str, Any]] = {}
  66. @classmethod
  67. def get_config(cls, guild: Guild) -> dict[str, Any]:
  68. """
  69. Returns all persisted configuration for the given guild.
  70. """
  71. config: dict[str, Any] = cls.__guild_id_to_config.get(guild.id)
  72. if config is not None:
  73. # Already in memory
  74. return config
  75. # Load from disk if possible
  76. cls.__trace(f'No loaded config for guild {guild.id}. Attempting to ' +
  77. 'load from disk.')
  78. config = cls.__read_guild_config(guild)
  79. if config is None:
  80. return {}
  81. cls.__guild_id_to_config[guild.id] = config
  82. return config
  83. @classmethod
  84. def get_config_value(cls, guild: Guild, key: str) -> Optional[Any]:
  85. """
  86. Returns a persisted guild config value stored under the given key.
  87. Returns `None` if not present.
  88. """
  89. return cls.get_config(guild).get(key)
  90. @classmethod
  91. def set_config_value(cls, guild: Guild, key: str, value: Optional[Any]) -> None:
  92. """
  93. Adds/updates the given key-value pair to the persisted config for the
  94. given Guild. If `value` is `None` the key will be removed from persisted
  95. config.
  96. """
  97. cls.set_config_values(guild, { key: value })
  98. @classmethod
  99. def set_config_values(cls, guild: Guild, values: Optional[dict[str, Optional[Any]]]) -> None:
  100. """
  101. Merges the given `values` dict with the saved config for the given guild
  102. and writes it to disk. `values` must be JSON-encodable or a `ValueError`
  103. will be raised. Keys with associated values of `None` will be removed
  104. from the persisted config.
  105. """
  106. if values is None or len(values) == 0:
  107. return
  108. config: dict[str, Any] = cls.get_config(guild)
  109. try:
  110. json.dumps(values)
  111. except Exception as e:
  112. raise ValueError(f'values not JSON encodable - {values}') from e
  113. for key, value in values.items():
  114. if value is None:
  115. del config[key]
  116. else:
  117. config[key] = value
  118. cls.__write_guild_config(guild, config)
  119. @classmethod
  120. def __write_guild_config(cls, guild: Guild, config: dict[str, Any]) -> None:
  121. """
  122. Saves config for a guild to a JSON file on disk.
  123. """
  124. path: str = cls.__guild_config_path(guild)
  125. cls.__trace(f'Saving config for guild {guild.id} to {path}')
  126. cls.__trace(f'config = {config}')
  127. with open(path, 'w', encoding='utf8') as file:
  128. # Pretty printing to make more legible for debugging
  129. # Sorting keys to help with diffs
  130. json.dump(config, file, indent='\t', sort_keys=True)
  131. cls.__trace('State saved')
  132. @classmethod
  133. def __read_guild_config(cls, guild: Guild) -> Optional[dict[str, Any]]:
  134. """
  135. Loads config for a guild from a JSON file on disk, or `None` if not
  136. found.
  137. """
  138. path: str = cls.__guild_config_path(guild)
  139. if not exists(path):
  140. cls.__trace(f'No config on disk for guild {guild.id}. Returning None.')
  141. return None
  142. cls.__trace(f'Loading config from disk for guild {guild.id}')
  143. with open(path, 'r', encoding='utf8') as file:
  144. config = json.load(file)
  145. cls.__trace('State loaded')
  146. return config
  147. @classmethod
  148. def __guild_config_path(cls, guild: Guild) -> str:
  149. """
  150. Returns the JSON file path where guild config should be written.
  151. """
  152. config_value: str = CONFIG['config_path']
  153. path: str = config_value if config_value.endswith('/') else f'{config_value}/'
  154. return f'{path}guild_{guild.id}.json'
  155. @classmethod
  156. def __trace(cls, message: Any) -> None:
  157. # print(f'{cls.__name__}: {str(message)}')
  158. pass