AirflowのOperatorの引数をテストするときにはまったこと

どうも、こんにちは。GCPの踊り子です。

Cloud ComposerあらためAirflowでDAGの中で利用されている、Operatorに必要な引数が渡されているかをテストするときにハマったことがあったので、ブログにしておきます。

from unittest.mock import call

from airflow.operators.python_operator import PythonOperator


def dummy():
    pass


class AnyInstanceOf:
    "A helper object that compares equal to every instance of the specified class."

    def __init__(self, cls):
        self.cls = cls

    def __eq__(self, other):
        return isinstance(other, self.cls)

    def __ne__(self, other):
        return not isinstance(other, self.cls)

    def __repr__(self):
        return f"<ANY {self.cls.__name__}>"


def test_example(mocker):
    m = mocker.spy(PythonOperator, "__init__")

    PythonOperator(
        task_id="run_after_loop",
        python_callable=dummy,
    )

    assert m.call_args_list == [
        call(
            AnyInstanceOf(PythonOperator),
            task_id="run_after_loop",
        )
    ]

実装としては PythonOperaotr に渡された引数をテストしたいケースです。ただしこれはこのままでは動きません。

        hook = getattr(self, "_hook_apply_defaults", None)
        if hook:
            args, kwargs = hook(**kwargs, default_args=default_args)
            default_args = kwargs.pop("default_args", {})

        if not hasattr(self, "_BaseOperator__init_kwargs"):
            self._BaseOperator__init_kwargs = {}
        self._BaseOperator__from_mapped = instantiated_from_mapped

        result = func(self, **kwargs, default_args=default_args)

        # Store the args passed to init -- we need them to support task.map serialzation!
        self._BaseOperator__init_kwargs.update(kwargs)  # type: ignore

        # Set upstream task defined by XComArgs passed to template fields of the operator.
        # BUT: only do this _ONCE_, not once for each class in the hierarchy
>       if not instantiated_from_mapped and func == self.__init__.__wrapped__:  # type: ignore[misc]
E       AttributeError: 'function' object has no attribute '__wrapped__'

普通にspyするとwrapのメソッドがないので、spyすることができません。そこで、mockにpatchを当てることで対応します。

    m = mocker.spy(PythonOperator, "__init__")
    mocker.patch.object(m, "__wrapped__", create=True)

こうする感じですね。こうすることで、Operatorについても問題なくspyすることができました。元々は__new__ をmockして、DummyOperator を返却するような実装をしていました。

def side_effect(*args, **kwargs):
    return EmptyOperator(
        task_id=kwargs["task_id"],
        start_date=DateTime(2023, 6, 13, 1, 0, 0, tzinfo=timezone("UTC")),
    )


def test_create_dag(mocker):
    python_operator_mock = mocker.patch.object(
        common_distribution_report.PythonOperator,
        "__new__",
        side_effect=side_effect,
    )

これでもうまくいっていたのですが、どうもpytestが__new__のpatchを許容していないようで、インタープリタを再起動するか、テストごとにプロセスを分けないと、mockの状態が解除されないという挙動をしていたので、ハマりポイントでした。

インターネット探してもあまり情報がなくて、みんなOperationのテストどうしてんだろ?と思いつつも、うまく動いて最高!!!