| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221 |
- """
- General utility functions.
- """
- import re
- import sys
- import traceback
- from datetime import datetime, timedelta
- from typing import Any, Optional, Union
-
- import discord
- from discord import Guild, Interaction, Permissions
- from discord.app_commands import Transformer
- from discord.ext.commands import BadArgument, Cog
-
-
- def dump_stacktrace(e: BaseException) -> None:
- print(e, file=sys.stderr)
- traceback.print_exception(type(e), e, e.__traceback__)
-
- def timedelta_from_str(s: str) -> timedelta:
- """
- Parses a timespan.
-
- Format examples:
- "30m"
- "10s"
- "90d"
- "1h30m"
- "73d18h22m52s"
-
- Parameters
- ----------
- s : str
- string to parse
-
- Returns
- -------
- timedelta
-
- Raises
- ------
- ValueError
- if parsing fails
- """
- p: re.Pattern = re.compile('^(?:[0-9]+[a-zA-Z])+$')
- if p.match(s) is None:
- raise ValueError(f'Illegal timespan value "{s}". Examples: 30s, 5m, 1h30m, 30d')
- p = re.compile('([0-9]+)([dhms])')
- days: int = 0
- hours: int = 0
- minutes: int = 0
- seconds: int = 0
- for m in p.finditer(s):
- scalar = int(m.group(1))
- unit = m.group(2).lower()
- if unit == 'd':
- days = scalar
- elif unit == 'h':
- hours = scalar
- elif unit == 'm':
- minutes = scalar
- elif unit == 's':
- seconds = scalar
- else:
- raise ValueError(f'Invalid unit "{unit}". Valid units: "s"=seconds, "m"=minutes, "h"=hours, "d"=days')
- return timedelta(days=days, hours=hours, minutes=minutes, seconds=seconds)
-
- def str_from_timedelta(td: timedelta) -> str:
- """
- Encodes a timedelta as a str. E.g. "3d2h"
- """
- d: int = td.days
- h: int = td.seconds // 3600
- m: int = (td.seconds // 60) % 60
- s: int = td.seconds % 60
- components: list[str] = []
- if d != 0:
- components.append(f'{d}d')
- if h != 0:
- components.append(f'{h}h')
- if m != 0:
- components.append(f'{m}m')
- if s != 0 or len(components) == 0:
- components.append(f'{s}s')
- return ''.join(components)
-
- def describe_timedelta(td: timedelta, max_components: int = 2) -> str:
- """
- Formats a human-readable description of a time span. E.g. "3 days 2 hours".
- """
- d: int = td.days
- h: int = td.seconds // 3600
- m: int = (td.seconds // 60) % 60
- s: int = td.seconds % 60
- components: list[str] = []
- if d != 0:
- components.append('1 day' if d == 1 else f'{d} days')
- if h != 0:
- components.append('1 hour' if h == 1 else f'{h} hours')
- if m != 0:
- components.append('1 minute' if m == 1 else f'{m} minutes')
- if s != 0 or len(components) == 0:
- components.append('1 second' if s == 1 else f'{s} seconds')
- if len(components) > max_components:
- components = components[0:max_components]
- return ' '.join(components)
-
- def _old_first_command_group(cog: Cog) -> Optional[discord.ext.commands.Group]:
- """Returns the first command Group found in a cog."""
- for member_name in dir(cog):
- member = getattr(cog, member_name)
- if isinstance(member, discord.ext.commands.Group):
- return member
- return None
-
- def first_command_group(cog: Cog) -> Optional[discord.app_commands.Group]:
- """Returns the first slash command Group found in a cog."""
- for member_name in dir(cog):
- member = getattr(cog, member_name)
- if isinstance(member, discord.app_commands.Group):
- return member
- return None
-
- def bot_log(guild: Optional[Guild], cog_class: Optional[type], message: Any) -> None:
- """Logs a message to stdout with time, cog, and guild info."""
- now: datetime = datetime.now() # local
- s = f'[{now.strftime("%Y-%m-%dT%H:%M:%S")}|'
- s += f'{cog_class.__name__}|' if cog_class else '-|'
- s += f'{guild.name}] ' if guild else '-] '
- s += str(message)
- print(s)
-
- __QUOTE_CHARS: str = '\'"'
- __ID_REGEX: re.Pattern = re.compile('^[0-9]{17,20}$')
- __MENTION_REGEX: re.Pattern = re.compile('^<@[!&]([0-9]{17,20})>$')
- __USER_MENTION_REGEX: re.Pattern = re.compile('^<@!([0-9]{17,20})>$')
- __ROLE_MENTION_REGEX: re.Pattern = re.compile('^<@&([0-9]{17,20})>$')
-
- def is_user_id(val: str) -> bool:
- """Tests if a string is in user/role ID format."""
- return __ID_REGEX.match(val) is not None
-
- def is_mention(val: str) -> bool:
- """Tests if a string is a user or role mention."""
- return __MENTION_REGEX.match(val) is not None
-
- def is_role_mention(val: str) -> bool:
- """Tests if a string is a role mention."""
- return __ROLE_MENTION_REGEX.match(val) is not None
-
- def is_user_mention(val: str) -> bool:
- """Tests if a string is a user mention."""
- return __USER_MENTION_REGEX.match(val) is not None
-
- def user_id_from_mention(mention: str) -> str:
- """Extracts the user ID from a mention. Raises a ValueError if malformed."""
- m = __USER_MENTION_REGEX.match(mention)
- if m:
- return m.group(1)
- raise ValueError(f'"{mention}" is not an @ user mention')
-
- def mention_from_user_id(user_id: Union[str, int]) -> str:
- """Returns a Markdown user mention from a user id."""
- return f'<@!{user_id}>'
-
- def mention_from_role_id(role_id: Union[str, int]) -> str:
- """Returns a Markdown role mention from a role id."""
- return f'<@&{role_id}>'
-
- def str_from_quoted_str(val: str) -> str:
- """Removes the leading and trailing quotes from a string."""
- if len(val) < 2 or val[0:1] not in __QUOTE_CHARS or val[-1:] not in __QUOTE_CHARS:
- raise ValueError(f'Not a quoted string: {val}')
- return val[1:-1]
-
- def blockquote_markdown(markdown: str) -> str:
- """Encloses some Markdown in a blockquote."""
- return '> ' + (markdown.replace('\n', '\n> '))
-
- def indent_markdown(markdown: str) -> str:
- """Indents a block of Markdown by one level."""
- return ' ' + (markdown.replace('\n', '\n '))
-
- def suppress_markdown_url_previews(markdown: str) -> str:
- """Finds URLs in markdown and encloses them in <...> to suppress the preview."""
- return re.sub(r'(?<!<)(https?://\S+)(?!>)', '<\\1>', markdown)
-
- def format_bytes(size: int) -> str:
- """Formats s size in bytes to a human readable description (e.g. "3.2 KiB")"""
- if size < 0:
- size = 0
- kib = 1024
- mib = kib * kib
- gib = mib * kib
- if size < kib:
- return f"{size:,} bytes"
- if size < 10 * kib:
- return f"{size/kib:,.1f} KiB"
- if size < mib:
- return f"{size/kib:,.0f} KiB"
- if size < 10 * mib:
- return f"{size/mib:,.1f} MiB"
- if size < gib:
- return f"{size/mib:,.0f} MiB"
- if size < 10 * gib:
- return f"{size/gib:,.1f} GiB"
- return f"{size/gib:,.0f} GiB"
-
- MOD_PERMISSIONS: Permissions = Permissions(Permissions.manage_messages.flag)
-
- class TimeDeltaTransformer(Transformer):
- async def transform(self, interaction: Interaction, value: Any) -> timedelta:
- try:
- return timedelta_from_str(str(value))
- except ValueError as e:
- print("Invalid time delta:", e)
- raise BadArgument(str(e))
-
- @property
- def _error_display_name(self) -> str:
- return 'timedelta'
|