| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192 |
- """
- Handles storage of persisted and non-persisted data for the bot.
- """
- import json
- from datetime import datetime, timezone, timedelta
- from os.path import exists
- from typing import Any, Optional
- from discord import Guild
-
- from config import CONFIG
- from rocketbot.collections import AgeBoundDict
-
-
- class ConfigKey:
- """
- Common keys in persisted guild storage.
- """
- WARNING_CHANNEL_ID = 'warning_channel_id'
- WARNING_MENTION = 'warning_mention'
-
- class Storage:
- """
- Static class for managing persisted bot configuration and transient state
- on a per-guild basis.
- """
-
- # -- Transient state management -----------------------------------------
-
- __guild_id_to_state: dict[int, dict[str, Any]] = {}
-
- @classmethod
- def get_state(cls, guild: Guild) -> dict[str, Any]:
- """
- Returns transient state for the given guild. This state is not preserved
- if the bot is restarted.
- """
- state: dict[str, Any] = cls.__guild_id_to_state.get(guild.id)
- if state is None:
- state = {}
- cls.__guild_id_to_state[guild.id] = state
- return state
-
- @classmethod
- def get_state_value(cls, guild: Guild, key: str) -> Optional[Any]:
- """
- Returns a state value for the given guild and key, or `None` if not set.
- """
- return cls.get_state(guild).get(key)
-
- @classmethod
- def set_state_value(cls, guild: Guild, key: str, value: Optional[Any]) -> None:
- """
- Updates a transient value associated with the given guild and key name.
- A value of `None` removes any previous value for that key.
- """
- cls.set_state_values(guild, { key: value })
-
- @classmethod
- def set_state_values(cls, guild: Guild, values: Optional[dict[str, Optional[Any]]]) -> None:
- """
- Merges in a set of key-value pairs into the transient state for the
- given guild. Any pairs with a value of `None` will be removed from the
- transient state.
- """
- if values is None or len(values) == 0:
- return
- state: dict[str, Any] = cls.get_state(guild)
- for key, value in values.items():
- if value is None:
- del state[key]
- else:
- state[key] = value
- # XXX: Superstitious. Should update by ref already but saw weirdness once.
- cls.__guild_id_to_state[guild.id] = state
-
- # -- Persisted configuration management ---------------------------------
-
- # discord.Guild.id -> dict
- __guild_id_to_config: dict[int, dict[str, Any]] = {}
-
- @classmethod
- def get_config(cls, guild: Guild) -> dict[str, Any]:
- """
- Returns all persisted configuration for the given guild.
- """
- config: dict[str, Any] = cls.__guild_id_to_config.get(guild.id)
- if config is not None:
- # Already in memory
- return config
- # Load from disk if possible
- cls.__trace(f'No loaded config for guild {guild.id}. Attempting to ' +
- 'load from disk.')
- config = cls.__read_guild_config(guild)
- if config is None:
- config = {}
- cls.__guild_id_to_config[guild.id] = config
- return config
-
- @classmethod
- def get_config_value(cls, guild: Guild, key: str) -> Optional[Any]:
- """
- Returns a persisted guild config value stored under the given key.
- Returns `None` if not present.
- """
- return cls.get_config(guild).get(key)
-
- @classmethod
- def set_config_value(cls, guild: Guild, key: str, value: Optional[Any]) -> None:
- """
- Adds/updates the given key-value pair to the persisted config for the
- given Guild. If `value` is `None` the key will be removed from persisted
- config.
- """
- cls.set_config_values(guild, { key: value })
-
- @classmethod
- def set_config_values(cls, guild: Guild, values: Optional[dict[str, Optional[Any]]]) -> None:
- """
- Merges the given `values` dict with the saved config for the given guild
- and writes it to disk. `values` must be JSON-encodable or a `ValueError`
- will be raised. Keys with associated values of `None` will be removed
- from the persisted config.
- """
- if values is None or len(values) == 0:
- return
- config: dict[str, Any] = cls.get_config(guild)
- try:
- json.dumps(values)
- except Exception as e:
- raise ValueError(f'values not JSON encodable - {values}') from e
- for key, value in values.items():
- if value is None:
- del config[key]
- else:
- config[key] = value
- cls.__write_guild_config(guild, config)
-
- @classmethod
- def get_bot_messages(cls, guild: Guild) -> AgeBoundDict[int, Any, datetime, timedelta]:
- """Returns all the bot messages for a guild."""
- bm = cls.get_state_value(guild, 'bot_messages')
- if bm is None:
- far_future = datetime.now(timezone.utc) + timedelta(days=1000)
- bm = AgeBoundDict(timedelta(seconds=600),
- lambda k, v : v.message_sent_at() or far_future)
- Storage.set_state_value(guild, 'bot_messages', bm)
- return bm
-
- @classmethod
- def __write_guild_config(cls, guild: Guild, config: dict[str, Any]) -> None:
- """
- Saves config for a guild to a JSON file on disk.
- """
- path: str = cls.__guild_config_path(guild)
- cls.__trace(f'Saving config for guild {guild.id} to {path}')
- cls.__trace(f'config = {config}')
- config['_guild_name'] = guild.name # Just for making JSON files easier to identify
- with open(path, 'w', encoding='utf8') as file:
- # Pretty printing to make more legible for debugging
- # Sorting keys to help with diffs
- json.dump(config, file, indent='\t', sort_keys=True)
- cls.__trace('State saved')
-
- @classmethod
- def __read_guild_config(cls, guild: Guild) -> Optional[dict[str, Any]]:
- """
- Loads config for a guild from a JSON file on disk, or `None` if not
- found.
- """
- path: str = cls.__guild_config_path(guild)
- if not exists(path):
- cls.__trace(f'No config on disk for guild {guild.id}. Returning None.')
- return None
- cls.__trace(f'Loading config from disk for guild {guild.id}')
- with open(path, 'r', encoding='utf8') as file:
- config = json.load(file)
- cls.__trace('State loaded')
- return config
-
- @classmethod
- def __guild_config_path(cls, guild: Guild) -> str:
- """
- Returns the JSON file path where guild config should be written.
- """
- config_value: str = CONFIG['config_path']
- path: str = config_value if config_value.endswith('/') else f'{config_value}/'
- return f'{path}guild_{guild.id}.json'
-
- @classmethod
- def __trace(cls, message: Any) -> None:
- # print(f'{cls.__name__}: {str(message)}')
- pass
|