Skip to content

Commit

Permalink
feat: export redis and mongo variables (#29)
Browse files Browse the repository at this point in the history
* feat: export redis and mongo variables

* fix: move str_to_bool to utils
  • Loading branch information
levisingularity authored Nov 13, 2023
1 parent c4aa371 commit fb71f58
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 93 deletions.
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,13 @@ DAS_MONGODB_HOSTNAME=172.17.0.2
DAS_MONGODB_PORT=27017
DAS_MONGODB_USERNAME=mongo
DAS_MONGODB_PASSWORD=mongo
DAS_MONGODB_TLS_CA_FILE=global-bundle.pem [optional]
DAS_REDIS_HOSTNAME=127.0.0.1
DAS_REDIS_PORT=6379
DAS_USE_REDIS_CLUSTER=false [default: true]
DAS_USE_CACHED_NODES=false [default: true]
DAS_USE_CACHED_LINK_TYPES=false [default: true]
DAS_USE_CACHED_NODE_TYPES=false [default: true]
```

**2.2 or you can export necessary environment using the enviroment file**
Expand Down
169 changes: 76 additions & 93 deletions hyperon_das_atomdb/adapters/redis_mongo_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,11 @@
)
from hyperon_das_atomdb.logger import logger
from hyperon_das_atomdb.utils.expression_hasher import ExpressionHasher
from hyperon_das_atomdb.utils.parse import str_to_bool

USE_CACHED_NODES = True
USE_CACHED_LINK_TYPES = True
USE_CACHED_NODE_TYPES = True
USE_CACHED_NODES = str_to_bool(os.environ.get("DAS_USE_CACHED_NODES"))
USE_CACHED_LINK_TYPES = str_to_bool(os.environ.get("DAS_USE_CACHED_LINK_TYPES"))
USE_CACHED_NODE_TYPES = str_to_bool(os.environ.get("DAS_USE_CACHED_NODE_TYPES"))


class NodeDocuments:
Expand Down Expand Up @@ -73,23 +74,17 @@ class RedisMongoDB(IAtomDB):
def __repr__(self) -> str:
return "<Atom database RedisMongo>" # pragma no cover

def __init__(self, database_name: str = 'das') -> None:
def __init__(self, database_name: str = "das") -> None:
"""
Initialize an instance of a custom class with Redis
and MongoDB connections.
"""
self.database_name = database_name
self._setup_databases()
self.mongo_link_collection = {
'1': self.mongo_db.get_collection(
MongoCollectionNames.LINKS_ARITY_1
),
'2': self.mongo_db.get_collection(
MongoCollectionNames.LINKS_ARITY_2
),
'N': self.mongo_db.get_collection(
MongoCollectionNames.LINKS_ARITY_N
),
"1": self.mongo_db.get_collection(MongoCollectionNames.LINKS_ARITY_1),
"2": self.mongo_db.get_collection(MongoCollectionNames.LINKS_ARITY_2),
"N": self.mongo_db.get_collection(MongoCollectionNames.LINKS_ARITY_N),
}
self.mongo_nodes_collection = self.mongo_db.get_collection(
MongoCollectionNames.NODES
Expand Down Expand Up @@ -126,44 +121,56 @@ def _setup_databases(self) -> None:
self.redis = self._connection_redis()

def _connection_mongo_db(self) -> Database:
mongo_hostname = os.environ.get('DAS_MONGODB_HOSTNAME')
mongo_port = os.environ.get('DAS_MONGODB_PORT')
mongo_username = os.environ.get('DAS_MONGODB_USERNAME')
mongo_password = os.environ.get('DAS_MONGODB_PASSWORD')
mongo_hostname = os.environ.get("DAS_MONGODB_HOSTNAME")
mongo_port = os.environ.get("DAS_MONGODB_PORT")
mongo_username = os.environ.get("DAS_MONGODB_USERNAME")
mongo_password = os.environ.get("DAS_MONGODB_PASSWORD")
mongo_tls_ca_file = os.environ.get("DAS_MONGODB_TLS_CA_FILE")

logger().info(
f"Connecting to MongoDB at {mongo_hostname}:{mongo_port}"
)
logger().info(f"Connecting to MongoDB at {mongo_hostname}:{mongo_port}")

try:
self.mongo_db = MongoClient(
f'mongodb://{mongo_username}:{mongo_password}@{mongo_hostname}:{mongo_port}'
)[self.database_name]
if mongo_tls_ca_file:
self.mongo_db = MongoClient(
f"mongodb://{mongo_username}:{mongo_password}@{mongo_hostname}:{mongo_port}?tls=true&tlsCAFile={mongo_tls_ca_file}&replicaSet=rs0&readPreference=secondaryPreferred&retryWrites=false"
)[self.database_name] # aws
else:
self.mongo_db = MongoClient(
f"mongodb://{mongo_username}:{mongo_password}@{mongo_hostname}:{mongo_port}"
)[self.database_name]
return self.mongo_db
except ValueError as e:
raise ConnectionMongoDBException(
message='error creating a MongoClient', details=str(e)
message="error creating a MongoClient", details=str(e)
)

def _connection_redis(self) -> Redis:
redis_hostname = os.environ.get('DAS_REDIS_HOSTNAME')
redis_port = os.environ.get('DAS_REDIS_PORT')
redis_hostname = os.environ.get("DAS_REDIS_HOSTNAME")
redis_port = os.environ.get("DAS_REDIS_PORT")
redis_password = os.environ.get("DAS_REDIS_PASSWORD")
redis_username = os.environ.get("DAS_REDIS_USERNAME")
redis_cluster = str_to_bool(os.environ.get("DAS_USE_REDIS_CLUSTER"))

redis_connection = {
"host": redis_hostname,
"port": redis_port,
"decode_responses": False,
}

if redis_port == "7000":
if redis_password and redis_username:
redis_connection["password"] = redis_password
redis_connection["username"] = redis_username

if redis_cluster:
logger().info(
f"Connecting to Redis cluster at {redis_hostname}:{redis_port}"
)
self.redis = RedisCluster(
host=redis_hostname, port=redis_port, decode_responses=False
)
self.redis = RedisCluster(**redis_connection)
else:
logger().info(
"Connecting to standalone Redis at "
f"{redis_hostname}:{redis_port}"
)
self.redis = Redis(
host=redis_hostname, port=redis_port, decode_responses=False
"Connecting to standalone Redis at " f"{redis_hostname}:{redis_port}"
)
self.redis = Redis(**redis_connection)

return self.redis

Expand Down Expand Up @@ -194,20 +201,18 @@ def prefetch(self) -> None:
document[MongoFieldNames.NODE_NAME]
self.node_documents.add(node_id, document)
else:
self.node_documents.count = (
self.mongo_nodes_collection.count_documents({})
)
self.node_documents.count = self.mongo_nodes_collection.count_documents({})
if USE_CACHED_LINK_TYPES:
for tag in ["1", "2", "N"]:
for document in self.mongo_link_collection[tag].find():
self.link_type_cache[
document[MongoFieldNames.ID_HASH]
] = document[MongoFieldNames.TYPE_NAME]
self.link_type_cache[document[MongoFieldNames.ID_HASH]] = document[
MongoFieldNames.TYPE_NAME
]
if USE_CACHED_NODE_TYPES:
for document in self.mongo_nodes_collection.find():
self.node_type_cache[
document[MongoFieldNames.ID_HASH]
] = document[MongoFieldNames.TYPE_NAME]
self.node_type_cache[document[MongoFieldNames.ID_HASH]] = document[
MongoFieldNames.TYPE_NAME
]
for document in self.mongo_types_collection.find():
hash_id = document[MongoFieldNames.ID_HASH]
named_type = document[MongoFieldNames.TYPE_NAME]
Expand All @@ -219,9 +224,7 @@ def prefetch(self) -> None:
self.named_type_hash[named_type] = named_type_hash
self.named_type_hash_reverse[named_type_hash] = named_type
if type_document is not None:
self.named_types[named_type] = type_document[
MongoFieldNames.TYPE_NAME
]
self.named_types[named_type] = type_document[MongoFieldNames.TYPE_NAME]
self.parent_type[named_type_hash] = type_document[
MongoFieldNames.TYPE_NAME_HASH
]
Expand All @@ -233,16 +236,14 @@ def _retrieve_mongo_document(self, handle: str, arity=-1) -> dict:
if arity == 0:
return self.mongo_nodes_collection.find_one(mongo_filter)
elif arity == 2:
return self.mongo_link_collection['2'].find_one(mongo_filter)
return self.mongo_link_collection["2"].find_one(mongo_filter)
elif arity == 1:
return self.mongo_link_collection['1'].find_one(mongo_filter)
return self.mongo_link_collection["1"].find_one(mongo_filter)
else:
return self.mongo_link_collection['N'].find_one(mongo_filter)
return self.mongo_link_collection["N"].find_one(mongo_filter)
# The order of keys in search is important. Greater to smallest
# probability of proper arity
for collection in [
self.mongo_link_collection[key] for key in ['2', '1', 'N']
]:
for collection in [self.mongo_link_collection[key] for key in ["2", "1", "N"]]:
document = collection.find_one(mongo_filter)
if document:
return document
Expand All @@ -267,9 +268,7 @@ def _build_named_type_hash_template(
answer.append(v)
return answer

def _build_named_type_template(
self, template: Union[str, List[Any]]
) -> List[Any]:
def _build_named_type_template(self, template: Union[str, List[Any]]) -> List[Any]:
if isinstance(template, str):
ret = self.named_type_hash_reverse.get(template, None)
return ret
Expand All @@ -287,9 +286,7 @@ def _get_mongo_document_keys(self, document: Dict) -> List[str]:
answer = []
index = 0
while True:
key = document.get(
f'{MongoFieldNames.KEY_PREFIX.value}_{index}', None
)
key = document.get(f"{MongoFieldNames.KEY_PREFIX.value}_{index}", None)
if key is None:
return answer
else:
Expand All @@ -304,9 +301,7 @@ def _build_deep_representation(self, handle, arity=-1):
answer["type"] = document[MongoFieldNames.TYPE_NAME]
answer["targets"] = []
for target_handle in self._get_mongo_document_keys(document):
answer["targets"].append(
self._build_deep_representation(target_handle)
)
answer["targets"].append(self._build_deep_representation(target_handle))
else:
answer["type"] = document[MongoFieldNames.TYPE_NAME]
answer["name"] = document[MongoFieldNames.NODE_NAME]
Expand All @@ -315,9 +310,7 @@ def _build_deep_representation(self, handle, arity=-1):
def _create_node_handle(self, node_type: str, node_name: str) -> str:
return ExpressionHasher.terminal_hash(node_type, node_name)

def _create_link_handle(
self, link_type: str, target_handles: List[str]
) -> str:
def _create_link_handle(self, link_type: str, target_handles: List[str]) -> str:
return ExpressionHasher.expression_hash(
self._get_atom_type_hash(link_type), target_handles
)
Expand All @@ -342,26 +335,22 @@ def get_node_handle(self, node_type: str, node_name: str) -> str:
node_handle = self._create_node_handle(node_type, node_name)
document = self._retrieve_mongo_document(node_handle, 0)
if document is not None:
return document['_id']
return document["_id"]
else:
raise NodeDoesNotExistException(
message='This node does not exist',
details=f'{node_type}:{node_name}',
message="This node does not exist",
details=f"{node_type}:{node_name}",
)

def get_link_handle(
self, link_type: str, target_handles: List[str]
) -> str:
def get_link_handle(self, link_type: str, target_handles: List[str]) -> str:
link_handle = self._create_link_handle(link_type, target_handles)
document = self._retrieve_mongo_document(
link_handle, len(target_handles)
)
document = self._retrieve_mongo_document(link_handle, len(target_handles))
if document is not None:
return document['_id']
return document["_id"]
else:
raise LinkDoesNotExistException(
message='This link does not exist',
details=f'{link_type}:{target_handles}',
message="This link does not exist",
details=f"{link_type}:{target_handles}",
)

def get_link_targets(self, link_handle: str) -> List[str]:
Expand All @@ -373,7 +362,7 @@ def get_link_targets(self, link_handle: str) -> List[str]:
def is_ordered(self, link_handle: str) -> bool:
document = self._retrieve_mongo_document(link_handle)
if document is None:
raise ValueError(f'Invalid handle: {link_handle}')
raise ValueError(f"Invalid handle: {link_handle}")
return True

def get_matched_links(
Expand Down Expand Up @@ -407,20 +396,18 @@ def get_matched_links(
[link_type_hash, *target_handles]
)

patterns_matched = self._retrieve_key_value(
KeyPrefix.PATTERNS, pattern_hash
)
patterns_matched = self._retrieve_key_value(KeyPrefix.PATTERNS, pattern_hash)

if len(patterns_matched) > 0:
if extra_parameters and extra_parameters.get('toplevel_only'):
if extra_parameters and extra_parameters.get("toplevel_only"):
return self._filter_non_toplevel(patterns_matched)

return patterns_matched

def get_all_nodes(self, node_type: str, names: bool = False) -> List[str]:
node_type_hash = self._get_atom_type_hash(node_type)
if node_type_hash is None:
raise ValueError(f'Invalid node type: {node_type}')
raise ValueError(f"Invalid node type: {node_type}")
if names:
return [
document[MongoFieldNames.NODE_NAME]
Expand All @@ -446,7 +433,7 @@ def get_matched_type_template(
KeyPrefix.TEMPLATES, template_hash
)
if len(templates_matched) > 0:
if extra_parameters and extra_parameters.get('toplevel_only'):
if extra_parameters and extra_parameters.get("toplevel_only"):
return self._filter_non_toplevel(templates_matched)
return templates_matched
except Exception as exception:
Expand All @@ -460,14 +447,12 @@ def get_matched_type(
KeyPrefix.TEMPLATES, named_type_hash
)
if len(templates_matched) > 0:
if extra_parameters and extra_parameters.get('toplevel_only'):
if extra_parameters and extra_parameters.get("toplevel_only"):
return self._filter_non_toplevel(templates_matched)
return templates_matched

def get_node_name(self, node_handle: str) -> str:
answer = self._retrieve_key_value(
KeyPrefix.NAMED_ENTITIES, node_handle
)
answer = self._retrieve_key_value(KeyPrefix.NAMED_ENTITIES, node_handle)
if not answer:
raise ValueError(f"Invalid handle: {node_handle}")
return answer[0].decode()
Expand All @@ -476,7 +461,7 @@ def get_matched_node_name(self, node_type: str, substring: str) -> str:
node_type_hash = self._get_atom_type_hash(node_type)
mongo_filter = {
MongoFieldNames.TYPE: node_type_hash,
MongoFieldNames.NODE_NAME: {'$regex': substring},
MongoFieldNames.NODE_NAME: {"$regex": substring},
}
return [
document[MongoFieldNames.ID_HASH]
Expand All @@ -485,9 +470,7 @@ def get_matched_node_name(self, node_type: str, substring: str) -> str:

def get_atom_as_dict(self, handle, arity=-1) -> dict:
answer = {}
document = (
self.node_documents.get(handle, None) if arity <= 0 else None
)
document = self.node_documents.get(handle, None) if arity <= 0 else None
if document is None:
document = self._retrieve_mongo_document(handle, arity)
if document:
Expand Down Expand Up @@ -549,6 +532,6 @@ def _filter_non_toplevel(self, matches: list) -> list:
for match in matches:
link_handle = match[0]
link = self._retrieve_mongo_document(link_handle, len(match[-1]))
if link['is_toplevel']:
if link["is_toplevel"]:
matches_toplevel_only.append(match)
return matches_toplevel_only
5 changes: 5 additions & 0 deletions hyperon_das_atomdb/utils/parse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
def str_to_bool(value: str) -> bool:
if value is None:
return True

return False if value.lower() == "false" else True

0 comments on commit fb71f58

Please sign in to comment.