Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bpo-46032: Check types in singledispatch's register() at declaration time #30050

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions Lib/functools.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,6 +740,7 @@ def _compose_mro(cls, types):
# Remove entries which are already present in the __mro__ or unrelated.
def is_related(typ):
return (typ not in bases and hasattr(typ, '__mro__')
and not isinstance(typ, GenericAlias)
and issubclass(cls, typ))
types = [n for n in types if is_related(n)]
# Remove entries which are strict bases of other entries (they will end up
Expand Down Expand Up @@ -841,9 +842,13 @@ def _is_union_type(cls):
from typing import get_origin, Union
return get_origin(cls) in {Union, types.UnionType}

def _is_valid_union_type(cls):
def _is_valid_dispatch_type(cls):
if isinstance(cls, type) and not isinstance(cls, GenericAlias):
return True
from typing import get_args
return _is_union_type(cls) and all(isinstance(arg, type) for arg in get_args(cls))
return (_is_union_type(cls) and
all(isinstance(arg, type) and not isinstance(arg, GenericAlias)
for arg in get_args(cls)))

def register(cls, func=None):
"""generic_func.register(cls, func) -> func
Expand All @@ -852,9 +857,15 @@ def register(cls, func=None):

"""
nonlocal cache_token
if func is None:
if isinstance(cls, type) or _is_valid_union_type(cls):
if _is_valid_dispatch_type(cls):
if func is None:
return lambda f: register(cls, f)
else:
if func is not None:
raise TypeError(
f"Invalid first argument to `register()`. "
f"{cls!r} is not a class or union type."
)
ann = getattr(cls, '__annotations__', {})
if not ann:
raise TypeError(
Expand All @@ -867,7 +878,7 @@ def register(cls, func=None):
# only import typing if annotation parsing is necessary
from typing import get_type_hints
argname, cls = next(iter(get_type_hints(func).items()))
if not isinstance(cls, type) and not _is_valid_union_type(cls):
if not _is_valid_dispatch_type(cls):
if _is_union_type(cls):
raise TypeError(
f"Invalid annotation for {argname!r}. "
Expand Down
68 changes: 68 additions & 0 deletions Lib/test/test_functools.py
Original file line number Diff line number Diff line change
Expand Up @@ -2722,6 +2722,74 @@ def _(arg: int | float):
self.assertEqual(f(1), "types.UnionType")
self.assertEqual(f(1.0), "types.UnionType")

def test_register_genericalias(self):
@functools.singledispatch
def f(arg):
return "default"

with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
f.register(list[int], lambda arg: "types.GenericAlias")
with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
f.register(typing.List[int], lambda arg: "typing.GenericAlias")
with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
f.register(list[int] | str, lambda arg: "types.UnionTypes(types.GenericAlias)")
with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
f.register(typing.List[float] | bytes, lambda arg: "typing.Union[typing.GenericAlias]")
with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
f.register(typing.Any, lambda arg: "typing.Any")

self.assertEqual(f([1]), "default")
self.assertEqual(f([1.0]), "default")
self.assertEqual(f(""), "default")
self.assertEqual(f(b""), "default")

def test_register_genericalias_decorator(self):
@functools.singledispatch
def f(arg):
return "default"

with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
f.register(list[int])
with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
f.register(typing.List[int])
with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
f.register(list[int] | str)
with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
f.register(typing.List[int] | str)
with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
f.register(typing.Any)

def test_register_genericalias_annotation(self):
@functools.singledispatch
def f(arg):
return "default"

with self.assertRaisesRegex(TypeError, "Invalid annotation for 'arg'"):
@f.register
def _(arg: list[int]):
return "types.GenericAlias"
with self.assertRaisesRegex(TypeError, "Invalid annotation for 'arg'"):
@f.register
def _(arg: typing.List[float]):
return "typing.GenericAlias"
with self.assertRaisesRegex(TypeError, "Invalid annotation for 'arg'"):
@f.register
def _(arg: list[int] | str):
return "types.UnionType(types.GenericAlias)"
with self.assertRaisesRegex(TypeError, "Invalid annotation for 'arg'"):
@f.register
def _(arg: typing.List[float] | bytes):
return "typing.Union[typing.GenericAlias]"
with self.assertRaisesRegex(TypeError, "Invalid annotation for 'arg'"):
@f.register
def _(arg: typing.Any):
return "typing.Any"

self.assertEqual(f([1]), "default")
self.assertEqual(f([1.0]), "default")
self.assertEqual(f(""), "default")
self.assertEqual(f(b""), "default")


class CachedCostItem:
_cost = 1
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
The ``registry()`` method of :func:`functools.singledispatch` functions
checks now the first argument or the first parameter annotation and raises a
TypeError if it is not supported. Previously unsupported "types" were
ignored (e.g. ``typing.List[int]``) or caused an error at calling time (e.g.
``list[int]``).