diff --git a/src/db_models.py b/src/db_models.py index 631b0f7..016bc41 100644 --- a/src/db_models.py +++ b/src/db_models.py @@ -24,6 +24,7 @@ class Entity(Model): phone_number_hash = CharField() password_hash = CharField() country_code = CharField() + device_id = CharField(null=True) client_publish_pub_key = TextField(null=True) client_device_id_pub_key = TextField(null=True) server_crypto_metadata = TextField(null=True) diff --git a/src/device_id.py b/src/device_id.py index 7b91a77..cd75493 100644 --- a/src/device_id.py +++ b/src/device_id.py @@ -11,7 +11,7 @@ def compute_device_id(secret_key, phone_number, public_key) -> str: Compute a device ID using HMAC and SHA-256. Args: - secret_key (str): The secret key used for HMAC. + secret_key (bytes): The secret key used for HMAC. phone_number (str): The phone number to be included in the HMAC input. public_key (str): The public key to be included in the HMAC input. @@ -19,5 +19,5 @@ def compute_device_id(secret_key, phone_number, public_key) -> str: str: The hexadecimal representation of the HMAC digest. """ combined_input = phone_number + public_key - hmac_object = hmac.new(secret_key.encode(), combined_input.encode(), hashlib.sha256) + hmac_object = hmac.new(secret_key, combined_input.encode(), hashlib.sha256) return hmac_object.hexdigest() diff --git a/src/grpc_entity_service.py b/src/grpc_entity_service.py index 2ace104..eb28950 100644 --- a/src/grpc_entity_service.py +++ b/src/grpc_entity_service.py @@ -26,6 +26,7 @@ decrypt_and_decode, ) from src.long_lived_token import generate_llt, verify_llt +from src.device_id import compute_device_id HASHING_KEY = load_key(get_configs("HASHING_SALT"), 32) KEYSTORE_PATH = get_configs("KEYSTORE_PATH") @@ -216,6 +217,9 @@ def complete_creation(request): "phone_number_hash": phone_number_hash, "password_hash": password_hash, "country_code": country_code_ciphertext_b64, + "device_id": compute_device_id( + shared_key, request.phone_number, request.client_device_id_pub_key + ), "client_publish_pub_key": request.client_publish_pub_key, "client_device_id_pub_key": request.client_device_id_pub_key, "server_crypto_metadata": crypto_metadata_ciphertext_b64, @@ -346,6 +350,8 @@ def initiate_authentication(request, entity_obj): return pow_response message, expires = pow_response + entity_obj.device_id = None + entity_obj.save() return response( requires_ownership_proof=True, @@ -407,6 +413,9 @@ def complete_authentication(request, entity_obj): long_lived_token = generate_llt(eid, shared_key) + entity_obj.device_id = compute_device_id( + shared_key, request.phone_number, request.client_device_id_pub_key + ) entity_obj.client_publish_pub_key = request.client_publish_pub_key entity_obj.client_device_id_pub_key = request.client_device_id_pub_key entity_obj.server_crypto_metadata = crypto_metadata_ciphertext_b64 diff --git a/src/otp_service.py b/src/otp_service.py index 9ffe8b6..760f572 100644 --- a/src/otp_service.py +++ b/src/otp_service.py @@ -14,7 +14,7 @@ TWILIO_AUTH_TOKEN = get_configs("TWILIO_AUTH_TOKEN") TWILIO_SERVICE_SID = get_configs("TWILIO_SERVICE_SID") MOCK_OTP = get_configs("MOCK_OTP") -MOCK_OTP = True if MOCK_OTP and MOCK_OTP.lower() == "true" else False +MOCK_OTP = MOCK_OTP.lower() == "true" if MOCK_OTP is not None else False RATE_LIMIT_WINDOWS = [ {"duration": 2, "count": 1}, # 2 minute window diff --git a/src/tokens.py b/src/tokens.py index 37eb726..94794e7 100644 --- a/src/tokens.py +++ b/src/tokens.py @@ -102,7 +102,6 @@ def remove_none_values(values): {"a": 1, "b": 2, "c": None} ] filtered_values = remove_none_values(values) - print(filtered_values) # Output: [{'a': 1, 'c': 3}, {'b': 2, 'c': 3}, {'a': 1, 'b': 2}] """ return [{k: v for k, v in value.items() if v is not None} for value in values] diff --git a/src/utils.py b/src/utils.py index 19e5464..da1032c 100644 --- a/src/utils.py +++ b/src/utils.py @@ -162,27 +162,28 @@ def get_configs(config_name: str, strict: bool = False) -> str: raise -def set_configs(config_name: str, config_value: str) -> None: +def set_configs(config_name, config_value) -> None: """ Sets the value of a configuration in the environment variables. Args: config_name (str): The name of the configuration to set. - config_value (str): The value of the configuration to set. + config_value (str or bool): The value of the configuration to set. Raises: - ValueError: If config_name or config_value is empty. + ValueError: If config_name is empty. """ - if not config_name or not config_value: + if not config_name: error_message = ( - f"Cannot set configuration. Invalid config_name '{config_name}' ", - "or config_value '{config_value}'.", + f"Cannot set configuration. Invalid config_name '{config_name}'." ) logger.error(error_message) raise ValueError(error_message) try: - os.environ[config_name] = config_value + if isinstance(config_value, bool): + config_value = str(config_value).lower() + os.environ[config_name] = str(config_value) except Exception as error: logger.error("Failed to set configuration '%s': %s", config_name, error) raise diff --git a/tests/test_entity.py b/tests/test_entity.py index ee8ba79..4f430d8 100644 --- a/tests/test_entity.py +++ b/tests/test_entity.py @@ -1,8 +1,10 @@ """Test module for entity controller functions.""" +import base64 import pytest from peewee import SqliteDatabase from src.utils import create_tables, set_configs, generate_eid +from src.device_id import compute_device_id @pytest.fixture() @@ -54,16 +56,15 @@ def test_create_entity_additional_fields(): eid = generate_eid(phone_number_hash) password_hash = "password_hash2" country_code = "CM" - publish_pub_key = "-----BEGIN PUBLIC KEY-----\n1234\n-----END PUBLIC KEY-----" - device_id_pub_key = ( - "-----BEGIN DEVICE PUBLIC KEY-----\n1234\n-----END DEVICE PUBLIC KEY-----" - ) - + publish_pub_key = base64.b64encode(b"\x82" * 32).decode("utf-8") + device_id_pub_key = base64.b64encode(b"\x82" * 32).decode("utf-8") + device_id = compute_device_id(b"\x82" * 32, "+237123456789", publish_pub_key) entity = create_entity( eid, phone_number_hash, password_hash, country_code, + device_id=device_id, client_publish_pub_key=publish_pub_key, client_device_id_pub_key=device_id_pub_key, ) @@ -73,6 +74,7 @@ def test_create_entity_additional_fields(): assert entity.phone_number_hash == phone_number_hash assert entity.password_hash == password_hash assert entity.country_code == country_code + assert entity.device_id == device_id assert entity.client_publish_pub_key == publish_pub_key assert entity.client_device_id_pub_key == device_id_pub_key diff --git a/tests/test_grpc_entity_service.py b/tests/test_grpc_entity_service.py index 4d2b376..cc46211 100644 --- a/tests/test_grpc_entity_service.py +++ b/tests/test_grpc_entity_service.py @@ -37,6 +37,7 @@ def configure_test_environment(session_temp_dir): the application mode to testing. """ set_configs("MODE", "testing") + set_configs("MOCK_OTP", True) set_configs("KEYSTORE_PATH", str(session_temp_dir)) hashing_key_path = session_temp_dir / "hash.key" diff --git a/tests/test_otp_service.py b/tests/test_otp_service.py index 6d36571..cec2508 100644 --- a/tests/test_otp_service.py +++ b/tests/test_otp_service.py @@ -10,6 +10,7 @@ def set_testing_mode(): """Set the application mode to testing.""" set_configs("MODE", "testing") + set_configs("MOCK_OTP", True) @pytest.fixture(autouse=True)