Skip to content

Commit

Permalink
Merge pull request #45 from delta-mpc/registry
Browse files Browse the repository at this point in the history
Registry
  • Loading branch information
mh739025250 authored Nov 1, 2022
2 parents 0c75d89 + 12f588f commit d7e75e8
Show file tree
Hide file tree
Showing 8 changed files with 110 additions and 68 deletions.
2 changes: 1 addition & 1 deletion delta_node/chain/identity/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ async def update_name(self, address: str, name: str) -> str:
_logger.error(e)
raise

async def updaet_url(self, address: str, url: str) -> str:
async def update_url(self, address: str, url: str) -> str:
req = pb.UpdateUrlReq(address=address, url=url)
try:
resp = await self.stub.UpdateUrl(req)
Expand Down
18 changes: 12 additions & 6 deletions delta_node/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,21 @@ async def _run():
await db.init(config.db)
chain.init(config.chain_host, config.chain_port, ssl=False)
zk.init(config.zk_host, config.zk_port, ssl=False)
await registry.register(config.node_url, config.node_name)

r = registry.Registry(url=config.node_url, name=config.node_name)
await r.register()

registry_fut = asyncio.create_task(r.start())
runner_fut = asyncio.create_task(runner.run())
app_fut = asyncio.create_task(app.run("0.0.0.0", config.api_port))

fut = asyncio.gather(runner_fut, app_fut)
fut = asyncio.gather(registry_fut, runner_fut, app_fut)
loop.add_signal_handler(signal.SIGINT, lambda: fut.cancel())
loop.add_signal_handler(signal.SIGTERM, lambda: fut.cancel())
try:
await fut
finally:
await registry.unregister()
await r.stop()
await r.unregister()
chain.close()
zk.close()
await db.close()
Expand All @@ -52,7 +55,7 @@ def run():


async def _leave():
from delta_node import chain, config, db, log, pool, registry
from delta_node import chain, config, db, log, registry

if len(config.chain_host) == 0:
raise RuntimeError("chain connector host is required")
Expand All @@ -69,7 +72,10 @@ async def _leave():

await db.init(config.db)
chain.init(config.chain_host, config.chain_port, ssl=False)
await registry.unregister()

r = registry.Registry(url=config.node_url, name=config.node_name)
await r.unregister()

chain.close()
await db.close()
listener.stop()
Expand Down
113 changes: 69 additions & 44 deletions delta_node/registry/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
from typing import Optional

import sqlalchemy as sa
from sqlalchemy.exc import NoResultFound
from async_lru import alru_cache
from delta_node import config, db
from delta_node.entity.identity import Node
from delta_node.chain import identity
from delta_node.entity.identity import Node
from sqlalchemy.exc import NoResultFound

__all__ = ["register", "get_node_address", "unregister"]
__all__ = ["get_node_address", "Registry"]


_logger = logging.getLogger(__name__)
Expand All @@ -27,52 +27,77 @@ async def get_node_address() -> str:
raise


async def register(
url: str = config.node_url,
name: str = config.node_name,
):
async with db.session_scope() as sess:
q = sa.select(Node).where(Node.id == 1)
node: Optional[Node] = (await sess.execute(q)).scalars().one_or_none()

if node:
# join first to avoid address changed when connect to monkey chain connector
_, address = await identity.get_client().join(url, name)
updated = False
if node.address != address:
node.address = address
updated = True
if node.url != url:
await identity.get_client().updaet_url(node.address, url)
node.url = url
updated = True
if node.name != name:
await identity.get_client().update_name(node.address, name)
node.name = name
updated = True
if updated:
class Registry(object):
def __init__(
self, url: str = config.node_url, name: str = config.node_name
) -> None:
self.url = url
self.name = name

self.running_task: Optional[asyncio.Task] = None

async def register(self):
_, address = await identity.get_client().join(self.url, self.name)

async with db.session_scope() as sess:
q = sa.select(Node).where(Node.id == 1)
node: Optional[Node] = (await sess.execute(q)).scalars().one_or_none()

if node is not None:
update = False
if node.address != address:
node.address = address
update = True
if node.url != self.url:
node.url = self.url
update = True
if node.name != self.name:
node.name = self.name
update = True
if update:
sess.add(node)
await sess.commit()
_logger.info(f"register new node, node address: {address}")
else:
_logger.info(f"registered node, node address: {address}")
else:
node = Node(url=self.url, name=self.name, address=address)
sess.add(node)
await sess.commit()
_logger.info(f"registered node, node address: {node.address}")
_logger.info(f"register new node, node address: {address}")

else:
_, address = await identity.get_client().join(url, name)
node = Node(url=url, name=name, address=address)
sess.add(node)
await sess.commit()
await sess.refresh(node)
_logger.info(f"register new node, node address: {node.address}")
async def unregister(self):
address = await get_node_address()
await identity.get_client().leave(address)

async with db.session_scope() as sess:
q = sa.select(Node).where(Node.id == 1)
node = (await sess.execute(q)).scalar_one()

await sess.delete(node)
await sess.commit()

async def unregister():
address = await get_node_address()
await identity.get_client().leave(address)
_logger.info(f"node {address} leave")

async with db.session_scope() as sess:
q = sa.select(Node).where(Node.id == 1)
node = (await sess.execute(q)).scalar_one()
async def start(self, interval: int = 60):
async def run():
while True:
await asyncio.sleep(interval)
await identity.get_client().join(self.url, self.name)

await sess.delete(node)
await sess.commit()
if self.running_task is None:
self.running_task = asyncio.create_task(run())
try:
await self.running_task
except asyncio.CancelledError:
pass
except Exception as e:
_logger.exception(e)
raise
else:
raise ValueError("registry is already started")

_logger.info(f"node {address} leave")
async def stop(self):
if self.running_task is not None:
self.running_task.cancel()
_logger.info("stop registry")
14 changes: 7 additions & 7 deletions delta_node/utils/precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@


def fix_precision(arr: AggValueType, precision: int) -> AggValueType:
arr = arr.astype(np.float64)
arr = arr * (10**precision)
arr = arr.astype(np.int64)
return arr
_arr = arr.astype(np.float64)
_arr = _arr * (10**precision)
_arr = _arr.astype(np.int64)
return _arr


def unfix_precision(arr: AggValueType, precision: int) -> AggValueType:
arr = arr.astype(np.float64)
arr = arr / (10**precision)
return arr
_arr = arr.astype(np.float64)
_arr = _arr / (10**precision)
return _arr
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
aiosqlite==0.17.0
async_lru==1.0.2
cryptography==3.4.7
delta-task==0.8.0
delta-task==0.8.1
fastapi==0.70.1
grpclib==0.4.2
httpx==0.23.0
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def run_tests(self):

setup(
name="delta_node",
version="0.8.0",
version="0.8.1",
packages=find_packages(),
package_data={"delta_node": ["dataset/examples/*.csv"]},
include_package_data=True,
Expand All @@ -39,7 +39,7 @@ def run_tests(self):
"aiosqlite==0.17.0",
"async_lru==1.0.2",
"cryptography==3.4.7",
"delta-task==0.8.0",
"delta-task==0.8.1",
"fastapi==0.70.1",
"grpclib==0.4.2",
"httpx==0.23.0",
Expand Down
2 changes: 1 addition & 1 deletion tests/chain/identity_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ async def test_identity(identity_client: identity.Client):
assert info.url == url
# update url
new_url = "http://127.0.0.1:6800"
await identity_client.updaet_url(address, new_url)
await identity_client.update_url(address, new_url)
info = await identity_client.get_node_info(address)
assert info.url == new_url
url = new_url
Expand Down
23 changes: 17 additions & 6 deletions tests/registry_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio

import pytest
from delta_node import db, registry, chain
Expand All @@ -10,14 +11,24 @@ async def test_register():
chain.init("127.0.0.1", 4500)
url = "http://127.0.0.1:6800"
name = "node1"
await registry.register(url, name)

r = registry.Registry(url, name)
await r.register()

fut = asyncio.create_task(r.start(interval=1))

address = await registry.get_node_address()
info = await identity.get_client().get_node_info(address)
assert info.address == address
assert info.name == name
assert info.url == url

node_info = await identity.get_client().get_node_info(address=address)
assert node_info.address == address
assert node_info.name == name
assert node_info.url == url
try:
await asyncio.wait_for(fut, timeout=2)
except asyncio.TimeoutError:
pass

await registry.unregister()
await r.stop()
await r.unregister()
chain.close()
await db.close()

0 comments on commit d7e75e8

Please sign in to comment.