どうも、こんにちは。GCPの踊り子です。
Cloud ComposerあらためAirflowでDAGの中で利用されている、Operatorに必要な引数が渡されているかをテストするときにハマったことがあったので、ブログにしておきます。
from unittest.mock import call
from airflow.operators.python_operator import PythonOperator
def dummy():
pass
class AnyInstanceOf:
"A helper object that compares equal to every instance of the specified class."
def __init__(self, cls):
self.cls = cls
def __eq__(self, other):
return isinstance(other, self.cls)
def __ne__(self, other):
return not isinstance(other, self.cls)
def __repr__(self):
return f"<ANY {self.cls.__name__}>"
def test_example(mocker):
m = mocker.spy(PythonOperator, "__init__")
PythonOperator(
task_id="run_after_loop",
python_callable=dummy,
)
assert m.call_args_list == [
call(
AnyInstanceOf(PythonOperator),
task_id="run_after_loop",
)
]
実装としては PythonOperaotr
に渡された引数をテストしたいケースです。ただしこれはこのままでは動きません。
hook = getattr(self, "_hook_apply_defaults", None)
if hook:
args, kwargs = hook(**kwargs, default_args=default_args)
default_args = kwargs.pop("default_args", {})
if not hasattr(self, "_BaseOperator__init_kwargs"):
self._BaseOperator__init_kwargs = {}
self._BaseOperator__from_mapped = instantiated_from_mapped
result = func(self, **kwargs, default_args=default_args)
# Store the args passed to init -- we need them to support task.map serialzation!
self._BaseOperator__init_kwargs.update(kwargs) # type: ignore
# Set upstream task defined by XComArgs passed to template fields of the operator.
# BUT: only do this _ONCE_, not once for each class in the hierarchy
> if not instantiated_from_mapped and func == self.__init__.__wrapped__: # type: ignore[misc]
E AttributeError: 'function' object has no attribute '__wrapped__'
普通にspyするとwrapのメソッドがないので、spyすることができません。そこで、mockにpatchを当てることで対応します。
m = mocker.spy(PythonOperator, "__init__")
mocker.patch.object(m, "__wrapped__", create=True)
こうする感じですね。こうすることで、Operatorについても問題なくspyすることができました。元々は__new__
をmockして、DummyOperator
を返却するような実装をしていました。
def side_effect(*args, **kwargs):
return EmptyOperator(
task_id=kwargs["task_id"],
start_date=DateTime(2023, 6, 13, 1, 0, 0, tzinfo=timezone("UTC")),
)
def test_create_dag(mocker):
python_operator_mock = mocker.patch.object(
common_distribution_report.PythonOperator,
"__new__",
side_effect=side_effect,
)
これでもうまくいっていたのですが、どうもpytestが__new__
のpatchを許容していないようで、インタープリタを再起動するか、テストごとにプロセスを分けないと、mockの状態が解除されないという挙動をしていたので、ハマりポイントでした。
インターネット探してもあまり情報がなくて、みんなOperationのテストどうしてんだろ?と思いつつも、うまく動いて最高!!!