How can I define algebraic data types in Python?

Python 3.10 version

Here is a Python 3.10 version of Brent’s answer with pattern-matching and prettier union type syntax:

from dataclasses import dataclass

@dataclass
class Point:
    x: float
    y: float

@dataclass
class Circle:
    x: float
    y: float
    r: float

@dataclass
class Rectangle:
    x: float
    y: float
    w: float
    h: float

Shape = Point | Circle | Rectangle

def print_shape(shape: Shape):
    match shape:
        case Point(x, y):
            print(f"Point {x} {y}")
        case Circle(x, y, r):
            print(f"Circle {x} {y} {r}")
        case Rectangle(x, y, w, h):
            print(f"Rectangle {x} {y} {w} {h}")

print_shape(Point(1, 2))
print_shape(Circle(3, 5, 7))
print_shape(Rectangle(11, 13, 17, 19))
print_shape(4)  # mypy type error

You can even do recursive types:

from __future__ import annotations
from dataclasses import dataclass

@dataclass
class Branch:
    value: int
    left: Tree
    right: Tree

Tree = Branch | None

def contains(tree: Tree, value: int):
    match tree:
        case None:
            return False
        case Branch(x, left, right):
            return x == value or contains(left, value) or contains(right, value)

tree = Branch(1, Branch(2, None, None), Branch(3, None, Branch(4, None, None)))

assert contains(tree, 1)
assert contains(tree, 2)
assert contains(tree, 3)
assert contains(tree, 4)
assert not contains(tree, 5)

Note the need for from __future__ import annotations in order to annotate with a type that hasn’t been defined yet.

Exhaustiveness checking for ADTs can be enforced with mypy using typing.assert_never() in Python 3.11+ or as part of the typing-extensions backport for older versions of Python.

def print_shape(shape: Shape):
    match shape:
        case Point(x, y):
            print(f"Point {x} {y}")
        case Circle(x, y, r):
            print(f"Circle {x} {y} {r}")
        case _ as unreachable:
            # mypy will throw a type checking error
            # because Rectangle is not covered in the match.
            assert_never(unreachable)

Leave a Comment