Skip to content

Commit

Permalink
Merge pull request #1324 from phenobarbital/new-drivers
Browse files Browse the repository at this point in the history
New drivers
  • Loading branch information
phenobarbital authored Oct 24, 2024
2 parents 9ae076c + 0d657ad commit ce18ef0
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 89 deletions.
102 changes: 41 additions & 61 deletions asyncdb/drivers/pg.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from collections.abc import Callable, Iterable
from typing import Any, Optional, Union
from dataclasses import is_dataclass
import contextlib
from datamodel import BaseModel
import asyncpg
from asyncpg.exceptions import (
Expand Down Expand Up @@ -918,11 +919,15 @@ async def transaction(self):

async def commit(self):
if self._transaction:
await self._transaction.commit()
try:
await self._transaction.commit()
finally:
self._transaction = None

async def rollback(self):
if self._transaction:
await self._transaction.rollback()
self._transaction = None

async def cursor(self, sentence: Union[str, any], params: Iterable[Any] = None, **kwargs): # pylint: disable=W0236
if not sentence:
Expand Down Expand Up @@ -967,6 +972,38 @@ async def __anext__(self):
raise StopAsyncIteration

## COPY Functions
@contextlib.asynccontextmanager
async def handle_copy_errors(self, operation_name: str):
try:
yield
except (
QueryCanceledError,
StatementError,
UniqueViolationError,
ForeignKeyViolationError,
NotNullViolationError
) as err:
self._logger.warning(
f"AsyncPg {operation_name}: {err}"
)
raise
except UndefinedTableError as ex:
raise StatementError(
f"Error {operation_name}, table doesn't exist: {ex}"
) from ex
except UndefinedColumnError as ex:
raise StatementError(
f"Error {operation_name}, Undefined Column: {ex}"
) from ex
except (InvalidSQLStatementNameError, PostgresSyntaxError) as ex:
raise StatementError(
f"Error {operation_name}: Invalid Statement: {ex}"
) from ex
except Exception as ex:
raise DriverError(
f"Error {operation_name}: {ex}"
) from ex

## type: [ text, csv, binary ]
async def copy_from_table(self, table="", schema="public", output=None, file_type="csv", columns=None):
"""table_copy
Expand All @@ -976,7 +1013,7 @@ async def copy_from_table(self, table="", schema="public", output=None, file_typ
"""
if not self._connection:
await self.connection()
try:
async with self.handle_copy_errors("Copy From Table"):
result = await self._connection.copy_from_table(
table_name=table,
schema_name=schema,
Expand All @@ -985,23 +1022,6 @@ async def copy_from_table(self, table="", schema="public", output=None, file_typ
output=output,
)
return result
except (
QueryCanceledError,
StatementError,
UniqueViolationError,
ForeignKeyViolationError,
NotNullViolationError
) as err:
self._logger.warning(
f"AsyncPg Copy From Table: {err}"
)
raise
except UndefinedTableError as ex:
raise StatementError(f"Error on Copy, Table {table }doesn't exists: {ex}") from ex
except (InvalidSQLStatementNameError, PostgresSyntaxError, UndefinedColumnError) as ex:
raise StatementError(f"Error on Copy, Invalid Statement Error: {ex}") from ex
except Exception as ex:
raise DriverError(f"Error on Table Copy: {ex}") from ex

async def copy_to_table(self, table="", schema="public", source=None, file_type="csv", columns=None):
"""copy_to_table
Expand All @@ -1014,7 +1034,7 @@ async def copy_to_table(self, table="", schema="public", source=None, file_type=
if self._transaction:
# a transaction exists:
await self._transaction.commit()
try:
async with self.handle_copy_errors("Copy To Table"):
result = await self._connection.copy_to_table(
table_name=table,
schema_name=schema,
Expand All @@ -1023,25 +1043,6 @@ async def copy_to_table(self, table="", schema="public", source=None, file_type=
source=source,
)
return result
except (
QueryCanceledError,
StatementError,
UniqueViolationError,
ForeignKeyViolationError,
NotNullViolationError
) as err:
self._logger.warning(
f"AsyncPg Copy To Table: {err}"
)
raise
except UndefinedTableError as ex:
raise StatementError(
f"Error on Copy to Table {table } doesn't exists: {ex}") from ex
except (InvalidSQLStatementNameError, PostgresSyntaxError, UndefinedColumnError) as ex:
raise StatementError(
f"Error on Copy, Invalid Statement Error: {ex}") from ex
except Exception as ex:
raise DriverError(f"Error on Copy to Table {ex}") from ex

async def copy_into_table(self, table="", schema="public", source=None, columns=None):
"""copy_into_table
Expand All @@ -1054,32 +1055,11 @@ async def copy_into_table(self, table="", schema="public", source=None, columns=
if self._transaction:
# a transaction exists:
await self._transaction.commit()
try:
async with self.handle_copy_errors("Copy Into Table"):
result = await self._connection.copy_records_to_table(
table_name=table, schema_name=schema, columns=columns, records=source
)
return result
except (
QueryCanceledError,
StatementError,
UniqueViolationError,
ForeignKeyViolationError,
NotNullViolationError
) as err:
self._logger.warning(
f"AsyncPg Copy Into Table: {err}"
)
raise
except UndefinedTableError as ex:
raise StatementError(f"Error on Copy to Table {table } doesn't exists: {ex}") from ex
except (InvalidSQLStatementNameError, PostgresSyntaxError, UndefinedColumnError) as ex:
raise StatementError(f"Error on Copy, Invalid Statement Error: {ex}") from ex
except InterfaceError as ex:
raise DriverError(f"Error on Copy into Table Function: {ex}") from ex
except (RuntimeError, PostgresError) as ex:
raise DriverError(f"Postgres Error on Copy into Table: {ex}") from ex
except Exception as ex:
raise DriverError(f"Error on Copy into Table: {ex}") from ex

## Model Logic:
async def column_info(self, tablename: str, schema: str = None):
Expand Down
14 changes: 11 additions & 3 deletions asyncdb/drivers/rethink.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,11 +645,19 @@ async def write(
if isinstance(data, list) and len(data) > batch_size:
# Handle batch insertion for large lists
for start in range(0, len(data), batch_size):
batch = data[start : start + batch_size]
self._logger.debug(
f"Rethink: Saving batch {start + 1} to {start + batch_size} of {len(data)} records"
)
batch = data[start:start + batch_size]
result = await self._batch_insert(table, batch, on_conflict, changes, durability)
if result["errors"] > 0:
raise DriverError(f"INSERT Error in batch: {result['first_error']}")
return {"inserted": len(data), "batches": (len(data) + batch_size - 1) // batch_size}
raise DriverError(
f"INSERT Error in batch: {result['first_error']}"
)
return {
"inserted": len(data),
"batches": (len(data) + batch_size - 1) // batch_size
}
else:
result = await self._batch_insert(table, data, on_conflict, changes, durability)
if result["errors"] > 0:
Expand Down
2 changes: 1 addition & 1 deletion asyncdb/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
__title__ = "asyncdb"
__description__ = "Library for Asynchronous data source connections \
Collection of asyncio drivers."
__version__ = "2.9.5"
__version__ = "2.9.6"
__author__ = "Jesus Lara"
__author_email__ = "jesuslarag@gmail.com"
__license__ = "BSD"
47 changes: 23 additions & 24 deletions examples/test_asyncdb.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# -*- coding: utf-8 -*-
from asyncdb import AsyncDB, AsyncPool
from asyncdb.meta import asyncORM
from asyncdb.exceptions import NoDataFound, ProviderError, StatementError

"""
Expand All @@ -26,7 +25,7 @@
loop = asyncio.get_event_loop()
asyncio.set_event_loop(loop)

asyncpg_url = "postgres://troc_pgdata:12345678@127.0.0.1:5432/navigator_dev"
asyncpg_url = "postgres://troc_pgdata:12345678@127.0.0.1:5432/navigator"

pool = AsyncPool("pg", dsn=asyncpg_url, loop=loop)
loop.run_until_complete(pool.connect())
Expand All @@ -44,7 +43,7 @@ def adb():
if pool.is_connected():
#db = asyncio.get_running_loop().run_until_complete(dbpool.acquire())
db = loop.run_until_complete(pool.acquire())
return asyncORM(db=db)
return db

def sharing_token(token):
db = adb()
Expand Down Expand Up @@ -97,26 +96,28 @@ async def connect(c):
result, error = await conn.execute("SET TIMEZONE TO 'America/New_York'")
await t.commit()
# table copy
await c.copy_from_table(
table="stores",
schema="walmart",
columns=["store_id", "store_name"],
output="stores.csv",
)
async with await c.transaction() as t:
await t.copy_from_table(
table="stores",
schema="walmart",
columns=["store_id", "store_name"],
output="stores.csv",
)
# copy from file to table
# TODO: repair error io.UnsupportedOperation: read
# await c.copy_to_table(table = 'stores', schema = 'test', columns = [ 'store_id', 'store_name'], source = '/home/jesuslara/proyectos/navigator-next/stores.csv')
# copy from asyncpg records
# try:
# await c.copy_into_table(
# table="stores",
# schema="test",
# columns=["store_id", "store_name"],
# source=stores,
# )
# except (StatementError, ProviderError) as err:
# print(str(err))
# return False
async with await t.transaction() as t:
await t.copy_to_table(table = 'stores', schema = 'test', columns = [ 'store_id', 'store_name'], source = '/home/jesuslara/proyectos/navigator-next/stores.csv')
# copy from asyncpg records
try:
await c.copy_into_table(
table="stores",
schema="test",
columns=["store_id", "store_name"],
source=stores,
)
except (StatementError, ProviderError) as err:
print(str(err))
return False


async def prepared(p):
Expand All @@ -128,11 +129,9 @@ async def prepared(p):

if __name__ == '__main__':
try:
a = sharing_token('67C1BEE8DDC0BB873930D04FAF16B338F8CB09490571F8901E534937D4EFA8EE33230C435BDA93B7C7CEBA67858C4F70321A0D92201947F13278F495F92DDC0BE5FDFCF0684704C78A3E7BA5133ACADBE2E238F25D568AEC4170EB7A0BE819CE8F758B890855E5445EB22BE52439FA377D00C9E4225BC6DAEDD2DAC084446E7F697BF1CEC129DFB84FA129B7B8881C66EEFD91A0869DAE5D71FD5055FCFF75')
print(a.columns())
# # test: first with db connected:
e = AsyncDB("pg", dsn=asyncpg_url, loop=loop)
loop.run_until_complete(connect(e))
# loop.run_until_complete(prepared(e))
finally:
pool.terminate()
loop.close()

0 comments on commit ce18ef0

Please sign in to comment.