Python 3.10 version
Here is a Python 3.10 version of Brent’s answer with pattern-matching and prettier union type syntax:
from dataclasses import dataclass
@dataclass
class Point:
x: float
y: float
@dataclass
class Circle:
x: float
y: float
r: float
@dataclass
class Rectangle:
x: float
y: float
w: float
h: float
Shape = Point | Circle | Rectangle
def print_shape(shape: Shape):
match shape:
case Point(x, y):
print(f"Point {x} {y}")
case Circle(x, y, r):
print(f"Circle {x} {y} {r}")
case Rectangle(x, y, w, h):
print(f"Rectangle {x} {y} {w} {h}")
print_shape(Point(1, 2))
print_shape(Circle(3, 5, 7))
print_shape(Rectangle(11, 13, 17, 19))
print_shape(4) # mypy type error
You can even do recursive types:
from __future__ import annotations
from dataclasses import dataclass
@dataclass
class Branch:
value: int
left: Tree
right: Tree
Tree = Branch | None
def contains(tree: Tree, value: int):
match tree:
case None:
return False
case Branch(x, left, right):
return x == value or contains(left, value) or contains(right, value)
tree = Branch(1, Branch(2, None, None), Branch(3, None, Branch(4, None, None)))
assert contains(tree, 1)
assert contains(tree, 2)
assert contains(tree, 3)
assert contains(tree, 4)
assert not contains(tree, 5)
Note the need for from __future__ import annotations
in order to annotate with a type that hasn’t been defined yet.
Exhaustiveness checking for ADTs can be enforced with mypy
using typing.assert_never()
in Python 3.11+ or as part of the typing-extensions
backport for older versions of Python.
def print_shape(shape: Shape):
match shape:
case Point(x, y):
print(f"Point {x} {y}")
case Circle(x, y, r):
print(f"Circle {x} {y} {r}")
case _ as unreachable:
# mypy will throw a type checking error
# because Rectangle is not covered in the match.
assert_never(unreachable)