Skip to content

Commit 9197f3a

Browse files
committed
fix logic
1 parent 1f446b1 commit 9197f3a

File tree

4 files changed

+60
-12
lines changed

4 files changed

+60
-12
lines changed

pymongo/asynchronous/client_session.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -406,15 +406,16 @@ def __init__(self, opts: Optional[TransactionOptions], client: AsyncMongoClient[
406406
self.recovery_token = None
407407
self.attempt = 0
408408
self.client = client
409+
self.has_completed_command = False
409410

410411
def active(self) -> bool:
411412
return self.state in (_TxnState.STARTING, _TxnState.IN_PROGRESS)
412413

413414
def starting(self) -> bool:
414415
return self.state == _TxnState.STARTING
415416

416-
def set_in_progress(self):
417-
self.state = _TxnState.IN_PROGRESS
417+
def set_starting(self):
418+
self.state = _TxnState.STARTING
418419

419420
@property
420421
def pinned_conn(self) -> Optional[AsyncConnection]:
@@ -1067,6 +1068,8 @@ def _apply_to(
10671068
)
10681069

10691070
if self._transaction.state == _TxnState.STARTING:
1071+
# First command begins a new transaction.
1072+
self._transaction.state = _TxnState.IN_PROGRESS
10701073
command["startTransaction"] = True
10711074

10721075
assert self._transaction.opts

pymongo/asynchronous/mongo_client.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2777,8 +2777,11 @@ async def run(self) -> T:
27772777
try:
27782778
res = await self._read() if self._is_read else await self._write()
27792779
await self._retry_policy.record_success(self._attempt_number > 0)
2780-
if self._session._starting_transaction:
2781-
self._session._transaction.set_in_progress()
2780+
# Track whether the transaction has completed command.
2781+
# If we need to apply backpressure to the first command,
2782+
# we will need to revert back to starting state.
2783+
if self._session.in_transaction:
2784+
self._session._transaction.has_completed_command = True
27822785
return res
27832786
except ServerSelectionTimeoutError:
27842787
# The application may think the write was never attempted
@@ -2793,6 +2796,7 @@ async def run(self) -> T:
27932796
always_retryable = False
27942797
overloaded = False
27952798
exc_to_check = exc
2799+
27962800
# Execute specialized catch on read
27972801
if self._is_read:
27982802
if isinstance(exc, (ConnectionFailure, OperationFailure)):
@@ -2813,8 +2817,16 @@ async def run(self) -> T:
28132817
self._retrying = True
28142818
self._last_error = exc
28152819
self._attempt_number += 1
2816-
else:
2817-
raise
2820+
2821+
# Revert back to starting state if we're in a transaction but haven't completed the first
2822+
# command.
2823+
if (
2824+
self._session.in_transaction
2825+
and not self._session._transaction.has_completed_command
2826+
):
2827+
self._session._transaction.set_starting()
2828+
else:
2829+
raise
28182830

28192831
# Specialized catch on write operation
28202832
if not self._is_read:
@@ -2848,6 +2860,15 @@ async def run(self) -> T:
28482860
self._last_error = exc
28492861
if self._last_error is None:
28502862
self._last_error = exc
2863+
# Revert back to starting state if we're in a transaction but haven't completed the first
2864+
# command.
2865+
if (
2866+
self._session.in_transaction
2867+
and not self._session._transaction.has_completed_command
2868+
):
2869+
self._session._transaction.set_starting()
2870+
else:
2871+
raise
28512872

28522873
if self._client.topology_description.topology_type == TOPOLOGY_TYPE.Sharded:
28532874
self._deprioritized_servers.append(self._server)

pymongo/synchronous/client_session.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -404,15 +404,16 @@ def __init__(self, opts: Optional[TransactionOptions], client: MongoClient[Any])
404404
self.recovery_token = None
405405
self.attempt = 0
406406
self.client = client
407+
self.has_completed_command = False
407408

408409
def active(self) -> bool:
409410
return self.state in (_TxnState.STARTING, _TxnState.IN_PROGRESS)
410411

411412
def starting(self) -> bool:
412413
return self.state == _TxnState.STARTING
413414

414-
def set_in_progress(self):
415-
self.state = _TxnState.IN_PROGRESS
415+
def set_starting(self):
416+
self.state = _TxnState.STARTING
416417

417418
@property
418419
def pinned_conn(self) -> Optional[Connection]:
@@ -1063,6 +1064,8 @@ def _apply_to(
10631064
)
10641065

10651066
if self._transaction.state == _TxnState.STARTING:
1067+
# First command begins a new transaction.
1068+
self._transaction.state = _TxnState.IN_PROGRESS
10661069
command["startTransaction"] = True
10671070

10681071
assert self._transaction.opts

pymongo/synchronous/mongo_client.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2767,8 +2767,11 @@ def run(self) -> T:
27672767
try:
27682768
res = self._read() if self._is_read else self._write()
27692769
self._retry_policy.record_success(self._attempt_number > 0)
2770-
if self._session._starting_transaction:
2771-
self._session._transaction.set_in_progress()
2770+
# Track whether the transaction has completed command.
2771+
# If we need to apply backpressure to the first command,
2772+
# we will need to revert back to starting state.
2773+
if self._session.in_transaction:
2774+
self._session._transaction.has_completed_command = True
27722775
return res
27732776
except ServerSelectionTimeoutError:
27742777
# The application may think the write was never attempted
@@ -2783,6 +2786,7 @@ def run(self) -> T:
27832786
always_retryable = False
27842787
overloaded = False
27852788
exc_to_check = exc
2789+
27862790
# Execute specialized catch on read
27872791
if self._is_read:
27882792
if isinstance(exc, (ConnectionFailure, OperationFailure)):
@@ -2803,8 +2807,16 @@ def run(self) -> T:
28032807
self._retrying = True
28042808
self._last_error = exc
28052809
self._attempt_number += 1
2806-
else:
2807-
raise
2810+
2811+
# Revert back to starting state if we're in a transaction but haven't completed the first
2812+
# command.
2813+
if (
2814+
self._session.in_transaction
2815+
and not self._session._transaction.has_completed_command
2816+
):
2817+
self._session._transaction.set_starting()
2818+
else:
2819+
raise
28082820

28092821
# Specialized catch on write operation
28102822
if not self._is_read:
@@ -2838,6 +2850,15 @@ def run(self) -> T:
28382850
self._last_error = exc
28392851
if self._last_error is None:
28402852
self._last_error = exc
2853+
# Revert back to starting state if we're in a transaction but haven't completed the first
2854+
# command.
2855+
if (
2856+
self._session.in_transaction
2857+
and not self._session._transaction.has_completed_command
2858+
):
2859+
self._session._transaction.set_starting()
2860+
else:
2861+
raise
28412862

28422863
if self._client.topology_description.topology_type == TOPOLOGY_TYPE.Sharded:
28432864
self._deprioritized_servers.append(self._server)

0 commit comments

Comments
 (0)