フューチャー技術ブログ

Pythonの関数で入力と出力の型を束縛させたい 〜デコレータ編〜

Pythonの型定義についてちょっと苦戦したのでメモ。

Functional Python Programmingという本のサンプルで次のようなサンプルがありました。

from collections.abc import Callable
from functools import wraps

def nullable(function: Callable[[float], float]) -> Callable[[float | None], float | None]:
@wraps(function)
def null_wrapper(value: float | None) -> float | None:
return None if value is None else function(value)
return null_wrapper

これはfloatを受け取ってfloatを返す数値計算の関数があった場合に、入力値がNoneの場合は例外を投げずにNoneを返せるようになるデコレータの実装です。

でもよくよく見ると、結果の関数の型は入力も出力も「flaotかもしれないしNoneかもしれない」という情報しかなく、flaot のときに Noneが返っても合法となってしまい、元の実装を厳密に反映した型になってません。

通常の関数の場合

通常の関数の場合はこういうのを実現するのは難しくはありません。まずは型変数を使う場合。mypy playgroundを使うと簡単にテストできます。

from typing import TypeVar

T = TypeVar('T', float, None)

def double(value: T) -> T:
if value is None:
return None
else:
return value * 2

double(10) * 2
double(None) * 2 # これの返り値はNoneだからエラーになる

コメント通りエラーが出ることがわかります。overload()を使う方法もあります。こちらの方が長いのですが simpleでPythonicな気がします。

from typing import overload

@overload
def double(value: None) -> None:
...

@overload
def double(value: float) -> float:
...

def double(value: float | None):
if value is None:
return None
else:
return value * 2

double(2) * 2
double(None) * 2 # これの返り値はNoneだからエラーになる

デコレータだとうまくいかない

さて、デコレータとして素直に実装しようとするとエラーが出てしまいました。 @wraps があるとエラーになってしまいます。

from collections.abc import Callable
from typing import TypeVar
from functools import wraps

T = TypeVar("T", float, None)

def nullable(function: Callable[[float], float]) -> Callable[[T], T]:
#@wraps(function) # コメント外すとエラー
def null_wrapper(value: T) -> T:
return None if value is None else function(value)
return null_wrapper

@nullable
def double(v: float) -> float:
return v * 2

double(None) * 2 # エラーになるべき: OK
double(1.4) * 2. # 通るべき: OK

mypyはどうも Callable[[T], T] | Callable[[None], None]という定義はうまく扱えないらしい、という @moriyoshit さん情報。@functools.wrapsは、デコレータが付与された関数の名前とdocstringを新しく作った関数に適用することで利用者の便宜をはかるためのデコレータです。これがうまく処理できない。

error: Argument 1 to "__call__" of "_Wrapper" has incompatible type
"Callable[[T], T]"; expected "Callable[[Never], Never]" [arg-type]

@functools.wrapsが内部で使っているfunctools.update_wrapper()を使いながら、キャストで無理やりエラーを押さえ込むとうまくいきました。

from collections.abc import Callable
from typing import TypeVar
from functools import update_wrapper

T = TypeVar("T", float, None)

def nullable(function: Callable[[float], float]) -> Callable[[T], T]:
def null_wrapper(value: T) -> T:
return None if value is None else function(value)
update_wrapper(cast(Callable[[Never], Never], null_wrapper), function)
return null_wrapper

まあたまにこういうハックが必要になったりしますよね。

まあ後で言われて気づきましたが、どうせ実行時にならないと型は決まらないから、デコレータに関してはこのような細かい型定義をしようとしてもあまりご利益はなさそうです。絶対に None が入らない入力を渡した場合に Noneチェックをいちいちしなくても良いというのがメリットですが・・・

まとめ

Pythonの関数、特にデコレータにおいて入力と出力の型を厳密に束縛する方法について記載しました。

  • 通常の関数では、TypeVar や overload を活用することで、入力型に応じて出力型が変化するような精密な型定義が可能
  • デコレータで同様の型束縛を実現しようとすると、特に @functools.wraps を使用する際に型エラーが発生しやすく、update_wrapper とキャストを用いたハック的な回避策が必要になる場合がある