Skip to content

Commit

Permalink
Make overloads support classmethod and staticmethod (#5224)
Browse files Browse the repository at this point in the history
* Move 'is_class' and 'is_static' into FuncBase

This commit moves the `is_class` and `is_static` fields into FuncBase.
It also cleans up the list of flags so they don't repeat the
'is_property' entry, which is now present in `FUNCBASE_FLAGS`.

The high-level plan is to modify the `is_class` and `is_static` fields
in OverloadedFuncDef for use later in mypy.

* Make semantic analysis phase record class/static methods with overloads

This commit adjusts the semantic analysis phase to detect and record
when an overload appears to be a classmethod or staticmethod.

* Broaden class/static method checks to catch overloads

This commit modifies mypy to use the `is_static` and `is_class` fields
of OverloadedFuncDef as appropriate.

I found the code snippets to modify by asking PyCharm for all instances
of code using those two fields and modified the surrounding code as
appropriate.

* Add support for overloaded classmethods in attrs/dataclasses

Both the attrs and dataclasses plugins manually patch classmethods -- we
do the same for overloads.

* Respond to code review

This commit:

1. Updates astdiff.py and adds a case to one of the fine-grained
   dependency test files.

2. Adds some helper methods to FunctionLike.

3. Performs a few misc cleanups.

* Respond to code review; add tests for self types
  • Loading branch information
Michael0x2a authored and ilevkivskyi committed Jun 16, 2018
1 parent e66d53b commit 29889c8
Show file tree
Hide file tree
Showing 16 changed files with 677 additions and 30 deletions.
16 changes: 6 additions & 10 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1289,12 +1289,10 @@ def check_override(self, override: FunctionLike, original: FunctionLike,
# this could be unsafe with reverse operator methods.
fail = True

if isinstance(original, CallableType) and isinstance(override, CallableType):
if (isinstance(original.definition, FuncItem) and
isinstance(override.definition, FuncItem)):
if ((original.definition.is_static or original.definition.is_class) and
not (override.definition.is_static or override.definition.is_class)):
fail = True
if isinstance(original, FunctionLike) and isinstance(override, FunctionLike):
if ((original.is_classmethod() or original.is_staticmethod()) and
not (override.is_classmethod() or override.is_staticmethod())):
fail = True

if fail:
emitted_msg = False
Expand Down Expand Up @@ -3911,8 +3909,6 @@ def is_untyped_decorator(typ: Optional[Type]) -> bool:
def is_static(func: Union[FuncBase, Decorator]) -> bool:
if isinstance(func, Decorator):
return is_static(func.func)
elif isinstance(func, OverloadedFuncDef):
return any(is_static(item) for item in func.items)
elif isinstance(func, FuncItem):
elif isinstance(func, FuncBase):
return func.is_static
return False
assert False, "Unexpected func type: {}".format(type(func))
3 changes: 2 additions & 1 deletion mypy/checkmember.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,8 @@ def analyze_class_attribute_access(itype: Instance,
return handle_partial_attribute_type(t, is_lvalue, msg, symnode)
if not is_method and (isinstance(t, TypeVarType) or get_type_vars(t)):
msg.fail(messages.GENERIC_INSTANCE_VAR_CLASS_ACCESS, context)
is_classmethod = is_decorated and cast(Decorator, node.node).func.is_class
is_classmethod = ((is_decorated and cast(Decorator, node.node).func.is_class)
or (isinstance(node.node, FuncBase) and node.node.is_class))
return add_class_tvars(t, itype, is_classmethod, builtin_type, original_type)
elif isinstance(node.node, Var):
not_ready_callback(name, context)
Expand Down
8 changes: 7 additions & 1 deletion mypy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
TypeInfo, Context, MypyFile, op_methods, FuncDef, reverse_type_aliases,
ARG_POS, ARG_OPT, ARG_NAMED, ARG_NAMED_OPT, ARG_STAR, ARG_STAR2,
ReturnStmt, NameExpr, Var, CONTRAVARIANT, COVARIANT, SymbolNode,
CallExpr, Expression
CallExpr, Expression, OverloadedFuncDef,
)

# Constants that represent simple type checker error message, i.e. messages
Expand Down Expand Up @@ -942,6 +942,12 @@ def incompatible_typevar_value(self,
self.format(typ)),
context)

def overload_inconsistently_applies_decorator(self, decorator: str, context: Context) -> None:
self.fail(
'Overload does not consistently use the "@{}" '.format(decorator)
+ 'decorator on all function signatures.',
context)

def overloaded_signatures_overlap(self, index1: int, index2: int, context: Context) -> None:
self.fail('Overloaded function signatures {} and {} overlap with '
'incompatible return types'.format(index1, index2), context)
Expand Down
27 changes: 15 additions & 12 deletions mypy/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,13 +370,20 @@ def __str__(self) -> str:
return 'ImportedName(%s)' % self.target_fullname


FUNCBASE_FLAGS = [
'is_property', 'is_class', 'is_static',
]


class FuncBase(Node):
"""Abstract base class for function-like nodes"""

__slots__ = ('type',
'unanalyzed_type',
'info',
'is_property',
'is_class', # Uses "@classmethod"
'is_static', # USes "@staticmethod"
'_fullname',
)

Expand All @@ -391,6 +398,8 @@ def __init__(self) -> None:
# TODO: Type should be Optional[TypeInfo]
self.info = cast(TypeInfo, None)
self.is_property = False
self.is_class = False
self.is_static = False
# Name with module prefix
# TODO: Type should be Optional[str]
self._fullname = cast(str, None)
Expand Down Expand Up @@ -436,8 +445,8 @@ def serialize(self) -> JsonDict:
'items': [i.serialize() for i in self.items],
'type': None if self.type is None else self.type.serialize(),
'fullname': self._fullname,
'is_property': self.is_property,
'impl': None if self.impl is None else self.impl.serialize()
'impl': None if self.impl is None else self.impl.serialize(),
'flags': get_flags(self, FUNCBASE_FLAGS),
}

@classmethod
Expand All @@ -451,7 +460,7 @@ def deserialize(cls, data: JsonDict) -> 'OverloadedFuncDef':
if data.get('type') is not None:
res.type = mypy.types.deserialize_type(data['type'])
res._fullname = data['fullname']
res.is_property = data['is_property']
set_flags(res, data['flags'])
# NOTE: res.info will be set in the fixup phase.
return res

Expand Down Expand Up @@ -481,9 +490,9 @@ def set_line(self, target: Union[Context, int], column: Optional[int] = None) ->
self.variable.set_line(self.line, self.column)


FUNCITEM_FLAGS = [
FUNCITEM_FLAGS = FUNCBASE_FLAGS + [
'is_overload', 'is_generator', 'is_coroutine', 'is_async_generator',
'is_awaitable_coroutine', 'is_static', 'is_class',
'is_awaitable_coroutine',
]


Expand All @@ -503,8 +512,6 @@ class FuncItem(FuncBase):
'is_coroutine', # Defined using 'async def' syntax?
'is_async_generator', # Is an async def generator?
'is_awaitable_coroutine', # Decorated with '@{typing,asyncio}.coroutine'?
'is_static', # Uses @staticmethod?
'is_class', # Uses @classmethod?
'expanded', # Variants of function with type variables with values expanded
)

Expand All @@ -525,8 +532,6 @@ def __init__(self,
self.is_coroutine = False
self.is_async_generator = False
self.is_awaitable_coroutine = False
self.is_static = False
self.is_class = False
self.expanded = [] # type: List[FuncItem]

self.min_args = 0
Expand All @@ -547,7 +552,7 @@ def is_dynamic(self) -> bool:


FUNCDEF_FLAGS = FUNCITEM_FLAGS + [
'is_decorated', 'is_conditional', 'is_abstract', 'is_property',
'is_decorated', 'is_conditional', 'is_abstract',
]


Expand All @@ -561,7 +566,6 @@ class FuncDef(FuncItem, SymbolNode, Statement):
'is_decorated',
'is_conditional',
'is_abstract',
'is_property',
'original_def',
)

Expand All @@ -575,7 +579,6 @@ def __init__(self,
self.is_decorated = False
self.is_conditional = False # Defined conditionally (within block)?
self.is_abstract = False
self.is_property = False
# Original conditional definition
self.original_def = None # type: Union[None, FuncDef, Var, Decorator]

Expand Down
10 changes: 10 additions & 0 deletions mypy/plugins/attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,16 @@ def _add_init(ctx: 'mypy.plugin.ClassDefContext', attributes: List[Attribute],
func_type = stmt.func.type
if isinstance(func_type, CallableType):
func_type.arg_types[0] = ctx.api.class_type(ctx.cls.info)
if isinstance(stmt, OverloadedFuncDef) and stmt.is_class:
func_type = stmt.type
if isinstance(func_type, Overloaded):
class_type = ctx.api.class_type(ctx.cls.info)
for item in func_type.items():
item.arg_types[0] = class_type
if stmt.impl is not None:
assert isinstance(stmt.impl, Decorator)
if isinstance(stmt.impl.func.type, CallableType):
stmt.impl.func.type.arg_types[0] = class_type


class MethodAdder:
Expand Down
14 changes: 12 additions & 2 deletions mypy/plugins/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
from mypy.nodes import (
ARG_OPT, ARG_POS, MDEF, Argument, AssignmentStmt, CallExpr,
Context, Decorator, Expression, FuncDef, JsonDict, NameExpr,
SymbolTableNode, TempNode, TypeInfo, Var,
OverloadedFuncDef, SymbolTableNode, TempNode, TypeInfo, Var,
)
from mypy.plugin import ClassDefContext
from mypy.plugins.common import _add_method, _get_decorator_bool_argument
from mypy.types import (
CallableType, Instance, NoneTyp, TypeVarDef, TypeVarType,
CallableType, Instance, NoneTyp, Overloaded, TypeVarDef, TypeVarType,
)

# The set of decorators that generate dataclasses.
Expand Down Expand Up @@ -95,6 +95,16 @@ def transform(self) -> None:
func_type = stmt.func.type
if isinstance(func_type, CallableType):
func_type.arg_types[0] = self._ctx.api.class_type(self._ctx.cls.info)
if isinstance(stmt, OverloadedFuncDef) and stmt.is_class:
func_type = stmt.type
if isinstance(func_type, Overloaded):
class_type = ctx.api.class_type(ctx.cls.info)
for item in func_type.items():
item.arg_types[0] = class_type
if stmt.impl is not None:
assert isinstance(stmt.impl, Decorator)
if isinstance(stmt.impl.func.type, CallableType):
stmt.impl.func.type.arg_types[0] = class_type

# Add an eq method, but only if the class doesn't already have one.
if decorator_arguments['eq'] and info.get('__eq__') is None:
Expand Down
31 changes: 31 additions & 0 deletions mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,37 @@ def _visit_overloaded_func_def(self, defn: OverloadedFuncDef) -> None:
# redefinitions already.
return

# We know this is an overload def -- let's handle classmethod and staticmethod
class_status = []
static_status = []
for item in defn.items:
if isinstance(item, Decorator):
inner = item.func
elif isinstance(item, FuncDef):
inner = item
else:
assert False, "The 'item' variable is an unexpected type: {}".format(type(item))
class_status.append(inner.is_class)
static_status.append(inner.is_static)

if defn.impl is not None:
if isinstance(defn.impl, Decorator):
inner = defn.impl.func
elif isinstance(defn.impl, FuncDef):
inner = defn.impl
else:
assert False, "Unexpected impl type: {}".format(type(defn.impl))
class_status.append(inner.is_class)
static_status.append(inner.is_static)

if len(set(class_status)) != 1:
self.msg.overload_inconsistently_applies_decorator('classmethod', defn)
elif len(set(static_status)) != 1:
self.msg.overload_inconsistently_applies_decorator('staticmethod', defn)
else:
defn.is_class = class_status[0]
defn.is_static = static_status[0]

if self.type and not self.is_func_scope():
self.type.names[defn.name()] = SymbolTableNode(MDEF, defn,
typ=defn.type)
Expand Down
6 changes: 3 additions & 3 deletions mypy/server/astdiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class level -- these are handled at attribute level (say, 'mod.Cls.method'

from mypy.nodes import (
SymbolTable, TypeInfo, Var, SymbolNode, Decorator, TypeVarExpr,
OverloadedFuncDef, FuncItem, MODULE_REF, TYPE_ALIAS, UNBOUND_IMPORTED, TVAR
FuncBase, OverloadedFuncDef, FuncItem, MODULE_REF, TYPE_ALIAS, UNBOUND_IMPORTED, TVAR
)
from mypy.types import (
Type, TypeVisitor, UnboundType, AnyType, NoneTyp, UninhabitedType,
Expand Down Expand Up @@ -167,13 +167,13 @@ def snapshot_definition(node: Optional[SymbolNode],
The representation is nested tuples and dicts. Only externally
visible attributes are included.
"""
if isinstance(node, (OverloadedFuncDef, FuncItem)):
if isinstance(node, FuncBase):
# TODO: info
if node.type:
signature = snapshot_type(node.type)
else:
signature = snapshot_untyped_signature(node)
return ('Func', common, node.is_property, signature)
return ('Func', common, node.is_property, node.is_class, node.is_static, signature)
elif isinstance(node, Var):
return ('Var', common, snapshot_optional_type(node.type))
elif isinstance(node, Decorator):
Expand Down
4 changes: 4 additions & 0 deletions mypy/strconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,10 @@ def visit_overloaded_func_def(self, o: 'mypy.nodes.OverloadedFuncDef') -> str:
a.insert(0, o.type)
if o.impl:
a.insert(0, o.impl)
if o.is_static:
a.insert(-1, 'Static')
if o.is_class:
a.insert(-1, 'Class')
return self.dump(a, o)

def visit_class_def(self, o: 'mypy.nodes.ClassDef') -> str:
Expand Down
3 changes: 3 additions & 0 deletions mypy/treetransform.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,9 @@ def visit_overloaded_func_def(self, node: OverloadedFuncDef) -> OverloadedFuncDe
new._fullname = node._fullname
new.type = self.optional_type(node.type)
new.info = node.info
new.is_static = node.is_static
new.is_class = node.is_class
new.is_property = node.is_property
if node.impl:
new.impl = cast(OverloadPart, node.impl.accept(self))
return new
Expand Down
20 changes: 19 additions & 1 deletion mypy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from mypy import experiments
from mypy.nodes import (
INVARIANT, SymbolNode, ARG_POS, ARG_OPT, ARG_STAR, ARG_STAR2, ARG_NAMED, ARG_NAMED_OPT,
FuncDef
FuncBase, FuncDef,
)
from mypy.sharedparse import argument_elide_name
from mypy.util import IdMapper
Expand Down Expand Up @@ -645,6 +645,12 @@ def with_name(self, name: str) -> 'FunctionLike': pass
@abstractmethod
def get_name(self) -> Optional[str]: pass

@abstractmethod
def is_classmethod(self) -> bool: pass

@abstractmethod
def is_staticmethod(self) -> bool: pass


FormalArgument = NamedTuple('FormalArgument', [
('name', Optional[str]),
Expand Down Expand Up @@ -828,6 +834,12 @@ def with_name(self, name: str) -> 'CallableType':
def get_name(self) -> Optional[str]:
return self.name

def is_classmethod(self) -> bool:
return isinstance(self.definition, FuncBase) and self.definition.is_class

def is_staticmethod(self) -> bool:
return isinstance(self.definition, FuncBase) and self.definition.is_static

def max_fixed_args(self) -> int:
n = len(self.arg_types)
if self.is_var_arg:
Expand Down Expand Up @@ -1046,6 +1058,12 @@ def with_name(self, name: str) -> 'Overloaded':
def get_name(self) -> Optional[str]:
return self._items[0].name

def is_classmethod(self) -> bool:
return self._items[0].is_classmethod()

def is_staticmethod(self) -> bool:
return self._items[0].is_staticmethod()

def accept(self, visitor: 'TypeVisitor[T]') -> T:
return visitor.visit_overloaded(self)

Expand Down
32 changes: 32 additions & 0 deletions test-data/unit/check-attr.test
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,38 @@ a = A.new()
reveal_type(a.foo) # E: Revealed type is 'def () -> builtins.int'
[builtins fixtures/classmethod.pyi]

[case testAttrsOtherOverloads]
import attr
from typing import overload, Union

@attr.s
class A:
a = attr.ib()
b = attr.ib(default=3)

@classmethod
def other(cls) -> str:
return "..."

@overload
@classmethod
def foo(cls, x: int) -> int: ...

@overload
@classmethod
def foo(cls, x: str) -> str: ...

@classmethod
def foo(cls, x: Union[int, str]) -> Union[int, str]:
reveal_type(cls) # E: Revealed type is 'def (a: Any, b: Any =) -> __main__.A'
reveal_type(cls.other()) # E: Revealed type is 'builtins.str'
return x

reveal_type(A.foo(3)) # E: Revealed type is 'builtins.int'
reveal_type(A.foo("foo")) # E: Revealed type is 'builtins.str'

[builtins fixtures/classmethod.pyi]

[case testAttrsDefaultDecorator]
import attr
@attr.s
Expand Down
Loading

0 comments on commit 29889c8

Please sign in to comment.