Skip to content

Commit

Permalink
Better callable: Callable[[Arg('x', int), VarArg(str)], int] now a …
Browse files Browse the repository at this point in the history
…thing you can do (#2607)

Implements an experimental feature to allow Callable to have any kind of signature an actual function definition does.

This should enable better typing of callbacks &c.

Initial discussion: python/typing#239
Proposal, v. similar to this impl: python/typing#264
Relevant typeshed PR: python/typeshed#793
  • Loading branch information
sixolet authored and JukkaL committed May 2, 2017
1 parent 058a8a6 commit ddf03d1
Show file tree
Hide file tree
Showing 19 changed files with 660 additions and 81 deletions.
36 changes: 36 additions & 0 deletions extensions/mypy_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from mypy_extensions import TypedDict
"""

from typing import Any

# NOTE: This module must support Python 2.7 in addition to Python 3.x

import sys
Expand Down Expand Up @@ -92,6 +94,40 @@ class Point2D(TypedDict):
syntax forms work for Python 2.7 and 3.2+
"""

# Argument constructors for making more-detailed Callables. These all just
# return their type argument, to make them complete noops in terms of the
# `typing` module.


def Arg(type=Any, name=None):
"""A normal positional argument"""
return type


def DefaultArg(type=Any, name=None):
"""A positional argument with a default value"""
return type


def NamedArg(type=Any, name=None):
"""A keyword-only argument"""
return type


def DefaultNamedArg(type=Any, name=None):
"""A keyword-only argument with a default value"""
return type


def VarArg(type=Any):
"""A *args-style variadic positional argument"""
return type


def KwArg(type=Any):
"""A **kwargs-style variadic keyword argument"""
return type


# Return type that indicates a function does not return
class NoReturn: pass
72 changes: 65 additions & 7 deletions mypy/exprtotype.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,38 @@

from mypy.nodes import (
Expression, NameExpr, MemberExpr, IndexExpr, TupleExpr,
ListExpr, StrExpr, BytesExpr, UnicodeExpr, EllipsisExpr,
get_member_expr_fullname
ListExpr, StrExpr, BytesExpr, UnicodeExpr, EllipsisExpr, CallExpr,
ARG_POS, ARG_NAMED, get_member_expr_fullname
)
from mypy.fastparse import parse_type_comment
from mypy.types import Type, UnboundType, TypeList, EllipsisType
from mypy.types import (
Type, UnboundType, TypeList, EllipsisType, AnyType, Optional, CallableArgument,
)


class TypeTranslationError(Exception):
"""Exception raised when an expression is not valid as a type."""


def expr_to_unanalyzed_type(expr: Expression) -> Type:
def _extract_argument_name(expr: Expression) -> Optional[str]:
if isinstance(expr, NameExpr) and expr.name == 'None':
return None
elif isinstance(expr, StrExpr):
return expr.value
elif isinstance(expr, UnicodeExpr):
return expr.value
else:
raise TypeTranslationError()


def expr_to_unanalyzed_type(expr: Expression, _parent: Optional[Expression] = None) -> Type:
"""Translate an expression to the corresponding type.
The result is not semantically analyzed. It can be UnboundType or TypeList.
Raise TypeTranslationError if the expression cannot represent a type.
"""
# The `parent` paremeter is used in recursive calls to provide context for
# understanding whether an CallableArgument is ok.
if isinstance(expr, NameExpr):
name = expr.name
return UnboundType(name, line=expr.line, column=expr.column)
Expand All @@ -29,22 +44,65 @@ def expr_to_unanalyzed_type(expr: Expression) -> Type:
else:
raise TypeTranslationError()
elif isinstance(expr, IndexExpr):
base = expr_to_unanalyzed_type(expr.base)
base = expr_to_unanalyzed_type(expr.base, expr)
if isinstance(base, UnboundType):
if base.args:
raise TypeTranslationError()
if isinstance(expr.index, TupleExpr):
args = expr.index.items
else:
args = [expr.index]
base.args = [expr_to_unanalyzed_type(arg) for arg in args]
base.args = [expr_to_unanalyzed_type(arg, expr) for arg in args]
if not base.args:
base.empty_tuple_index = True
return base
else:
raise TypeTranslationError()
elif isinstance(expr, CallExpr) and isinstance(_parent, ListExpr):
c = expr.callee
names = []
# Go through the dotted member expr chain to get the full arg
# constructor name to look up
while True:
if isinstance(c, NameExpr):
names.append(c.name)
break
elif isinstance(c, MemberExpr):
names.append(c.name)
c = c.expr
else:
raise TypeTranslationError()
arg_const = '.'.join(reversed(names))

# Go through the constructor args to get its name and type.
name = None
default_type = AnyType(implicit=True)
typ = default_type # type: Type
for i, arg in enumerate(expr.args):
if expr.arg_names[i] is not None:
if expr.arg_names[i] == "name":
if name is not None:
# Two names
raise TypeTranslationError()
name = _extract_argument_name(arg)
continue
elif expr.arg_names[i] == "type":
if typ is not default_type:
# Two types
raise TypeTranslationError()
typ = expr_to_unanalyzed_type(arg, expr)
continue
else:
raise TypeTranslationError()
elif i == 0:
typ = expr_to_unanalyzed_type(arg, expr)
elif i == 1:
name = _extract_argument_name(arg)
else:
raise TypeTranslationError()
return CallableArgument(typ, name, arg_const, expr.line, expr.column)
elif isinstance(expr, ListExpr):
return TypeList([expr_to_unanalyzed_type(t) for t in expr.items],
return TypeList([expr_to_unanalyzed_type(t, expr) for t in expr.items],
line=expr.line, column=expr.column)
elif isinstance(expr, (StrExpr, BytesExpr, UnicodeExpr)):
# Parse string literal type.
Expand Down
102 changes: 83 additions & 19 deletions mypy/fastparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@
StarExpr, YieldFromExpr, NonlocalDecl, DictionaryComprehension,
SetComprehension, ComplexExpr, EllipsisExpr, YieldExpr, Argument,
AwaitExpr, TempNode, Expression, Statement,
ARG_POS, ARG_OPT, ARG_STAR, ARG_NAMED, ARG_NAMED_OPT, ARG_STAR2
ARG_POS, ARG_OPT, ARG_STAR, ARG_NAMED, ARG_NAMED_OPT, ARG_STAR2,
check_arg_names,
)
from mypy.types import (
Type, CallableType, AnyType, UnboundType, TupleType, TypeList, EllipsisType,
CallableArgument,
)
from mypy import defaults
from mypy import experiments
Expand Down Expand Up @@ -444,24 +446,12 @@ def make_argument(arg: ast3.arg, default: Optional[ast3.expr], kind: int) -> Arg
new_args.append(make_argument(args.kwarg, None, ARG_STAR2))
names.append(args.kwarg)

seen_names = set() # type: Set[str]
for name in names:
if name.arg in seen_names:
self.fail("duplicate argument '{}' in function definition".format(name.arg),
name.lineno, name.col_offset)
break
seen_names.add(name.arg)
def fail_arg(msg: str, arg: ast3.arg) -> None:
self.fail(msg, arg.lineno, arg.col_offset)

return new_args
check_arg_names([name.arg for name in names], names, fail_arg)

def stringify_name(self, n: ast3.AST) -> str:
if isinstance(n, ast3.Name):
return n.id
elif isinstance(n, ast3.Attribute):
sv = self.stringify_name(n.value)
if sv is not None:
return "{}.{}".format(sv, n.attr)
return None # Can't do it.
return new_args

# ClassDef(identifier name,
# expr* bases,
Expand All @@ -474,7 +464,7 @@ def visit_ClassDef(self, n: ast3.ClassDef) -> ClassDef:
metaclass_arg = find(lambda x: x.arg == 'metaclass', n.keywords)
metaclass = None
if metaclass_arg:
metaclass = self.stringify_name(metaclass_arg.value)
metaclass = stringify_name(metaclass_arg.value)
if metaclass is None:
metaclass = '<error>' # To be reported later

Expand Down Expand Up @@ -965,6 +955,21 @@ class TypeConverter(ast3.NodeTransformer): # type: ignore # typeshed PR #931
def __init__(self, errors: Errors, line: int = -1) -> None:
self.errors = errors
self.line = line
self.node_stack = [] # type: List[ast3.AST]

def visit(self, node: ast3.AST) -> Type:
"""Modified visit -- keep track of the stack of nodes"""
self.node_stack.append(node)
try:
return super().visit(node)
finally:
self.node_stack.pop()

def parent(self) -> ast3.AST:
"""Return the AST node above the one we are processing"""
if len(self.node_stack) < 2:
return None
return self.node_stack[-2]

def fail(self, msg: str, line: int, column: int) -> None:
self.errors.report(line, column, msg)
Expand All @@ -985,6 +990,55 @@ def visit_NoneType(self, n: Any) -> Type:
def translate_expr_list(self, l: Sequence[ast3.AST]) -> List[Type]:
return [self.visit(e) for e in l]

def visit_Call(self, e: ast3.Call) -> Type:
# Parse the arg constructor
if not isinstance(self.parent(), ast3.List):
return self.generic_visit(e)
f = e.func
constructor = stringify_name(f)
if not constructor:
self.fail("Expected arg constructor name", e.lineno, e.col_offset)
name = None # type: Optional[str]
default_type = AnyType(implicit=True)
typ = default_type # type: Type
for i, arg in enumerate(e.args):
if i == 0:
typ = self.visit(arg)
elif i == 1:
name = self._extract_argument_name(arg)
else:
self.fail("Too many arguments for argument constructor",
f.lineno, f.col_offset)
for k in e.keywords:
value = k.value
if k.arg == "name":
if name is not None:
self.fail('"{}" gets multiple values for keyword argument "name"'.format(
constructor), f.lineno, f.col_offset)
name = self._extract_argument_name(value)
elif k.arg == "type":
if typ is not default_type:
self.fail('"{}" gets multiple values for keyword argument "type"'.format(
constructor), f.lineno, f.col_offset)
typ = self.visit(value)
else:
self.fail(
'Unexpected argument "{}" for argument constructor'.format(k.arg),
value.lineno, value.col_offset)
return CallableArgument(typ, name, constructor, e.lineno, e.col_offset)

def translate_argument_list(self, l: Sequence[ast3.AST]) -> TypeList:
return TypeList([self.visit(e) for e in l], line=self.line)

def _extract_argument_name(self, n: ast3.expr) -> str:
if isinstance(n, ast3.Str):
return n.s.strip()
elif isinstance(n, ast3.NameConstant) and str(n.value) == 'None':
return None
self.fail('Expected string literal for argument name, got {}'.format(
type(n).__name__), self.line, 0)
return None

def visit_Name(self, n: ast3.Name) -> Type:
return UnboundType(n.id, line=self.line)

Expand Down Expand Up @@ -1036,4 +1090,14 @@ def visit_Ellipsis(self, n: ast3.Ellipsis) -> Type:

# List(expr* elts, expr_context ctx)
def visit_List(self, n: ast3.List) -> Type:
return TypeList(self.translate_expr_list(n.elts), line=self.line)
return self.translate_argument_list(n.elts)


def stringify_name(n: ast3.AST) -> Optional[str]:
if isinstance(n, ast3.Name):
return n.id
elif isinstance(n, ast3.Attribute):
sv = stringify_name(n.value)
if sv is not None:
return "{}.{}".format(sv, n.attr)
return None # Can't do it.
12 changes: 5 additions & 7 deletions mypy/fastparse2.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
UnaryExpr, LambdaExpr, ComparisonExpr, DictionaryComprehension,
SetComprehension, ComplexExpr, EllipsisExpr, YieldExpr, Argument,
Expression, Statement, BackquoteExpr, PrintStmt, ExecStmt,
ARG_POS, ARG_OPT, ARG_STAR, ARG_NAMED, ARG_STAR2, OverloadPart,
ARG_POS, ARG_OPT, ARG_STAR, ARG_NAMED, ARG_STAR2, OverloadPart, check_arg_names,
)
from mypy.types import (
Type, CallableType, AnyType, UnboundType, EllipsisType
Expand Down Expand Up @@ -439,12 +439,10 @@ def get_type(i: int) -> Optional[Type]:
new_args.append(Argument(Var(n.kwarg), typ, None, ARG_STAR2))
names.append(n.kwarg)

seen_names = set() # type: Set[str]
for name in names:
if name in seen_names:
self.fail("duplicate argument '{}' in function definition".format(name), line, 0)
break
seen_names.add(name)
# We don't have any context object to give, but we have closed around the line num
def fail_arg(msg: str, arg: None) -> None:
self.fail(msg, line, 0)
check_arg_names(names, [None] * len(names), fail_arg)

return new_args, decompose_stmts

Expand Down
3 changes: 3 additions & 0 deletions mypy/indirection.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ def visit_unbound_type(self, t: types.UnboundType) -> Set[str]:
def visit_type_list(self, t: types.TypeList) -> Set[str]:
return self._visit(*t.items)

def visit_callable_argument(self, t: types.CallableArgument) -> Set[str]:
return self._visit(t.typ)

def visit_any(self, t: types.AnyType) -> Set[str]:
return set()

Expand Down
10 changes: 5 additions & 5 deletions mypy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@
ARG_OPT: "DefaultArg",
ARG_NAMED: "NamedArg",
ARG_NAMED_OPT: "DefaultNamedArg",
ARG_STAR: "StarArg",
ARG_STAR: "VarArg",
ARG_STAR2: "KwArg",
}

Expand Down Expand Up @@ -214,15 +214,15 @@ def format(self, typ: Type, verbosity: int = 0) -> str:
verbosity = max(verbosity - 1, 0))))
else:
constructor = ARG_CONSTRUCTOR_NAMES[arg_kind]
if arg_kind in (ARG_STAR, ARG_STAR2):
if arg_kind in (ARG_STAR, ARG_STAR2) or arg_name is None:
arg_strings.append("{}({})".format(
constructor,
strip_quotes(self.format(arg_type))))
else:
arg_strings.append("{}('{}', {})".format(
arg_strings.append("{}({}, {})".format(
constructor,
arg_name,
strip_quotes(self.format(arg_type))))
strip_quotes(self.format(arg_type)),
repr(arg_name)))

return 'Callable[[{}], {}]'.format(", ".join(arg_strings), return_type)
else:
Expand Down
Loading

0 comments on commit ddf03d1

Please sign in to comment.