Skip to content

Commit 2e751b1

Browse files
authored
Merge pull request #48 from wayfair-incubator/efficient_writes
efficient writes and update tests
2 parents 275530d + 869e2e7 commit 2e751b1

File tree

7 files changed

+19
-23
lines changed

7 files changed

+19
-23
lines changed

dagger/modeler/definition.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -497,6 +497,7 @@ async def create_instance(
497497
*,
498498
repartition: bool = True,
499499
seed: random.Random = None,
500+
submit_task: bool = False,
500501
**kwargs,
501502
) -> ITemplateDAGInstance[KT, VT]:
502503
"""Method for creating an instance of a workflow definition
@@ -506,6 +507,7 @@ async def create_instance(
506507
:param repartition: Flag indicating if the creation of this instance needs to be stored on the current node or
507508
by the owner of the partition defined by the partition_key_lookup
508509
:param seed: the seed to use to create all internal instances of the workflow
510+
:param submit_task: if True also submit the task for execution
509511
:param **kwargs: Other keyword arguments
510512
:return: An instance of the workflow
511513
"""
@@ -541,6 +543,11 @@ async def create_instance(
541543
if repartition:
542544
await self.app.tasks_topic.send(key=template_instance.runtime_parameters[partition_key_lookup], value=template_instance) # type: ignore
543545
else:
546+
if submit_task:
547+
template_instance.status = TaskStatus(
548+
code=TaskStatusEnum.SUBMITTED.name,
549+
value=TaskStatusEnum.SUBMITTED.value,
550+
)
544551
await self.app._store_and_create_task(template_instance) # type: ignore
545552
return template_instance
546553

dagger/service/services.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -775,9 +775,9 @@ def main(self, override_logging=False) -> None:
775775

776776
async def _store_and_create_task(self, task):
777777
if isinstance(task, ITemplateDAGInstance):
778-
await self._store_root_template_instance(task)
779778
if task.status.code == TaskStatusEnum.SUBMITTED.name:
780779
await task.start(workflow_instance=task)
780+
await self._store_root_template_instance(task)
781781

782782
async def _process_tasks_create_event(self, stream):
783783
"""Upon creation of tasks, store them in the datastore.

dagger/store/stores.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@ async def execute_system_timer_task(self) -> None: # pragma: no cover
168168
finished = await task.start(workflow_instance)
169169
else:
170170
await task.start(workflow_instance)
171+
await self.app._update_instance(task=workflow_instance) # type: ignore
171172
if finished:
172173
await self.remove_trigger(trigger)
173174
if not task or task.status.code in TERMINAL_STATUSES:

dagger/tasks/task.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,6 @@ async def on_complete( # noqa: C901
253253
else:
254254
time_completed = int(time.time())
255255
self.time_completed = time_completed
256-
await dagger.service.services.Dagger.app._update_instance(task=workflow_instance) # type: ignore
257256
if not iterate:
258257
logger.debug("Skipping on_complete as iterate is false")
259258
return
@@ -341,7 +340,6 @@ async def start(
341340
runtime_parameters=workflow_instance.runtime_parameters,
342341
workflow_instance=workflow_instance,
343342
)
344-
await dagger.service.services.Dagger.app._update_instance(task=workflow_instance) # type: ignore
345343
if self.status.code == TaskStatusEnum.FAILURE.name:
346344
await self.on_complete(
347345
status=self.status, workflow_instance=workflow_instance
@@ -421,7 +419,6 @@ async def start(self, workflow_instance: Optional[ITemplateDAGInstance]) -> bool
421419
code=TaskStatusEnum.EXECUTING.name, value=TaskStatusEnum.EXECUTING.value
422420
)
423421
self.time_submitted = int(time.time())
424-
await dagger.service.services.Dagger.app._update_instance(task=workflow_instance) # type: ignore
425422
if self.time_to_execute and int(time.time()) < self.time_to_execute:
426423
return False
427424
if (
@@ -582,7 +579,6 @@ async def start(self, workflow_instance: Optional[ITemplateDAGInstance]) -> None
582579
logger.warning(
583580
f"The task instance to skip with id {next_task_id} was not found. Skipped but did not set status to {TaskStatusEnum.SKIPPED.value}"
584581
)
585-
await dagger.service.services.Dagger.app._update_instance(task=workflow_instance) # type: ignore
586582
await self.on_complete(workflow_instance=workflow_instance)
587583

588584
async def execute(
@@ -715,7 +711,6 @@ async def start(self, workflow_instance: Optional[ITemplateDAGInstance]) -> None
715711
code=TaskStatusEnum.EXECUTING.name, value=TaskStatusEnum.EXECUTING.value
716712
)
717713
self.time_submitted = int(time.time())
718-
await dagger.service.services.Dagger.app._update_instance(task=workflow_instance) # type: ignore
719714

720715
async def _update_correletable_key(self, workflow_instance: ITask) -> None:
721716
"""Updates the correletable key if the local is not the same as global key.
@@ -889,10 +884,9 @@ async def process_event_helper(self, event): # noqa: C901
889884
await task_instance.on_complete(
890885
workflow_instance=workflow_instance
891886
)
892-
else:
893-
await dagger.service.services.Dagger.app._update_instance(
894-
task=workflow_instance
895-
) # type: ignore
887+
await dagger.service.services.Dagger.app._update_instance(
888+
task=workflow_instance
889+
) # type: ignore
896890
processed_task = True
897891

898892
if getattr(self.__task, "match_only_one", False):
@@ -985,7 +979,6 @@ async def start(self, workflow_instance: Optional[ITemplateDAGInstance]) -> None
985979
runtime_parameters=workflow_instance.runtime_parameters,
986980
workflow_instance=workflow_instance,
987981
)
988-
await dagger.service.services.Dagger.app._update_instance(task=workflow_instance) # type: ignore
989982
logger.debug(
990983
f"Starting task {self.task_name} with root dag id {self.root_dag}, parent task id {self.parent_id}, and task id {self.id}"
991984
)
@@ -1062,7 +1055,6 @@ async def start(self, workflow_instance: Optional[ITemplateDAGInstance]) -> None
10621055
runtime_parameters=workflow_instance.runtime_parameters,
10631056
workflow_instance=workflow_instance,
10641057
)
1065-
await dagger.service.services.Dagger.app._update_instance(task=workflow_instance) # type: ignore
10661058
logger.debug(
10671059
f"Starting task {self.task_name} with parent task id {self.parent_id}, and task id {self.id}"
10681060
)

dagger/templates/template.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ async def create_instance(
6666
*,
6767
repartition: bool = True,
6868
seed: random.Random = None,
69+
submit_task: bool = False,
6970
**kwargs,
7071
) -> ITemplateDAGInstance: # pragma: no cover
7172
"""Method for creating an instance of a workflow definition
@@ -75,6 +76,7 @@ async def create_instance(
7576
:param repartition: Flag indicating if the creation of this instance needs to be stored on the current node or
7677
by the owner of the partition defined by the partition_key_lookup
7778
:param seed: the seed to use to create all internal instances of the workflow
79+
:param submit_task: if True also submit the task for execution
7880
:param **kwargs: Other keyword arguments
7981
:return: An instance of the workflow
8082
"""

integration_tests/test_app.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -727,19 +727,21 @@ async def create_and_submit_pizza_delivery_workflow(
727727
pizza_workflow_instance = await pizza_workflow_template.create_instance(
728728
uuid.uuid1(),
729729
repartition=False, # Create this instance on the current worker
730+
submit_task=True,
730731
order_id=order_id,
731732
customer_id=customer_id,
732733
pizza_type=pizza_type,
733734
)
734-
await workflow_engine.submit(pizza_workflow_instance, repartition=False)
735735

736736

737737
@workflow_engine.faust_app.agent(simple_topic_stop)
738738
async def simple_data_stream_stop(stream):
739739
async for value in stream:
740740

741741
instance = await workflow_engine.get_instance(running_task_ids[-1])
742-
await instance.stop()
742+
await instance.stop(
743+
runtime_parameters=instance.runtime_parameters, workflow_instance=instance
744+
)
743745

744746

745747
@workflow_engine.faust_app.agent(simple_topic)
@@ -767,10 +769,10 @@ async def simple_data_stream(stream):
767769
complete_by_time=120000,
768770
repartition=False,
769771
seed=rd,
772+
submit_task=True,
770773
)
771774
templates.append(instance)
772775
running_task_ids.append(instance.id)
773-
await workflow_engine.submit(instance, repartition=False)
774776

775777

776778
@workflow_engine.faust_app.agent(orders_topic)

tests/tasks/test_task.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -149,11 +149,9 @@ async def test_parallel_composite_task_start_non_terminal(
149149
workflow_instance=workflow_instance_fixture
150150
)
151151
assert parallel_composite_task_fixture.execute.called
152-
assert dagger.service.services.Dagger.app._update_instance.called
153152
assert not parallel_composite_task_fixture.on_complete.called
154153
assert child_task1.start.called
155154
assert child_task2.start.called
156-
assert dagger.service.services.Dagger.app._update_instance.called
157155
except Exception:
158156
pytest.fail("Error should not be thrown")
159157

@@ -544,7 +542,6 @@ async def test_executortask(self, executor_fixture, workflow_instance_fixture):
544542
parent_task.notify = CoroutineMock()
545543
assert executor_fixture.get_id() == executor_fixture.id
546544
await executor_fixture.start(workflow_instance=workflow_instance_fixture)
547-
assert dagger.service.services.Dagger.app._update_instance.called
548545
assert executor_fixture.status.code == TaskStatusEnum.COMPLETED.name
549546
assert executor_fixture.time_completed != 0
550547
assert parent_task.notify.called
@@ -569,7 +566,6 @@ async def test_decisiontask(self, decision_fixture, workflow_instance_fixture):
569566
workflow_instance_fixture.runtime_parameters = {}
570567
assert decision_fixture.get_id() == decision_fixture.id
571568
await decision_fixture.start(workflow_instance=workflow_instance_fixture)
572-
assert dagger.service.services.Dagger.app._update_instance.called
573569
assert decision_fixture.on_complete.called
574570
with pytest.raises(NotImplementedError):
575571
await decision_fixture.execute(
@@ -646,7 +642,6 @@ async def test_sensortask(self, sensor_fixture, workflow_instance_fixture):
646642
ret_val = sensor_fixture.get_correlatable_key(payload)
647643
assert payload == ret_val
648644
await sensor_fixture.start(workflow_instance=workflow_instance_fixture)
649-
assert dagger.service.services.Dagger.app._update_instance.called
650645
assert not sensor_fixture.on_complete.called
651646
with pytest.raises(NotImplementedError):
652647
await sensor_fixture.execute(
@@ -714,7 +709,6 @@ async def test_current_triggertask(
714709
)
715710
assert trigger_fixture.get_id() == trigger_fixture.id
716711
await trigger_fixture.start(workflow_instance=workflow_instance_fixture)
717-
assert dagger.service.services.Dagger.app._update_instance.called
718712
assert trigger_fixture.status.code == TaskStatusEnum.COMPLETED.name
719713
assert (
720714
dagger.service.services.Dagger.app._store.process_trigger_task_complete.called
@@ -729,7 +723,6 @@ async def test_future_interval_fixture(
729723
dagger.service.services.Dagger.app._update_instance = CoroutineMock()
730724
assert interval_fixture.get_id() == interval_fixture.id
731725
await interval_fixture.start(workflow_instance=workflow_instance_fixture)
732-
assert dagger.service.services.Dagger.app._update_instance.called
733726
assert interval_fixture.status.code == TaskStatusEnum.EXECUTING.name
734727

735728
@pytest.mark.asyncio
@@ -743,7 +736,6 @@ async def test_current_interval_fixture(
743736
dagger.service.services.Dagger.app._store.insert_trigger = CoroutineMock()
744737
assert interval_fixture.get_id() == interval_fixture.id
745738
await interval_fixture.start(workflow_instance=workflow_instance_fixture)
746-
assert dagger.service.services.Dagger.app._update_instance.called
747739
assert interval_fixture.status.code == TaskStatusEnum.COMPLETED.name
748740

749741
@pytest.mark.asyncio

0 commit comments

Comments
 (0)