Adding all files

This commit is contained in:
2026-02-03 20:32:43 +02:00
parent 2588d10ba0
commit 77b70b600f
1457 changed files with 184865 additions and 0 deletions

View File

@@ -0,0 +1,26 @@
# SPDX-License-Identifier: MIT
"""disnake.ext.commands
~~~~~~~~~~~~~~~~~~~~~
An extension module to facilitate creation of bot commands.
:copyright: (c) 2015-2021 Rapptz, 2021-present Disnake Development
:license: MIT, see LICENSE for more details.
"""
from .base_core import *
from .bot import *
from .cog import *
from .context import *
from .converter import *
from .cooldowns import *
from .core import *
from .ctx_menus_core import *
from .custom_warnings import *
from .errors import *
from .flag_converter import *
from .flags import *
from .help import *
from .params import *
from .slash_core import *

View File

@@ -0,0 +1,37 @@
# SPDX-License-Identifier: MIT
from typing import TYPE_CHECKING, Any, Callable, Coroutine, TypeVar, Union
if TYPE_CHECKING:
from disnake import ApplicationCommandInteraction
from .cog import Cog
from .context import Context
from .errors import CommandError
T = TypeVar("T")
FuncT = TypeVar("FuncT", bound=Callable[..., Any])
Coro = Coroutine[Any, Any, T]
MaybeCoro = Union[T, Coro[T]]
CoroFunc = Callable[..., Coro[Any]]
Check = Union[
Callable[["Cog", "Context[Any]"], MaybeCoro[bool]], Callable[["Context[Any]"], MaybeCoro[bool]]
]
AppCheck = Union[
Callable[["Cog", "ApplicationCommandInteraction"], MaybeCoro[bool]],
Callable[["ApplicationCommandInteraction"], MaybeCoro[bool]],
]
Hook = Union[Callable[["Cog", "Context[Any]"], Coro[Any]], Callable[["Context[Any]"], Coro[Any]]]
Error = Union[
Callable[["Cog", "Context[Any]", "CommandError"], Coro[Any]],
Callable[["Context[Any]", "CommandError"], Coro[Any]],
]
# This is merely a tag type to avoid circular import issues.
# Yes, this is a terrible solution but ultimately it is the only solution.
class _BaseCommand:
__slots__ = ()

View File

@@ -0,0 +1,914 @@
# SPDX-License-Identifier: MIT
from __future__ import annotations
import asyncio
import datetime
import functools
from abc import ABC
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Optional,
Tuple,
TypeVar,
Union,
cast,
overload,
)
from disnake.app_commands import ApplicationCommand
from disnake.enums import ApplicationCommandType
from disnake.flags import ApplicationInstallTypes, InteractionContextTypes
from disnake.permissions import Permissions
from disnake.utils import (
_generated,
_overload_with_permissions,
async_all,
iscoroutinefunction,
maybe_coroutine,
)
from .cooldowns import BucketType, CooldownMapping, MaxConcurrency
from .errors import CheckFailure, CommandError, CommandInvokeError, CommandOnCooldown
if TYPE_CHECKING:
from typing_extensions import Concatenate, ParamSpec, Self
from disnake.interactions import ApplicationCommandInteraction
from ._types import AppCheck, Coro, Error, Hook
from .cog import Cog
from .interaction_bot_base import InteractionBotBase
ApplicationCommandInteractionT = TypeVar(
"ApplicationCommandInteractionT", bound=ApplicationCommandInteraction, covariant=True
)
P = ParamSpec("P")
CommandCallback = Callable[..., Coro[Any]]
InteractionCommandCallback = Union[
Callable[Concatenate["CogT", ApplicationCommandInteractionT, P], Coro[Any]],
Callable[Concatenate[ApplicationCommandInteractionT, P], Coro[Any]],
]
__all__ = (
"InvokableApplicationCommand",
"default_member_permissions",
"install_types",
"contexts",
)
T = TypeVar("T")
AppCommandT = TypeVar("AppCommandT", bound="InvokableApplicationCommand")
CogT = TypeVar("CogT", bound="Cog")
HookT = TypeVar("HookT", bound="Hook")
ErrorT = TypeVar("ErrorT", bound="Error")
def _get_overridden_method(method):
return getattr(method.__func__, "__cog_special_method__", method)
def wrap_callback(coro):
@functools.wraps(coro)
async def wrapped(*args, **kwargs):
try:
ret = await coro(*args, **kwargs)
except CommandError:
raise
except asyncio.CancelledError:
return
except Exception as exc:
raise CommandInvokeError(exc) from exc
return ret
return wrapped
class InvokableApplicationCommand(ABC):
"""A base class that implements the protocol for a bot application command.
These are not created manually, instead they are created via the
decorator or functional interface.
The following classes implement this ABC:
- :class:`~.InvokableSlashCommand`
- :class:`~.InvokableMessageCommand`
- :class:`~.InvokableUserCommand`
Attributes
----------
name: :class:`str`
The name of the command.
qualified_name: :class:`str`
The full command name, including parent names in the case of slash subcommands or groups.
For example, the qualified name for ``/one two three`` would be ``one two three``.
body: :class:`.ApplicationCommand`
An object being registered in the API.
callback: :ref:`coroutine <coroutine>`
The coroutine that is executed when the command is called.
cog: Optional[:class:`Cog`]
The cog that this command belongs to. ``None`` if there isn't one.
checks: List[Callable[[:class:`.ApplicationCommandInteraction`], :class:`bool`]]
A list of predicates that verifies if the command could be executed
with the given :class:`.ApplicationCommandInteraction` as the sole parameter. If an exception
is necessary to be thrown to signal failure, then one inherited from
:exc:`.CommandError` should be used. Note that if the checks fail then
:exc:`.CheckFailure` exception is raised to the :func:`.on_slash_command_error`
event.
guild_ids: Optional[Tuple[:class:`int`, ...]]
The list of IDs of the guilds where the command is synced. ``None`` if this command is global.
auto_sync: :class:`bool`
Whether to automatically register the command.
extras: Dict[:class:`str`, Any]
A dict of user provided extras to attach to the command.
.. versionadded:: 2.5
"""
__original_kwargs__: Dict[str, Any]
body: ApplicationCommand
def __new__(cls, *args: Any, **kwargs: Any) -> Self:
self = super().__new__(cls)
# todo: refactor to not require None and change this to be based on the presence of a kwarg
self.__original_kwargs__ = {k: v for k, v in kwargs.items() if v is not None}
return self
def __init__(self, func: CommandCallback, *, name: Optional[str] = None, **kwargs: Any) -> None:
self.__command_flag__ = None
self._callback: CommandCallback = func
self.name: str = name or func.__name__
self.qualified_name: str = self.name
# Annotation parser needs this attribute because body doesn't exist at this moment.
# We will use this attribute later in order to set the allowed contexts.
self._guild_only: bool = kwargs.get("guild_only", False)
self.extras: Dict[str, Any] = kwargs.get("extras") or {}
if not isinstance(self.name, str):
raise TypeError("Name of a command must be a string.")
if "default_permission" in kwargs:
raise TypeError(
"`default_permission` is deprecated and will always be set to `True`. "
"See `default_member_permissions` and `contexts` instead."
)
# XXX: remove in next major/minor version
# the parameter was called `integration_types` in earlier stages of the user apps PR.
# since unknown kwargs unfortunately get silently ignored, at least try to warn users
# in this specific case
if "integration_types" in kwargs:
raise TypeError("`integration_types` has been renamed to `install_types`.")
try:
checks = func.__commands_checks__
checks.reverse()
except AttributeError:
checks = kwargs.get("checks", [])
self.checks: List[AppCheck] = checks
try:
cooldown = func.__commands_cooldown__
except AttributeError:
cooldown = kwargs.get("cooldown")
# TODO: Figure out how cooldowns even work with interactions
if cooldown is None:
buckets = CooldownMapping(cooldown, BucketType.default)
elif isinstance(cooldown, CooldownMapping):
buckets = cooldown
else:
raise TypeError("Cooldown must be a an instance of CooldownMapping or None.")
self._buckets: CooldownMapping = buckets
try:
max_concurrency = func.__commands_max_concurrency__
except AttributeError:
max_concurrency = kwargs.get("max_concurrency")
self._max_concurrency: Optional[MaxConcurrency] = max_concurrency
self.cog: Optional[Cog] = None
self.guild_ids: Optional[Tuple[int, ...]] = None
self.auto_sync: bool = True
self._before_invoke: Optional[Hook] = None
self._after_invoke: Optional[Hook] = None
# this should copy all attributes that can be changed after instantiation via decorators
def _ensure_assignment_on_copy(self, other: AppCommandT) -> AppCommandT:
other._before_invoke = self._before_invoke
other._after_invoke = self._after_invoke
if self.checks != other.checks:
other.checks = self.checks.copy()
if self._buckets.valid and not other._buckets.valid:
other._buckets = self._buckets.copy()
if self._max_concurrency != other._max_concurrency:
# _max_concurrency won't be None at this point
other._max_concurrency = cast("MaxConcurrency", self._max_concurrency).copy()
if (
# see https://github.com/DisnakeDev/disnake/pull/678#discussion_r938113624:
# if these are not equal, then either `self` had a decorator, or `other` got a
# value from `*_command_attrs`; we only want to copy in the former case
self.body._default_member_permissions != other.body._default_member_permissions
and self.body._default_member_permissions is not None
):
other.body._default_member_permissions = self.body._default_member_permissions
if (
self.body.install_types != other.body.install_types
and self.body.install_types is not None # see above
):
other.body.install_types = ApplicationInstallTypes._from_value(
self.body.install_types.value
)
if (
self.body.contexts != other.body.contexts
and self.body.contexts is not None # see above
):
other.body.contexts = InteractionContextTypes._from_value(self.body.contexts.value)
try:
other.on_error = self.on_error
except AttributeError:
pass
return other
def copy(self: AppCommandT) -> AppCommandT:
"""Create a copy of this application command.
Returns
-------
:class:`InvokableApplicationCommand`
A new instance of this application command.
"""
copy = type(self)(self.callback, **self.__original_kwargs__)
return self._ensure_assignment_on_copy(copy)
def _update_copy(self: AppCommandT, kwargs: Dict[str, Any]) -> AppCommandT:
if kwargs:
kw = kwargs.copy()
kw.update(self.__original_kwargs__)
copy = type(self)(self.callback, **kw)
return self._ensure_assignment_on_copy(copy)
else:
return self.copy()
def _apply_guild_only(self) -> None:
# If we have a `GuildCommandInteraction` annotation, set `contexts` and `install_types` accordingly.
# This matches the old pre-user-apps behavior.
if self._guild_only:
# n.b. this overwrites any user-specified parameter
# FIXME(3.0): this should raise if these were set elsewhere (except `*_command_attrs`) already
self.body.contexts = InteractionContextTypes(guild=True)
self.body.install_types = ApplicationInstallTypes(guild=True)
def _apply_defaults(self, bot: InteractionBotBase) -> None:
self.body._default_install_types = bot._default_install_types
self.body._default_contexts = bot._default_contexts
@property
def dm_permission(self) -> bool:
""":class:`bool`: Whether this command can be used in DMs."""
return self.body.dm_permission
@property
def default_member_permissions(self) -> Optional[Permissions]:
"""Optional[:class:`.Permissions`]: The default required member permissions for this command.
A member must have *all* these permissions to be able to invoke the command in a guild.
This is a default value, the set of users/roles that may invoke this command can be
overridden by moderators on a guild-specific basis, disregarding this setting.
If ``None`` is returned, it means everyone can use the command by default.
If an empty :class:`.Permissions` object is returned (that is, all permissions set to ``False``),
this means no one can use the command.
.. versionadded:: 2.5
"""
return self.body.default_member_permissions
@property
def install_types(self) -> Optional[ApplicationInstallTypes]:
"""Optional[:class:`.ApplicationInstallTypes`]: The installation types
where the command is available. Only available for global commands.
.. versionadded:: 2.10
"""
return self.body.install_types
@property
def contexts(self) -> Optional[InteractionContextTypes]:
"""Optional[:class:`.InteractionContextTypes`]: The interaction contexts
where the command can be used. Only available for global commands.
.. versionadded:: 2.10
"""
return self.body.contexts
@property
def callback(self) -> CommandCallback:
return self._callback
def add_check(self, func: AppCheck) -> None:
"""Adds a check to the application command.
This is the non-decorator interface to :func:`.app_check`.
Parameters
----------
func
The function that will be used as a check.
"""
self.checks.append(func)
def remove_check(self, func: AppCheck) -> None:
"""Removes a check from the application command.
This function is idempotent and will not raise an exception
if the function is not in the command's checks.
Parameters
----------
func
The function to remove from the checks.
"""
try:
self.checks.remove(func)
except ValueError:
pass
async def __call__(
self, interaction: ApplicationCommandInteraction, *args: Any, **kwargs: Any
) -> Any:
"""|coro|
Calls the internal callback that the application command holds.
.. note::
This bypasses all mechanisms -- including checks, converters,
invoke hooks, cooldowns, etc. You must take care to pass
the proper arguments and types to this function.
"""
if self.cog is not None:
return await self.callback(self.cog, interaction, *args, **kwargs)
else:
return await self.callback(interaction, *args, **kwargs)
def _prepare_cooldowns(self, inter: ApplicationCommandInteraction) -> None:
if self._buckets.valid:
dt = inter.created_at
current = dt.replace(tzinfo=datetime.timezone.utc).timestamp()
bucket = self._buckets.get_bucket(inter, current) # type: ignore
if bucket is not None: # pyright: ignore[reportUnnecessaryComparison]
retry_after = bucket.update_rate_limit(current)
if retry_after:
raise CommandOnCooldown(bucket, retry_after, self._buckets.type) # type: ignore
async def prepare(self, inter: ApplicationCommandInteraction) -> None:
inter.application_command = self
if not await self.can_run(inter):
raise CheckFailure(f"The check functions for command {self.qualified_name!r} failed.")
if self._max_concurrency is not None:
await self._max_concurrency.acquire(inter) # type: ignore
try:
self._prepare_cooldowns(inter)
await self.call_before_hooks(inter)
except Exception:
if self._max_concurrency is not None:
await self._max_concurrency.release(inter) # type: ignore
raise
def is_on_cooldown(self, inter: ApplicationCommandInteraction) -> bool:
"""Checks whether the application command is currently on cooldown.
Parameters
----------
inter: :class:`.ApplicationCommandInteraction`
The interaction with the application command currently being invoked.
Returns
-------
:class:`bool`
A boolean indicating if the application command is on cooldown.
"""
if not self._buckets.valid:
return False
bucket = self._buckets.get_bucket(inter) # type: ignore
dt = inter.created_at
current = dt.replace(tzinfo=datetime.timezone.utc).timestamp()
return bucket.get_tokens(current) == 0
def reset_cooldown(self, inter: ApplicationCommandInteraction) -> None:
"""Resets the cooldown on this application command.
Parameters
----------
inter: :class:`.ApplicationCommandInteraction`
The interaction with this application command
"""
if self._buckets.valid:
bucket = self._buckets.get_bucket(inter) # type: ignore
bucket.reset()
def get_cooldown_retry_after(self, inter: ApplicationCommandInteraction) -> float:
"""Retrieves the amount of seconds before this application command can be tried again.
Parameters
----------
inter: :class:`.ApplicationCommandInteraction`
The interaction with this application command.
Returns
-------
:class:`float`
The amount of time left on this command's cooldown in seconds.
If this is ``0.0`` then the command isn't on cooldown.
"""
if self._buckets.valid:
bucket = self._buckets.get_bucket(inter) # type: ignore
dt = inter.created_at
current = dt.replace(tzinfo=datetime.timezone.utc).timestamp()
return bucket.get_retry_after(current)
return 0.0
# This method isn't really usable in this class, but it's usable in subclasses.
async def invoke(self, inter: ApplicationCommandInteraction, *args: Any, **kwargs: Any) -> None:
await self.prepare(inter)
try:
await self(inter, *args, **kwargs)
except CommandError:
inter.command_failed = True
raise
except asyncio.CancelledError:
inter.command_failed = True
return
except Exception as exc:
inter.command_failed = True
raise CommandInvokeError(exc) from exc
finally:
if self._max_concurrency is not None:
await self._max_concurrency.release(inter) # type: ignore
await self.call_after_hooks(inter)
def error(self, coro: ErrorT) -> ErrorT:
"""A decorator that registers a coroutine as a local error handler.
A local error handler is an error event limited to a single application command.
Parameters
----------
coro: :ref:`coroutine <coroutine>`
The coroutine to register as the local error handler.
Raises
------
TypeError
The coroutine passed is not actually a coroutine.
"""
if not iscoroutinefunction(coro):
raise TypeError("The error handler must be a coroutine.")
self.on_error: Error = coro
return coro
def has_error_handler(self) -> bool:
"""Checks whether the application command has an error handler registered."""
return hasattr(self, "on_error")
async def _call_local_error_handler(
self, inter: ApplicationCommandInteraction, error: CommandError
) -> Any:
if not self.has_error_handler():
return
injected = wrap_callback(self.on_error)
if self.cog is not None:
return await injected(self.cog, inter, error)
else:
return await injected(inter, error)
async def _call_external_error_handlers(
self, inter: ApplicationCommandInteraction, error: CommandError
) -> None:
"""Overridden in subclasses"""
raise error
async def dispatch_error(
self, inter: ApplicationCommandInteraction, error: CommandError
) -> None:
if not await self._call_local_error_handler(inter, error):
await self._call_external_error_handlers(inter, error)
async def call_before_hooks(self, inter: ApplicationCommandInteraction) -> None:
# now that we're done preparing we can call the pre-command hooks
# first, call the command local hook:
cog = self.cog
if self._before_invoke is not None:
# should be cog if @commands.before_invoke is used
instance = getattr(self._before_invoke, "__self__", cog)
# __self__ only exists for methods, not functions
# however, if @command.before_invoke is used, it will be a function
if instance:
await self._before_invoke(instance, inter) # type: ignore
else:
await self._before_invoke(inter) # type: ignore
if inter.data.type is ApplicationCommandType.chat_input:
partial_attr_name = "slash_command"
elif inter.data.type is ApplicationCommandType.user:
partial_attr_name = "user_command"
elif inter.data.type is ApplicationCommandType.message:
partial_attr_name = "message_command"
else:
return
# call the cog local hook if applicable:
if cog is not None:
meth = getattr(cog, f"cog_before_{partial_attr_name}_invoke", None)
hook = _get_overridden_method(meth)
if hook is not None:
await hook(inter)
# call the bot global hook if necessary
hook = getattr(inter.bot, f"_before_{partial_attr_name}_invoke", None)
if hook is not None:
await hook(inter)
async def call_after_hooks(self, inter: ApplicationCommandInteraction) -> None:
cog = self.cog
if self._after_invoke is not None:
instance = getattr(self._after_invoke, "__self__", cog)
if instance:
await self._after_invoke(instance, inter) # type: ignore
else:
await self._after_invoke(inter) # type: ignore
if inter.data.type is ApplicationCommandType.chat_input:
partial_attr_name = "slash_command"
elif inter.data.type is ApplicationCommandType.user:
partial_attr_name = "user_command"
elif inter.data.type is ApplicationCommandType.message:
partial_attr_name = "message_command"
else:
return
# call the cog local hook if applicable:
if cog is not None:
meth = getattr(cog, f"cog_after_{partial_attr_name}_invoke", None)
hook = _get_overridden_method(meth)
if hook is not None:
await hook(inter)
# call the bot global hook if necessary
hook = getattr(inter.bot, f"_after_{partial_attr_name}_invoke", None)
if hook is not None:
await hook(inter)
def before_invoke(self, coro: HookT) -> HookT:
"""A decorator that registers a coroutine as a pre-invoke hook.
A pre-invoke hook is called directly before the command is called.
This pre-invoke hook takes a sole parameter, a :class:`.ApplicationCommandInteraction`.
Parameters
----------
coro: :ref:`coroutine <coroutine>`
The coroutine to register as the pre-invoke hook.
Raises
------
TypeError
The coroutine passed is not actually a coroutine.
"""
if not iscoroutinefunction(coro):
raise TypeError("The pre-invoke hook must be a coroutine.")
self._before_invoke = coro
return coro
def after_invoke(self, coro: HookT) -> HookT:
"""A decorator that registers a coroutine as a post-invoke hook.
A post-invoke hook is called directly after the command is called.
This post-invoke hook takes a sole parameter, a :class:`.ApplicationCommandInteraction`.
Parameters
----------
coro: :ref:`coroutine <coroutine>`
The coroutine to register as the post-invoke hook.
Raises
------
TypeError
The coroutine passed is not actually a coroutine.
"""
if not iscoroutinefunction(coro):
raise TypeError("The post-invoke hook must be a coroutine.")
self._after_invoke = coro
return coro
@property
def cog_name(self) -> Optional[str]:
"""Optional[:class:`str`]: The name of the cog this application command belongs to, if any."""
return type(self.cog).__cog_name__ if self.cog is not None else None
async def can_run(self, inter: ApplicationCommandInteraction) -> bool:
"""|coro|
Checks if the command can be executed by checking all the predicates
inside the :attr:`~Command.checks` attribute.
Parameters
----------
inter: :class:`.ApplicationCommandInteraction`
The interaction with the application command currently being invoked.
Raises
------
:class:`CommandError`
Any application command error that was raised during a check call will be propagated
by this function.
Returns
-------
:class:`bool`
A boolean indicating if the application command can be invoked.
"""
original = inter.application_command
inter.application_command = self
if inter.data.type is ApplicationCommandType.chat_input:
partial_attr_name = "slash_command"
elif inter.data.type is ApplicationCommandType.user:
partial_attr_name = "user_command"
elif inter.data.type is ApplicationCommandType.message:
partial_attr_name = "message_command"
else:
return True
try:
if inter.bot and not await inter.bot.application_command_can_run(inter):
raise CheckFailure(
f"The global check functions for command {self.qualified_name} failed."
)
cog = self.cog
if cog is not None:
meth = getattr(cog, f"cog_{partial_attr_name}_check", None)
local_check = _get_overridden_method(meth)
if local_check is not None:
ret = await maybe_coroutine(local_check, inter)
if not ret:
return False
predicates = self.checks
if not predicates:
# since we have no checks, then we just return True.
return True
return await async_all(predicate(inter) for predicate in predicates) # type: ignore
finally:
inter.application_command = original
@overload
@_generated
def default_member_permissions(
value: int = 0,
*,
add_reactions: bool = ...,
administrator: bool = ...,
attach_files: bool = ...,
ban_members: bool = ...,
change_nickname: bool = ...,
connect: bool = ...,
create_events: bool = ...,
create_forum_threads: bool = ...,
create_guild_expressions: bool = ...,
create_instant_invite: bool = ...,
create_private_threads: bool = ...,
create_public_threads: bool = ...,
deafen_members: bool = ...,
embed_links: bool = ...,
external_emojis: bool = ...,
external_stickers: bool = ...,
kick_members: bool = ...,
manage_channels: bool = ...,
manage_emojis: bool = ...,
manage_emojis_and_stickers: bool = ...,
manage_events: bool = ...,
manage_guild: bool = ...,
manage_guild_expressions: bool = ...,
manage_messages: bool = ...,
manage_nicknames: bool = ...,
manage_permissions: bool = ...,
manage_roles: bool = ...,
manage_threads: bool = ...,
manage_webhooks: bool = ...,
mention_everyone: bool = ...,
moderate_members: bool = ...,
move_members: bool = ...,
mute_members: bool = ...,
pin_messages: bool = ...,
priority_speaker: bool = ...,
read_message_history: bool = ...,
read_messages: bool = ...,
request_to_speak: bool = ...,
send_messages: bool = ...,
send_messages_in_threads: bool = ...,
send_polls: bool = ...,
send_tts_messages: bool = ...,
send_voice_messages: bool = ...,
speak: bool = ...,
start_embedded_activities: bool = ...,
stream: bool = ...,
use_application_commands: bool = ...,
use_embedded_activities: bool = ...,
use_external_apps: bool = ...,
use_external_emojis: bool = ...,
use_external_sounds: bool = ...,
use_external_stickers: bool = ...,
use_slash_commands: bool = ...,
use_soundboard: bool = ...,
use_voice_activation: bool = ...,
view_audit_log: bool = ...,
view_channel: bool = ...,
view_creator_monetization_analytics: bool = ...,
view_guild_insights: bool = ...,
) -> Callable[[T], T]: ...
@overload
@_generated
def default_member_permissions(
value: int = 0,
) -> Callable[[T], T]: ...
@_overload_with_permissions
def default_member_permissions(value: int = 0, **permissions: bool) -> Callable[[T], T]:
"""A decorator that sets default required member permissions for the application command.
Unlike :func:`~.has_permissions`, this decorator does not add any checks.
Instead, it prevents the command from being run by members without *all* required permissions,
if not overridden by moderators on a guild-specific basis.
See also the ``default_member_permissions`` parameter for application command decorators.
.. note::
This does not work with slash subcommands/groups.
.. versionadded:: 2.5
Example
-------
This would only allow members with :attr:`~.Permissions.manage_messages` *and*
:attr:`~.Permissions.view_audit_log` permissions to use the command by default,
however moderators can override this and allow/disallow specific users and
roles to use the command in their guilds regardless of this setting.
.. code-block:: python3
@bot.slash_command()
@commands.default_member_permissions(manage_messages=True, view_audit_log=True)
async def purge(inter, num: int):
...
Parameters
----------
value: :class:`int`
A raw permission bitfield of an integer representing the required permissions.
May be used instead of specifying kwargs.
**permissions: bool
The required permissions for a command. A member must have *all* these
permissions to be able to invoke the command.
Setting a permission to ``False`` does not affect the result.
"""
if isinstance(value, bool):
raise TypeError("`value` cannot be a bool value")
perms_value = Permissions(value, **permissions).value
def decorator(func: T) -> T:
from .slash_core import SubCommand, SubCommandGroup
if isinstance(func, InvokableApplicationCommand):
if isinstance(func, (SubCommand, SubCommandGroup)):
raise TypeError(
"Cannot set `default_member_permissions` on subcommands or subcommand groups"
)
if func.body._default_member_permissions is not None:
raise ValueError(
"Cannot set `default_member_permissions` in both parameter and decorator"
)
func.body._default_member_permissions = perms_value
else:
func.__default_member_permissions__ = perms_value # type: ignore
return func
return decorator
def install_types(*, guild: bool = False, user: bool = False) -> Callable[[T], T]:
"""A decorator that sets the installation types where the
application command is available.
See also the ``install_types`` parameter for application command decorators.
.. note::
This does not work with slash subcommands/groups.
.. versionadded:: 2.10
Parameters
----------
**params: bool
The installation types; see :class:`.ApplicationInstallTypes`.
Setting a parameter to ``False`` does not affect the result.
"""
def decorator(func: T) -> T:
from .slash_core import SubCommand, SubCommandGroup
install_types = ApplicationInstallTypes(guild=guild, user=user)
if isinstance(func, InvokableApplicationCommand):
if isinstance(func, (SubCommand, SubCommandGroup)):
raise TypeError("Cannot set `install_types` on subcommands or subcommand groups")
# special case - don't overwrite if `_guild_only` was set, since that takes priority
if not func._guild_only:
if func.body.install_types is not None:
raise ValueError("Cannot set `install_types` in both parameter and decorator")
func.body.install_types = install_types
else:
func.__install_types__ = install_types # type: ignore
return func
return decorator
def contexts(
*, guild: bool = False, bot_dm: bool = False, private_channel: bool = False
) -> Callable[[T], T]:
"""A decorator that sets the interaction contexts where the application command can be used.
See also the ``contexts`` parameter for application command decorators.
.. note::
This does not work with slash subcommands/groups.
.. versionadded:: 2.10
Parameters
----------
**params: bool
The interaction contexts; see :class:`.InteractionContextTypes`.
Setting a parameter to ``False`` does not affect the result.
"""
def decorator(func: T) -> T:
from .slash_core import SubCommand, SubCommandGroup
contexts = InteractionContextTypes(
guild=guild, bot_dm=bot_dm, private_channel=private_channel
)
if isinstance(func, InvokableApplicationCommand):
if isinstance(func, (SubCommand, SubCommandGroup)):
raise TypeError("Cannot set `contexts` on subcommands or subcommand groups")
# special case - don't overwrite if `_guild_only` was set, since that takes priority
if not func._guild_only:
if func.body._dm_permission is not None:
raise ValueError(
"Cannot use both `dm_permission` and `contexts` at the same time"
)
if func.body.contexts is not None:
raise ValueError("Cannot set `contexts` in both parameter and decorator")
func.body.contexts = contexts
else:
func.__contexts__ = contexts # type: ignore
return func
return decorator

View File

@@ -0,0 +1,563 @@
# SPDX-License-Identifier: MIT
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Sequence, Set, Union
import disnake
from .bot_base import BotBase, when_mentioned, when_mentioned_or
from .interaction_bot_base import InteractionBotBase
if TYPE_CHECKING:
import asyncio
import aiohttp
from typing_extensions import Self
from disnake.activity import BaseActivity
from disnake.client import GatewayParams
from disnake.enums import Status
from disnake.flags import (
ApplicationInstallTypes,
Intents,
InteractionContextTypes,
MemberCacheFlags,
)
from disnake.i18n import LocalizationProtocol
from disnake.mentions import AllowedMentions
from disnake.message import Message
from ._types import MaybeCoro
from .bot_base import PrefixType
from .flags import CommandSyncFlags
from .help import HelpCommand
__all__ = (
"when_mentioned",
"when_mentioned_or",
"BotBase",
"Bot",
"InteractionBot",
"AutoShardedBot",
"AutoShardedInteractionBot",
)
MISSING: Any = disnake.utils.MISSING
class Bot(BotBase, InteractionBotBase, disnake.Client):
"""Represents a discord bot.
This class is a subclass of :class:`disnake.Client` and as a result
anything that you can do with a :class:`disnake.Client` you can do with
this bot.
This class also subclasses :class:`.GroupMixin` to provide the functionality
to manage commands.
Parameters
----------
test_guilds: List[:class:`int`]
The list of IDs of the guilds where you're going to test your application commands.
Defaults to ``None``, which means global registration of commands across
all guilds.
.. versionadded:: 2.1
command_sync_flags: :class:`.CommandSyncFlags`
The command sync flags for the session. This is a way of
controlling when and how application commands will be synced with the Discord API.
If not given, defaults to :func:`CommandSyncFlags.default`.
.. versionadded:: 2.7
sync_commands: :class:`bool`
Whether to enable automatic synchronization of application commands in your code.
Defaults to ``True``, which means that commands in API are automatically synced
with the commands in your code.
.. versionadded:: 2.1
.. deprecated:: 2.7
Replaced with ``command_sync_flags``.
sync_commands_on_cog_unload: :class:`bool`
Whether to sync the application commands on cog unload / reload. Defaults to ``True``.
.. versionadded:: 2.1
.. deprecated:: 2.7
Replaced with ``command_sync_flags``.
sync_commands_debug: :class:`bool`
Whether to always show sync debug logs (uses ``INFO`` log level if it's enabled, prints otherwise).
If disabled, uses the default ``DEBUG`` log level which isn't shown unless the log level is changed manually.
Useful for tracking the commands being registered in the API.
Defaults to ``False``.
.. versionadded:: 2.1
.. versionchanged:: 2.4
Changes the log level of corresponding messages from ``DEBUG`` to ``INFO`` or ``print``\\s them,
instead of controlling whether they are enabled at all.
.. deprecated:: 2.7
Replaced with ``command_sync_flags``.
localization_provider: :class:`.LocalizationProtocol`
An implementation of :class:`.LocalizationProtocol` to use for localization of
application commands.
If not provided, the default :class:`.LocalizationStore` implementation is used.
.. versionadded:: 2.5
strict_localization: :class:`bool`
Whether to raise an exception when localizations for a specific key couldn't be found.
This is mainly useful for testing/debugging, consider disabling this eventually
as missing localized names will automatically fall back to the default/base name without it.
Only applicable if the ``localization_provider`` parameter is not provided.
Defaults to ``False``.
.. versionadded:: 2.5
default_install_types: Optional[:class:`.ApplicationInstallTypes`]
The default installation types where application commands will be available.
This applies to all commands added either through the respective decorators
or directly using :meth:`.add_slash_command` (etc.).
Any value set directly on the command, e.g. using the :func:`.install_types` decorator,
the ``install_types`` parameter, ``slash_command_attrs`` (etc.) at the cog-level, or from
the :class:`.GuildCommandInteraction` annotation, takes precedence over this default.
.. versionadded:: 2.10
default_contexts: Optional[:class:`.InteractionContextTypes`]
The default contexts where application commands will be usable.
This applies to all commands added either through the respective decorators
or directly using :meth:`.add_slash_command` (etc.).
Any value set directly on the command, e.g. using the :func:`.contexts` decorator,
the ``contexts`` parameter, ``slash_command_attrs`` (etc.) at the cog-level, or from
the :class:`.GuildCommandInteraction` annotation, takes precedence over this default.
.. versionadded:: 2.10
Attributes
----------
command_prefix
The command prefix is what the message content must contain initially
to have a command invoked. This prefix could either be a string to
indicate what the prefix should be, or a callable that takes in the bot
as its first parameter and :class:`disnake.Message` as its second
parameter and returns the prefix. This is to facilitate "dynamic"
command prefixes. This callable can be either a regular function or
a coroutine.
An empty string as the prefix always matches, enabling prefix-less
command invocation. While this may be useful in DMs it should be avoided
in servers, as it's likely to cause performance issues and unintended
command invocations.
The command prefix could also be an iterable of strings indicating that
multiple checks for the prefix should be used and the first one to
match will be the invocation prefix. You can get this prefix via
:attr:`.Context.prefix`. To avoid confusion empty iterables are not
allowed.
If the prefix is ``None``, the bot won't listen to any prefixes, and prefix
commands will not be processed. If you don't need prefix commands, consider
using :class:`InteractionBot` or :class:`AutoShardedInteractionBot` instead,
which are drop-in replacements, just without prefix command support.
This can be provided as a parameter at creation.
.. note::
When passing multiple prefixes be careful to not pass a prefix
that matches a longer prefix occurring later in the sequence. For
example, if the command prefix is ``('!', '!?')`` the ``'!?'``
prefix will never be matched to any message as the previous one
matches messages starting with ``!?``. This is especially important
when passing an empty string, it should always be last as no prefix
after it will be matched.
case_insensitive: :class:`bool`
Whether the commands should be case insensitive. Defaults to ``False``. This
attribute does not carry over to groups. You must set it to every group if
you require group commands to be case insensitive as well.
This can be provided as a parameter at creation.
description: :class:`str`
The content prefixed into the default help message.
This can be provided as a parameter at creation.
help_command: Optional[:class:`.HelpCommand`]
The help command implementation to use. This can be dynamically
set at runtime. To remove the help command pass ``None``. For more
information on implementing a help command, see :ref:`ext_commands_api_help_commands`.
This can be provided as a parameter at creation.
owner_id: Optional[:class:`int`]
The ID of the user that owns the bot. If this is not set and is then queried via
:meth:`.is_owner` then it is fetched automatically using
:meth:`~.Bot.application_info`.
This can be provided as a parameter at creation.
owner_ids: Optional[Collection[:class:`int`]]
The IDs of the users that own the bot. This is similar to :attr:`owner_id`.
If this is not set and the application is team based, then it is
fetched automatically using :meth:`~.Bot.application_info` (taking team roles into account).
For performance reasons it is recommended to use a :class:`set`
for the collection. You cannot set both ``owner_id`` and ``owner_ids``.
This can be provided as a parameter at creation.
.. versionadded:: 1.3
strip_after_prefix: :class:`bool`
Whether to strip whitespace characters after encountering the command
prefix. This allows for ``! hello`` and ``!hello`` to both work if
the ``command_prefix`` is set to ``!``. Defaults to ``False``.
This can be provided as a parameter at creation.
.. versionadded:: 1.7
reload: :class:`bool`
Whether to enable automatic extension reloading on file modification for debugging.
Whenever you save an extension with reloading enabled the file will be automatically
reloaded for you so you do not have to reload the extension manually. Defaults to ``False``
This can be provided as a parameter at creation.
.. versionadded:: 2.1
i18n: :class:`.LocalizationProtocol`
An implementation of :class:`.LocalizationProtocol` used for localization of
application commands.
.. versionadded:: 2.5
"""
if TYPE_CHECKING:
def __init__(
self,
command_prefix: Optional[
Union[PrefixType, Callable[[Self, Message], MaybeCoro[PrefixType]]]
] = None,
help_command: Optional[HelpCommand] = ...,
description: Optional[str] = None,
*,
strip_after_prefix: bool = False,
owner_id: Optional[int] = None,
owner_ids: Optional[Set[int]] = None,
reload: bool = False,
case_insensitive: bool = False,
command_sync_flags: CommandSyncFlags = ...,
sync_commands: bool = ...,
sync_commands_debug: bool = ...,
sync_commands_on_cog_unload: bool = ...,
test_guilds: Optional[Sequence[int]] = None,
default_install_types: Optional[ApplicationInstallTypes] = None,
default_contexts: Optional[InteractionContextTypes] = None,
asyncio_debug: bool = False,
loop: Optional[asyncio.AbstractEventLoop] = None,
shard_id: Optional[int] = None,
shard_count: Optional[int] = None,
enable_debug_events: bool = False,
enable_gateway_error_handler: bool = True,
gateway_params: Optional[GatewayParams] = None,
connector: Optional[aiohttp.BaseConnector] = None,
proxy: Optional[str] = None,
proxy_auth: Optional[aiohttp.BasicAuth] = None,
assume_unsync_clock: bool = True,
max_messages: Optional[int] = 1000,
application_id: Optional[int] = None,
heartbeat_timeout: float = 60.0,
guild_ready_timeout: float = 2.0,
allowed_mentions: Optional[AllowedMentions] = None,
activity: Optional[BaseActivity] = None,
status: Optional[Union[Status, str]] = None,
intents: Optional[Intents] = None,
chunk_guilds_at_startup: Optional[bool] = None,
member_cache_flags: Optional[MemberCacheFlags] = None,
localization_provider: Optional[LocalizationProtocol] = None,
strict_localization: bool = False,
) -> None: ...
class AutoShardedBot(BotBase, InteractionBotBase, disnake.AutoShardedClient):
"""Similar to :class:`.Bot`, except that it is inherited from
:class:`disnake.AutoShardedClient` instead.
"""
if TYPE_CHECKING:
def __init__(
self,
command_prefix: Optional[
Union[PrefixType, Callable[[Self, Message], MaybeCoro[PrefixType]]]
] = None,
help_command: Optional[HelpCommand] = ...,
description: Optional[str] = None,
*,
strip_after_prefix: bool = False,
owner_id: Optional[int] = None,
owner_ids: Optional[Set[int]] = None,
reload: bool = False,
case_insensitive: bool = False,
command_sync_flags: CommandSyncFlags = ...,
sync_commands: bool = ...,
sync_commands_debug: bool = ...,
sync_commands_on_cog_unload: bool = ...,
test_guilds: Optional[Sequence[int]] = None,
default_install_types: Optional[ApplicationInstallTypes] = None,
default_contexts: Optional[InteractionContextTypes] = None,
asyncio_debug: bool = False,
loop: Optional[asyncio.AbstractEventLoop] = None,
shard_ids: Optional[List[int]] = None, # instead of shard_id
shard_count: Optional[int] = None,
enable_debug_events: bool = False,
enable_gateway_error_handler: bool = True,
gateway_params: Optional[GatewayParams] = None,
connector: Optional[aiohttp.BaseConnector] = None,
proxy: Optional[str] = None,
proxy_auth: Optional[aiohttp.BasicAuth] = None,
assume_unsync_clock: bool = True,
max_messages: Optional[int] = 1000,
application_id: Optional[int] = None,
heartbeat_timeout: float = 60.0,
guild_ready_timeout: float = 2.0,
allowed_mentions: Optional[AllowedMentions] = None,
activity: Optional[BaseActivity] = None,
status: Optional[Union[Status, str]] = None,
intents: Optional[Intents] = None,
chunk_guilds_at_startup: Optional[bool] = None,
member_cache_flags: Optional[MemberCacheFlags] = None,
localization_provider: Optional[LocalizationProtocol] = None,
strict_localization: bool = False,
) -> None: ...
class InteractionBot(InteractionBotBase, disnake.Client):
"""Represents a discord bot for application commands only.
This class is a subclass of :class:`disnake.Client` and as a result
anything that you can do with a :class:`disnake.Client` you can do with
this bot.
This class also subclasses InteractionBotBase to provide the functionality
to manage application commands.
Parameters
----------
test_guilds: List[:class:`int`]
The list of IDs of the guilds where you're going to test your application commands.
Defaults to ``None``, which means global registration of commands across
all guilds.
.. versionadded:: 2.1
command_sync_flags: :class:`.CommandSyncFlags`
The command sync flags for the session. This is a way of
controlling when and how application commands will be synced with the Discord API.
If not given, defaults to :func:`CommandSyncFlags.default`.
.. versionadded:: 2.7
sync_commands: :class:`bool`
Whether to enable automatic synchronization of application commands in your code.
Defaults to ``True``, which means that commands in API are automatically synced
with the commands in your code.
.. versionadded:: 2.1
.. deprecated:: 2.7
Replaced with ``command_sync_flags``.
sync_commands_on_cog_unload: :class:`bool`
Whether to sync the application commands on cog unload / reload. Defaults to ``True``.
.. versionadded:: 2.1
.. deprecated:: 2.7
Replaced with ``command_sync_flags``.
sync_commands_debug: :class:`bool`
Whether to always show sync debug logs (uses ``INFO`` log level if it's enabled, prints otherwise).
If disabled, uses the default ``DEBUG`` log level which isn't shown unless the log level is changed manually.
Useful for tracking the commands being registered in the API.
Defaults to ``False``.
.. versionadded:: 2.1
.. versionchanged:: 2.4
Changes the log level of corresponding messages from ``DEBUG`` to ``INFO`` or ``print``\\s them,
instead of controlling whether they are enabled at all.
.. deprecated:: 2.7
Replaced with ``command_sync_flags``.
localization_provider: :class:`.LocalizationProtocol`
An implementation of :class:`.LocalizationProtocol` to use for localization of
application commands.
If not provided, the default :class:`.LocalizationStore` implementation is used.
.. versionadded:: 2.5
strict_localization: :class:`bool`
Whether to raise an exception when localizations for a specific key couldn't be found.
This is mainly useful for testing/debugging, consider disabling this eventually
as missing localized names will automatically fall back to the default/base name without it.
Only applicable if the ``localization_provider`` parameter is not provided.
Defaults to ``False``.
.. versionadded:: 2.5
default_install_types: Optional[:class:`.ApplicationInstallTypes`]
The default installation types where application commands will be available.
This applies to all commands added either through the respective decorators
or directly using :meth:`.add_slash_command` (etc.).
Any value set directly on the command, e.g. using the :func:`.install_types` decorator,
the ``install_types`` parameter, ``slash_command_attrs`` (etc.) at the cog-level, or from
the :class:`.GuildCommandInteraction` annotation, takes precedence over this default.
.. versionadded:: 2.10
default_contexts: Optional[:class:`.InteractionContextTypes`]
The default contexts where application commands will be usable.
This applies to all commands added either through the respective decorators
or directly using :meth:`.add_slash_command` (etc.).
Any value set directly on the command, e.g. using the :func:`.contexts` decorator,
the ``contexts`` parameter, ``slash_command_attrs`` (etc.) at the cog-level, or from
the :class:`.GuildCommandInteraction` annotation, takes precedence over this default.
.. versionadded:: 2.10
Attributes
----------
owner_id: Optional[:class:`int`]
The ID of the user that owns the bot. If this is not set and is then queried via
:meth:`.is_owner` then it is fetched automatically using
:meth:`~.Bot.application_info`.
This can be provided as a parameter at creation.
owner_ids: Optional[Collection[:class:`int`]]
The IDs of the users that own the bot. This is similar to :attr:`owner_id`.
If this is not set and the application is team based, then it is
fetched automatically using :meth:`~.Bot.application_info` (taking team roles into account).
For performance reasons it is recommended to use a :class:`set`
for the collection. You cannot set both ``owner_id`` and ``owner_ids``.
This can be provided as a parameter at creation.
reload: :class:`bool`
Whether to enable automatic extension reloading on file modification for debugging.
Whenever you save an extension with reloading enabled the file will be automatically
reloaded for you so you do not have to reload the extension manually. Defaults to ``False``
This can be provided as a parameter at creation.
.. versionadded:: 2.1
i18n: :class:`.LocalizationProtocol`
An implementation of :class:`.LocalizationProtocol` used for localization of
application commands.
.. versionadded:: 2.5
"""
if TYPE_CHECKING:
def __init__(
self,
*,
owner_id: Optional[int] = None,
owner_ids: Optional[Set[int]] = None,
reload: bool = False,
command_sync_flags: CommandSyncFlags = ...,
sync_commands: bool = ...,
sync_commands_debug: bool = ...,
sync_commands_on_cog_unload: bool = ...,
test_guilds: Optional[Sequence[int]] = None,
default_install_types: Optional[ApplicationInstallTypes] = None,
default_contexts: Optional[InteractionContextTypes] = None,
asyncio_debug: bool = False,
loop: Optional[asyncio.AbstractEventLoop] = None,
shard_id: Optional[int] = None,
shard_count: Optional[int] = None,
enable_debug_events: bool = False,
enable_gateway_error_handler: bool = True,
gateway_params: Optional[GatewayParams] = None,
connector: Optional[aiohttp.BaseConnector] = None,
proxy: Optional[str] = None,
proxy_auth: Optional[aiohttp.BasicAuth] = None,
assume_unsync_clock: bool = True,
max_messages: Optional[int] = 1000,
application_id: Optional[int] = None,
heartbeat_timeout: float = 60.0,
guild_ready_timeout: float = 2.0,
allowed_mentions: Optional[AllowedMentions] = None,
activity: Optional[BaseActivity] = None,
status: Optional[Union[Status, str]] = None,
intents: Optional[Intents] = None,
chunk_guilds_at_startup: Optional[bool] = None,
member_cache_flags: Optional[MemberCacheFlags] = None,
localization_provider: Optional[LocalizationProtocol] = None,
strict_localization: bool = False,
) -> None: ...
class AutoShardedInteractionBot(InteractionBotBase, disnake.AutoShardedClient):
"""Similar to :class:`.InteractionBot`, except that it is inherited from
:class:`disnake.AutoShardedClient` instead.
"""
if TYPE_CHECKING:
def __init__(
self,
*,
owner_id: Optional[int] = None,
owner_ids: Optional[Set[int]] = None,
reload: bool = False,
command_sync_flags: CommandSyncFlags = ...,
sync_commands: bool = ...,
sync_commands_debug: bool = ...,
sync_commands_on_cog_unload: bool = ...,
test_guilds: Optional[Sequence[int]] = None,
default_install_types: Optional[ApplicationInstallTypes] = None,
default_contexts: Optional[InteractionContextTypes] = None,
asyncio_debug: bool = False,
loop: Optional[asyncio.AbstractEventLoop] = None,
shard_ids: Optional[List[int]] = None, # instead of shard_id
shard_count: Optional[int] = None,
enable_debug_events: bool = False,
enable_gateway_error_handler: bool = True,
gateway_params: Optional[GatewayParams] = None,
connector: Optional[aiohttp.BaseConnector] = None,
proxy: Optional[str] = None,
proxy_auth: Optional[aiohttp.BasicAuth] = None,
assume_unsync_clock: bool = True,
max_messages: Optional[int] = 1000,
application_id: Optional[int] = None,
heartbeat_timeout: float = 60.0,
guild_ready_timeout: float = 2.0,
allowed_mentions: Optional[AllowedMentions] = None,
activity: Optional[BaseActivity] = None,
status: Optional[Union[Status, str]] = None,
intents: Optional[Intents] = None,
chunk_guilds_at_startup: Optional[bool] = None,
member_cache_flags: Optional[MemberCacheFlags] = None,
localization_provider: Optional[LocalizationProtocol] = None,
strict_localization: bool = False,
) -> None: ...

View File

@@ -0,0 +1,609 @@
# SPDX-License-Identifier: MIT
from __future__ import annotations
import collections
import collections.abc
import inspect
import logging
import sys
import traceback
import warnings
from typing import TYPE_CHECKING, Any, Callable, Iterable, List, Optional, Type, TypeVar, Union
import disnake
from disnake.utils import iscoroutinefunction
from . import errors
from .common_bot_base import CommonBotBase
from .context import Context
from .core import GroupMixin
from .custom_warnings import MessageContentPrefixWarning
from .help import DefaultHelpCommand, HelpCommand
from .view import StringView
if TYPE_CHECKING:
from typing_extensions import Self
from disnake.message import Message
from ._types import Check, CoroFunc, MaybeCoro
__all__ = (
"when_mentioned",
"when_mentioned_or",
"BotBase",
)
MISSING: Any = disnake.utils.MISSING
T = TypeVar("T")
CFT = TypeVar("CFT", bound="CoroFunc")
CXT = TypeVar("CXT", bound="Context")
PrefixType = Union[str, Iterable[str]]
_log = logging.getLogger(__name__)
def when_mentioned(bot: BotBase, msg: Message) -> List[str]:
"""A callable that implements a command prefix equivalent to being mentioned.
These are meant to be passed into the :attr:`.Bot.command_prefix` attribute.
"""
# bot.user will never be None when this is called
return [f"<@{bot.user.id}> ", f"<@!{bot.user.id}> "] # type: ignore
def when_mentioned_or(*prefixes: str) -> Callable[[BotBase, Message], List[str]]:
"""A callable that implements when mentioned or other prefixes provided.
These are meant to be passed into the :attr:`.Bot.command_prefix` attribute.
Example
-------
.. code-block:: python3
bot = commands.Bot(command_prefix=commands.when_mentioned_or('!'))
.. note::
This callable returns another callable, so if this is done inside a custom
callable, you must call the returned callable, for example:
.. code-block:: python3
async def get_prefix(bot, message):
extras = await prefixes_for(message.guild) # returns a list
return commands.when_mentioned_or(*extras)(bot, message)
See Also
--------
:func:`.when_mentioned`
"""
def inner(bot: BotBase, msg: Message) -> List[str]:
r = list(prefixes)
r = when_mentioned(bot, msg) + r
return r
return inner
def _is_submodule(parent: str, child: str) -> bool:
return parent == child or child.startswith(parent + ".")
class _DefaultRepr:
def __repr__(self) -> str:
return "<default-help-command>"
_default: Any = _DefaultRepr()
class BotBase(CommonBotBase, GroupMixin):
def __init__(
self,
command_prefix: Optional[
Union[PrefixType, Callable[[Self, Message], MaybeCoro[PrefixType]]]
] = None,
help_command: Optional[HelpCommand] = _default,
description: Optional[str] = None,
*,
strip_after_prefix: bool = False,
**options: Any,
) -> None:
super().__init__(**options)
if not isinstance(self, disnake.Client):
raise RuntimeError("BotBase mixin must be used with disnake.Client") # noqa: TRY004
alternative = (
"AutoShardedInteractionBot"
if isinstance(self, disnake.AutoShardedClient)
else "InteractionBot"
)
if command_prefix is None:
disnake.utils.warn_deprecated(
"Using `command_prefix=None` is deprecated and will result in "
"an error in future versions. "
f"If you don't need any prefix functionality, consider using {alternative}.",
stacklevel=2,
)
elif (
# note: no need to check for empty iterables,
# as they won't be allowed by `get_prefix`
command_prefix is not when_mentioned and not self.intents.message_content
):
warnings.warn(
"Message Content intent is not enabled and a prefix is configured. "
"This may cause limited functionality for prefix commands. "
"If you want prefix commands, pass an intents object with message_content set to True. "
f"If you don't need any prefix functionality, consider using {alternative}. "
"Alternatively, set prefix to disnake.ext.commands.when_mentioned to silence this warning.",
MessageContentPrefixWarning,
stacklevel=2,
)
self.command_prefix = command_prefix
self._checks: List[Check] = []
self._check_once: List[Check] = []
self._before_invoke: Optional[CoroFunc] = None
self._after_invoke: Optional[CoroFunc] = None
self._help_command: Optional[HelpCommand] = None
self.description: str = inspect.cleandoc(description) if description else ""
self.strip_after_prefix: bool = strip_after_prefix
if help_command is _default:
self.help_command = DefaultHelpCommand()
else:
self.help_command = help_command
# internal helpers
async def on_command_error(self, context: Context, exception: errors.CommandError) -> None:
"""|coro|
The default command error handler provided by the bot.
This is for text commands only, and doesn't apply to application commands.
By default this prints to :data:`sys.stderr` however it could be
overridden to have a different implementation.
This only fires if you do not specify any listeners for command error.
"""
if self.extra_events.get("on_command_error", None):
return
command = context.command
if command and command.has_error_handler():
return
cog = context.cog
if cog and cog.has_error_handler():
return
print(f"Ignoring exception in command {context.command}:", file=sys.stderr)
traceback.print_exception(
type(exception), exception, exception.__traceback__, file=sys.stderr
)
# global check registration
def add_check(
self,
func: Check,
*,
call_once: bool = False,
) -> None:
"""Adds a global check to the bot.
This is for text commands only, and doesn't apply to application commands.
This is the non-decorator interface to :meth:`.check` and :meth:`.check_once`.
Parameters
----------
func
The function that was used as a global check.
call_once: :class:`bool`
If the function should only be called once per
:meth:`.invoke` call.
"""
if call_once:
self._check_once.append(func)
else:
self._checks.append(func)
def remove_check(
self,
func: Check,
*,
call_once: bool = False,
) -> None:
"""Removes a global check from the bot.
This is for text commands only, and doesn't apply to application commands.
This function is idempotent and will not raise an exception
if the function is not in the global checks.
Parameters
----------
func
The function to remove from the global checks.
call_once: :class:`bool`
If the function was added with ``call_once=True`` in
the :meth:`.Bot.add_check` call or using :meth:`.check_once`.
"""
check_list = self._check_once if call_once else self._checks
try:
check_list.remove(func)
except ValueError:
pass
def check(self, func: T) -> T:
"""A decorator that adds a global check to the bot.
This is for text commands only, and doesn't apply to application commands.
A global check is similar to a :func:`.check` that is applied
on a per command basis except it is run before any command checks
have been verified and applies to every command the bot has.
.. note::
This function can either be a regular function or a coroutine.
Similar to a command :func:`.check`\\, this takes a single parameter
of type :class:`.Context` and can only raise exceptions inherited from
:exc:`.CommandError`.
Example
-------
.. code-block:: python3
@bot.check
def check_commands(ctx):
return ctx.command.qualified_name in allowed_commands
"""
# T was used instead of Check to ensure the type matches on return
self.add_check(func) # type: ignore
return func
def check_once(self, func: CFT) -> CFT:
"""A decorator that adds a "call once" global check to the bot.
This is for text commands only, and doesn't apply to application commands.
Unlike regular global checks, this one is called only once
per :meth:`.invoke` call.
Regular global checks are called whenever a command is called
or :meth:`.Command.can_run` is called. This type of check
bypasses that and ensures that it's called only once, even inside
the default help command.
.. note::
When using this function the :class:`.Context` sent to a group subcommand
may only parse the parent command and not the subcommands due to it
being invoked once per :meth:`.Bot.invoke` call.
.. note::
This function can either be a regular function or a coroutine.
Similar to a command :func:`.check`\\, this takes a single parameter
of type :class:`.Context` and can only raise exceptions inherited from
:exc:`.CommandError`.
Example
-------
.. code-block:: python3
@bot.check_once
def whitelist(ctx):
return ctx.message.author.id in my_whitelist
"""
self.add_check(func, call_once=True)
return func
async def can_run(self, ctx: Context, *, call_once: bool = False) -> bool:
data = self._check_once if call_once else self._checks
if len(data) == 0:
return True
# type-checker doesn't distinguish between functions and methods
return await disnake.utils.async_all(f(ctx) for f in data) # type: ignore
def before_invoke(self, coro: CFT) -> CFT:
"""A decorator that registers a coroutine as a pre-invoke hook.
This is for text commands only, and doesn't apply to application commands.
A pre-invoke hook is called directly before the command is
called. This makes it a useful function to set up database
connections or any type of set up required.
This pre-invoke hook takes a sole parameter, a :class:`.Context`.
.. note::
The :meth:`~.Bot.before_invoke` and :meth:`~.Bot.after_invoke` hooks are
only called if all checks and argument parsing procedures pass
without error. If any check or argument parsing procedures fail
then the hooks are not called.
Parameters
----------
coro: :ref:`coroutine <coroutine>`
The coroutine to register as the pre-invoke hook.
Raises
------
TypeError
The coroutine passed is not actually a coroutine.
"""
if not iscoroutinefunction(coro):
raise TypeError("The pre-invoke hook must be a coroutine.")
self._before_invoke = coro
return coro
def after_invoke(self, coro: CFT) -> CFT:
"""A decorator that registers a coroutine as a post-invoke hook.
This is for text commands only, and doesn't apply to application commands.
A post-invoke hook is called directly after the command is
called. This makes it a useful function to clean-up database
connections or any type of clean up required.
This post-invoke hook takes a sole parameter, a :class:`.Context`.
.. note::
Similar to :meth:`~.Bot.before_invoke`\\, this is not called unless
checks and argument parsing procedures succeed. This hook is,
however, **always** called regardless of the internal command
callback raising an error (i.e. :exc:`.CommandInvokeError`\\).
This makes it ideal for clean-up scenarios.
Parameters
----------
coro: :ref:`coroutine <coroutine>`
The coroutine to register as the post-invoke hook.
Raises
------
TypeError
The coroutine passed is not actually a coroutine.
"""
if not iscoroutinefunction(coro):
raise TypeError("The post-invoke hook must be a coroutine.")
self._after_invoke = coro
return coro
# extensions
def _remove_module_references(self, name: str) -> None:
super()._remove_module_references(name)
# remove all the commands from the module
for cmd in self.all_commands.copy().values():
if cmd.module and _is_submodule(name, cmd.module):
if isinstance(cmd, GroupMixin):
cmd.recursively_remove_all_commands()
self.remove_command(cmd.name)
# help command stuff
@property
def help_command(self) -> Optional[HelpCommand]:
return self._help_command
@help_command.setter
def help_command(self, value: Optional[HelpCommand]) -> None:
if value is not None and not isinstance(value, HelpCommand):
raise TypeError("help_command must be a subclass of HelpCommand or None")
if self._help_command is not None:
self._help_command._remove_from_bot(self)
self._help_command = value
if value is not None:
value._add_to_bot(self)
# command processing
async def get_prefix(self, message: Message) -> Optional[Union[List[str], str]]:
"""|coro|
Retrieves the prefix the bot is listening to
with the message as a context.
Parameters
----------
message: :class:`disnake.Message`
The message context to get the prefix of.
Returns
-------
Optional[Union[List[:class:`str`], :class:`str`]]
A list of prefixes or a single prefix that the bot is
listening for. None if the bot isn't listening for prefixes.
"""
ret = self.command_prefix
if callable(ret):
ret = await disnake.utils.maybe_coroutine(ret, self, message)
if ret is None:
return None
if not isinstance(ret, str):
try:
ret = list(ret)
except TypeError:
# It's possible that a generator raised this exception. Don't
# replace it with our own error if that's the case.
if isinstance(ret, collections.abc.Iterable):
raise
raise TypeError(
"command_prefix must be plain string, iterable of strings, or callable "
f"returning either of these, not {ret.__class__.__name__}"
) from None
if not ret:
raise ValueError("Iterable command_prefix must contain at least one prefix")
return ret
async def get_context(self, message: Message, *, cls: Type[CXT] = Context) -> CXT:
"""|coro|
Returns the invocation context from the message.
This is a more low-level counter-part for :meth:`.process_commands`
to allow users more fine grained control over the processing.
The returned context is not guaranteed to be a valid invocation
context, :attr:`.Context.valid` must be checked to make sure it is.
If the context is not valid then it is not a valid candidate to be
invoked under :meth:`~.Bot.invoke`.
Parameters
----------
message: :class:`disnake.Message`
The message to get the invocation context from.
cls
The factory class that will be used to create the context.
By default, this is :class:`.Context`. Should a custom
class be provided, it must be similar enough to :class:`.Context`\'s
interface.
Returns
-------
:class:`.Context`
The invocation context. The type of this can change via the
``cls`` parameter.
"""
view = StringView(message.content)
ctx = cls(prefix=None, view=view, bot=self, message=message)
if message.author.id == self.user.id: # type: ignore
return ctx
prefix = await self.get_prefix(message)
invoked_prefix = prefix
if prefix is None:
return ctx
elif isinstance(prefix, str):
if not view.skip_string(prefix):
return ctx
else:
try:
# if the context class' __init__ consumes something from the view this
# will be wrong. That seems unreasonable though.
if message.content.startswith(tuple(prefix)):
invoked_prefix = disnake.utils.find(view.skip_string, prefix)
else:
return ctx
except TypeError:
if not isinstance(prefix, list):
raise TypeError(
"get_prefix must return either a string or a list of string, "
f"not {prefix.__class__.__name__}"
) from None
# It's possible a bad command_prefix got us here.
for value in prefix:
if not isinstance(value, str):
raise TypeError(
"Iterable command_prefix or list returned from get_prefix must "
f"contain only strings, not {value.__class__.__name__}"
) from None
# Getting here shouldn't happen
raise
if self.strip_after_prefix:
view.skip_ws()
invoker = view.get_word()
ctx.invoked_with = invoker
# type-checker fails to narrow invoked_prefix type.
ctx.prefix = invoked_prefix # type: ignore
ctx.command = self.all_commands.get(invoker)
return ctx
async def invoke(self, ctx: Context) -> None:
"""|coro|
Invokes the command given under the invocation context and
handles all the internal event dispatch mechanisms.
Parameters
----------
ctx: :class:`.Context`
The invocation context to invoke.
"""
if ctx.command is not None:
self.dispatch("command", ctx)
try:
if await self.can_run(ctx, call_once=True):
await ctx.command.invoke(ctx)
else:
raise errors.CheckFailure("The global check once functions failed.")
except errors.CommandError as exc:
await ctx.command.dispatch_error(ctx, exc)
else:
self.dispatch("command_completion", ctx)
elif ctx.invoked_with:
exc = errors.CommandNotFound(f'Command "{ctx.invoked_with}" is not found')
self.dispatch("command_error", ctx, exc)
async def process_commands(self, message: Message) -> None:
"""|coro|
This function processes the commands that have been registered
to the bot and other groups. Without this coroutine, none of the
commands will be triggered.
By default, this coroutine is called inside the :func:`.on_message`
event. If you choose to override the :func:`.on_message` event, then
you should invoke this coroutine as well.
This is built using other low level tools, and is equivalent to a
call to :meth:`~.Bot.get_context` followed by a call to :meth:`~.Bot.invoke`.
This also checks if the message's author is a bot and doesn't
call :meth:`~.Bot.get_context` or :meth:`~.Bot.invoke` if so.
Parameters
----------
message: :class:`disnake.Message`
The message to process commands for.
"""
if message.author.bot:
return
ctx = await self.get_context(message)
await self.invoke(ctx)
async def on_message(self, message) -> None:
await self.process_commands(message)

View File

@@ -0,0 +1,899 @@
# SPDX-License-Identifier: MIT
from __future__ import annotations
import inspect
import logging
from typing import (
TYPE_CHECKING,
Any,
Callable,
ClassVar,
Dict,
Generator,
List,
Optional,
Tuple,
Type,
Union,
)
import disnake
import disnake.utils
from disnake.enums import Event
from ._types import _BaseCommand
from .base_core import InvokableApplicationCommand
from .ctx_menus_core import InvokableMessageCommand, InvokableUserCommand
from .slash_core import InvokableSlashCommand
if TYPE_CHECKING:
from typing_extensions import Self
from disnake.interactions import ApplicationCommandInteraction
from ._types import FuncT, MaybeCoro
from .bot import AutoShardedBot, AutoShardedInteractionBot, Bot, InteractionBot
from .context import Context
from .core import Command
AnyBot = Union[Bot, AutoShardedBot, InteractionBot, AutoShardedInteractionBot]
__all__ = (
"CogMeta",
"Cog",
)
MISSING: Any = disnake.utils.MISSING
_log = logging.getLogger(__name__)
def _cog_special_method(func: FuncT) -> FuncT:
func.__cog_special_method__ = None
return func
class CogMeta(type):
"""A metaclass for defining a cog.
Note that you should probably not use this directly. It is exposed
purely for documentation purposes along with making custom metaclasses to intermix
with other metaclasses such as the :class:`abc.ABCMeta` metaclass.
For example, to create an abstract cog mixin class, the following would be done.
.. code-block:: python3
import abc
class CogABCMeta(commands.CogMeta, abc.ABCMeta):
pass
class SomeMixin(metaclass=abc.ABCMeta):
pass
class SomeCogMixin(SomeMixin, commands.Cog, metaclass=CogABCMeta):
pass
.. note::
When passing an attribute of a metaclass that is documented below, note
that you must pass it as a keyword-only argument to the class creation
like the following example:
.. code-block:: python3
class MyCog(commands.Cog, name='My Cog'):
pass
Attributes
----------
name: :class:`str`
The cog name. By default, it is the name of the class with no modification.
description: :class:`str`
The cog description. By default, it is the cleaned docstring of the class.
.. versionadded:: 1.6
command_attrs: Dict[:class:`str`, Any]
A list of attributes to apply to every command inside this cog. The dictionary
is passed into the :class:`Command` options at ``__init__``.
If you specify attributes inside the command attribute in the class, it will
override the one specified inside this attribute. For example:
.. code-block:: python3
class MyCog(commands.Cog, command_attrs=dict(hidden=True)):
@commands.command()
async def foo(self, ctx):
pass # hidden -> True
@commands.command(hidden=False)
async def bar(self, ctx):
pass # hidden -> False
slash_command_attrs: Dict[:class:`str`, Any]
A list of attributes to apply to every slash command inside this cog. The dictionary
is passed into the options of every :class:`InvokableSlashCommand` at ``__init__``.
Usage of this kwarg is otherwise the same as with ``command_attrs``.
.. note:: This does not apply to instances of :class:`SubCommand` or :class:`SubCommandGroup`.
.. versionadded:: 2.5
user_command_attrs: Dict[:class:`str`, Any]
A list of attributes to apply to every user command inside this cog. The dictionary
is passed into the options of every :class:`InvokableUserCommand` at ``__init__``.
Usage of this kwarg is otherwise the same as with ``command_attrs``.
.. versionadded:: 2.5
message_command_attrs: Dict[:class:`str`, Any]
A list of attributes to apply to every message command inside this cog. The dictionary
is passed into the options of every :class:`InvokableMessageCommand` at ``__init__``.
Usage of this kwarg is otherwise the same as with ``command_attrs``.
.. versionadded:: 2.5
"""
__cog_name__: str
__cog_settings__: Dict[str, Any]
__cog_slash_settings__: Dict[str, Any]
__cog_user_settings__: Dict[str, Any]
__cog_message_settings__: Dict[str, Any]
__cog_commands__: List[Command]
__cog_app_commands__: List[InvokableApplicationCommand]
__cog_listeners__: List[Tuple[str, str]]
def __new__(cls: Type[CogMeta], *args: Any, **kwargs: Any) -> CogMeta:
name, bases, attrs = args
attrs["__cog_name__"] = kwargs.pop("name", name)
attrs["__cog_settings__"] = kwargs.pop("command_attrs", {})
attrs["__cog_slash_settings__"] = kwargs.pop("slash_command_attrs", {})
attrs["__cog_user_settings__"] = kwargs.pop("user_command_attrs", {})
attrs["__cog_message_settings__"] = kwargs.pop("message_command_attrs", {})
description = kwargs.pop("description", None)
if description is None:
description = inspect.cleandoc(attrs.get("__doc__", ""))
attrs["__cog_description__"] = description
commands = {}
app_commands = {}
listeners = {}
no_bot_cog = (
"Commands or listeners must not start with cog_ or bot_ (in method {0.__name__}.{1})"
)
new_cls = super().__new__(cls, name, bases, attrs, **kwargs)
for base in reversed(new_cls.__mro__):
for elem, value in base.__dict__.items():
commands.pop(elem, None)
app_commands.pop(elem, None)
listeners.pop(elem, None)
is_static_method = isinstance(value, staticmethod)
if is_static_method:
value = value.__func__
if isinstance(value, _BaseCommand):
if is_static_method:
raise TypeError(
f"Command in method {base}.{elem!r} must not be staticmethod."
)
if elem.startswith(("cog_", "bot_")):
raise TypeError(no_bot_cog.format(base, elem))
commands[elem] = value
elif isinstance(value, InvokableApplicationCommand):
if is_static_method:
raise TypeError(
f"Application command in method {base}.{elem!r} must not be staticmethod."
)
if elem.startswith(("cog_", "bot_")):
raise TypeError(no_bot_cog.format(base, elem))
app_commands[elem] = value
elif disnake.utils.iscoroutinefunction(value):
if hasattr(value, "__cog_listener__"):
if elem.startswith(("cog_", "bot_")):
raise TypeError(no_bot_cog.format(base, elem))
listeners[elem] = value
new_cls.__cog_commands__ = list(commands.values()) # this will be copied in Cog.__new__
new_cls.__cog_app_commands__ = list(app_commands.values())
listeners_as_list = []
for listener in listeners.values():
for listener_name in listener.__cog_listener_names__:
# I use __name__ instead of just storing the value so I can inject
# the self attribute when the time comes to add them to the bot
listeners_as_list.append((listener_name, listener.__name__))
new_cls.__cog_listeners__ = listeners_as_list
return new_cls
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args)
@classmethod
def qualified_name(cls) -> str:
return cls.__cog_name__
class Cog(metaclass=CogMeta):
"""The base class that all cogs must inherit from.
A cog is a collection of commands, listeners, and optional state to
help group commands together. More information on them can be found on
the :ref:`ext_commands_cogs` page.
When inheriting from this class, the options shown in :class:`CogMeta`
are equally valid here.
"""
__cog_name__: ClassVar[str]
__cog_settings__: ClassVar[Dict[str, Any]]
__cog_commands__: ClassVar[List[Command]]
__cog_app_commands__: ClassVar[List[InvokableApplicationCommand]]
__cog_listeners__: ClassVar[List[Tuple[str, str]]]
def __new__(cls, *args: Any, **kwargs: Any) -> Self:
# For issue 426, we need to store a copy of the command objects
# since we modify them to inject `self` to them.
# To do this, we need to interfere with the Cog creation process.
self = super().__new__(cls)
cmd_attrs = cls.__cog_settings__
slash_cmd_attrs = cls.__cog_slash_settings__
user_cmd_attrs = cls.__cog_user_settings__
message_cmd_attrs = cls.__cog_message_settings__
# Either update the command with the cog provided defaults or copy it.
cog_app_commands: List[InvokableApplicationCommand] = []
for c in cls.__cog_app_commands__:
if isinstance(c, InvokableSlashCommand):
c = c._update_copy(slash_cmd_attrs)
elif isinstance(c, InvokableUserCommand):
c = c._update_copy(user_cmd_attrs)
elif isinstance(c, InvokableMessageCommand):
c = c._update_copy(message_cmd_attrs)
cog_app_commands.append(c)
self.__cog_app_commands__ = tuple(cog_app_commands) # type: ignore # overriding ClassVar
# Replace the old command objects with the new copies
for app_command in self.__cog_app_commands__:
setattr(self, app_command.callback.__name__, app_command)
self.__cog_commands__ = tuple(c._update_copy(cmd_attrs) for c in cls.__cog_commands__) # type: ignore # overriding ClassVar
lookup = {cmd.qualified_name: cmd for cmd in self.__cog_commands__}
for command in self.__cog_commands__:
setattr(self, command.callback.__name__, command)
parent = command.parent
if parent is not None:
# Get the latest parent reference
parent = lookup[parent.qualified_name] # type: ignore
# Update our parent's reference to our self
parent.remove_command(command.name) # type: ignore
parent.add_command(command) # type: ignore
return self
def get_commands(self) -> List[Command]:
"""Returns a list of commands the cog has.
Returns
-------
List[:class:`.Command`]
A :class:`list` of :class:`.Command`\\s that are
defined inside this cog.
.. note::
This does not include subcommands.
"""
return [c for c in self.__cog_commands__ if c.parent is None]
def get_application_commands(self) -> List[InvokableApplicationCommand]:
"""Returns a list of application commands the cog has.
Returns
-------
List[:class:`.InvokableApplicationCommand`]
A :class:`list` of :class:`.InvokableApplicationCommand`\\s that are
defined inside this cog.
.. note::
This does not include subcommands.
"""
return list(self.__cog_app_commands__)
def get_slash_commands(self) -> List[InvokableSlashCommand]:
"""Returns a list of slash commands the cog has.
Returns
-------
List[:class:`.InvokableSlashCommand`]
A :class:`list` of :class:`.InvokableSlashCommand`\\s that are
defined inside this cog.
.. note::
This does not include subcommands.
"""
return [c for c in self.__cog_app_commands__ if isinstance(c, InvokableSlashCommand)]
def get_user_commands(self) -> List[InvokableUserCommand]:
"""Returns a list of user commands the cog has.
Returns
-------
List[:class:`.InvokableUserCommand`]
A :class:`list` of :class:`.InvokableUserCommand`\\s that are
defined inside this cog.
"""
return [c for c in self.__cog_app_commands__ if isinstance(c, InvokableUserCommand)]
def get_message_commands(self) -> List[InvokableMessageCommand]:
"""Returns a list of message commands the cog has.
Returns
-------
List[:class:`.InvokableMessageCommand`]
A :class:`list` of :class:`.InvokableMessageCommand`\\s that are
defined inside this cog.
"""
return [c for c in self.__cog_app_commands__ if isinstance(c, InvokableMessageCommand)]
@property
def qualified_name(self) -> str:
""":class:`str`: Returns the cog's specified name, not the class name."""
return self.__cog_name__
@property
def description(self) -> str:
""":class:`str`: Returns the cog's description, typically the cleaned docstring."""
return self.__cog_description__
@description.setter
def description(self, description: str) -> None:
self.__cog_description__ = description
def walk_commands(self) -> Generator[Command, None, None]:
"""An iterator that recursively walks through this cog's commands and subcommands.
Yields
------
Union[:class:`.Command`, :class:`.Group`]
A command or group from the cog.
"""
from .core import GroupMixin
for command in self.__cog_commands__:
if command.parent is None:
yield command
if isinstance(command, GroupMixin):
yield from command.walk_commands()
def get_listeners(self) -> List[Tuple[str, Callable[..., Any]]]:
"""Returns a :class:`list` of (name, function) listener pairs the cog has.
Returns
-------
List[Tuple[:class:`str`, :ref:`coroutine <coroutine>`]]
The listeners defined in this cog.
"""
return [(name, getattr(self, method_name)) for name, method_name in self.__cog_listeners__]
@classmethod
def _get_overridden_method(cls, method: FuncT) -> Optional[FuncT]:
"""Return None if the method is not overridden. Otherwise returns the overridden method."""
return getattr(method.__func__, "__cog_special_method__", method)
@classmethod
def listener(cls, name: Union[str, Event] = MISSING) -> Callable[[FuncT], FuncT]:
"""A decorator that marks a function as a listener.
This is the cog equivalent of :meth:`.Bot.listen`.
Parameters
----------
name: Union[:class:`str`, :class:`.Event`]
The name of the event being listened to. If not provided, it
defaults to the function's name.
Raises
------
TypeError
The function is not a coroutine function or a string or an :class:`.Event` enum member was not passed as
the name.
"""
if name is not MISSING and not isinstance(name, (str, Event)):
raise TypeError(
f"Cog.listener expected str or Enum but received {name.__class__.__name__!r} instead."
)
def decorator(func: FuncT) -> FuncT:
actual = func
if isinstance(actual, staticmethod):
actual = actual.__func__
if not disnake.utils.iscoroutinefunction(actual):
raise TypeError("Listener function must be a coroutine function.")
actual.__cog_listener__ = True
to_assign = (
actual.__name__
if name is MISSING
else (name if isinstance(name, str) else f"on_{name.value}")
)
try:
actual.__cog_listener_names__.append(to_assign)
except AttributeError:
actual.__cog_listener_names__ = [to_assign]
# we have to return `func` instead of `actual` because
# we need the type to be `staticmethod` for the metaclass
# to pick it up but the metaclass unfurls the function and
# thus the assignments need to be on the actual function
return func
return decorator
def has_error_handler(self) -> bool:
"""Whether the cog has an error handler.
.. versionadded:: 1.7
:return type: :class:`bool`
"""
return not hasattr(self.cog_command_error.__func__, "__cog_special_method__")
def has_slash_error_handler(self) -> bool:
"""Whether the cog has a slash command error handler.
:return type: :class:`bool`
"""
return not hasattr(self.cog_slash_command_error.__func__, "__cog_special_method__")
def has_user_error_handler(self) -> bool:
"""Whether the cog has a user command error handler.
:return type: :class:`bool`
"""
return not hasattr(self.cog_user_command_error.__func__, "__cog_special_method__")
def has_message_error_handler(self) -> bool:
"""Whether the cog has a message command error handler.
:return type: :class:`bool`
"""
return not hasattr(self.cog_message_command_error.__func__, "__cog_special_method__")
@_cog_special_method
async def cog_load(self) -> None:
"""A special method that is called as a task when the cog is added."""
pass
@_cog_special_method
def cog_unload(self) -> None:
"""A special method that is called when the cog gets removed.
This function **cannot** be a coroutine. It must be a regular
function.
Subclasses must replace this if they want special unloading behaviour.
"""
pass
@_cog_special_method
def bot_check_once(self, ctx: Context) -> MaybeCoro[bool]:
"""A special method that registers as a :meth:`.Bot.check_once`
check.
This is for text commands only, and doesn't apply to application commands.
This function **can** be a coroutine and must take a sole parameter,
``ctx``, to represent the :class:`.Context`.
"""
return True
@_cog_special_method
def bot_check(self, ctx: Context) -> MaybeCoro[bool]:
"""A special method that registers as a :meth:`.Bot.check`
check.
This is for text commands only, and doesn't apply to application commands.
This function **can** be a coroutine and must take a sole parameter,
``ctx``, to represent the :class:`.Context`.
"""
return True
@_cog_special_method
def bot_slash_command_check_once(self, inter: ApplicationCommandInteraction) -> MaybeCoro[bool]:
"""A special method that registers as a :meth:`.Bot.slash_command_check_once`
check.
This function **can** be a coroutine and must take a sole parameter,
``inter``, to represent the :class:`.ApplicationCommandInteraction`.
"""
return True
@_cog_special_method
def bot_slash_command_check(self, inter: ApplicationCommandInteraction) -> MaybeCoro[bool]:
"""A special method that registers as a :meth:`.Bot.slash_command_check`
check.
This function **can** be a coroutine and must take a sole parameter,
``inter``, to represent the :class:`.ApplicationCommandInteraction`.
"""
return True
@_cog_special_method
def bot_user_command_check_once(self, inter: ApplicationCommandInteraction) -> MaybeCoro[bool]:
"""Similar to :meth:`.Bot.slash_command_check_once` but for user commands."""
return True
@_cog_special_method
def bot_user_command_check(self, inter: ApplicationCommandInteraction) -> MaybeCoro[bool]:
"""Similar to :meth:`.Bot.slash_command_check` but for user commands."""
return True
@_cog_special_method
def bot_message_command_check_once(
self, inter: ApplicationCommandInteraction
) -> MaybeCoro[bool]:
"""Similar to :meth:`.Bot.slash_command_check_once` but for message commands."""
return True
@_cog_special_method
def bot_message_command_check(self, inter: ApplicationCommandInteraction) -> MaybeCoro[bool]:
"""Similar to :meth:`.Bot.slash_command_check` but for message commands."""
return True
@_cog_special_method
def cog_check(self, ctx: Context) -> MaybeCoro[bool]:
"""A special method that registers as a :func:`~.check`
for every text command and subcommand in this cog.
This is for text commands only, and doesn't apply to application commands.
This function **can** be a coroutine and must take a sole parameter,
``ctx``, to represent the :class:`.Context`.
"""
return True
@_cog_special_method
def cog_slash_command_check(self, inter: ApplicationCommandInteraction) -> MaybeCoro[bool]:
"""A special method that registers as a :func:`~.check`
for every slash command and subcommand in this cog.
This function **can** be a coroutine and must take a sole parameter,
``inter``, to represent the :class:`.ApplicationCommandInteraction`.
"""
return True
@_cog_special_method
def cog_user_command_check(self, inter: ApplicationCommandInteraction) -> MaybeCoro[bool]:
"""Similar to :meth:`.Cog.cog_slash_command_check` but for user commands."""
return True
@_cog_special_method
def cog_message_command_check(self, inter: ApplicationCommandInteraction) -> MaybeCoro[bool]:
"""Similar to :meth:`.Cog.cog_slash_command_check` but for message commands."""
return True
@_cog_special_method
async def cog_command_error(self, ctx: Context, error: Exception) -> None:
"""A special method that is called whenever an error
is dispatched inside this cog.
This is for text commands only, and doesn't apply to application commands.
This is similar to :func:`.on_command_error` except only applying
to the commands inside this cog.
This **must** be a coroutine.
Parameters
----------
ctx: :class:`.Context`
The invocation context where the error happened.
error: :class:`CommandError`
The error that was raised.
"""
pass
@_cog_special_method
async def cog_slash_command_error(
self, inter: ApplicationCommandInteraction, error: Exception
) -> None:
"""A special method that is called whenever an error
is dispatched inside this cog.
This is similar to :func:`.on_slash_command_error` except only applying
to the slash commands inside this cog.
This **must** be a coroutine.
Parameters
----------
inter: :class:`.ApplicationCommandInteraction`
The interaction where the error happened.
error: :class:`CommandError`
The error that was raised.
"""
pass
@_cog_special_method
async def cog_user_command_error(
self, inter: ApplicationCommandInteraction, error: Exception
) -> None:
"""Similar to :func:`cog_slash_command_error` but for user commands."""
pass
@_cog_special_method
async def cog_message_command_error(
self, inter: ApplicationCommandInteraction, error: Exception
) -> None:
"""Similar to :func:`cog_slash_command_error` but for message commands."""
pass
@_cog_special_method
async def cog_before_invoke(self, ctx: Context) -> None:
"""A special method that acts as a cog local pre-invoke hook,
similar to :meth:`.Command.before_invoke`.
This is for text commands only, and doesn't apply to application commands.
This **must** be a coroutine.
Parameters
----------
ctx: :class:`.Context`
The invocation context.
"""
pass
@_cog_special_method
async def cog_after_invoke(self, ctx: Context) -> None:
"""A special method that acts as a cog local post-invoke hook,
similar to :meth:`.Command.after_invoke`.
This is for text commands only, and doesn't apply to application commands.
This **must** be a coroutine.
Parameters
----------
ctx: :class:`.Context`
The invocation context.
"""
pass
@_cog_special_method
async def cog_before_slash_command_invoke(self, inter: ApplicationCommandInteraction) -> None:
"""A special method that acts as a cog local pre-invoke hook.
This is similar to :meth:`.Command.before_invoke` but for slash commands.
This **must** be a coroutine.
Parameters
----------
inter: :class:`.ApplicationCommandInteraction`
The interaction of the slash command.
"""
pass
@_cog_special_method
async def cog_after_slash_command_invoke(self, inter: ApplicationCommandInteraction) -> None:
"""A special method that acts as a cog local post-invoke hook.
This is similar to :meth:`.Command.after_invoke` but for slash commands.
This **must** be a coroutine.
Parameters
----------
inter: :class:`.ApplicationCommandInteraction`
The interaction of the slash command.
"""
pass
@_cog_special_method
async def cog_before_user_command_invoke(self, inter: ApplicationCommandInteraction) -> None:
"""Similar to :meth:`cog_before_slash_command_invoke` but for user commands."""
pass
@_cog_special_method
async def cog_after_user_command_invoke(self, inter: ApplicationCommandInteraction) -> None:
"""Similar to :meth:`cog_after_slash_command_invoke` but for user commands."""
pass
@_cog_special_method
async def cog_before_message_command_invoke(self, inter: ApplicationCommandInteraction) -> None:
"""Similar to :meth:`cog_before_slash_command_invoke` but for message commands."""
pass
@_cog_special_method
async def cog_after_message_command_invoke(self, inter: ApplicationCommandInteraction) -> None:
"""Similar to :meth:`cog_after_slash_command_invoke` but for message commands."""
pass
def _inject(self, bot: AnyBot) -> Self:
from .bot import AutoShardedInteractionBot, InteractionBot
cls = self.__class__
if (
isinstance(bot, (InteractionBot, AutoShardedInteractionBot))
and len(self.__cog_commands__) > 0
):
raise TypeError("@commands.command is not supported for interaction bots.")
# realistically, the only thing that can cause loading errors
# is essentially just the command loading, which raises if there are
# duplicates. When this condition is met, we want to undo all what
# we've added so far for some form of atomic loading.
for index, command in enumerate(self.__cog_commands__):
command.cog = self
if command.parent is None:
try:
bot.add_command(command) # type: ignore
except Exception:
# undo our additions
for to_undo in self.__cog_commands__[:index]:
if to_undo.parent is None:
bot.remove_command(to_undo.name) # type: ignore
raise
for index, command in enumerate(self.__cog_app_commands__):
command.cog = self
try:
if isinstance(command, InvokableSlashCommand):
bot.add_slash_command(command)
elif isinstance(command, InvokableUserCommand):
bot.add_user_command(command)
elif isinstance(command, InvokableMessageCommand):
bot.add_message_command(command)
except Exception:
# undo our additions
for to_undo in self.__cog_app_commands__[:index]:
if isinstance(to_undo, InvokableSlashCommand):
bot.remove_slash_command(to_undo.name)
elif isinstance(to_undo, InvokableUserCommand):
bot.remove_user_command(to_undo.name)
elif isinstance(to_undo, InvokableMessageCommand):
bot.remove_message_command(to_undo.name)
raise
if not hasattr(self.cog_load.__func__, "__cog_special_method__"):
bot.loop.create_task(disnake.utils.maybe_coroutine(self.cog_load))
# check if we're overriding the default
if cls.bot_check is not Cog.bot_check:
if isinstance(bot, (InteractionBot, AutoShardedInteractionBot)):
raise TypeError("Cog.bot_check is not supported for interaction bots.")
bot.add_check(self.bot_check)
if cls.bot_check_once is not Cog.bot_check_once:
if isinstance(bot, (InteractionBot, AutoShardedInteractionBot)):
raise TypeError("Cog.bot_check_once is not supported for interaction bots.")
bot.add_check(self.bot_check_once, call_once=True)
# Add application command checks
if cls.bot_slash_command_check is not Cog.bot_slash_command_check:
bot.add_app_command_check(self.bot_slash_command_check, slash_commands=True)
if cls.bot_user_command_check is not Cog.bot_user_command_check:
bot.add_app_command_check(self.bot_user_command_check, user_commands=True)
if cls.bot_message_command_check is not Cog.bot_message_command_check:
bot.add_app_command_check(self.bot_message_command_check, message_commands=True)
# Add app command one-off checks
if cls.bot_slash_command_check_once is not Cog.bot_slash_command_check_once:
bot.add_app_command_check(
self.bot_slash_command_check_once,
call_once=True,
slash_commands=True,
)
if cls.bot_user_command_check_once is not Cog.bot_user_command_check_once:
bot.add_app_command_check(
self.bot_user_command_check_once, call_once=True, user_commands=True
)
if cls.bot_message_command_check_once is not Cog.bot_message_command_check_once:
bot.add_app_command_check(
self.bot_message_command_check_once,
call_once=True,
message_commands=True,
)
# while Bot.add_listener can raise if it's not a coroutine,
# this precondition is already met by the listener decorator
# already, thus this should never raise.
# Outside of, memory errors and the like...
for name, method_name in self.__cog_listeners__:
bot.add_listener(getattr(self, method_name), name)
try:
if bot._command_sync_flags.sync_on_cog_actions:
bot._schedule_delayed_command_sync()
except NotImplementedError:
pass
return self
def _eject(self, bot: AnyBot) -> None:
cls = self.__class__
try:
for command in self.__cog_commands__:
if command.parent is None:
bot.remove_command(command.name) # type: ignore
for app_command in self.__cog_app_commands__:
if isinstance(app_command, InvokableSlashCommand):
bot.remove_slash_command(app_command.name)
elif isinstance(app_command, InvokableUserCommand):
bot.remove_user_command(app_command.name)
elif isinstance(app_command, InvokableMessageCommand):
bot.remove_message_command(app_command.name)
for name, method_name in self.__cog_listeners__:
bot.remove_listener(getattr(self, method_name), name)
if cls.bot_check is not Cog.bot_check:
bot.remove_check(self.bot_check) # type: ignore
if cls.bot_check_once is not Cog.bot_check_once:
bot.remove_check(self.bot_check_once, call_once=True) # type: ignore
# Remove application command checks
if cls.bot_slash_command_check is not Cog.bot_slash_command_check:
bot.remove_app_command_check(self.bot_slash_command_check, slash_commands=True)
if cls.bot_user_command_check is not Cog.bot_user_command_check:
bot.remove_app_command_check(self.bot_user_command_check, user_commands=True)
if cls.bot_message_command_check is not Cog.bot_message_command_check:
bot.remove_app_command_check(self.bot_message_command_check, message_commands=True)
# Remove app command one-off checks
if cls.bot_slash_command_check_once is not Cog.bot_slash_command_check_once:
bot.remove_app_command_check(
self.bot_slash_command_check_once,
call_once=True,
slash_commands=True,
)
if cls.bot_user_command_check_once is not Cog.bot_user_command_check_once:
bot.remove_app_command_check(
self.bot_user_command_check_once,
call_once=True,
user_commands=True,
)
if cls.bot_message_command_check_once is not Cog.bot_message_command_check_once:
bot.remove_app_command_check(
self.bot_message_command_check_once,
call_once=True,
message_commands=True,
)
finally:
try:
if bot._command_sync_flags.sync_on_cog_actions:
bot._schedule_delayed_command_sync()
except NotImplementedError:
pass
try:
self.cog_unload()
except Exception as e:
_log.error(
"An error occurred while unloading the %s cog.", self.qualified_name, exc_info=e
)

View File

@@ -0,0 +1,545 @@
# SPDX-License-Identifier: MIT
from __future__ import annotations
import asyncio
import collections.abc
import importlib.machinery
import importlib.util
import logging
import os
import sys
import time
import types
from typing import TYPE_CHECKING, Any, Dict, Generic, List, Mapping, Optional, Set, TypeVar, Union
import disnake
import disnake.utils
from . import errors
from .cog import Cog
if TYPE_CHECKING:
from ._types import CoroFunc
from .bot import AutoShardedBot, AutoShardedInteractionBot, Bot, InteractionBot
from .help import HelpCommand
AnyBot = Union[Bot, AutoShardedBot, InteractionBot, AutoShardedInteractionBot]
__all__ = ("CommonBotBase",)
_log = logging.getLogger(__name__)
CogT = TypeVar("CogT", bound="Cog")
CFT = TypeVar("CFT", bound="CoroFunc")
MISSING: Any = disnake.utils.MISSING
def _is_submodule(parent: str, child: str) -> bool:
return parent == child or child.startswith(parent + ".")
class CommonBotBase(Generic[CogT]):
if TYPE_CHECKING:
extra_events: Dict[str, List[CoroFunc]]
def __init__(
self,
*args: Any,
owner_id: Optional[int] = None,
owner_ids: Optional[Set[int]] = None,
reload: bool = False,
**kwargs: Any,
) -> None:
self.__cogs: Dict[str, Cog] = {}
self.__extensions: Dict[str, types.ModuleType] = {}
self._is_closed: bool = False
self.owner_id: Optional[int] = owner_id
self.owner_ids: Set[int] = owner_ids or set()
self.owner: Optional[disnake.User] = None
self.owners: Set[disnake.TeamMember] = set()
if self.owner_id and self.owner_ids:
raise TypeError("Both owner_id and owner_ids are set.")
if self.owner_ids and not isinstance(self.owner_ids, collections.abc.Collection):
raise TypeError(f"owner_ids must be a collection not {self.owner_ids.__class__!r}")
self.reload: bool = reload
super().__init__(*args, **kwargs)
# FIXME: make event name pos-only or remove entirely in v3.0
def dispatch(self, event_name: str, *args: Any, **kwargs: Any) -> None:
# super() will resolve to Client
super().dispatch(event_name, *args, **kwargs) # type: ignore
async def _fill_owners(self) -> None:
if self.owner_id or self.owner_ids:
return
app: disnake.AppInfo = await self.application_info() # type: ignore
if app.team:
self.owners = owners = {
member
for member in app.team.members
# these roles can access the bot token, consider them bot owners
if member.role in (disnake.TeamMemberRole.admin, disnake.TeamMemberRole.developer)
}
self.owner_ids = {m.id for m in owners}
else:
self.owner = app.owner
self.owner_id = app.owner.id
async def close(self) -> None:
self._is_closed = True
for extension in tuple(self.__extensions):
try:
self.unload_extension(extension)
except Exception as error:
error.__suppress_context__ = True
_log.error("Failed to unload extension %r", extension, exc_info=error)
for cog in tuple(self.__cogs):
try:
self.remove_cog(cog)
except Exception as error:
error.__suppress_context__ = True
_log.exception("Failed to remove cog %r", cog, exc_info=error)
await super().close() # type: ignore
@disnake.utils.copy_doc(disnake.Client.login)
async def login(self, token: str) -> None:
await super().login(token=token) # type: ignore
loop: asyncio.AbstractEventLoop = self.loop # type: ignore
if self.reload:
loop.create_task(self._watchdog())
# prefetch
loop.create_task(self._fill_owners())
async def is_owner(self, user: Union[disnake.User, disnake.Member]) -> bool:
"""|coro|
Checks if a :class:`~disnake.User` or :class:`~disnake.Member` is the owner of
this bot.
If :attr:`owner_id` and :attr:`owner_ids` are not set, they are fetched automatically
through the use of :meth:`~.Bot.application_info`.
.. versionchanged:: 1.3
The function also checks if the application is team-owned if
:attr:`owner_ids` is not set.
.. versionchanged:: 2.10
Also takes team roles into account; only team members with the :attr:`~disnake.TeamMemberRole.admin`
or :attr:`~disnake.TeamMemberRole.developer` roles are considered bot owners.
Parameters
----------
user: :class:`.abc.User`
The user to check for.
Returns
-------
:class:`bool`
Whether the user is the owner.
"""
if not self.owner_id and not self.owner_ids:
await self._fill_owners()
if self.owner_id:
return user.id == self.owner_id
else:
return user.id in self.owner_ids
def add_cog(self, cog: Cog, *, override: bool = False) -> None:
"""Adds a "cog" to the bot.
A cog is a class that has its own event listeners and commands.
This automatically re-syncs application commands, provided that
:attr:`command_sync_flags.sync_on_cog_actions <.CommandSyncFlags.sync_on_cog_actions>`
isn't disabled.
.. versionchanged:: 2.0
:exc:`.ClientException` is raised when a cog with the same name
is already loaded.
Parameters
----------
cog: :class:`.Cog`
The cog to register to the bot.
override: :class:`bool`
If a previously loaded cog with the same name should be ejected
instead of raising an error.
.. versionadded:: 2.0
Raises
------
TypeError
The cog does not inherit from :class:`.Cog`.
CommandError
An error happened during loading.
ClientException
A cog with the same name is already loaded.
"""
if not isinstance(cog, Cog):
raise TypeError("cogs must derive from Cog")
cog_name = cog.__cog_name__
existing = self.__cogs.get(cog_name)
if existing is not None:
if not override:
raise disnake.ClientException(f"Cog named {cog_name!r} already loaded")
self.remove_cog(cog_name)
# NOTE: Should be covariant
cog = cog._inject(self) # type: ignore
self.__cogs[cog_name] = cog
def get_cog(self, name: str) -> Optional[Cog]:
"""Gets the cog instance requested.
If the cog is not found, ``None`` is returned instead.
Parameters
----------
name: :class:`str`
The name of the cog you are requesting.
This is equivalent to the name passed via keyword
argument in class creation or the class name if unspecified.
Returns
-------
Optional[:class:`Cog`]
The cog that was requested. If not found, returns ``None``.
"""
return self.__cogs.get(name)
def remove_cog(self, name: str) -> Optional[Cog]:
"""Removes a cog from the bot and returns it.
All registered commands and event listeners that the
cog has registered will be removed as well.
If no cog is found then this method has no effect.
This automatically re-syncs application commands, provided that
:attr:`command_sync_flags.sync_on_cog_actions <.CommandSyncFlags.sync_on_cog_actions>`
isn't disabled.
Parameters
----------
name: :class:`str`
The name of the cog to remove.
Returns
-------
Optional[:class:`.Cog`]
The cog that was removed. Returns ``None`` if not found.
"""
cog = self.__cogs.pop(name, None)
if cog is None:
return
help_command: Optional[HelpCommand] = getattr(self, "_help_command", None)
if help_command and help_command.cog is cog:
help_command.cog = None
# NOTE: Should be covariant
cog._eject(self) # type: ignore
return cog
@property
def cogs(self) -> Mapping[str, Cog]:
"""Mapping[:class:`str`, :class:`Cog`]: A read-only mapping of cog name to cog."""
return types.MappingProxyType(self.__cogs)
# extensions
def _remove_module_references(self, name: str) -> None:
# find all references to the module
# remove the cogs registered from the module
for cogname, cog in self.__cogs.copy().items():
if _is_submodule(name, cog.__module__):
self.remove_cog(cogname)
# remove all the listeners from the module
for event_list in self.extra_events.copy().values():
remove = [
index
for index, event in enumerate(event_list)
if event.__module__ and _is_submodule(name, event.__module__)
]
for index in reversed(remove):
del event_list[index]
def _call_module_finalizers(self, lib: types.ModuleType, key: str) -> None:
try:
func = lib.teardown
except AttributeError:
pass
else:
try:
func(self)
except Exception as error:
error.__suppress_context__ = True
_log.error("Exception in extension finalizer %r", key, exc_info=error)
finally:
self.__extensions.pop(key, None)
sys.modules.pop(key, None)
name = lib.__name__
for module in list(sys.modules.keys()):
if _is_submodule(name, module):
del sys.modules[module]
def _load_from_module_spec(self, spec: importlib.machinery.ModuleSpec, key: str) -> None:
# precondition: key not in self.__extensions
lib = importlib.util.module_from_spec(spec)
sys.modules[key] = lib
try:
spec.loader.exec_module(lib) # type: ignore
except Exception as e:
del sys.modules[key]
raise errors.ExtensionFailed(key, e) from e
try:
setup = lib.setup
except AttributeError:
del sys.modules[key]
raise errors.NoEntryPointError(key) from None
try:
setup(self)
except Exception as e:
del sys.modules[key]
self._remove_module_references(lib.__name__)
self._call_module_finalizers(lib, key)
raise errors.ExtensionFailed(key, e) from e
else:
self.__extensions[key] = lib
def _resolve_name(self, name: str, package: Optional[str]) -> str:
try:
return importlib.util.resolve_name(name, package)
except ImportError as e:
raise errors.ExtensionNotFound(name) from e
def load_extension(self, name: str, *, package: Optional[str] = None) -> None:
"""Loads an extension.
An extension is a python module that contains commands, cogs, or
listeners.
An extension must have a global function, ``setup`` defined as
the entry point on what to do when the extension is loaded. This entry
point must have a single argument, the ``bot``.
Parameters
----------
name: :class:`str`
The extension name to load. It must be dot separated like
regular Python imports if accessing a sub-module. e.g.
``foo.test`` if you want to import ``foo/test.py``.
package: Optional[:class:`str`]
The package name to resolve relative imports with.
This is required when loading an extension using a relative path, e.g ``.foo.test``.
Defaults to ``None``.
.. versionadded:: 1.7
Raises
------
ExtensionNotFound
The extension could not be imported.
This is also raised if the name of the extension could not
be resolved using the provided ``package`` parameter.
ExtensionAlreadyLoaded
The extension is already loaded.
NoEntryPointError
The extension does not have a setup function.
ExtensionFailed
The extension or its setup function had an execution error.
"""
name = self._resolve_name(name, package)
if name in self.__extensions:
raise errors.ExtensionAlreadyLoaded(name)
spec = importlib.util.find_spec(name)
if spec is None:
raise errors.ExtensionNotFound(name)
self._load_from_module_spec(spec, name)
def unload_extension(self, name: str, *, package: Optional[str] = None) -> None:
"""Unloads an extension.
When the extension is unloaded, all commands, listeners, and cogs are
removed from the bot and the module is un-imported.
The extension can provide an optional global function, ``teardown``,
to do miscellaneous clean-up if necessary. This function takes a single
parameter, the ``bot``, similar to ``setup`` from
:meth:`~.Bot.load_extension`.
Parameters
----------
name: :class:`str`
The extension name to unload. It must be dot separated like
regular Python imports if accessing a sub-module. e.g.
``foo.test`` if you want to import ``foo/test.py``.
package: Optional[:class:`str`]
The package name to resolve relative imports with.
This is required when unloading an extension using a relative path, e.g ``.foo.test``.
Defaults to ``None``.
.. versionadded:: 1.7
Raises
------
ExtensionNotFound
The name of the extension could not
be resolved using the provided ``package`` parameter.
ExtensionNotLoaded
The extension was not loaded.
"""
name = self._resolve_name(name, package)
lib = self.__extensions.get(name)
if lib is None:
raise errors.ExtensionNotLoaded(name)
self._remove_module_references(lib.__name__)
self._call_module_finalizers(lib, name)
def reload_extension(self, name: str, *, package: Optional[str] = None) -> None:
"""Atomically reloads an extension.
This replaces the extension with the same extension, only refreshed. This is
equivalent to a :meth:`unload_extension` followed by a :meth:`load_extension`
except done in an atomic way. That is, if an operation fails mid-reload then
the bot will roll-back to the prior working state.
Parameters
----------
name: :class:`str`
The extension name to reload. It must be dot separated like
regular Python imports if accessing a sub-module. e.g.
``foo.test`` if you want to import ``foo/test.py``.
package: Optional[:class:`str`]
The package name to resolve relative imports with.
This is required when reloading an extension using a relative path, e.g ``.foo.test``.
Defaults to ``None``.
.. versionadded:: 1.7
Raises
------
ExtensionNotLoaded
The extension was not loaded.
ExtensionNotFound
The extension could not be imported.
This is also raised if the name of the extension could not
be resolved using the provided ``package`` parameter.
NoEntryPointError
The extension does not have a setup function.
ExtensionFailed
The extension setup function had an execution error.
"""
name = self._resolve_name(name, package)
lib = self.__extensions.get(name)
if lib is None:
raise errors.ExtensionNotLoaded(name)
# get the previous module states from sys modules
modules = {
name: module
for name, module in sys.modules.items()
if _is_submodule(lib.__name__, name)
}
try:
# Unload and then load the module...
self._remove_module_references(lib.__name__)
self._call_module_finalizers(lib, name)
self.load_extension(name)
except Exception:
# if the load failed, the remnants should have been
# cleaned from the load_extension function call
# so let's load it from our old compiled library.
lib.setup(self)
self.__extensions[name] = lib
# revert sys.modules back to normal and raise back to caller
sys.modules.update(modules)
raise
def load_extensions(self, path: str) -> None:
"""Loads all extensions in a directory.
.. versionadded:: 2.4
Parameters
----------
path: :class:`str`
The path to search for extensions
"""
for extension in disnake.utils.search_directory(path):
self.load_extension(extension)
@property
def extensions(self) -> Mapping[str, types.ModuleType]:
"""Mapping[:class:`str`, :class:`py:types.ModuleType`]: A read-only mapping of extension name to extension."""
return types.MappingProxyType(self.__extensions)
async def _watchdog(self) -> None:
"""|coro|
Starts the bot watchdog which will watch currently loaded extensions
and reload them when they're modified.
"""
if isinstance(self, disnake.Client):
await self.wait_until_ready()
reload_log = logging.getLogger(__name__)
if isinstance(self, disnake.Client):
is_closed = self.is_closed
else:
is_closed = lambda: False
reload_log.info("WATCHDOG: Watching extensions")
last = time.time()
while not is_closed():
t = time.time()
extensions = set()
for name, module in self.extensions.items():
file = module.__file__
if file and os.stat(file).st_mtime > last:
extensions.add(name)
if extensions:
try:
self.i18n.reload() # type: ignore
except Exception as e:
reload_log.exception(e)
for name in extensions:
try:
self.reload_extension(name)
except errors.ExtensionError as e:
reload_log.exception(e)
else:
reload_log.info(f"WATCHDOG: Reloaded '{name}'")
await asyncio.sleep(1)
last = t

View File

@@ -0,0 +1,387 @@
# SPDX-License-Identifier: MIT
from __future__ import annotations
import inspect
import re
from typing import TYPE_CHECKING, Any, Dict, Generic, List, Optional, TypeVar, Union
import disnake.abc
import disnake.utils
from disnake import ApplicationCommandInteraction
from disnake.message import Message
if TYPE_CHECKING:
from typing_extensions import ParamSpec
from disnake.channel import DMChannel, GroupChannel
from disnake.guild import Guild, GuildMessageable
from disnake.member import Member
from disnake.state import ConnectionState
from disnake.user import ClientUser, User
from disnake.voice_client import VoiceProtocol
from .bot import AutoShardedBot, Bot
from .cog import Cog
from .core import Command
from .view import StringView
__all__ = ("Context", "GuildContext")
MISSING: Any = disnake.utils.MISSING
T = TypeVar("T")
BotT = TypeVar("BotT", bound="Union[Bot, AutoShardedBot]")
CogT = TypeVar("CogT", bound="Cog")
if TYPE_CHECKING:
P = ParamSpec("P")
else:
P = TypeVar("P")
class Context(disnake.abc.Messageable, Generic[BotT]):
"""Represents the context in which a command is being invoked under.
This class contains a lot of meta data to help you understand more about
the invocation context. This class is not created manually and is instead
passed around to commands as the first parameter.
This class implements the :class:`.abc.Messageable` ABC.
Attributes
----------
message: :class:`.Message`
The message that triggered the command being executed.
bot: :class:`.Bot`
The bot that contains the command being executed.
args: :class:`list`
The list of transformed arguments that were passed into the command.
If this is accessed during the :func:`.on_command_error` event
then this list could be incomplete.
kwargs: :class:`dict`
A dictionary of transformed arguments that were passed into the command.
Similar to :attr:`args`\\, if this is accessed in the
:func:`.on_command_error` event then this dict could be incomplete.
current_parameter: Optional[:class:`inspect.Parameter`]
The parameter that is currently being inspected and converted.
This is only of use for within converters.
.. versionadded:: 2.0
prefix: Optional[:class:`str`]
The prefix that was used to invoke the command.
command: Optional[:class:`Command`]
The command that is being invoked currently.
invoked_with: Optional[:class:`str`]
The command name that triggered this invocation. Useful for finding out
which alias called the command.
invoked_parents: List[:class:`str`]
The command names of the parents that triggered this invocation. Useful for
finding out which aliases called the command.
For example in commands ``?a b c test``, the invoked parents are ``['a', 'b', 'c']``.
.. versionadded:: 1.7
invoked_subcommand: Optional[:class:`Command`]
The subcommand that was invoked.
If no valid subcommand was invoked then this is equal to ``None``.
subcommand_passed: Optional[:class:`str`]
The string that was attempted to call a subcommand. This does not have
to point to a valid registered subcommand and could just point to a
nonsense string. If nothing was passed to attempt a call to a
subcommand then this is set to ``None``.
command_failed: :class:`bool`
Whether the command failed to be parsed, checked, or invoked.
"""
def __init__(
self,
*,
message: Message,
bot: BotT,
view: StringView,
args: List[Any] = MISSING,
kwargs: Dict[str, Any] = MISSING,
prefix: Optional[str] = None,
command: Optional[Command] = None,
invoked_with: Optional[str] = None,
invoked_parents: List[str] = MISSING,
invoked_subcommand: Optional[Command] = None,
subcommand_passed: Optional[str] = None,
command_failed: bool = False,
current_parameter: Optional[inspect.Parameter] = None,
) -> None:
self.message: Message = message
self.bot: BotT = bot
self.args: List[Any] = args or []
self.kwargs: Dict[str, Any] = kwargs or {}
self.prefix: Optional[str] = prefix
self.command: Optional[Command] = command
self.view: StringView = view
self.invoked_with: Optional[str] = invoked_with
self.invoked_parents: List[str] = invoked_parents or []
self.invoked_subcommand: Optional[Command] = invoked_subcommand
self.subcommand_passed: Optional[str] = subcommand_passed
self.command_failed: bool = command_failed
self.current_parameter: Optional[inspect.Parameter] = current_parameter
self._state: ConnectionState = self.message._state
async def invoke(self, command: Command[CogT, P, T], /, *args: P.args, **kwargs: P.kwargs) -> T:
"""|coro|
Calls a command with the arguments given.
This is useful if you want to just call the callback that a
:class:`.Command` holds internally.
.. note::
This does not handle converters, checks, cooldowns, pre-invoke,
or after-invoke hooks in any matter. It calls the internal callback
directly as-if it was a regular function.
You must take care in passing the proper arguments when
using this function.
Parameters
----------
command: :class:`.Command`
The command that is going to be called.
*args
The arguments to use.
**kwargs
The keyword arguments to use.
Raises
------
TypeError
The command argument to invoke is missing.
"""
return await command(self, *args, **kwargs)
async def reinvoke(self, *, call_hooks: bool = False, restart: bool = True) -> None:
"""|coro|
Calls the command again.
This is similar to :meth:`.invoke` except that it bypasses
checks, cooldowns, and error handlers.
.. note::
If you want to bypass :exc:`.UserInputError` derived exceptions,
it is recommended to use the regular :meth:`.invoke`
as it will work more naturally. After all, this will end up
using the old arguments the user has used and will thus just
fail again.
Parameters
----------
call_hooks: :class:`bool`
Whether to call the before and after invoke hooks.
restart: :class:`bool`
Whether to start the call chain from the very beginning
or where we left off (i.e. the command that caused the error).
The default is to start where we left off.
Raises
------
ValueError
The context to reinvoke is not valid.
"""
cmd = self.command
view = self.view
if cmd is None:
raise ValueError("This context is not valid.")
# some state to revert to when we're done
index, previous = view.index, view.previous
invoked_with = self.invoked_with
invoked_subcommand = self.invoked_subcommand
invoked_parents = self.invoked_parents
subcommand_passed = self.subcommand_passed
if restart:
to_call = cmd.root_parent or cmd
view.index = len(self.prefix or "")
view.previous = 0
self.invoked_parents = []
self.invoked_with = view.get_word() # advance to get the root command
else:
to_call = cmd
try:
await to_call.reinvoke(self, call_hooks=call_hooks)
finally:
self.command = cmd
view.index = index
view.previous = previous
self.invoked_with = invoked_with
self.invoked_subcommand = invoked_subcommand
self.invoked_parents = invoked_parents
self.subcommand_passed = subcommand_passed
@property
def valid(self) -> bool:
""":class:`bool`: Whether the invocation context is valid to be invoked with."""
return self.prefix is not None and self.command is not None
async def _get_channel(self) -> disnake.abc.Messageable:
return self.channel
@property
def clean_prefix(self) -> str:
""":class:`str`: The cleaned up invoke prefix. i.e. mentions are ``@name`` instead of ``<@id>``.
.. versionadded:: 2.0
"""
if self.prefix is None:
return ""
user = self.me
# this breaks if the prefix mention is not the bot itself but I
# consider this to be an *incredibly* strange use case. I'd rather go
# for this common use case rather than waste performance for the
# odd one.
pattern = re.compile(rf"<@!?{user.id}>")
return pattern.sub("@" + user.display_name.replace("\\", r"\\"), self.prefix)
@property
def cog(self) -> Optional[Cog]:
"""Optional[:class:`.Cog`]: Returns the cog associated with this context's command. Returns ``None`` if it does not exist."""
if self.command is None:
return None
return self.command.cog
@disnake.utils.cached_property
def guild(self) -> Optional[Guild]:
"""Optional[:class:`.Guild`]: Returns the guild associated with this context's command. Returns ``None`` if not available."""
return self.message.guild
@disnake.utils.cached_property
def channel(self) -> Union[GuildMessageable, DMChannel, GroupChannel]:
"""Union[:class:`.abc.Messageable`]: Returns the channel associated with this context's command.
Shorthand for :attr:`.Message.channel`.
"""
return self.message.channel
@disnake.utils.cached_property
def author(self) -> Union[User, Member]:
"""Union[:class:`~disnake.User`, :class:`.Member`]:
Returns the author associated with this context's command. Shorthand for :attr:`.Message.author`
"""
return self.message.author
@disnake.utils.cached_property
def me(self) -> Union[Member, ClientUser]:
"""Union[:class:`.Member`, :class:`.ClientUser`]:
Similar to :attr:`.Guild.me` except it may return the :class:`.ClientUser` in private message contexts.
"""
# bot.user will never be None at this point.
return self.guild.me if self.guild is not None else self.bot.user
@property
def voice_client(self) -> Optional[VoiceProtocol]:
r"""Optional[:class:`.VoiceProtocol`]: A shortcut to :attr:`.Guild.voice_client`\, if applicable."""
g = self.guild
return g.voice_client if g else None
async def send_help(self, *args: Any) -> Any:
"""|coro|
Shows the help command for the specified entity if given.
The entity can be a command or a cog.
If no entity is given, then it'll show help for the
entire bot.
If the entity is a string, then it looks up whether it's a
:class:`Cog` or a :class:`Command`.
.. note::
Due to the way this function works, instead of returning
something similar to :meth:`~.commands.HelpCommand.command_not_found`
this returns :class:`None` on bad input or no help command.
Parameters
----------
entity: Optional[Union[:class:`Command`, :class:`Cog`, :class:`str`]]
The entity to show help for.
Returns
-------
Any
The result of the help command, if any.
"""
from .core import Command, Group, wrap_callback
from .errors import CommandError
bot = self.bot
cmd = bot.help_command
if cmd is None:
return None
cmd = cmd.copy()
cmd.context = self
if len(args) == 0:
await cmd.prepare_help_command(self, None)
mapping = cmd.get_bot_mapping()
injected = wrap_callback(cmd.send_bot_help)
try:
return await injected(mapping)
except CommandError as e:
await cmd.on_help_command_error(self, e)
return None
entity = args[0]
if isinstance(entity, str):
entity = bot.get_cog(entity) or bot.get_command(entity)
if entity is None:
return None
if not hasattr(entity, "qualified_name"):
# if we're here then it's not a cog, group, or command.
return None
await cmd.prepare_help_command(self, entity.qualified_name)
try:
if hasattr(entity, "__cog_commands__"):
injected = wrap_callback(cmd.send_cog_help)
return await injected(entity)
elif isinstance(entity, Group):
injected = wrap_callback(cmd.send_group_help)
return await injected(entity)
elif isinstance(entity, Command):
injected = wrap_callback(cmd.send_command_help)
return await injected(entity)
else:
return None
except CommandError as e:
await cmd.on_help_command_error(self, e)
@disnake.utils.copy_doc(Message.reply)
async def reply(self, content: Optional[str] = None, **kwargs: Any) -> Message:
return await self.message.reply(content, **kwargs)
class GuildContext(Context):
"""A Context subclass meant for annotation
No runtime behavior is changed but annotations are modified
to seem like the context may never be invoked in a DM.
"""
guild: Guild
channel: GuildMessageable
author: Member
me: Member
AnyContext = Union[Context, ApplicationCommandInteraction]

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,391 @@
# SPDX-License-Identifier: MIT
from __future__ import annotations
import asyncio
import time
from collections import deque
from typing import TYPE_CHECKING, Any, Callable, Deque, Dict, Optional
from disnake.enums import Enum
from disnake.member import Member
from .errors import MaxConcurrencyReached
if TYPE_CHECKING:
from typing_extensions import Self
from ...message import Message
__all__ = (
"BucketType",
"Cooldown",
"CooldownMapping",
"DynamicCooldownMapping",
"MaxConcurrency",
)
class BucketType(Enum):
"""Specifies a type of bucket for, e.g. a cooldown."""
default = 0
"""The default bucket operates on a global basis."""
user = 1
"""The user bucket operates on a per-user basis."""
guild = 2
"""The guild bucket operates on a per-guild basis."""
channel = 3
"""The channel bucket operates on a per-channel basis."""
member = 4
"""The member bucket operates on a per-member basis."""
category = 5
"""The category bucket operates on a per-category basis."""
role = 6
"""The role bucket operates on a per-role basis.
.. versionadded:: 1.3
"""
def get_key(self, msg: Message) -> Any:
if self is BucketType.user:
return msg.author.id
elif self is BucketType.guild:
return (msg.guild or msg.author).id
elif self is BucketType.channel:
return msg.channel.id
elif self is BucketType.member:
return ((msg.guild and msg.guild.id), msg.author.id)
elif self is BucketType.category:
return (msg.channel.category or msg.channel).id # type: ignore
elif self is BucketType.role:
# if author is not a Member we are in a private-channel context; returning its id
# yields the same result as for a guild with only the @everyone role
return (
msg.author.top_role if msg.guild and isinstance(msg.author, Member) else msg.channel
).id
def __call__(self, msg: Message) -> Any:
return self.get_key(msg)
class Cooldown:
"""Represents a cooldown for a command.
Attributes
----------
rate: :class:`int`
The total number of tokens available per :attr:`per` seconds.
per: :class:`float`
The length of the cooldown period in seconds.
"""
__slots__ = ("rate", "per", "_window", "_tokens", "_last")
def __init__(self, rate: float, per: float) -> None:
self.rate: int = int(rate)
self.per: float = float(per)
self._window: float = 0.0
self._tokens: int = self.rate
self._last: float = 0.0
def get_tokens(self, current: Optional[float] = None) -> int:
"""Returns the number of available tokens before rate limiting is applied.
Parameters
----------
current: Optional[:class:`float`]
The time in seconds since Unix epoch to calculate tokens at.
If not supplied then :func:`time.time()` is used.
Returns
-------
:class:`int`
The number of tokens available before the cooldown is to be applied.
"""
if not current:
current = time.time()
tokens = self._tokens
if current > self._window + self.per:
tokens = self.rate
return tokens
def get_retry_after(self, current: Optional[float] = None) -> float:
"""Returns the time in seconds until the cooldown will be reset.
Parameters
----------
current: Optional[:class:`float`]
The current time in seconds since Unix epoch.
If not supplied, then :func:`time.time()` is used.
Returns
-------
:class:`float`
The number of seconds to wait before this cooldown will be reset.
"""
current = current or time.time()
tokens = self.get_tokens(current)
if tokens == 0:
return self.per - (current - self._window)
return 0.0
def update_rate_limit(self, current: Optional[float] = None) -> Optional[float]:
"""Updates the cooldown rate limit.
Parameters
----------
current: Optional[:class:`float`]
The time in seconds since Unix epoch to update the rate limit at.
If not supplied, then :func:`time.time()` is used.
Returns
-------
Optional[:class:`float`]
The retry-after time in seconds if rate limited.
"""
current = current or time.time()
self._last = current
self._tokens = self.get_tokens(current)
# first token used means that we start a new rate limit window
if self._tokens == self.rate:
self._window = current
# check if we are rate limited
if self._tokens == 0:
return self.per - (current - self._window)
# we're not so decrement our tokens
self._tokens -= 1
def reset(self) -> None:
"""Reset the cooldown to its initial state."""
self._tokens = self.rate
self._last = 0.0
def copy(self) -> Cooldown:
"""Creates a copy of this cooldown.
Returns
-------
:class:`Cooldown`
A new instance of this cooldown.
"""
return Cooldown(self.rate, self.per)
def __repr__(self) -> str:
return f"<Cooldown rate: {self.rate} per: {self.per} window: {self._window} tokens: {self._tokens}>"
class CooldownMapping:
def __init__(
self,
original: Optional[Cooldown],
type: Callable[[Message], Any],
) -> None:
if not callable(type):
raise TypeError("Cooldown type must be a BucketType or callable")
self._cache: Dict[Any, Cooldown] = {}
self._cooldown: Optional[Cooldown] = original
self._type: Callable[[Message], Any] = type
def copy(self) -> CooldownMapping:
ret = CooldownMapping(self._cooldown, self._type)
ret._cache = self._cache.copy()
return ret
@property
def valid(self) -> bool:
return self._cooldown is not None
@property
def type(self) -> Callable[[Message], Any]:
return self._type
@classmethod
def from_cooldown(cls, rate: float, per: float, type) -> Self:
return cls(Cooldown(rate, per), type)
def _bucket_key(self, msg: Message) -> Any:
return self._type(msg)
def _verify_cache_integrity(self, current: Optional[float] = None) -> None:
# we want to delete all cache objects that haven't been used
# in a cooldown window. e.g. if we have a command that has a
# cooldown of 60s and it has not been used in 60s then that key should be deleted
current = current or time.time()
dead_keys = [k for k, v in self._cache.items() if current > v._last + v.per]
for k in dead_keys:
del self._cache[k]
def _is_default(self) -> bool:
# This method can be overridden in subclasses
return self._type is BucketType.default
def create_bucket(self, message: Message) -> Cooldown:
return self._cooldown.copy() # type: ignore
def get_bucket(self, message: Message, current: Optional[float] = None) -> Cooldown:
if self._is_default():
return self._cooldown # type: ignore
self._verify_cache_integrity(current)
key = self._bucket_key(message)
if key not in self._cache:
bucket = self.create_bucket(message)
if bucket is not None: # pyright: ignore[reportUnnecessaryComparison]
self._cache[key] = bucket
else:
bucket = self._cache[key]
return bucket
def update_rate_limit(
self, message: Message, current: Optional[float] = None
) -> Optional[float]:
bucket = self.get_bucket(message, current)
return bucket.update_rate_limit(current)
class DynamicCooldownMapping(CooldownMapping):
def __init__(
self, factory: Callable[[Message], Cooldown], type: Callable[[Message], Any]
) -> None:
super().__init__(None, type)
self._factory: Callable[[Message], Cooldown] = factory
def copy(self) -> DynamicCooldownMapping:
ret = DynamicCooldownMapping(self._factory, self._type)
ret._cache = self._cache.copy()
return ret
@property
def valid(self) -> bool:
return True
def _is_default(self) -> bool:
# In dynamic mappings even default bucket types may have custom behavior
return False
def create_bucket(self, message: Message) -> Cooldown:
return self._factory(message)
class _Semaphore:
"""A custom version of a semaphore.
If you're wondering why asyncio.Semaphore isn't being used,
it's because it doesn't expose the internal value. This internal
value is necessary because I need to support both `wait=True` and
`wait=False`.
An asyncio.Queue could have been used to do this as well -- but it is
not as inefficient since internally that uses two queues and is a bit
overkill for what is basically a counter.
"""
__slots__ = ("value", "loop", "_waiters")
def __init__(self, number: int) -> None:
self.value: int = number
self.loop: asyncio.AbstractEventLoop = asyncio.get_running_loop()
self._waiters: Deque[asyncio.Future] = deque()
def __repr__(self) -> str:
return f"<_Semaphore value={self.value} waiters={len(self._waiters)}>"
def locked(self) -> bool:
return self.value == 0
def is_active(self) -> bool:
return len(self._waiters) > 0
def wake_up(self) -> None:
while self._waiters:
future = self._waiters.popleft()
if not future.done():
future.set_result(None)
return
async def acquire(self, *, wait: bool = False) -> bool:
if not wait and self.value <= 0:
# signal that we're not acquiring
return False
while self.value <= 0:
future = self.loop.create_future()
self._waiters.append(future)
try:
await future
except Exception:
future.cancel()
if self.value > 0 and not future.cancelled():
self.wake_up()
raise
self.value -= 1
return True
def release(self) -> None:
self.value += 1
self.wake_up()
class MaxConcurrency:
__slots__ = ("number", "per", "wait", "_mapping")
def __init__(self, number: int, *, per: BucketType, wait: bool) -> None:
self._mapping: Dict[Any, _Semaphore] = {}
self.per: BucketType = per
self.number: int = number
self.wait: bool = wait
if number <= 0:
raise ValueError("max_concurrency 'number' cannot be less than 1")
if not isinstance(per, BucketType):
raise TypeError(f"max_concurrency 'per' must be of type BucketType not {type(per)!r}")
def copy(self) -> Self:
return self.__class__(self.number, per=self.per, wait=self.wait)
def __repr__(self) -> str:
return f"<MaxConcurrency per={self.per!r} number={self.number} wait={self.wait}>"
def get_key(self, message: Message) -> Any:
return self.per.get_key(message)
async def acquire(self, message: Message) -> None:
key = self.get_key(message)
try:
sem = self._mapping[key]
except KeyError:
self._mapping[key] = sem = _Semaphore(self.number)
acquired = await sem.acquire(wait=self.wait)
if not acquired:
raise MaxConcurrencyReached(self.number, self.per)
async def release(self, message: Message) -> None:
# Technically there's no reason for this function to be async
# But it might be more useful in the future
key = self.get_key(message)
try:
sem = self._mapping[key]
except KeyError:
# ...? peculiar
return
else:
sem.release()
if sem.value >= self.number and not sem.is_active():
del self._mapping[key]

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,466 @@
# SPDX-License-Identifier: MIT
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Sequence, Tuple, Union
from disnake.app_commands import MessageCommand, UserCommand
from disnake.flags import ApplicationInstallTypes, InteractionContextTypes
from disnake.i18n import Localized
from disnake.permissions import Permissions
from disnake.utils import iscoroutinefunction
from .base_core import InvokableApplicationCommand, _get_overridden_method
from .errors import CommandError
from .params import safe_call
if TYPE_CHECKING:
from typing_extensions import ParamSpec
from disnake.i18n import LocalizedOptional
from disnake.interactions import (
ApplicationCommandInteraction,
MessageCommandInteraction,
UserCommandInteraction,
)
from .base_core import CogT, InteractionCommandCallback
P = ParamSpec("P")
__all__ = ("InvokableUserCommand", "InvokableMessageCommand", "user_command", "message_command")
class InvokableUserCommand(InvokableApplicationCommand):
"""A class that implements the protocol for a bot user command (context menu).
These are not created manually, instead they are created via the
decorator or functional interface.
Attributes
----------
name: :class:`str`
The name of the user command.
qualified_name: :class:`str`
The full command name, equivalent to :attr:`.name` for this type of command.
body: :class:`.UserCommand`
An object being registered in the API.
callback: :ref:`coroutine <coroutine>`
The coroutine that is executed when the user command is called.
cog: Optional[:class:`Cog`]
The cog that this user command belongs to. ``None`` if there isn't one.
checks: List[Callable[[:class:`.ApplicationCommandInteraction`], :class:`bool`]]
A list of predicates that verifies if the command could be executed
with the given :class:`.ApplicationCommandInteraction` as the sole parameter. If an exception
is necessary to be thrown to signal failure, then one inherited from
:exc:`.CommandError` should be used. Note that if the checks fail then
:exc:`.CheckFailure` exception is raised to the :func:`.on_user_command_error`
event.
guild_ids: Optional[Tuple[:class:`int`, ...]]
The list of IDs of the guilds where the command is synced. ``None`` if this command is global.
auto_sync: :class:`bool`
Whether to automatically register the command.
extras: Dict[:class:`str`, Any]
A dict of user provided extras to attach to the command.
.. note::
This object may be copied by the library.
.. versionadded:: 2.5
"""
def __init__(
self,
func: InteractionCommandCallback[CogT, UserCommandInteraction, P],
*,
name: LocalizedOptional = None,
dm_permission: Optional[bool] = None, # deprecated
default_member_permissions: Optional[Union[Permissions, int]] = None,
nsfw: Optional[bool] = None,
install_types: Optional[ApplicationInstallTypes] = None,
contexts: Optional[InteractionContextTypes] = None,
guild_ids: Optional[Sequence[int]] = None,
auto_sync: Optional[bool] = None,
**kwargs: Any,
) -> None:
name_loc = Localized._cast(name, False)
super().__init__(func, name=name_loc.string, **kwargs)
self.guild_ids: Optional[Tuple[int, ...]] = None if guild_ids is None else tuple(guild_ids)
self.auto_sync: bool = True if auto_sync is None else auto_sync
try:
default_member_permissions = func.__default_member_permissions__
except AttributeError:
pass
try:
install_types = func.__install_types__
except AttributeError:
pass
try:
contexts = func.__contexts__
except AttributeError:
pass
self.body = UserCommand(
name=name_loc._upgrade(self.name),
dm_permission=dm_permission,
default_member_permissions=default_member_permissions,
nsfw=nsfw,
install_types=install_types,
contexts=contexts,
)
self._apply_guild_only()
async def _call_external_error_handlers(
self, inter: ApplicationCommandInteraction, error: CommandError
) -> None:
stop_propagation = False
cog = self.cog
try:
if cog is not None:
local = _get_overridden_method(cog.cog_user_command_error)
if local is not None:
stop_propagation = await local(inter, error)
# User has an option to cancel the global error handler by returning True
finally:
if not stop_propagation:
inter.bot.dispatch("user_command_error", inter, error)
async def __call__(
self,
interaction: ApplicationCommandInteraction,
target: Any = None,
*args: Any,
**kwargs: Any,
) -> None:
# the target may just not be passed in
args = (target or interaction.target, *args)
if self.cog is not None:
await safe_call(self.callback, self.cog, interaction, *args, **kwargs)
else:
await safe_call(self.callback, interaction, *args, **kwargs)
class InvokableMessageCommand(InvokableApplicationCommand):
"""A class that implements the protocol for a bot message command (context menu).
These are not created manually, instead they are created via the
decorator or functional interface.
Attributes
----------
name: :class:`str`
The name of the message command.
qualified_name: :class:`str`
The full command name, equivalent to :attr:`.name` for this type of command.
body: :class:`.MessageCommand`
An object being registered in the API.
callback: :ref:`coroutine <coroutine>`
The coroutine that is executed when the message command is called.
cog: Optional[:class:`Cog`]
The cog that this message command belongs to. ``None`` if there isn't one.
checks: List[Callable[[:class:`.ApplicationCommandInteraction`], :class:`bool`]]
A list of predicates that verifies if the command could be executed
with the given :class:`.ApplicationCommandInteraction` as the sole parameter. If an exception
is necessary to be thrown to signal failure, then one inherited from
:exc:`.CommandError` should be used. Note that if the checks fail then
:exc:`.CheckFailure` exception is raised to the :func:`.on_message_command_error`
event.
guild_ids: Optional[Tuple[:class:`int`, ...]]
The list of IDs of the guilds where the command is synced. ``None`` if this command is global.
auto_sync: :class:`bool`
Whether to automatically register the command.
extras: Dict[:class:`str`, Any]
A dict of user provided extras to attach to the command.
.. note::
This object may be copied by the library.
.. versionadded:: 2.5
"""
def __init__(
self,
func: InteractionCommandCallback[CogT, MessageCommandInteraction, P],
*,
name: LocalizedOptional = None,
dm_permission: Optional[bool] = None, # deprecated
default_member_permissions: Optional[Union[Permissions, int]] = None,
nsfw: Optional[bool] = None,
install_types: Optional[ApplicationInstallTypes] = None,
contexts: Optional[InteractionContextTypes] = None,
guild_ids: Optional[Sequence[int]] = None,
auto_sync: Optional[bool] = None,
**kwargs: Any,
) -> None:
name_loc = Localized._cast(name, False)
super().__init__(func, name=name_loc.string, **kwargs)
self.guild_ids: Optional[Tuple[int, ...]] = None if guild_ids is None else tuple(guild_ids)
self.auto_sync: bool = True if auto_sync is None else auto_sync
try:
default_member_permissions = func.__default_member_permissions__
except AttributeError:
pass
try:
install_types = func.__install_types__
except AttributeError:
pass
try:
contexts = func.__contexts__
except AttributeError:
pass
self.body = MessageCommand(
name=name_loc._upgrade(self.name),
dm_permission=dm_permission,
default_member_permissions=default_member_permissions,
nsfw=nsfw,
install_types=install_types,
contexts=contexts,
)
self._apply_guild_only()
async def _call_external_error_handlers(
self, inter: ApplicationCommandInteraction, error: CommandError
) -> None:
stop_propagation = False
cog = self.cog
try:
if cog is not None:
local = _get_overridden_method(cog.cog_message_command_error)
if local is not None:
stop_propagation = await local(inter, error)
# User has an option to cancel the global error handler by returning True
finally:
if not stop_propagation:
inter.bot.dispatch("message_command_error", inter, error)
async def __call__(
self,
interaction: ApplicationCommandInteraction,
target: Any = None,
*args: Any,
**kwargs: Any,
) -> None:
# the target may just not be passed in
args = (target or interaction.target, *args)
if self.cog is not None:
await safe_call(self.callback, self.cog, interaction, *args, **kwargs)
else:
await safe_call(self.callback, interaction, *args, **kwargs)
def user_command(
*,
name: LocalizedOptional = None,
dm_permission: Optional[bool] = None, # deprecated
default_member_permissions: Optional[Union[Permissions, int]] = None,
nsfw: Optional[bool] = None,
install_types: Optional[ApplicationInstallTypes] = None,
contexts: Optional[InteractionContextTypes] = None,
guild_ids: Optional[Sequence[int]] = None,
auto_sync: Optional[bool] = None,
extras: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Callable[[InteractionCommandCallback[CogT, UserCommandInteraction, P]], InvokableUserCommand]:
"""A shortcut decorator that builds a user command.
Parameters
----------
name: Optional[Union[:class:`str`, :class:`.Localized`]]
The name of the user command (defaults to the function name).
.. versionchanged:: 2.5
Added support for localizations.
dm_permission: :class:`bool`
Whether this command can be used in DMs.
Defaults to ``True``.
.. deprecated:: 2.10
Use ``contexts`` instead.
This is equivalent to the :attr:`.InteractionContextTypes.bot_dm` flag.
default_member_permissions: Optional[Union[:class:`.Permissions`, :class:`int`]]
The default required permissions for this command.
See :attr:`.ApplicationCommand.default_member_permissions` for details.
.. versionadded:: 2.5
nsfw: :class:`bool`
Whether this command is :ddocs:`age-restricted <interactions/application-commands#agerestricted-commands>`.
Defaults to ``False``.
.. versionadded:: 2.8
install_types: Optional[:class:`.ApplicationInstallTypes`]
The installation types where the command is available.
Defaults to :attr:`.ApplicationInstallTypes.guild` only.
Only available for global commands.
See :ref:`app_command_contexts` for details.
.. versionadded:: 2.10
contexts: Optional[:class:`.InteractionContextTypes`]
The interaction contexts where the command can be used.
Only available for global commands.
See :ref:`app_command_contexts` for details.
.. versionadded:: 2.10
auto_sync: :class:`bool`
Whether to automatically register the command. Defaults to ``True``.
guild_ids: Sequence[:class:`int`]
If specified, the client will register the command in these guilds.
Otherwise, this command will be registered globally.
extras: Dict[:class:`str`, Any]
A dict of user provided extras to attach to the command.
.. note::
This object may be copied by the library.
.. versionadded:: 2.5
Returns
-------
Callable[..., :class:`InvokableUserCommand`]
A decorator that converts the provided method into an InvokableUserCommand and returns it.
"""
def decorator(
func: InteractionCommandCallback[CogT, UserCommandInteraction, P],
) -> InvokableUserCommand:
if not iscoroutinefunction(func):
raise TypeError(f"<{func.__qualname__}> must be a coroutine function")
if hasattr(func, "__command_flag__"):
raise TypeError("Callback is already a command.")
if guild_ids and not all(isinstance(guild_id, int) for guild_id in guild_ids):
raise ValueError("guild_ids must be a sequence of int.")
return InvokableUserCommand(
func,
name=name,
dm_permission=dm_permission,
default_member_permissions=default_member_permissions,
nsfw=nsfw,
install_types=install_types,
contexts=contexts,
guild_ids=guild_ids,
auto_sync=auto_sync,
extras=extras,
**kwargs,
)
return decorator
def message_command(
*,
name: LocalizedOptional = None,
dm_permission: Optional[bool] = None, # deprecated
default_member_permissions: Optional[Union[Permissions, int]] = None,
nsfw: Optional[bool] = None,
install_types: Optional[ApplicationInstallTypes] = None,
contexts: Optional[InteractionContextTypes] = None,
guild_ids: Optional[Sequence[int]] = None,
auto_sync: Optional[bool] = None,
extras: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Callable[
[InteractionCommandCallback[CogT, MessageCommandInteraction, P]],
InvokableMessageCommand,
]:
"""A shortcut decorator that builds a message command.
Parameters
----------
name: Optional[Union[:class:`str`, :class:`.Localized`]]
The name of the message command (defaults to the function name).
.. versionchanged:: 2.5
Added support for localizations.
dm_permission: :class:`bool`
Whether this command can be used in DMs.
Defaults to ``True``.
.. deprecated:: 2.10
Use ``contexts`` instead.
This is equivalent to the :attr:`.InteractionContextTypes.bot_dm` flag.
default_member_permissions: Optional[Union[:class:`.Permissions`, :class:`int`]]
The default required permissions for this command.
See :attr:`.ApplicationCommand.default_member_permissions` for details.
.. versionadded:: 2.5
nsfw: :class:`bool`
Whether this command is :ddocs:`age-restricted <interactions/application-commands#agerestricted-commands>`.
Defaults to ``False``.
.. versionadded:: 2.8
install_types: Optional[:class:`.ApplicationInstallTypes`]
The installation types where the command is available.
Defaults to :attr:`.ApplicationInstallTypes.guild` only.
Only available for global commands.
See :ref:`app_command_contexts` for details.
.. versionadded:: 2.10
contexts: Optional[:class:`.InteractionContextTypes`]
The interaction contexts where the command can be used.
Only available for global commands.
See :ref:`app_command_contexts` for details.
.. versionadded:: 2.10
auto_sync: :class:`bool`
Whether to automatically register the command. Defaults to ``True``.
guild_ids: Sequence[:class:`int`]
If specified, the client will register the command in these guilds.
Otherwise, this command will be registered globally.
extras: Dict[:class:`str`, Any]
A dict of user provided extras to attach to the command.
.. note::
This object may be copied by the library.
.. versionadded:: 2.5
Returns
-------
Callable[..., :class:`InvokableMessageCommand`]
A decorator that converts the provided method into an InvokableMessageCommand and then returns it.
"""
def decorator(
func: InteractionCommandCallback[CogT, MessageCommandInteraction, P],
) -> InvokableMessageCommand:
if not iscoroutinefunction(func):
raise TypeError(f"<{func.__qualname__}> must be a coroutine function")
if hasattr(func, "__command_flag__"):
raise TypeError("Callback is already a command.")
if guild_ids and not all(isinstance(guild_id, int) for guild_id in guild_ids):
raise ValueError("guild_ids must be a sequence of int.")
return InvokableMessageCommand(
func,
name=name,
dm_permission=dm_permission,
default_member_permissions=default_member_permissions,
nsfw=nsfw,
install_types=install_types,
contexts=contexts,
guild_ids=guild_ids,
auto_sync=auto_sync,
extras=extras,
**kwargs,
)
return decorator

View File

@@ -0,0 +1,11 @@
# SPDX-License-Identifier: MIT
from disnake import DiscordWarning
__all__ = ("MessageContentPrefixWarning",)
class MessageContentPrefixWarning(DiscordWarning):
"""Warning for invalid prefixes without message content."""
pass

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,615 @@
# SPDX-License-Identifier: MIT
from __future__ import annotations
import inspect
import re
import sys
from dataclasses import dataclass, field
from typing import (
TYPE_CHECKING,
Any,
Dict,
Iterator,
List,
Literal,
Optional,
Pattern,
Set,
Tuple,
Union,
get_args,
get_origin,
)
from disnake.utils import MISSING, maybe_coroutine, resolve_annotation
from .converter import run_converters
from .errors import (
BadFlagArgument,
CommandError,
MissingFlagArgument,
MissingRequiredFlag,
TooManyFlags,
)
from .view import StringView
__all__ = (
"Flag",
"flag",
"FlagConverter",
)
if TYPE_CHECKING:
from typing_extensions import Self
from .context import Context
@dataclass
class Flag:
"""Represents a flag parameter for :class:`FlagConverter`.
The :func:`.flag` function helps
create these flag objects, but it is not necessary to
do so. These cannot be constructed manually.
Attributes
----------
name: :class:`str`
The name of the flag.
aliases: List[:class:`str`]
The aliases of the flag name.
attribute: :class:`str`
The attribute in the class that corresponds to this flag.
default: Any
The default value of the flag, if available.
annotation: Any
The underlying evaluated annotation of the flag.
max_args: :class:`int`
The maximum number of arguments the flag can accept.
A negative value indicates an unlimited amount of arguments.
override: :class:`bool`
Whether multiple given values overrides the previous value.
"""
name: str = MISSING
aliases: List[str] = field(default_factory=list)
attribute: str = MISSING
annotation: Any = MISSING
default: Any = MISSING
max_args: int = MISSING
override: bool = MISSING
cast_to_dict: bool = False
@property
def required(self) -> bool:
""":class:`bool`: Whether the flag is required.
A required flag has no default value.
"""
return self.default is MISSING
def flag(
*,
name: str = MISSING,
aliases: List[str] = MISSING,
default: Any = MISSING,
max_args: int = MISSING,
override: bool = MISSING,
) -> Any:
"""Override default functionality and parameters of the underlying :class:`FlagConverter`
class attributes.
Parameters
----------
name: :class:`str`
The flag name. If not given, defaults to the attribute name.
aliases: List[:class:`str`]
Aliases to the flag name. If not given no aliases are set.
default: Any
The default parameter. This could be either a value or a callable that takes
:class:`Context` as its sole parameter. If not given then it defaults to
the default value given to the attribute.
max_args: :class:`int`
The maximum number of arguments the flag can accept.
A negative value indicates an unlimited amount of arguments.
The default value depends on the annotation given.
override: :class:`bool`
Whether multiple given values overrides the previous value. The default
value depends on the annotation given.
"""
return Flag(name=name, aliases=aliases, default=default, max_args=max_args, override=override)
def validate_flag_name(name: str, forbidden: Set[str]) -> None:
if not name:
raise ValueError("flag names should not be empty")
for ch in name:
if ch.isspace():
raise ValueError(f"flag name {name!r} cannot have spaces")
if ch == "\\":
raise ValueError(f"flag name {name!r} cannot have backslashes")
if ch in forbidden:
raise ValueError(f"flag name {name!r} cannot have any of {forbidden!r} within them")
def get_flags(
namespace: Dict[str, Any], globals: Dict[str, Any], locals: Dict[str, Any]
) -> Dict[str, Flag]:
annotations = namespace.get("__annotations__", {})
case_insensitive = namespace["__commands_flag_case_insensitive__"]
flags: Dict[str, Flag] = {}
cache: Dict[str, Any] = {}
names: Set[str] = set()
for name, annotation in annotations.items():
flag = namespace.pop(name, MISSING)
if isinstance(flag, Flag):
flag.annotation = annotation
else:
flag = Flag(name=name, annotation=annotation, default=flag)
flag.attribute = name
if flag.name is MISSING:
flag.name = name
annotation = flag.annotation = resolve_annotation(flag.annotation, globals, locals, cache)
if (
flag.default is MISSING
and hasattr(annotation, "__commands_is_flag__")
and annotation._can_be_constructible()
):
flag.default = annotation._construct_default
if flag.aliases is MISSING:
flag.aliases = []
# Add sensible defaults based off of the type annotation
# <type> -> (max_args=1)
# List[str] -> (max_args=-1)
# Tuple[int, ...] -> (max_args=1)
# Dict[K, V] -> (max_args=-1, override=True)
# Union[str, int] -> (max_args=1)
# Optional[str] -> (default=None, max_args=1)
try:
origin = annotation.__origin__
except AttributeError:
# A regular type hint
if flag.max_args is MISSING:
flag.max_args = 1
else:
if origin is Union:
# typing.Union
if flag.max_args is MISSING:
flag.max_args = 1
if annotation.__args__[-1] is type(None) and flag.default is MISSING:
# typing.Optional
flag.default = None
elif origin is tuple:
# typing.Tuple
# tuple parsing is e.g. `flag: peter 20`
# for Tuple[str, int] would give you flag: ('peter', 20)
if flag.max_args is MISSING:
flag.max_args = 1
elif origin is list:
# typing.List
if flag.max_args is MISSING:
flag.max_args = -1
elif origin is dict:
# typing.Dict[K, V]
# Equivalent to:
# typing.List[typing.Tuple[K, V]]
flag.cast_to_dict = True
if flag.max_args is MISSING:
flag.max_args = -1
if flag.override is MISSING:
flag.override = True
elif origin is Literal:
if flag.max_args is MISSING:
flag.max_args = 1
else:
raise TypeError(
f"Unsupported typing annotation {annotation!r} for {flag.name!r} flag"
)
if flag.override is MISSING:
flag.override = False
# Validate flag names are unique
name = flag.name.casefold() if case_insensitive else flag.name
if name in names:
raise TypeError(f"{flag.name!r} flag conflicts with previous flag or alias.")
else:
names.add(name)
for alias in flag.aliases:
# Validate alias is unique
alias = alias.casefold() if case_insensitive else alias
if alias in names:
raise TypeError(
f"{flag.name!r} flag alias {alias!r} conflicts with previous flag or alias."
)
else:
names.add(alias)
flags[flag.name] = flag
return flags
class FlagsMeta(type):
if TYPE_CHECKING:
__commands_is_flag__: bool
__commands_flags__: Dict[str, Flag]
__commands_flag_aliases__: Dict[str, str]
__commands_flag_regex__: Pattern[str]
__commands_flag_case_insensitive__: bool
__commands_flag_delimiter__: str
__commands_flag_prefix__: str
def __new__(
cls,
name: str,
bases: Tuple[type, ...],
attrs: Dict[str, Any],
*,
case_insensitive: bool = MISSING,
delimiter: str = MISSING,
prefix: str = MISSING,
) -> Self:
attrs["__commands_is_flag__"] = True
try:
global_ns = sys.modules[attrs["__module__"]].__dict__
except KeyError:
global_ns = {}
frame = inspect.currentframe()
try:
if frame is None:
local_ns = {}
else:
if frame.f_back is None:
local_ns = frame.f_locals
else:
local_ns = frame.f_back.f_locals
finally:
del frame
flags: Dict[str, Flag] = {}
aliases: Dict[str, str] = {}
for base in reversed(bases):
if base.__dict__.get("__commands_is_flag__", False):
flags.update(base.__dict__["__commands_flags__"])
aliases.update(base.__dict__["__commands_flag_aliases__"])
if case_insensitive is MISSING:
attrs["__commands_flag_case_insensitive__"] = base.__dict__[
"__commands_flag_case_insensitive__"
]
if delimiter is MISSING:
attrs["__commands_flag_delimiter__"] = base.__dict__[
"__commands_flag_delimiter__"
]
if prefix is MISSING:
attrs["__commands_flag_prefix__"] = base.__dict__["__commands_flag_prefix__"]
if case_insensitive is not MISSING:
attrs["__commands_flag_case_insensitive__"] = case_insensitive
if delimiter is not MISSING:
attrs["__commands_flag_delimiter__"] = delimiter
if prefix is not MISSING:
attrs["__commands_flag_prefix__"] = prefix
case_insensitive = attrs.setdefault("__commands_flag_case_insensitive__", False)
delimiter = attrs.setdefault("__commands_flag_delimiter__", ":")
prefix = attrs.setdefault("__commands_flag_prefix__", "")
for flag_name, flag in get_flags(attrs, global_ns, local_ns).items():
flags[flag_name] = flag
aliases.update(dict.fromkeys(flag.aliases, flag_name))
forbidden = set(delimiter).union(prefix)
for flag_name in flags:
validate_flag_name(flag_name, forbidden)
for alias_name in aliases:
validate_flag_name(alias_name, forbidden)
regex_flags = 0
if case_insensitive:
flags = {key.casefold(): value for key, value in flags.items()}
aliases = {key.casefold(): value.casefold() for key, value in aliases.items()}
regex_flags = re.IGNORECASE
keys = [re.escape(k) for k in flags]
keys.extend(re.escape(a) for a in aliases)
keys = sorted(keys, key=lambda t: len(t), reverse=True)
joined = "|".join(keys)
pattern = re.compile(
f"(({re.escape(prefix)})(?P<flag>{joined}){re.escape(delimiter)})", regex_flags
)
attrs["__commands_flag_regex__"] = pattern
attrs["__commands_flags__"] = flags
attrs["__commands_flag_aliases__"] = aliases
return type.__new__(cls, name, bases, attrs)
async def tuple_convert_all(
ctx: Context, argument: str, flag: Flag, converter: Any
) -> Tuple[Any, ...]:
view = StringView(argument)
results = []
param: inspect.Parameter = ctx.current_parameter # type: ignore
while not view.eof:
view.skip_ws()
if view.eof:
break
word = view.get_quoted_word()
if word is None:
break
try:
converted = await run_converters(ctx, converter, word, param)
except CommandError:
raise
except Exception as e:
raise BadFlagArgument(flag) from e
else:
results.append(converted)
return tuple(results)
async def tuple_convert_flag(
ctx: Context, argument: str, flag: Flag, converters: Any
) -> Tuple[Any, ...]:
view = StringView(argument)
results = []
param: inspect.Parameter = ctx.current_parameter # type: ignore
for converter in converters:
view.skip_ws()
if view.eof:
break
word = view.get_quoted_word()
if word is None:
break
try:
converted = await run_converters(ctx, converter, word, param)
except CommandError:
raise
except Exception as e:
raise BadFlagArgument(flag) from e
else:
results.append(converted)
if len(results) != len(converters):
raise BadFlagArgument(flag)
return tuple(results)
async def convert_flag(ctx: Context, argument: str, flag: Flag, annotation: Any = None) -> Any:
param: inspect.Parameter = ctx.current_parameter # type: ignore
annotation = annotation or flag.annotation
if origin := get_origin(annotation):
args = get_args(annotation)
if origin is tuple:
if args[-1] is Ellipsis:
return await tuple_convert_all(ctx, argument, flag, args[0])
else:
return await tuple_convert_flag(ctx, argument, flag, args)
elif origin is list:
# typing.List[x]
annotation = args[0]
return await convert_flag(ctx, argument, flag, annotation)
elif origin is Union and args[-1] is type(None):
# typing.Optional[x]
annotation = Union[args[:-1]] # type: ignore
return await run_converters(ctx, annotation, argument, param)
elif origin is dict:
# typing.Dict[K, V] -> typing.Tuple[K, V]
return await tuple_convert_flag(ctx, argument, flag, args)
try:
return await run_converters(ctx, annotation, argument, param)
except CommandError:
raise
except Exception as e:
raise BadFlagArgument(flag) from e
class FlagConverter(metaclass=FlagsMeta):
"""A converter that allows for a user-friendly flag syntax.
The flags are defined using :pep:`526` type annotations similar
to the :mod:`dataclasses` Python module. For more information on
how this converter works, check the appropriate
:ref:`documentation <ext_commands_flag_converter>`.
.. collapse:: operations
.. describe:: iter(x)
Returns an iterator of ``(flag_name, flag_value)`` pairs. This allows it
to be, for example, constructed as a dict or a list of pairs.
Note that aliases are not shown.
.. versionadded:: 2.0
Parameters
----------
case_insensitive: :class:`bool`
A class parameter to toggle case insensitivity of the flag parsing.
If ``True`` then flags are parsed in a case insensitive manner.
Defaults to ``False``.
prefix: :class:`str`
The prefix that all flags must be prefixed with. By default
there is no prefix.
delimiter: :class:`str`
The delimiter that separates a flag's argument from the flag's name.
By default this is ``:``.
"""
@classmethod
def get_flags(cls) -> Dict[str, Flag]:
"""Dict[:class:`str`, :class:`Flag`]: A mapping of flag name to flag object this converter has."""
return cls.__commands_flags__.copy()
@classmethod
def _can_be_constructible(cls) -> bool:
return all(not flag.required for flag in cls.__commands_flags__.values())
def __iter__(self) -> Iterator[Tuple[str, Any]]:
for flag in self.__class__.__commands_flags__.values():
yield (flag.name, getattr(self, flag.attribute))
@classmethod
async def _construct_default(cls, ctx: Context) -> Self:
self = cls.__new__(cls)
flags = cls.__commands_flags__
for flag in flags.values():
if callable(flag.default):
default = await maybe_coroutine(flag.default, ctx)
setattr(self, flag.attribute, default)
else:
setattr(self, flag.attribute, flag.default)
return self
def __repr__(self) -> str:
pairs = " ".join(
[
f"{flag.attribute}={getattr(self, flag.attribute)!r}"
for flag in self.get_flags().values()
]
)
return f"<{self.__class__.__name__} {pairs}>"
@classmethod
def parse_flags(cls, argument: str) -> Dict[str, List[str]]:
result: Dict[str, List[str]] = {}
flags = cls.__commands_flags__
aliases = cls.__commands_flag_aliases__
last_position = 0
last_flag: Optional[Flag] = None
case_insensitive = cls.__commands_flag_case_insensitive__
for match in cls.__commands_flag_regex__.finditer(argument):
begin, end = match.span(0)
key = match.group("flag")
if case_insensitive:
key = key.casefold()
if key in aliases:
key = aliases[key]
flag = flags.get(key)
if last_position and last_flag is not None:
value = argument[last_position : begin - 1].lstrip()
if not value:
raise MissingFlagArgument(last_flag)
try:
values = result[last_flag.name]
except KeyError:
result[last_flag.name] = [value]
else:
values.append(value)
last_position = end
last_flag = flag
# Add the remaining string to the last available flag
if last_position and last_flag is not None:
value = argument[last_position:].strip()
if not value:
raise MissingFlagArgument(last_flag)
try:
values = result[last_flag.name]
except KeyError:
result[last_flag.name] = [value]
else:
values.append(value)
# Verification of values will come at a later stage
return result
@classmethod
async def convert(cls, ctx: Context, argument: str) -> Self:
"""|coro|
The method that actually converters an argument to the flag mapping.
Parameters
----------
cls: Type[:class:`FlagConverter`]
The flag converter class.
ctx: :class:`Context`
The invocation context.
argument: :class:`str`
The argument to convert from.
Raises
------
FlagError
A flag related parsing error.
CommandError
A command related error.
Returns
-------
:class:`FlagConverter`
The flag converter instance with all flags parsed.
"""
arguments = cls.parse_flags(argument)
flags = cls.__commands_flags__
self = cls.__new__(cls)
for name, flag in flags.items():
try:
values = arguments[name]
except KeyError:
if flag.required:
raise MissingRequiredFlag(flag) from None
else:
if callable(flag.default):
default = await maybe_coroutine(flag.default, ctx)
setattr(self, flag.attribute, default)
else:
setattr(self, flag.attribute, flag.default)
continue
if flag.max_args > 0 and len(values) > flag.max_args:
if flag.override:
values = values[-flag.max_args :]
else:
raise TooManyFlags(flag, values)
# Special case:
if flag.max_args == 1:
value = await convert_flag(ctx, values[0], flag)
setattr(self, flag.attribute, value)
continue
# Another special case, tuple parsing.
# Tuple parsing is basically converting arguments within the flag
# So, given flag: hello 20 as the input and Tuple[str, int] as the type hint
# We would receive ('hello', 20) as the resulting value
# This uses the same whitespace and quoting rules as regular parameters.
values = [await convert_flag(ctx, value, flag) for value in values]
if flag.cast_to_dict:
values = dict(values)
setattr(self, flag.attribute, values)
return self

View File

@@ -0,0 +1,181 @@
# SPDX-License-Identifier: MIT
from __future__ import annotations
from typing import TYPE_CHECKING, NoReturn, overload
from disnake.flags import BaseFlags, alias_flag_value, all_flags_value, flag_value
from disnake.utils import _generated
if TYPE_CHECKING:
from typing_extensions import Self
__all__ = ("CommandSyncFlags",)
class CommandSyncFlags(BaseFlags):
"""Controls the library's application command syncing policy.
This allows for finer grained control over what commands are synced automatically and in what cases.
To construct an object you can pass keyword arguments denoting the flags
to enable or disable.
If command sync is disabled (see the docs of :attr:`sync_commands` for more info), other options will have no effect.
.. versionadded:: 2.7
.. collapse:: operations
.. describe:: x == y
Checks if two CommandSyncFlags instances are equal.
.. describe:: x != y
Checks if two CommandSyncFlags instances are not equal.
.. describe:: x <= y
Checks if a CommandSyncFlags instance is a subset of another CommandSyncFlags instance.
.. describe:: x >= y
Checks if a CommandSyncFlags instance is a superset of another CommandSyncFlags instance.
.. describe:: x < y
Checks if a CommandSyncFlags instance is a strict subset of another CommandSyncFlags instance.
.. describe:: x > y
Checks if a CommandSyncFlags instance is a strict superset of another CommandSyncFlags instance.
.. describe:: x | y, x |= y
Returns a new CommandSyncFlags instance with all enabled flags from both x and y.
(Using ``|=`` will update in place).
.. describe:: x & y, x &= y
Returns a new CommandSyncFlags instance with only flags enabled on both x and y.
(Using ``&=`` will update in place).
.. describe:: x ^ y, x ^= y
Returns a new CommandSyncFlags instance with only flags enabled on one of x or y, but not both.
(Using ``^=`` will update in place).
.. describe:: ~x
Returns a new CommandSyncFlags instance with all flags from x inverted.
.. describe:: hash(x)
Return the flag's hash.
.. describe:: iter(x)
Returns an iterator of ``(name, value)`` pairs. This allows it
to be, for example, constructed as a dict or a list of pairs.
Note that aliases are not shown.
Additionally supported are a few operations on class attributes.
.. describe:: CommandSyncFlags.y | CommandSyncFlags.z, CommandSyncFlags(y=True) | CommandSyncFlags.z
Returns a CommandSyncFlags instance with all provided flags enabled.
.. describe:: ~CommandSyncFlags.y
Returns a CommandSyncFlags instance with all flags except ``y`` inverted from their default value.
Attributes
----------
value: :class:`int`
The raw value. You should query flags via the properties
rather than using this raw value.
"""
__slots__ = ()
@overload
@_generated
def __init__(
self,
*,
allow_command_deletion: bool = ...,
sync_commands: bool = ...,
sync_commands_debug: bool = ...,
sync_global_commands: bool = ...,
sync_guild_commands: bool = ...,
sync_on_cog_actions: bool = ...,
) -> None: ...
@overload
@_generated
def __init__(self: NoReturn) -> None: ...
def __init__(self, **kwargs: bool) -> None:
self.value = all_flags_value(self.VALID_FLAGS)
for key, value in kwargs.items():
if key not in self.VALID_FLAGS:
raise TypeError(f"{key!r} is not a valid flag name.")
setattr(self, key, value)
@classmethod
def all(cls) -> Self:
"""A factory method that creates a :class:`CommandSyncFlags` with everything enabled."""
self = cls.__new__(cls)
self.value = all_flags_value(cls.VALID_FLAGS)
return self
@classmethod
def none(cls) -> Self:
"""A factory method that creates a :class:`CommandSyncFlags` with everything disabled."""
self = cls.__new__(cls)
self.value = self.DEFAULT_VALUE
return self
@classmethod
def default(cls) -> Self:
"""A factory method that creates a :class:`CommandSyncFlags` with the default settings.
The default is all flags enabled except for :attr:`sync_commands_debug`.
"""
self = cls.all()
self.sync_commands_debug = False
return self
@property
def _sync_enabled(self) -> bool:
return self.sync_global_commands or self.sync_guild_commands
@alias_flag_value
def sync_commands(self) -> int:
""":class:`bool`: Whether to sync global and guild app commands.
This controls the :attr:`sync_global_commands` and :attr:`sync_guild_commands` attributes.
Note that it is possible for sync to be enabled for guild *or* global commands yet this will return ``False``.
"""
return 1 << 3 | 1 << 4
@flag_value
def sync_commands_debug(self) -> int:
""":class:`bool`: Whether or not to show app command sync debug messages."""
return 1 << 0
@flag_value
def sync_on_cog_actions(self) -> int:
""":class:`bool`: Whether or not to sync app commands on cog load, unload, or reload."""
return 1 << 1
@flag_value
def allow_command_deletion(self) -> int:
""":class:`bool`: Whether to allow commands to be deleted by automatic command sync.
Current implementation of commands sync of renamed commands means that a rename of a command *will* result
in the old one being deleted and a new command being created.
"""
return 1 << 2
@flag_value
def sync_global_commands(self) -> int:
""":class:`bool`: Whether to sync global commands."""
return 1 << 3
@flag_value
def sync_guild_commands(self) -> int:
""":class:`bool`: Whether to sync per-guild commands."""
return 1 << 4

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,896 @@
# SPDX-License-Identifier: MIT
from __future__ import annotations
import asyncio
import inspect
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Optional,
Sequence,
Tuple,
TypeVar,
Union,
)
from disnake import utils
from disnake.app_commands import Option, SlashCommand
from disnake.enums import OptionType
from disnake.flags import ApplicationInstallTypes, InteractionContextTypes
from disnake.i18n import Localized
from disnake.interactions import ApplicationCommandInteraction
from disnake.permissions import Permissions
from .base_core import InvokableApplicationCommand, _get_overridden_method
from .errors import CommandError, CommandInvokeError
from .params import call_param_func, classify_autocompleter, expand_params
if TYPE_CHECKING:
from disnake.app_commands import Choices
from disnake.i18n import LocalizedOptional
from .base_core import CommandCallback
MISSING = utils.MISSING
__all__ = ("InvokableSlashCommand", "SubCommandGroup", "SubCommand", "slash_command")
SlashCommandT = TypeVar("SlashCommandT", bound="InvokableSlashCommand")
def _autocomplete(
self: Union[SubCommand, InvokableSlashCommand], option_name: str
) -> Callable[[Callable], Callable]:
for option in self.body.options:
if option.name == option_name:
option.autocomplete = True
break
else: # nobreak
raise ValueError(f"Option '{option_name}' doesn't exist in '{self.qualified_name}'")
def decorator(func: Callable) -> Callable:
classify_autocompleter(func)
self.autocompleters[option_name] = func
return func
return decorator
async def _call_autocompleter(
self: Union[InvokableSlashCommand, SubCommand],
param: str,
inter: ApplicationCommandInteraction,
user_input: str,
) -> Optional[Choices]:
autocomp = self.autocompleters.get(param)
if autocomp is None:
return None
if not callable(autocomp):
return autocomp
try:
requires_cog_param = autocomp.__has_cog_param__
except AttributeError:
requires_cog_param = False
cog = self.root_parent.cog if isinstance(self, SubCommand) else self.cog
filled = inter.filled_options
del filled[inter.data.focused_option.name]
try:
if requires_cog_param:
choices = autocomp(cog, inter, user_input, **filled)
else:
choices = autocomp(inter, user_input, **filled)
except TypeError:
if requires_cog_param:
choices = autocomp(cog, inter, user_input)
else:
choices = autocomp(inter, user_input)
if inspect.isawaitable(choices):
return await choices
return choices
_INVALID_SUB_KWARGS = frozenset(
{"dm_permission", "default_member_permissions", "install_types", "contexts"}
)
# this is just a helpful message for users trying to set specific
# top-level-only fields on subcommands or groups
def _check_invalid_sub_kwargs(func: CommandCallback, kwargs: Dict[str, Any]) -> None:
invalid_keys = kwargs.keys() & _INVALID_SUB_KWARGS
for decorator_key in [
"__default_member_permissions__",
"__install_types__",
"__contexts__",
]:
if hasattr(func, decorator_key):
invalid_keys.add(decorator_key.strip("_"))
if invalid_keys:
msg = f"Cannot set {utils.humanize_list(list(invalid_keys), 'or')} on subcommands or subcommand groups"
raise TypeError(msg)
class SubCommandGroup(InvokableApplicationCommand):
"""A class that implements the protocol for a bot slash command group.
These are not created manually, instead they are created via the
decorator or functional interface.
Attributes
----------
name: :class:`str`
The name of the group.
qualified_name: :class:`str`
The full command name, including parent names in the case of slash subcommands or groups.
For example, the qualified name for ``/one two three`` would be ``one two three``.
parent: :class:`InvokableSlashCommand`
The parent command this group belongs to.
.. versionadded:: 2.6
option: :class:`.Option`
API representation of this subcommand.
callback: :ref:`coroutine <coroutine>`
The coroutine that is executed when the command group is invoked.
cog: Optional[:class:`Cog`]
The cog that this group belongs to. ``None`` if there isn't one.
checks: List[Callable[[:class:`.ApplicationCommandInteraction`], :class:`bool`]]
A list of predicates that verifies if the group could be executed
with the given :class:`.ApplicationCommandInteraction` as the sole parameter. If an exception
is necessary to be thrown to signal failure, then one inherited from
:exc:`.CommandError` should be used. Note that if the checks fail then
:exc:`.CheckFailure` exception is raised to the :func:`.on_slash_command_error`
event.
extras: Dict[:class:`str`, Any]
A dict of user provided extras to attach to the subcommand group.
.. note::
This object may be copied by the library.
.. versionadded:: 2.5
"""
def __init__(
self,
func: CommandCallback,
parent: InvokableSlashCommand,
*,
name: LocalizedOptional = None,
**kwargs: Any,
) -> None:
name_loc = Localized._cast(name, False)
super().__init__(func, name=name_loc.string, **kwargs)
self.parent: InvokableSlashCommand = parent
self.children: Dict[str, SubCommand] = {}
# while subcommand groups don't have a description, parse the docstring regardless to
# retrieve the localization key, if any
docstring = utils.parse_docstring(func)
self.option = Option(
name=name_loc._upgrade(self.name, key=docstring["localization_key_name"]),
description="-",
type=OptionType.sub_command_group,
options=[],
)
self.qualified_name: str = f"{parent.qualified_name} {self.name}"
_check_invalid_sub_kwargs(func, kwargs)
@property
def root_parent(self) -> InvokableSlashCommand:
""":class:`InvokableSlashCommand`: Returns the slash command containing this group.
This is mainly for consistency with :class:`SubCommand`, and is equivalent to :attr:`parent`.
.. versionadded:: 2.6
"""
return self.parent
@property
def parents(self) -> Tuple[InvokableSlashCommand]:
"""Tuple[:class:`InvokableSlashCommand`]: Returns all parents of this group.
.. versionadded:: 2.6
"""
return (self.parent,)
@property
def body(self) -> Option:
return self.option
def sub_command(
self,
name: LocalizedOptional = None,
description: LocalizedOptional = None,
options: Optional[list] = None,
connectors: Optional[dict] = None,
extras: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Callable[[CommandCallback], SubCommand]:
"""A decorator that creates a subcommand in the subcommand group.
Parameters are the same as in :class:`InvokableSlashCommand.sub_command`
Returns
-------
Callable[..., :class:`SubCommand`]
A decorator that converts the provided method into a SubCommand, adds it to the bot, then returns it.
"""
def decorator(func: CommandCallback) -> SubCommand:
new_func = SubCommand(
func,
self,
name=name,
description=description,
options=options,
connectors=connectors,
extras=extras,
**kwargs,
)
self.children[new_func.name] = new_func
self.option.options.append(new_func.option)
return new_func
return decorator
class SubCommand(InvokableApplicationCommand):
"""A class that implements the protocol for a bot slash subcommand.
These are not created manually, instead they are created via the
decorator or functional interface.
Attributes
----------
name: :class:`str`
The name of the subcommand.
qualified_name: :class:`str`
The full command name, including parent names in the case of slash subcommands or groups.
For example, the qualified name for ``/one two three`` would be ``one two three``.
parent: Union[:class:`InvokableSlashCommand`, :class:`SubCommandGroup`]
The parent command or group this subcommand belongs to.
.. versionadded:: 2.6
option: :class:`.Option`
API representation of this subcommand.
callback: :ref:`coroutine <coroutine>`
The coroutine that is executed when the subcommand is called.
cog: Optional[:class:`Cog`]
The cog that this subcommand belongs to. ``None`` if there isn't one.
checks: List[Callable[[:class:`.ApplicationCommandInteraction`], :class:`bool`]]
A list of predicates that verifies if the subcommand could be executed
with the given :class:`.ApplicationCommandInteraction` as the sole parameter. If an exception
is necessary to be thrown to signal failure, then one inherited from
:exc:`.CommandError` should be used. Note that if the checks fail then
:exc:`.CheckFailure` exception is raised to the :func:`.on_slash_command_error`
event.
connectors: Dict[:class:`str`, :class:`str`]
A mapping of option names to function parameter names, mainly for internal processes.
extras: Dict[:class:`str`, Any]
A dict of user provided extras to attach to the subcommand.
.. note::
This object may be copied by the library.
.. versionadded:: 2.5
"""
def __init__(
self,
func: CommandCallback,
parent: Union[InvokableSlashCommand, SubCommandGroup],
*,
name: LocalizedOptional = None,
description: LocalizedOptional = None,
options: Optional[list] = None,
connectors: Optional[Dict[str, str]] = None,
**kwargs: Any,
) -> None:
name_loc = Localized._cast(name, False)
super().__init__(func, name=name_loc.string, **kwargs)
self.parent: Union[InvokableSlashCommand, SubCommandGroup] = parent
self.connectors: Dict[str, str] = connectors or {}
self.autocompleters: Dict[str, Any] = kwargs.get("autocompleters", {})
if options is None:
options = expand_params(self)
self.docstring = utils.parse_docstring(func)
desc_loc = Localized._cast(description, False)
self.option = Option(
name=name_loc._upgrade(self.name, key=self.docstring["localization_key_name"]),
description=desc_loc._upgrade(
self.docstring["description"] or "-", key=self.docstring["localization_key_desc"]
),
type=OptionType.sub_command,
options=options,
)
self.qualified_name = f"{parent.qualified_name} {self.name}"
_check_invalid_sub_kwargs(func, kwargs)
@property
def root_parent(self) -> InvokableSlashCommand:
""":class:`InvokableSlashCommand`: Returns the top-level slash command containing this subcommand,
even if the parent is a :class:`SubCommandGroup`.
.. versionadded:: 2.6
"""
return self.parent.parent if isinstance(self.parent, SubCommandGroup) else self.parent
@property
def parents(
self,
) -> Union[Tuple[InvokableSlashCommand], Tuple[SubCommandGroup, InvokableSlashCommand]]:
"""Union[Tuple[:class:`InvokableSlashCommand`], Tuple[:class:`SubCommandGroup`, :class:`InvokableSlashCommand`]]:
Returns all parents of this subcommand.
For example, the parents of the ``c`` subcommand in ``/a b c`` are ``(b, a)``.
.. versionadded:: 2.6
"""
# here I'm not using 'self.parent.parents + (self.parent,)' because it causes typing issues
if isinstance(self.parent, SubCommandGroup):
return (self.parent, self.parent.parent)
return (self.parent,)
@property
def description(self) -> str:
""":class:`str`: The slash sub command's description. Shorthand for :attr:`self.body.description <.Option.description>`."""
return self.body.description
@property
def body(self) -> Option:
""":class:`.Option`: The API representation for this slash sub command. Shorthand for :attr:`.SubCommand.option`"""
return self.option
async def _call_autocompleter(
self, param: str, inter: ApplicationCommandInteraction, user_input: str
) -> Optional[Choices]:
return await _call_autocompleter(self, param, inter, user_input)
async def invoke(self, inter: ApplicationCommandInteraction, *args: Any, **kwargs: Any) -> None:
for k, v in self.connectors.items():
if k in kwargs:
kwargs[v] = kwargs.pop(k)
await self.prepare(inter)
try:
await call_param_func(self.callback, inter, self.cog, **kwargs)
except CommandError:
inter.command_failed = True
raise
except asyncio.CancelledError:
inter.command_failed = True
return
except Exception as exc:
inter.command_failed = True
raise CommandInvokeError(exc) from exc
finally:
if self._max_concurrency is not None:
await self._max_concurrency.release(inter) # type: ignore
await self.call_after_hooks(inter)
def autocomplete(self, option_name: str) -> Callable[[Callable], Callable]:
"""A decorator that registers an autocomplete function for the specified option.
Parameters
----------
option_name: :class:`str`
The name of the slash command option.
"""
return _autocomplete(self, option_name)
class InvokableSlashCommand(InvokableApplicationCommand):
"""A class that implements the protocol for a bot slash command.
These are not created manually, instead they are created via the
decorator or functional interface.
Attributes
----------
name: :class:`str`
The name of the command.
qualified_name: :class:`str`
The full command name, including parent names in the case of slash subcommands or groups.
For example, the qualified name for ``/one two three`` would be ``one two three``.
body: :class:`.SlashCommand`
An object being registered in the API.
callback: :ref:`coroutine <coroutine>`
The coroutine that is executed when the command is called.
cog: Optional[:class:`Cog`]
The cog that this command belongs to. ``None`` if there isn't one.
checks: List[Callable[[:class:`.ApplicationCommandInteraction`], :class:`bool`]]
A list of predicates that verifies if the command could be executed
with the given :class:`.ApplicationCommandInteraction` as the sole parameter. If an exception
is necessary to be thrown to signal failure, then one inherited from
:exc:`.CommandError` should be used. Note that if the checks fail then
:exc:`.CheckFailure` exception is raised to the :func:`.on_slash_command_error`
event.
guild_ids: Optional[Tuple[:class:`int`, ...]]
The list of IDs of the guilds where the command is synced. ``None`` if this command is global.
connectors: Dict[:class:`str`, :class:`str`]
A mapping of option names to function parameter names, mainly for internal processes.
auto_sync: :class:`bool`
Whether to automatically register the command.
extras: Dict[:class:`str`, Any]
A dict of user provided extras to attach to the command.
.. note::
This object may be copied by the library.
.. versionadded:: 2.5
parent: ``None``
This exists for consistency with :class:`SubCommand` and :class:`SubCommandGroup`. Always ``None``.
.. versionadded:: 2.6
"""
def __init__(
self,
func: CommandCallback,
*,
name: LocalizedOptional = None,
description: LocalizedOptional = None,
options: Optional[List[Option]] = None,
dm_permission: Optional[bool] = None, # deprecated
default_member_permissions: Optional[Union[Permissions, int]] = None,
nsfw: Optional[bool] = None,
install_types: Optional[ApplicationInstallTypes] = None,
contexts: Optional[InteractionContextTypes] = None,
guild_ids: Optional[Sequence[int]] = None,
connectors: Optional[Dict[str, str]] = None,
auto_sync: Optional[bool] = None,
**kwargs: Any,
) -> None:
name_loc = Localized._cast(name, False)
super().__init__(func, name=name_loc.string, **kwargs)
self.parent = None
self.connectors: Dict[str, str] = connectors or {}
self.children: Dict[str, Union[SubCommand, SubCommandGroup]] = {}
self.auto_sync: bool = True if auto_sync is None else auto_sync
self.guild_ids: Optional[Tuple[int, ...]] = None if guild_ids is None else tuple(guild_ids)
self.autocompleters: Dict[str, Any] = kwargs.get("autocompleters", {})
if options is None:
options = expand_params(self)
self.docstring = utils.parse_docstring(func)
desc_loc = Localized._cast(description, False)
try:
default_member_permissions = func.__default_member_permissions__
except AttributeError:
pass
try:
install_types = func.__install_types__
except AttributeError:
pass
try:
contexts = func.__contexts__
except AttributeError:
pass
self.body: SlashCommand = SlashCommand(
name=name_loc._upgrade(self.name, key=self.docstring["localization_key_name"]),
description=desc_loc._upgrade(
self.docstring["description"] or "-", key=self.docstring["localization_key_desc"]
),
options=options or [],
dm_permission=dm_permission,
default_member_permissions=default_member_permissions,
nsfw=nsfw,
install_types=install_types,
contexts=contexts,
)
self._apply_guild_only()
@property
def root_parent(self) -> None:
"""``None``: This is for consistency with :class:`SubCommand` and :class:`SubCommandGroup`.
.. versionadded:: 2.6
"""
return None
@property
def parents(self) -> Tuple[()]:
"""Tuple[()]: This is mainly for consistency with :class:`SubCommand`, and is equivalent to an empty tuple.
.. versionadded:: 2.6
"""
return ()
def _ensure_assignment_on_copy(self, other: SlashCommandT) -> SlashCommandT:
super()._ensure_assignment_on_copy(other)
if self.connectors != other.connectors:
other.connectors = self.connectors.copy()
if self.autocompleters != other.autocompleters:
other.autocompleters = self.autocompleters.copy()
if self.children != other.children:
other.children = self.children.copy()
# update parents...
for child in other.children.values():
child.parent = other
if self.description != other.description and "description" not in other.__original_kwargs__:
# Allows overriding the default description cog-wide.
other.body.description = self.description
if self.options != other.options:
other.body.options = self.options
return other
@property
def description(self) -> str:
""":class:`str`: The slash command's description. Shorthand for :attr:`self.body.description <.SlashCommand.description>`."""
return self.body.description
@property
def options(self) -> List[Option]:
"""List[:class:`.Option`]: The list of options the slash command has. Shorthand for :attr:`self.body.options <.SlashCommand.options>`."""
return self.body.options
def sub_command(
self,
name: LocalizedOptional = None,
description: LocalizedOptional = None,
options: Optional[list] = None,
connectors: Optional[dict] = None,
extras: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Callable[[CommandCallback], SubCommand]:
"""A decorator that creates a subcommand under the base command.
Parameters
----------
name: Optional[Union[:class:`str`, :class:`.Localized`]]
The name of the subcommand (defaults to function name).
.. versionchanged:: 2.5
Added support for localizations.
description: Optional[Union[:class:`str`, :class:`.Localized`]]
The description of the subcommand.
.. versionchanged:: 2.5
Added support for localizations.
options: List[:class:`.Option`]
the options of the subcommand for registration in API
connectors: Dict[:class:`str`, :class:`str`]
which function param states for each option. If the name
of an option already matches the corresponding function param,
you don't have to specify the connectors. Connectors template:
``{"option-name": "param_name", ...}``
extras: Dict[:class:`str`, Any]
A dict of user provided extras to attach to the subcommand.
.. note::
This object may be copied by the library.
.. versionadded:: 2.5
Returns
-------
Callable[..., :class:`SubCommand`]
A decorator that converts the provided method into a :class:`SubCommand`, adds it to the bot, then returns it.
"""
def decorator(func: CommandCallback) -> SubCommand:
if len(self.children) == 0 and len(self.body.options) > 0:
self.body.options = []
new_func = SubCommand(
func,
self,
name=name,
description=description,
options=options,
connectors=connectors,
extras=extras,
**kwargs,
)
self.children[new_func.name] = new_func
self.body.options.append(new_func.option)
return new_func
return decorator
def sub_command_group(
self,
name: LocalizedOptional = None,
extras: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Callable[[CommandCallback], SubCommandGroup]:
"""A decorator that creates a subcommand group under the base command.
Parameters
----------
name: Optional[Union[:class:`str`, :class:`.Localized`]]
The name of the subcommand group (defaults to function name).
.. versionchanged:: 2.5
Added support for localizations.
extras: Dict[:class:`str`, Any]
A dict of user provided extras to attach to the subcommand group.
.. note::
This object may be copied by the library.
.. versionadded:: 2.5
Returns
-------
Callable[..., :class:`SubCommandGroup`]
A decorator that converts the provided method into a :class:`SubCommandGroup`, adds it to the bot, then returns it.
"""
def decorator(func: CommandCallback) -> SubCommandGroup:
if len(self.children) == 0 and len(self.body.options) > 0:
self.body.options = []
new_func = SubCommandGroup(
func,
self,
name=name,
extras=extras,
**kwargs,
)
self.children[new_func.name] = new_func
self.body.options.append(new_func.option)
return new_func
return decorator
def autocomplete(self, option_name: str) -> Callable[[Callable], Callable]:
"""A decorator that registers an autocomplete function for the specified option.
Parameters
----------
option_name: :class:`str`
The name of the slash command option.
"""
return _autocomplete(self, option_name)
async def _call_external_error_handlers(
self, inter: ApplicationCommandInteraction, error: CommandError
) -> None:
stop_propagation = False
cog = self.cog
try:
if cog is not None:
local = _get_overridden_method(cog.cog_slash_command_error)
if local is not None:
stop_propagation = await local(inter, error)
# User has an option to cancel the global error handler by returning True
finally:
if not stop_propagation:
inter.bot.dispatch("slash_command_error", inter, error)
async def _call_autocompleter(
self, param: str, inter: ApplicationCommandInteraction, user_input: str
) -> Optional[Choices]:
return await _call_autocompleter(self, param, inter, user_input)
async def _call_relevant_autocompleter(self, inter: ApplicationCommandInteraction) -> None:
chain, _ = inter.data._get_chain_and_kwargs()
if len(chain) == 0:
subcmd = None
elif len(chain) == 1:
subcmd = self.children.get(chain[0])
elif len(chain) == 2:
group = self.children.get(chain[0])
if not isinstance(group, SubCommandGroup):
raise AssertionError("the first subcommand is not a SubCommandGroup instance")
subcmd = group.children.get(chain[1])
else:
raise ValueError("Command chain is too long")
focused_option = inter.data.focused_option
if subcmd is None or isinstance(subcmd, SubCommandGroup):
call_autocompleter = self._call_autocompleter
else:
call_autocompleter = subcmd._call_autocompleter
choices = await call_autocompleter(focused_option.name, inter, focused_option.value)
if choices is not None:
await inter.response.autocomplete(choices=choices)
async def invoke_children(self, inter: ApplicationCommandInteraction) -> None:
chain, kwargs = inter.data._get_chain_and_kwargs()
if len(chain) == 0:
group = None
subcmd = None
elif len(chain) == 1:
group = None
subcmd = self.children.get(chain[0])
elif len(chain) == 2:
group = self.children.get(chain[0])
if not isinstance(group, SubCommandGroup):
raise AssertionError("the first subcommand is not a SubCommandGroup instance")
subcmd = group.children.get(chain[1])
else:
raise ValueError("Command chain is too long")
if group is not None:
try:
await group.invoke(inter)
except CommandError as exc:
if not await group._call_local_error_handler(inter, exc):
raise
if subcmd is not None:
try:
await subcmd.invoke(inter, **kwargs)
except CommandError as exc:
if not await subcmd._call_local_error_handler(inter, exc):
raise
async def invoke(self, inter: ApplicationCommandInteraction) -> None:
await self.prepare(inter)
try:
if len(self.children) > 0:
await self(inter)
await self.invoke_children(inter)
else:
kwargs = inter.filled_options
for k, v in self.connectors.items():
if k in kwargs:
kwargs[v] = kwargs.pop(k)
await call_param_func(self.callback, inter, self.cog, **kwargs)
except CommandError:
inter.command_failed = True
raise
except asyncio.CancelledError:
inter.command_failed = True
return
except Exception as exc:
inter.command_failed = True
raise CommandInvokeError(exc) from exc
finally:
if self._max_concurrency is not None:
await self._max_concurrency.release(inter) # type: ignore
await self.call_after_hooks(inter)
def slash_command(
*,
name: LocalizedOptional = None,
description: LocalizedOptional = None,
dm_permission: Optional[bool] = None, # deprecated
default_member_permissions: Optional[Union[Permissions, int]] = None,
nsfw: Optional[bool] = None,
install_types: Optional[ApplicationInstallTypes] = None,
contexts: Optional[InteractionContextTypes] = None,
options: Optional[List[Option]] = None,
guild_ids: Optional[Sequence[int]] = None,
connectors: Optional[Dict[str, str]] = None,
auto_sync: Optional[bool] = None,
extras: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Callable[[CommandCallback], InvokableSlashCommand]:
"""A decorator that builds a slash command.
Parameters
----------
auto_sync: :class:`bool`
Whether to automatically register the command. Defaults to ``True``.
name: Optional[Union[:class:`str`, :class:`.Localized`]]
The name of the slash command (defaults to function name).
.. versionchanged:: 2.5
Added support for localizations.
description: Optional[Union[:class:`str`, :class:`.Localized`]]
The description of the slash command. It will be visible in Discord.
.. versionchanged:: 2.5
Added support for localizations.
nsfw: :class:`bool`
Whether this command is :ddocs:`age-restricted <interactions/application-commands#agerestricted-commands>`.
Defaults to ``False``.
.. versionadded:: 2.8
install_types: Optional[:class:`.ApplicationInstallTypes`]
The installation types where the command is available.
Defaults to :attr:`.ApplicationInstallTypes.guild` only.
Only available for global commands.
See :ref:`app_command_contexts` for details.
.. versionadded:: 2.10
contexts: Optional[:class:`.InteractionContextTypes`]
The interaction contexts where the command can be used.
Only available for global commands.
See :ref:`app_command_contexts` for details.
.. versionadded:: 2.10
options: List[:class:`.Option`]
The list of slash command options. The options will be visible in Discord.
This is the old way of specifying options. Consider using :ref:`param_syntax` instead.
dm_permission: :class:`bool`
Whether this command can be used in DMs.
Defaults to ``True``.
.. deprecated:: 2.10
Use ``contexts`` instead.
This is equivalent to the :attr:`.InteractionContextTypes.bot_dm` flag.
default_member_permissions: Optional[Union[:class:`.Permissions`, :class:`int`]]
The default required permissions for this command.
See :attr:`.ApplicationCommand.default_member_permissions` for details.
.. versionadded:: 2.5
guild_ids: List[:class:`int`]
If specified, the client will register the command in these guilds.
Otherwise, this command will be registered globally.
connectors: Dict[:class:`str`, :class:`str`]
Binds function names to option names. If the name
of an option already matches the corresponding function param,
you don't have to specify the connectors. Connectors template:
``{"option-name": "param_name", ...}``.
If you're using :ref:`param_syntax`, you don't need to specify this.
extras: Dict[:class:`str`, Any]
A dict of user provided extras to attach to the command.
.. note::
This object may be copied by the library.
.. versionadded:: 2.5
Returns
-------
Callable[..., :class:`InvokableSlashCommand`]
A decorator that converts the provided method into an InvokableSlashCommand and returns it.
"""
def decorator(func: CommandCallback) -> InvokableSlashCommand:
if not utils.iscoroutinefunction(func):
raise TypeError(f"<{func.__qualname__}> must be a coroutine function")
if hasattr(func, "__command_flag__"):
raise TypeError("Callback is already a command.")
if guild_ids and not all(isinstance(guild_id, int) for guild_id in guild_ids):
raise ValueError("guild_ids must be a sequence of int.")
return InvokableSlashCommand(
func,
name=name,
description=description,
options=options,
dm_permission=dm_permission,
default_member_permissions=default_member_permissions,
nsfw=nsfw,
install_types=install_types,
contexts=contexts,
guild_ids=guild_ids,
connectors=connectors,
auto_sync=auto_sync,
extras=extras,
**kwargs,
)
return decorator

View File

@@ -0,0 +1,174 @@
# SPDX-License-Identifier: MIT
from typing import Optional
from .errors import ExpectedClosingQuoteError, InvalidEndOfQuotedStringError, UnexpectedQuoteError
# map from opening quotes to closing quotes
_quotes = {
'"': '"',
"": "", # noqa: RUF001
"": "", # noqa: RUF001
"": "",
"": "",
"": "",
"": "",
"": "",
"": "",
"": "",
"": "",
"": "", # noqa: RUF001
"": "",
"«": "»",
"": "", # noqa: RUF001
"": "",
"": "",
}
_all_quotes = set(_quotes.keys()) | set(_quotes.values())
class StringView:
def __init__(self, buffer: str) -> None:
self.index = 0
self.buffer = buffer
self.end = len(buffer)
self.previous = 0
@property
def current(self) -> Optional[str]:
return None if self.eof else self.buffer[self.index]
@property
def eof(self) -> bool:
return self.index >= self.end
def undo(self) -> None:
self.index = self.previous
def skip_ws(self) -> bool:
pos = 0
while not self.eof:
try:
current = self.buffer[self.index + pos]
if not current.isspace():
break
pos += 1
except IndexError:
break
self.previous = self.index
self.index += pos
return self.previous != self.index
def skip_string(self, string: str) -> bool:
strlen = len(string)
if self.buffer[self.index : self.index + strlen] == string:
self.previous = self.index
self.index += strlen
return True
return False
def read_rest(self) -> str:
result = self.buffer[self.index :]
self.previous = self.index
self.index = self.end
return result
def read(self, n: int) -> str:
result = self.buffer[self.index : self.index + n]
self.previous = self.index
self.index += n
return result
def get(self) -> Optional[str]:
try:
result = self.buffer[self.index + 1]
except IndexError:
result = None
self.previous = self.index
self.index += 1
return result
def get_word(self) -> str:
pos = 0
while not self.eof:
try:
current = self.buffer[self.index + pos]
if current.isspace():
break
pos += 1
except IndexError:
break
self.previous = self.index
result = self.buffer[self.index : self.index + pos]
self.index += pos
return result
def get_quoted_word(self) -> Optional[str]:
current = self.current
if current is None:
return None
close_quote = _quotes.get(current)
is_quoted = bool(close_quote)
if is_quoted:
result = []
_escaped_quotes = (current, close_quote)
else:
result = [current]
_escaped_quotes = _all_quotes
while not self.eof:
current = self.get()
if not current:
if is_quoted:
# unexpected EOF
raise ExpectedClosingQuoteError(str(close_quote))
return "".join(result)
# currently we accept strings in the format of "hello world"
# to embed a quote inside the string you must escape it: "a \"world\""
if current == "\\":
next_char = self.get()
if not next_char:
# string ends with \ and no character after it
if is_quoted:
# if we're quoted then we're expecting a closing quote
raise ExpectedClosingQuoteError(str(close_quote))
# if we aren't then we just let it through
return "".join(result)
if next_char in _escaped_quotes:
# escaped quote
result.append(next_char)
else:
# different escape character, ignore it
self.undo()
result.append(current)
continue
if not is_quoted and current in _all_quotes:
# we aren't quoted
raise UnexpectedQuoteError(current)
# closing quote
if is_quoted and current == close_quote:
next_char = self.get()
valid_eof = not next_char or next_char.isspace()
if not valid_eof:
raise InvalidEndOfQuotedStringError(str(next_char))
# we're quoted so it's okay
return "".join(result)
if current.isspace() and not is_quoted:
# end of word found
return "".join(result)
result.append(current)
def __repr__(self) -> str:
return (
f"<StringView pos: {self.index} prev: {self.previous} end: {self.end} eof: {self.eof}>"
)

View File

@@ -0,0 +1,14 @@
# SPDX-License-Identifier: MIT
import typing as t
from mypy.plugin import Plugin
# FIXME: properly deprecate this in the future
class DisnakePlugin(Plugin):
"""Custom mypy plugin; no-op as of version 2.9."""
def plugin(version: str) -> t.Type[Plugin]:
return DisnakePlugin

View File

@@ -0,0 +1,799 @@
# SPDX-License-Identifier: MIT
from __future__ import annotations
import asyncio
import datetime
import inspect
import sys
import traceback
from collections.abc import Sequence
from typing import (
TYPE_CHECKING,
Any,
Callable,
Coroutine,
Generic,
List,
Optional,
Protocol,
Type,
TypeVar,
Union,
cast,
get_origin,
overload,
)
import aiohttp
import disnake
from disnake.backoff import ExponentialBackoff
from disnake.utils import MISSING, iscoroutinefunction, utcnow
if TYPE_CHECKING:
from typing_extensions import Concatenate, ParamSpec, Self
P = ParamSpec("P")
else:
P = TypeVar("P")
__all__ = ("loop",)
T = TypeVar("T")
_func = Callable[..., Coroutine[Any, Any, Any]]
LF = TypeVar("LF", bound=_func)
FT = TypeVar("FT", bound=_func)
ET = TypeVar("ET", bound=Callable[[Any, BaseException], Coroutine[Any, Any, Any]])
class SleepHandle:
__slots__ = ("future", "loop", "handle")
def __init__(self, dt: datetime.datetime, *, loop: asyncio.AbstractEventLoop) -> None:
self.loop = loop
self.future: asyncio.Future[bool] = loop.create_future()
relative_delta = disnake.utils.compute_timedelta(dt)
self.handle = loop.call_later(relative_delta, self.future.set_result, True)
def recalculate(self, dt: datetime.datetime) -> None:
self.handle.cancel()
relative_delta = disnake.utils.compute_timedelta(dt)
self.handle = self.loop.call_later(relative_delta, self.future.set_result, True)
def wait(self) -> asyncio.Future[bool]:
return self.future
def done(self) -> bool:
return self.future.done()
def cancel(self) -> None:
self.handle.cancel()
self.future.cancel()
class Loop(Generic[LF]):
"""A background task helper that abstracts the loop and reconnection logic for you.
The main interface to create this is through :func:`loop`.
"""
def __init__(
self,
coro: LF,
*,
seconds: float = 0,
minutes: float = 0,
hours: float = 0,
time: Union[datetime.time, Sequence[datetime.time]] = MISSING,
count: Optional[int] = None,
reconnect: bool = True,
loop: asyncio.AbstractEventLoop = MISSING,
) -> None:
""".. note:
If you overwrite ``__init__`` arguments, make sure to redefine .clone too.
"""
self.coro: LF = coro
self.reconnect: bool = reconnect
self.loop: asyncio.AbstractEventLoop = loop
self.count: Optional[int] = count
self._current_loop = 0
self._handle: SleepHandle = MISSING
self._task: asyncio.Task[None] = MISSING
self._injected: Any = None
self._valid_exception = (
OSError,
disnake.GatewayNotFound,
disnake.ConnectionClosed,
aiohttp.ClientError,
asyncio.TimeoutError,
)
self._before_loop = None
self._after_loop = None
self._is_being_cancelled = False
self._has_failed = False
self._stop_next_iteration = False
if self.count is not None and self.count <= 0:
raise ValueError("count must be greater than 0 or None.")
self.change_interval(seconds=seconds, minutes=minutes, hours=hours, time=time)
self._last_iteration_failed = False
self._last_iteration: datetime.datetime = MISSING
self._next_iteration = None
if not iscoroutinefunction(self.coro):
raise TypeError(f"Expected coroutine function, not {type(self.coro).__name__!r}.")
async def _call_loop_function(self, name: str, *args: Any, **kwargs: Any) -> None:
coro = getattr(self, "_" + name)
if coro is None:
return
if self._injected is not None:
await coro(self._injected, *args, **kwargs)
else:
await coro(*args, **kwargs)
def _try_sleep_until(self, dt: datetime.datetime) -> asyncio.Future[bool]:
self._handle = SleepHandle(dt=dt, loop=self.loop)
return self._handle.wait()
async def _loop(self, *args: Any, **kwargs: Any) -> None:
backoff = ExponentialBackoff()
await self._call_loop_function("before_loop")
self._last_iteration_failed = False
if self._time is not MISSING:
# the time index should be prepared every time the internal loop is started
self._prepare_time_index()
self._next_iteration = self._get_next_sleep_time()
else:
self._next_iteration = utcnow()
try:
await self._try_sleep_until(self._next_iteration)
while True:
if not self._last_iteration_failed:
self._last_iteration = self._next_iteration
self._next_iteration = self._get_next_sleep_time()
try:
await self.coro(*args, **kwargs)
self._last_iteration_failed = False
except self._valid_exception:
self._last_iteration_failed = True
if not self.reconnect:
raise
await asyncio.sleep(backoff.delay())
else:
await self._try_sleep_until(self._next_iteration)
if self._stop_next_iteration:
return
now = utcnow()
if now > self._next_iteration:
self._next_iteration = now
if self._time is not MISSING:
self._prepare_time_index(now)
self._current_loop += 1
if self._current_loop == self.count:
break
except asyncio.CancelledError:
self._is_being_cancelled = True
raise
except Exception as exc:
self._has_failed = True
await self._call_loop_function("error", exc)
raise
finally:
await self._call_loop_function("after_loop")
self._handle.cancel()
self._is_being_cancelled = False
self._current_loop = 0
self._stop_next_iteration = False
self._has_failed = False
def __get__(self, obj: T, objtype: Type[T]) -> Self:
if obj is None:
return self
clone = self.clone()
clone._injected = obj
setattr(obj, self.coro.__name__, clone)
return clone
def clone(self) -> Self:
instance = type(self)(
self.coro,
seconds=self._seconds,
hours=self._hours,
minutes=self._minutes,
time=self._time,
count=self.count,
reconnect=self.reconnect,
loop=self.loop,
)
instance._before_loop = self._before_loop
instance._after_loop = self._after_loop
instance._error = self._error
instance._injected = self._injected
return instance
@property
def seconds(self) -> Optional[float]:
"""Optional[:class:`float`]: Read-only value for the number of seconds
between each iteration. ``None`` if an explicit ``time`` value was passed instead.
.. versionadded:: 2.0
"""
if self._seconds is not MISSING:
return self._seconds
@property
def minutes(self) -> Optional[float]:
"""Optional[:class:`float`]: Read-only value for the number of minutes
between each iteration. ``None`` if an explicit ``time`` value was passed instead.
.. versionadded:: 2.0
"""
if self._minutes is not MISSING:
return self._minutes
@property
def hours(self) -> Optional[float]:
"""Optional[:class:`float`]: Read-only value for the number of hours
between each iteration. ``None`` if an explicit ``time`` value was passed instead.
.. versionadded:: 2.0
"""
if self._hours is not MISSING:
return self._hours
@property
def time(self) -> Optional[List[datetime.time]]:
"""Optional[List[:class:`datetime.time`]]: Read-only list for the exact times this loop runs at.
``None`` if relative times were passed instead.
.. versionadded:: 2.0
"""
if self._time is not MISSING:
return self._time.copy()
@property
def current_loop(self) -> int:
""":class:`int`: The current iteration of the loop."""
return self._current_loop
@property
def next_iteration(self) -> Optional[datetime.datetime]:
"""Optional[:class:`datetime.datetime`]: When the next iteration of the loop will occur.
.. versionadded:: 1.3
"""
if self._task is MISSING:
return None
elif (self._task and self._task.done()) or self._stop_next_iteration:
return None
return self._next_iteration
async def __call__(self, *args: Any, **kwargs: Any) -> Any:
"""|coro|
Calls the internal callback that the task holds.
.. versionadded:: 1.6
Parameters
----------
*args
The arguments to use.
**kwargs
The keyword arguments to use.
"""
if self._injected is not None:
args = (self._injected, *args)
return await self.coro(*args, **kwargs)
def start(self, *args: Any, **kwargs: Any) -> asyncio.Task[None]:
"""Starts the internal task in the event loop.
Parameters
----------
*args
The arguments to use.
**kwargs
The keyword arguments to use.
Raises
------
RuntimeError
A task has already been launched and is running.
Returns
-------
:class:`asyncio.Task`
The task that has been created.
"""
if self._task is not MISSING and not self._task.done():
raise RuntimeError("Task is already launched and is not completed.")
if self._injected is not None:
args = (self._injected, *args)
if self.loop is MISSING:
self.loop = disnake.utils.get_event_loop()
self._task = self.loop.create_task(self._loop(*args, **kwargs))
return self._task
def stop(self) -> None:
"""Gracefully stops the task from running.
Unlike :meth:`cancel`\\, this allows the task to finish its
current iteration before gracefully exiting.
.. note::
If the internal function raises an error that can be
handled before finishing then it will retry until
it succeeds.
If this is undesirable, either remove the error handling
before stopping via :meth:`clear_exception_types` or
use :meth:`cancel` instead.
.. versionadded:: 1.2
"""
if self._task is not MISSING and not self._task.done():
self._stop_next_iteration = True
def _can_be_cancelled(self) -> bool:
return bool(not self._is_being_cancelled and self._task and not self._task.done())
def cancel(self) -> None:
"""Cancels the internal task, if it is running."""
if self._can_be_cancelled():
self._task.cancel()
def restart(self, *args: Any, **kwargs: Any) -> None:
"""A convenience method to restart the internal task.
.. note::
Due to the way this function works, the task is not
returned like :meth:`start`.
Parameters
----------
*args
The arguments to use.
**kwargs
The keyword arguments to use.
"""
def restart_when_over(fut: Any, *, args: Any = args, kwargs: Any = kwargs) -> None:
self._task.remove_done_callback(restart_when_over)
self.start(*args, **kwargs)
if self._can_be_cancelled():
self._task.add_done_callback(restart_when_over)
self._task.cancel()
def add_exception_type(self, *exceptions: Type[BaseException]) -> None:
"""Adds exception types to be handled during the reconnect logic.
By default the exception types handled are those handled by
:meth:`disnake.Client.connect`\\, which includes a lot of internet disconnection
errors.
This function is useful if you're interacting with a 3rd party library that
raises its own set of exceptions.
Parameters
----------
*exceptions: Type[:class:`BaseException`]
An argument list of exception classes to handle.
Raises
------
TypeError
An exception passed is either not a class or not inherited from :class:`BaseException`.
"""
for exc in exceptions:
if not inspect.isclass(exc):
raise TypeError(f"{exc!r} must be a class.")
if not issubclass(exc, BaseException):
raise TypeError(f"{exc!r} must inherit from BaseException.")
self._valid_exception = (*self._valid_exception, *exceptions)
def clear_exception_types(self) -> None:
"""Removes all exception types that are handled.
.. note::
This operation obviously cannot be undone!
"""
self._valid_exception = ()
def remove_exception_type(self, *exceptions: Type[BaseException]) -> bool:
"""Removes exception types from being handled during the reconnect logic.
Parameters
----------
*exceptions: Type[:class:`BaseException`]
An argument list of exception classes to handle.
Returns
-------
:class:`bool`
Whether all exceptions were successfully removed.
"""
old_length = len(self._valid_exception)
self._valid_exception = tuple(x for x in self._valid_exception if x not in exceptions)
return len(self._valid_exception) == old_length - len(exceptions)
def get_task(self) -> Optional[asyncio.Task[None]]:
"""Fetches the internal task or ``None`` if there isn't one running.
:return type: Optional[:class:`asyncio.Task`]
"""
return self._task if self._task is not MISSING else None
def is_being_cancelled(self) -> bool:
"""Whether the task is being cancelled.
:return type: :class:`bool`
"""
return self._is_being_cancelled
def failed(self) -> bool:
"""Whether the internal task has failed.
.. versionadded:: 1.2
:return type: :class:`bool`
"""
return self._has_failed
def is_running(self) -> bool:
"""Check if the task is currently running.
.. versionadded:: 1.4
:return type: :class:`bool`
"""
return not bool(self._task.done()) if self._task is not MISSING else False
async def _error(self, *args: Any) -> None:
exception: Exception = args[-1]
print(
f"Unhandled exception in internal background task {self.coro.__name__!r}.",
file=sys.stderr,
)
traceback.print_exception(
type(exception), exception, exception.__traceback__, file=sys.stderr
)
def before_loop(self, coro: FT) -> FT:
"""A decorator that registers a coroutine to be called before the loop starts running.
This is useful if you want to wait for some bot state before the loop starts,
such as :meth:`disnake.Client.wait_until_ready`.
The coroutine must take no arguments (except ``self`` in a class context).
Parameters
----------
coro: :ref:`coroutine <coroutine>`
The coroutine to register before the loop runs.
Raises
------
TypeError
The function was not a coroutine.
"""
if not iscoroutinefunction(coro):
raise TypeError(f"Expected coroutine function, received {coro.__class__.__name__!r}.")
self._before_loop = coro
return coro
def after_loop(self, coro: FT) -> FT:
"""A decorator that register a coroutine to be called after the loop finished running.
The coroutine must take no arguments (except ``self`` in a class context).
.. note::
This coroutine is called even during cancellation. If it is desirable
to tell apart whether something was cancelled or not, check to see
whether :meth:`is_being_cancelled` is ``True`` or not.
Parameters
----------
coro: :ref:`coroutine <coroutine>`
The coroutine to register after the loop finishes.
Raises
------
TypeError
The function was not a coroutine.
"""
if not iscoroutinefunction(coro):
raise TypeError(f"Expected coroutine function, received {coro.__class__.__name__!r}.")
self._after_loop = coro
return coro
def error(self, coro: ET) -> ET:
"""A decorator that registers a coroutine to be called if the task encounters an unhandled exception.
The coroutine must take only one argument the exception raised (except ``self`` in a class context).
By default this prints to :data:`sys.stderr` however it could be
overridden to have a different implementation.
.. versionadded:: 1.4
Parameters
----------
coro: :ref:`coroutine <coroutine>`
The coroutine to register in the event of an unhandled exception.
Raises
------
TypeError
The function was not a coroutine.
"""
if not iscoroutinefunction(coro):
raise TypeError(f"Expected coroutine function, received {coro.__class__.__name__!r}.")
self._error = coro # type: ignore
return coro
def _get_next_sleep_time(self) -> datetime.datetime:
if self._sleep is not MISSING:
return self._last_iteration + datetime.timedelta(seconds=self._sleep)
if self._time_index >= len(self._time):
self._time_index = 0
if self._current_loop == 0:
# if we're at the last index on the first iteration, we need to sleep until tomorrow
return datetime.datetime.combine(
utcnow() + datetime.timedelta(days=1),
self._time[0],
)
next_time = self._time[self._time_index]
if self._current_loop == 0:
self._time_index += 1
if next_time > utcnow().timetz():
return datetime.datetime.combine(utcnow(), next_time)
else:
return datetime.datetime.combine(
utcnow() + datetime.timedelta(days=1),
next_time,
)
next_date = self._last_iteration
if next_time < next_date.timetz():
next_date += datetime.timedelta(days=1)
self._time_index += 1
return datetime.datetime.combine(next_date, next_time)
def _prepare_time_index(self, now: datetime.datetime = MISSING) -> None:
# now kwarg should be a datetime.datetime representing the time "now"
# to calculate the next time index from
# pre-condition: self._time is set
time_now = (now if now is not MISSING else utcnow().replace(microsecond=0)).timetz()
for idx, time in enumerate(self._time):
if time >= time_now:
self._time_index = idx
break
else:
self._time_index = 0
def _get_time_parameter(
self,
time: Union[datetime.time, Sequence[datetime.time]],
*,
dt: Type[datetime.time] = datetime.time,
utc: datetime.timezone = datetime.timezone.utc,
) -> List[datetime.time]:
if isinstance(time, dt):
inner = time if time.tzinfo is not None else time.replace(tzinfo=utc)
return [inner]
if not isinstance(time, Sequence):
raise TypeError(
f"Expected datetime.time or a sequence of datetime.time for ``time``, received {type(time)!r} instead."
)
if not time:
raise ValueError("time parameter must not be an empty sequence.")
ret: List[datetime.time] = []
for index, t in enumerate(time):
if not isinstance(t, dt):
raise TypeError(
f"Expected a sequence of {dt!r} for ``time``, received {type(t).__name__!r} at index {index} instead."
)
ret.append(t if t.tzinfo is not None else t.replace(tzinfo=utc))
ret = sorted(set(ret)) # de-dupe and sort times
return ret
def change_interval(
self,
*,
seconds: float = 0,
minutes: float = 0,
hours: float = 0,
time: Union[datetime.time, Sequence[datetime.time]] = MISSING,
) -> None:
"""Changes the interval for the sleep time.
.. versionadded:: 1.2
Parameters
----------
seconds: :class:`float`
The number of seconds between every iteration.
minutes: :class:`float`
The number of minutes between every iteration.
hours: :class:`float`
The number of hours between every iteration.
time: Union[:class:`datetime.time`, Sequence[:class:`datetime.time`]]
The exact times to run this loop at. Either a non-empty list or a single
value of :class:`datetime.time` should be passed.
This cannot be used in conjunction with the relative time parameters.
.. versionadded:: 2.0
.. note::
Duplicate times will be ignored, and only run once.
Raises
------
ValueError
An invalid value was given.
TypeError
An invalid value for the ``time`` parameter was passed, or the
``time`` parameter was passed in conjunction with relative time parameters.
"""
if time is MISSING:
seconds = seconds or 0
minutes = minutes or 0
hours = hours or 0
sleep = seconds + (minutes * 60.0) + (hours * 3600.0)
if sleep < 0:
raise ValueError("Total number of seconds cannot be less than zero.")
self._sleep = sleep
self._seconds = float(seconds)
self._hours = float(hours)
self._minutes = float(minutes)
self._time: List[datetime.time] = MISSING
else:
if any((seconds, minutes, hours)):
raise TypeError("Cannot mix explicit time with relative time")
self._time = self._get_time_parameter(time)
self._sleep = self._seconds = self._minutes = self._hours = MISSING
# `_last_iteration` can be missing if `change_interval` gets called in `before_loop` or
# before the event loop ticks after `start()`
if self.is_running() and self._last_iteration is not MISSING:
if self._time is not MISSING:
# prepare the next time index starting from after the last iteration
self._prepare_time_index(now=self._last_iteration)
self._next_iteration = self._get_next_sleep_time()
if not self._handle.done():
# the loop is sleeping, recalculate based on new interval
self._handle.recalculate(self._next_iteration)
T_co = TypeVar("T_co", covariant=True)
L_co = TypeVar("L_co", bound=Loop, covariant=True)
class Object(Protocol[T_co, P]):
def __new__(cls) -> T_co: ...
def __init__(self, *args: P.args, **kwargs: P.kwargs) -> None: ...
@overload
def loop(
*,
seconds: float = ...,
minutes: float = ...,
hours: float = ...,
time: Union[datetime.time, Sequence[datetime.time]] = ...,
count: Optional[int] = None,
reconnect: bool = True,
loop: asyncio.AbstractEventLoop = ...,
) -> Callable[[LF], Loop[LF]]: ...
@overload
def loop(
cls: Type[Object[L_co, Concatenate[LF, P]]], *_: P.args, **kwargs: P.kwargs
) -> Callable[[LF], L_co]: ...
def loop(
cls: Type[Object[L_co, Concatenate[LF, P]]] = Loop[Any],
**kwargs: Any,
) -> Callable[[LF], L_co]:
"""A decorator that schedules a task in the background for you with
optional reconnect logic. The decorator returns a :class:`Loop`.
Parameters
----------
cls: Type[:class:`Loop`]
The loop subclass to create an instance of. If provided, the following parameters
described below do not apply. Instead, this decorator will accept the same keywords
as the passed cls does.
.. versionadded:: 2.6
seconds: :class:`float`
The number of seconds between every iteration.
minutes: :class:`float`
The number of minutes between every iteration.
hours: :class:`float`
The number of hours between every iteration.
time: Union[:class:`datetime.time`, Sequence[:class:`datetime.time`]]
The exact times to run this loop at. Either a non-empty list or a single
value of :class:`datetime.time` should be passed. Timezones are supported.
If no timezone is given for the times, it is assumed to represent UTC time.
This cannot be used in conjunction with the relative time parameters.
.. note::
Duplicate times will be ignored, and only run once.
.. versionadded:: 2.0
count: Optional[:class:`int`]
The number of loops to do, ``None`` if it should be an
infinite loop.
reconnect: :class:`bool`
Whether to handle errors and restart the task
using an exponential back-off algorithm similar to the
one used in :meth:`disnake.Client.connect`.
loop: :class:`asyncio.AbstractEventLoop`
The loop to use to register the task, if not given
defaults to the current event loop or creates a new one
if there is none.
Raises
------
ValueError
An invalid value was given.
TypeError
The function was not a coroutine, the ``cls`` parameter was not a subclass of ``Loop``,
an invalid value for the ``time`` parameter was passed,
or ``time`` parameter was passed in conjunction with relative time parameters.
"""
if (origin := get_origin(cls)) is not None:
cls = origin
if not isinstance(cls, type) or not issubclass(cls, Loop):
raise TypeError(f"cls argument must be a subclass of Loop, got {cls!r}")
def decorator(func: LF) -> L_co:
if not iscoroutinefunction(func):
raise TypeError("decorated function must be a coroutine")
return cast("Type[L_co]", cls)(func, **kwargs)
return decorator