PEP:9999
Title:Variadic Generics
Author:Mark Mendoza <mendoza.mark.a at gmail.com>, Matthew Rahtz <mrahtz at google.com>, Vincent Siles <vsiles at fb.com>
Sponsor:TODO
Status:Draft
Type:Standards Track
Content-Type:text/x-rst
Created:16-Sep-2020
Python-Version:3.10
Post-History:07-Oct-2020

Contents

Abstract

ListVariadic is a variadic TypeVar. That is, just as a TypeVar enables parameterization with a single type, ListVariadic enables parameterization through an arbitrary number of types. Its primary function is to enable parameterization of tensor-like structures (in numerical libraries such as NumPy, TensorFlow, PyTorch, etc.) with the shape of the tensor.

Motivation

There are three motivations for variadic type variables:

  1. Parameter-transforming functions
  2. Arbitrary-length inhomogeneous tuples
  3. Annotation of tensor shapes in numerical computing libraries

We discuss each of these in turn below.

Parameter-transforming Functions

PEP 612 [5] introduced ParamSpec, enabling parameter types of one callable to be forwarded to another callable. However, in many cases we actually wish to transform parameter types before using them elsewhere in the signature.

Consider, for example, the signature of map for a particular choice of function and iterables:

from typing import Callable, Iterable, List

def func(int, str) -> float: ...
iter1: List[int]
iter2: List[str]

def map(func: Callable[[int, str], float],
        iter1: Iterable[int],
        iter2: Iterable[str]) -> float: ...

Note that the parameter types of the Callable become the types of the individual Iterable arguments.

A similar example is zip:

from typing import Iterable, List

iter1: List[int]
iter2: List[str]

def zip(iter1: Iterable[int],
        iter2: Iterable[str]) -> Iterable[int, str]: ...

Neither of these signatures can be specified in the general case using existing typing mechanisms.

Arbitrary-length Inhomogeneous Tuples

PEP 484 [6] allows us to type inhomogeneous tuples of fixed length...

from typing import Tuple

t: Tuple[int, str] = [0, 'a']

...and homogenous tuples of arbitrary length...

from typing import Tuple

def f(t: Tuple[int, ...]): ..

f((1, 2))     # Valid
f((1, 2, 3))  # Also valid

...but not inhomogeneous tuples of arbitrary length:

def duple(x):
  return x, x

duple((1, 'A')      # Should be valid
duple((1, 'A', 2))  # Should also be valid

Again, the signature of duple here cannot be specified using existing typing mechanisms.

Tensor Shapes

In the context of numerical computation with libraries such as NumPy and TensorFlow, the shape of arguments is often just as important as the argument type. For example, consider the following function which converts a batch [1] of videos to grayscale:

def to_gray(videos: Tensor): ...

From the signature alone, it is not obvious what shape of tensor [2] we should pass for the video_batch argument. Possibilities include, for example,

batch ?? time ?? height ?? width ?? channels

and

time ?? batch ?? channels ?? height ?? width. [3]

Ideally, we should have some way of making the required shape clear in the signature itself. Multiple proposals [7] [8] [9] have suggested the use of standard generics syntax for this purpose. We would write:

def to_gray(videos: Tensor[Time, Batch, Height, Width, Channels]): ...

In order to support this usage, Tensor must be variadic in their shape, because tensors can be of arbitrary rank.

Specification

In order to support the above use-cases, we introduce a variadic version of TypeVar called ListVariadic, along with two new type-transforming functions, Map and Concatenate. These are described in detail below.

ListVariadic

In the same way that a TypeVar is a stand-in for a single type, a ListVariadic is a stand-in with an arbitrary number of types in an ordered list.

ListVariadic is created in a similar way to TypeVar:

from typing import ListVariadic

Ts = ListVariadic('Ts')

ListVariadic can be used in the same contexts as TypeVar. For example, in function signatures:

from typing import ListVariadic, Tuple

Ts = ListVariadic('Ts')

def identity(x: Tuple[Ts]) -> Tuple[Ts]: ...
def args_to_tuple(*args: Ts) -> Tuple[Ts]: ...
def duple(x: Tuple[Ts]) -> Tuple[Tuple[Ts], Tuple[Ts]]: ...

identity((1, 'a'))     # Tuple[str, int]
args_to_tuple(1, 'a')  # Tuple[int, str]
duple((1, 'a'))        # Tuple[Tuple[int, str], Tuple[int, str]]

In class/method signatures:

from typing import Generic, ListVariadic

Shape = ListVariadic('Shape')
class Height: pass
class Width: pass

class Tensor(Generic[Shape]):

    def __abs__(self) -> Tensor[Shape]: ...

x: Tensor[Height, Width] = Tensor()  # Tensor[Height, Width]
y = |x|                              # Tensor[Height, Width]

Note that when a ListVariadic is used as a return type, it must be used to parameterize some other generic type such as Tuple. It cannot appear on its own:

from typing import Generic, ListVariadic, Tuple

Ts = ListVariadic('Ts')

def foo(x: Tuple[Ts]) -> Ts: ...         # Invalid
def bar(x: Tuple[Ts]) -> Tuple[Ts]: ...  # OK!

class MyVariadic(Generic[Ts]): ...
def baz(x: Tuple[Ts]) -> MyVariadic[Ts]: ...   # Also OK!

Also note that a ListVariadic can only refer to a flat list of types. However, see Concatenate below for a mechanism to join multiple ListVariadic together.

ListVariadic with a bound

ListVariadic supports a bound argument that has a similar purpose to the bound argument to TypeVar. bound takes a single type, and constrains all types in the ListVariadic to be a subtype of the type specified:

from typing import ListVariadic

class Employee: ...
class Manager(Employee): ...
Ts = ListVariadic('Ts', bound=Employee)
def foo(x: Tuple[Ts]): ...

foo((Employee(), Employee())) # OK!
foo((Employee(), Manager()))  # Also OK!

class Pigeon: ...
foo((Employee(), Pigeon()))   # Invalid

Variance

Consider a type Animal and a subclass Cat. A generic T is covariant in its type parameter if T[Cat] is considered a subclass of T[Animal]. Conversely, T is contravariant in its type if T[Animal] is a subclass of T[Cat]. If there is no subclass relationship between T[Animal] and T[Cat] at all, then T is invariant in its type.

To keep the scope of this PEP limited, variadic generics as defined in this PEP are always invariant. That is, given Ts = ListVariadic('Ts') and a generic Foo[Ts], Foo[Animal, Cat] has no subclass relationship to Foo[Animal, Animal]. We leave specification of other forms of variance in variadic generics for a future PEP.

Map

To enable typing of functions such as map and zip, we introduce Map, which is analogous to map, but for types:

from typing import List, ListVariadic, Map, Tuple

ArgTs = ListVariadic('ArgTs')

def args_to_tuples(*args: ArgTs) -> Map[Tuple, ArgTs]: ...

args_to_tuples(1)       # Tuple[int]
args_to_tuples(1, 'a')  # Tuple[Tuple[int], Tuple[str]]

Map can only be used in the context of function/method signatures. Its first argument should be a generic type, and its second argument should be a ListVariadic.

Map behaves differently depending on where it is being used.

In the context of argument types, Map can only be used as the type of *args or **kwargs, and specifies that the type of the Nth argument [4] is an instance of the generic type parameterized by the Nth type in the ListVariadic. For example, consider:

def foo(*args: Map[Tuple, ArgTs]):

Here, *args effectively expands to

arg1: Tuple[T1], arg2: Tuple[T2], ...

where T1, T2, etc. are the individual types in ArgTs.

In the context of return types, Map behaves differently depending on the length of the ListVariadic.

For example, consider:

def foo(*args: ArgTs) -> Map[Tuple, ArgTs]: ...

Here, the return type expands to

Tuple[Tuple[T1], Tuple[T2], ...]

Map allows us to specify the signature of map as:

from typing import Callable, ListVariadic, Map, TypeVar

ArgTs = ListVariadic('ArgTs')
ReturnT = TypeVar('ReturnT')

def map(func: Callable[[ArgTs], ReturnT],
        *iterables: Map[Iterable, ArgTs]) -> Iterable[ReturnT]: ...

def func(int, str) -> float: ...
# iter1 must be type Iterable[int], and
# iter2 must be type Iterable[str]
map(func, iter1, iter2)

Similarly, we can specify the signature of zip as:

from typing import Iterable, List, ListVariadic, Map

ArgTs = ListVariadic('ArgTs')

def zip(*iterables: Map[Iterable, ArgTs]) -> Iterable[ArgTs]): ...

l1: List[int]
l2: List[str]
zip(l1, l2)  # Iterable[int, str]

Accessing Individual Types

Map allows us to operate on types in a bulk fashion. For situations where we require access to each individual type, overloads can be used with individual TypeVar instances in place of the ListVariadic:

from typing import overload, Generic, ListVariadic, TypeVar

Shape = ListVariadic('Shape')
Axis1 = TypeVar('Axis1')
Axis2 = TypeVar('Axis2')

class Tensor(Generic[Shape): ...

@overload
class Tensor(Generic[Axis1, Axis2]):

  def transpose(self) -> Tensor[Axis2, Axis1]: ...

@overload
class Tensor(Generic[Axis1, Axis2, Axis3]):

  def transpose(self) -> Tensor[Axis3, Axis2, Axis1]: ...

Combining Variadics and Non-variadics

ListVariadic can be used with both regular TypeVar instances and other instances of ListVariadic:

from tying import Generic, ListVariadic, TypeVar

T = TypeVar('T')
T1s = ListVariadic('T1s')
T2s = ListVariadic('T2s')

class Bar(Generic[T, Ts1]):
  ...

class Baz(Generic[Ts1, Ts2]):
  ...

In these cases, when instantiating the generic with specific types, the types belonging to each ListVariadic must be made explicit with an extra pair of square brackets:

from tying import Generic, ListVariadic, TypeVar

T = TypeVar('T')
T1s = ListVariadic('T1s')
T2s = ListVariadic('T2s')

class Bar(Generic[T, Ts1]):
  ...

class Baz(Generic[Ts1, Ts2]):
  ...

baz: Baz[int, [float, str]]
bar: Bar[[int, float], [str]]

Concatenate

In some cases, we may want to form a new list of types by combining existing types and instances of TypeVar and ListVariadic. This is enabled by Concatenate:

from typing import Generic, ListVariadic

Shape = ListVariadic('Shape')
class Batch: pass
class Height: pass
class Width: pass

class Tensor(Generic[Shape]): ...

def add_batch(x: Tensor[Shape]) -> Tensor[Concatenate[Batch, Shape]]: ...

x: Tensor[Height, Width]
add_batch(x)  # Tensor[Batch, Height, Width]

Concatenate takes an arbitrary number of arguments, where each argument is a type, a TypeVar, or a ListVariadic, and is equivalent to all types passed as arguments concatenated in a flat list.

Concatenate with parameter types

Consider the following example:

Ts1 = ListVariadic('Ts1')
Ts2 = ListVariadic('Ts2')
def foo(x: Tuple[Concatenate[Ts1, Ts2]): ...
t: Tuple[int, int, str] = (1, 1, 'a')
foo(t)

Note that there is no way for us to determine which types to assign to which ListVariadic. Is Ts1 int, int and Ts2 str, or is Ts1 int and Ts2 int, str? For this reason, when used in the parameters list, Concatenate can take at most one ListVariadic argument.

Concatenate in class declarations

When Concatenate is used in class declarations, the same issue applies:

class Foo(Generic[Concatenate[Ts1, Ts2]]): ...
foo: Foo[int, int, str] = Foo()

Therefore Concatenate in class declarations behaves similarly to Concatenate in parameter lists.

Concatenate with return type

However, the same concern does not apply when Concatenate is used in the return type. Because type variables (both TypeVar and ListVariadic) must appears more than once in a function signature for the type variables to be of any utility, the identity of all ListVariadic instances must already have been uniquely specified by return type part of the signature. For example:

Ts1 = ListVariadic('Ts1')
Ts2 = ListVariadic('Ts2')
def foo(x: Tuple[Ts1], y: Tuple[Ts2]) -> Tuple[Concatenate[Ts1, Ts2]]: ...

Therefore, there are no restrictions on the arguments to Concatenate when used in the return type.

Rationale

TODO

Backwards Compatibility

TODO

Reference Implementation

TODO

Rejected Ideas

TODO

Open Issues

Naming

This document currently instantiates a variadic parameter using ListVariadic() because that's how it works in Pyre, the first type checker to implement support for variadic generics.

Another option would be, say, TypeVar('Ts' variadic=True), which has the advantage of being more intuitive - but maybe there are factors we're not aware of. What are your opinions?

Type-transforming function bracket style

In the current draft of this document, we've used square brackets to contain the arguments to type-transforming functions such as Map. This is consistent with e.g. Union, but since they are functions, we could also consider round brackets. Thoughts?

Access to individual parts of the list of types

Do we need to have a way of accessing individual types of a ListVariadic? For example, do we need support for something like the following?

def foo(t: Tuple[Ts]):
  x: Ts[0] = t[0]

Or are overloads enough? (See the Accessing Individual Types section.)

Nesting of Concatenate and Map

Should it be possible to nest Concatenate and Map? That is, should it be possible to write Concatenate[A, Concatenate[B, Ts]]? By default, for simplicity, we lean towards "no", but are there cases where this would be useful?

Do we need Concatenate?

We currently lean towards making concatenation of types explicit using Concatenate:

def add_batch(x: Tensor[Shape]) -> Tensor[Concatenate[Batch, Shape]]: ...
def foo(x: Tuple[Concatenate[T, Ts]): ...
def foo(x: Tuple[Ts1], y: Tuple[Ts2]) -> Tuple[Concatenate[Ts1, Ts2]]: ...

However, a simpler alternative would be to omit Concatenate altogether in this signature, making the concatenation and conversion to a flat list of types implicit:

def add_batch(x: Tensor[Shape]) -> Tensor[Batch, Shape]: ...
def foo(x: Tuple[T, Ts): ...
def foo(x: Tuple[Ts1], y: Tuple[Ts2]) -> Tuple[Ts1, Ts2]: ...

Integer parameterization

The examples of this PEP have parameterised tensor types using the semantic meaning of each axes, e.g. Tensor[Batch, Time]. However, we may also wish to parameterize using the actual value of each part of the shape, such as Tensor[Literal[64], Literal[64]].

There are two open questions related to such parameterisation.

1. Should we include examples of this? On one hand, it should clearly be valid to parameterize with literal types. On the other hand, it seems wise to discourage the use of integer parameterization in general: omitting the semantic meaning of each axis may increase the probability of errors later on in code.

2. Should we propose a syntactic sugar? Typing Literal is cumbersome; ideally, we could write Tensor[64, 64] as syntactic sugar for Tensor[Literal[64], Literal[64]]. Should we include a proposal for this in the PEP? The counter-argument is that, again, it may encourage something that we would like to discourage the general use of.

Integer generics

Consider a function such as np.tile:

x = np.zeros((3,))      # A tensor of shape [3]
y = np.tile(x, reps=2)  # y is now shape [5]

With this PEP making it possible to parameterise tensor types using the tensors' shapes, intuitively, we would specify the signature of such a function as:

@overload
def tile(A: Tensor[N], reps: Literal[2]) -> Tensor[2*N]: ...
# ...and other overloads for different values of `reps`

N here feels sort of like a type variable. However, type variables stand in for types, whereas here we want N to stand in for a particular value. N should be some sort of 'integer type variable'.

(Note that N could not be created as simply TypeVar('N', bound=int). This would state that N could stand for an int or any subtype of int. For our signature above, we would need N to stand for any instance of type int.)

We should support variadicity, too, to support e.g. np.zeros:

def zeros(shape: Shape) -> Tensor[Shape]: ...
x = zeros((2, 3))  # Should be Tensor[2, 3]

Do we need to discuss 'integer type variables' in this PEP, or can it wait until a future PEP?

Parameterizing the length of variadic generics

Should we be able to parameterize exactly how many items a ListVariadic contains? Consider e.g. reduction operations, which behave as:

x = np.zeros((2, 3, 5))
reduce_sum(x, axis=0)    # Shape (3, 5)
reduce_sum(x, axis=1)    # Shape (2, 5)

To compactly specify the signature of these operations, we could write something like:

# Tensor of rank N goes in, tensor of rank N-1 comes out
def reduce_sum(x: Tensor[Shape[N]], axis: int) -> Tensor[Shape[N-1]]: ...

Is this important? Or are we just going to specify the signature of such functions using overloading up to some maximum rank?

@overload
def reduce_sum(x: Tensor[A, B], axis: Literal[0]) -> Tensor[B]: ...

@overload
def reduce_sum(x: Tensor[A, B], axis: Literal[1]) -> Tensor[A]: ...

...

Footnotes

[1]'Batch' is machine learning parlance for 'a number of'.
[2]We use the term 'tensor' to refer to a matrix with an arbitrary number of dimensions. For example, a vector is a 1D tensor. In NumPy, the corresponding class is the ndarray; in TensorFlow, the Tensor; and so on.
[3]If the shape begins with 'batch ?? time', then videos_batch[0][1] would select the second frame of the first video. If the shape begins with 'time ?? batch', then videos_batch[1][0] would select the same frame.
[4]In the case of **kwargs, we mean the Nth argument as it appears in the function definition, not the Nth keyword argument specified in the function call.

References

[5]PEP 612, "Parameter Specification Variables": https://www.python.org/dev/peps/pep-0612
[6]PEP 484, "Type Hints": https://www.python.org/dev/peps/pep-0484
[7]Static typing of Python numeric stack: https://paper.dropbox.com/doc/Static-typing-of-Python-numeric-stack-summary-6ZQzTkgN6e0oXko8fEWwN
[8]Ideas for array shape typing in Python: https://docs.google.com/document/d/1vpMse4c6DrWH5rq2tQSx3qwP_m_0lyn-Ij4WHqQqRHY/edit
[9]Shape annotation syntax proposal: https://docs.google.com/document/d/1But-hjet8-djv519HEKvBN6Ik2lW3yu0ojZo6pG9osY/edit
[10]https://paper.dropbox.com/doc/Type-system-improvements-HHOkniMG9WcCgS0LzXZAe
[11]https://paper.dropbox.com/doc/Static-typing-of-Python-numeric-stack-summary-6ZQzTkgN6e0oXko8fEWwN
[12]https://github.com/facebook/pyre-check/blob/ae85c0c6e99e3bbfc92ec55104bfdc5b9b3097b2/docs/Variadic_Type_Variables_for_Decorators_and_Tensors.pdf

Acknowledgements

Thank you to Alfonso Casta??o for feedback on early versions of this PEP.

Resources

Variadic generics were discussed among a number of possible improvements to the type system at PyCon 2019. Ivan Levkivskyi collected notes on these discussions in Type system improvements [10] and Static typing of Python numeric stack [11] documents.

Expanding on these ideas, Mark Mendoza and Vincent Siles gave a presentation on Variadic Type Variables for Decorators and Tensors [12] at the 2019 Python Typing Summit.