Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add savepoint API #433

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions edgedb/asyncio_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,9 @@ def _exclusive(self):
finally:
self._locked = False

async def declare_savepoint(self, savepoint: str) -> transaction.Savepoint:
return await self._declare_savepoint(savepoint)


class AsyncIORetry(transaction.BaseRetry):

Expand Down
13 changes: 13 additions & 0 deletions edgedb/blocking_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,14 @@ async def close(self, timeout=None):
self._closing = False


class Savepoint(transaction.Savepoint):
def release(self):
self._tx._client._iter_coroutine(super().release())

def rollback(self):
self._tx._client._iter_coroutine(super().rollback())


class Iteration(transaction.BaseTransaction, abstract.Executor):

__slots__ = ("_managed", "_lock")
Expand Down Expand Up @@ -320,6 +328,11 @@ def _exclusive(self):
finally:
self._lock.release()

def declare_savepoint(self, savepoint: str) -> Savepoint:
return self._client._iter_coroutine(
self._declare_savepoint(savepoint, cls=Savepoint)
)


class Retry(transaction.BaseRetry):

Expand Down
51 changes: 51 additions & 0 deletions edgedb/transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
#


from __future__ import annotations

import enum

from . import abstract
Expand All @@ -32,12 +34,47 @@ class TransactionState(enum.Enum):
FAILED = 4


class Savepoint:
__slots__ = ('_name', '_tx', '_active')

def __init__(self, name: str, transaction: BaseTransaction):
self._name = name
self._tx = transaction
self._active = True

@property
def active(self):
return self._active

def _ensure_active(self):
if not self._active:
raise errors.InterfaceError(
f"savepoint {self._name!r} is no longer active"
)

async def release(self):
self._ensure_active()
await self._tx._privileged_execute(f"release savepoint {self._name}")
del self._tx._savepoints[self._name]
self._active = False

async def rollback(self):
self._ensure_active()
await self._tx._privileged_execute(
f"rollback to savepoint {self._name}"
)
names = list(self._tx._savepoints)
for name in names[names.index(self._name):]:
self._tx._savepoints.pop(name)._active = False


class BaseTransaction:

__slots__ = (
'_client',
'_connection',
'_options',
'_savepoints',
'_state',
'__retry',
'__iteration',
Expand All @@ -48,6 +85,7 @@ def __init__(self, retry, client, iteration):
self._client = client
self._connection = None
self._options = retry._options.transaction_options
self._savepoints = {}
self._state = TransactionState.NEW
self.__retry = retry
self.__iteration = iteration
Expand Down Expand Up @@ -128,6 +166,9 @@ async def _exit(self, extype, ex):
if not self.__started:
return False

for sp in self._savepoints.values():
sp._active = False

try:
if extype is None:
query = self._make_commit_query()
Expand Down Expand Up @@ -200,6 +241,16 @@ async def _privileged_execute(self, query: str) -> None:
state=self._get_state(),
))

async def _declare_savepoint(self, savepoint: str, cls=Savepoint):
if savepoint in self._savepoints:
raise errors.InterfaceError(
f"savepoint {savepoint!r} already exists"
)
await self._ensure_transaction()
await self._privileged_execute(f"declare savepoint {savepoint}")
self._savepoints[savepoint] = rv = cls(savepoint, self)
return rv


class BaseRetry:

Expand Down
55 changes: 55 additions & 0 deletions tests/test_async_tx.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ class TestAsyncTx(tb.AsyncQueryTestCase):
};
'''

TEARDOWN_METHOD = '''
DELETE test::TransactionTest;
'''

TEARDOWN = '''
DROP TYPE test::TransactionTest;
'''
Expand Down Expand Up @@ -104,3 +108,54 @@ async def test_async_transaction_exclusive(self):
):
await asyncio.wait_for(f1, timeout=5)
await asyncio.wait_for(f2, timeout=5)

async def test_async_transaction_savepoint_1(self):
async for tx in self.client.transaction():
async with tx:
sp1 = await tx.declare_savepoint("sp1")
sp2 = await tx.declare_savepoint("sp2")
with self.assertRaisesRegex(
edgedb.InterfaceError, "savepoint.*already exists"
):
await tx.declare_savepoint("sp1")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is contrary to server savepoint semantics, masking an earlier savepoint is allowed:

edgedb> start transaction;
OK: START TRANSACTION
edgedb[tx]> declare savepoint sp1;
OK: DECLARE SAVEPOINT
edgedb[tx]> declare savepoint sp1;
OK: DECLARE SAVEPOINT

Copy link
Contributor

@tailhook tailhook May 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This releases previous savepoint with that name, right?
Not it doesn't. It creates a nested one with that name. And then you have to release twice.

So releasing the one that was created first should release inner, to have correct semantics.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I think Paul is right.

await tx.execute('''
INSERT test::TransactionTest { name := '1' };
''')
await sp2.release()
with self.assertRaisesRegex(
edgedb.InterfaceError, "savepoint.*is no longer active"
):
await sp2.release()
await sp1.release()

result = await self.client.query('SELECT test::TransactionTest.name')

self.assertEqual(result, ["1"])

async def test_async_transaction_savepoint_2(self):
async for tx in self.client.transaction():
async with tx:
await tx.execute('''
INSERT test::TransactionTest { name := '1' };
''')
sp1 = await tx.declare_savepoint("sp1")
await tx.execute('''
INSERT test::TransactionTest { name := '2' };
''')
sp2 = await tx.declare_savepoint("sp2")
await tx.execute('''
INSERT test::TransactionTest { name := '3' };
''')
await sp1.rollback()
with self.assertRaisesRegex(
edgedb.InterfaceError, "savepoint.*is no longer active"
):
await sp1.rollback()
with self.assertRaisesRegex(
edgedb.InterfaceError, "savepoint.*is no longer active"
):
await sp2.rollback()

result = await self.client.query('SELECT test::TransactionTest.name')

self.assertEqual(result, ["1"])
55 changes: 55 additions & 0 deletions tests/test_sync_tx.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ class TestSyncTx(tb.SyncQueryTestCase):
};
'''

TEARDOWN_METHOD = '''
DELETE test::TransactionTest;
'''

TEARDOWN = '''
DROP TYPE test::TransactionTest;
'''
Expand Down Expand Up @@ -113,3 +117,54 @@ def test_sync_transaction_exclusive(self):
):
f1.result(timeout=5)
f2.result(timeout=5)

def test_sync_transaction_savepoint_1(self):
for tx in self.client.transaction():
with tx:
sp1 = tx.declare_savepoint("sp1")
sp2 = tx.declare_savepoint("sp2")
with self.assertRaisesRegex(
edgedb.InterfaceError, "savepoint.*already exists"
):
tx.declare_savepoint("sp1")
tx.execute('''
INSERT test::TransactionTest { name := '1' };
''')
sp2.release()
with self.assertRaisesRegex(
edgedb.InterfaceError, "savepoint.*is no longer active"
):
sp2.release()
sp1.release()

result = self.client.query('SELECT test::TransactionTest.name')

self.assertEqual(result, ["1"])

def test_sync_transaction_savepoint_2(self):
for tx in self.client.transaction():
with tx:
tx.execute('''
INSERT test::TransactionTest { name := '1' };
''')
sp1 = tx.declare_savepoint("sp1")
tx.execute('''
INSERT test::TransactionTest { name := '2' };
''')
sp2 = tx.declare_savepoint("sp2")
tx.execute('''
INSERT test::TransactionTest { name := '3' };
''')
sp1.rollback()
with self.assertRaisesRegex(
edgedb.InterfaceError, "savepoint.*is no longer active"
):
sp1.rollback()
with self.assertRaisesRegex(
edgedb.InterfaceError, "savepoint.*is no longer active"
):
sp2.rollback()

result = self.client.query('SELECT test::TransactionTest.name')

self.assertEqual(result, ["1"])