Skip to content

Commit

Permalink
feat(grpc): add X25519 public key validation in gRPC entity service
Browse files Browse the repository at this point in the history
- Added validation for X25519 public keys in gRPC entity service requests.
- Implemented a function `is_valid_x25519_public_key` to check the validity of X25519 public keys encoded in base64.
- Updated `validate_request_fields` function to include validation for X25519 public keys.
- Added tests for entity creation and authentication with invalid X25519 public keys.

Fixes: #99
  • Loading branch information
PromiseFru committed Jun 8, 2024
1 parent 8209947 commit 6cadc0e
Show file tree
Hide file tree
Showing 3 changed files with 158 additions and 41 deletions.
82 changes: 51 additions & 31 deletions src/grpc_entity_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
generate_crypto_metadata,
generate_eid,
get_shared_key,
is_valid_x25519_public_key,
)
from src.long_lived_token import generate_llt

Expand Down Expand Up @@ -62,9 +63,9 @@ def error_response(context, sys_msg, status_code, user_msg=None, _type=None):
return vault_pb2.CreateEntityResponse()


def check_missing_fields(context, request, required_fields):
def validate_request_fields(context, request, required_fields):
"""
Check for missing fields in the gRPC request.
Validates the fields in the gRPC request.
Args:
context: gRPC context.
Expand All @@ -82,6 +83,25 @@ def check_missing_fields(context, request, required_fields):
f"Missing required fields: {', '.join(missing_fields)}",
grpc.StatusCode.INVALID_ARGUMENT,
)

x25519_fields = [
"client_publish_pub_key",
"client_device_id_pub_key",
]
invalid_fields = {}

for field in set(x25519_fields) & set(required_fields):
is_valid, error = is_valid_x25519_public_key(getattr(request, field))
if not is_valid:
invalid_fields[field] = error

if invalid_fields:
return error_response(
context,
f"Invalid fields: {invalid_fields}",
grpc.StatusCode.INVALID_ARGUMENT,
)

return None


Expand Down Expand Up @@ -169,6 +189,15 @@ def complete_creation(request):
)
)

shared_key = get_shared_key(
os.path.join(KEYSTORE_PATH, f"{eid}_device_id.db"),
server_device_id_keypair.pnt_keystore,
server_device_id_keypair.secret_key,
base64.b64decode(request.client_device_id_pub_key),
)

long_lived_token = generate_llt(eid, shared_key)

fields = {
"eid": eid,
"phone_number_hash": phone_number_hash,
Expand All @@ -181,15 +210,6 @@ def complete_creation(request):

create_entity(**fields)

shared_key = get_shared_key(
os.path.join(KEYSTORE_PATH, f"{eid}_device_id.db"),
server_device_id_keypair.pnt_keystore,
server_device_id_keypair.secret_key,
base64.b64decode(request.client_device_id_pub_key),
)

long_lived_token = generate_llt(eid, shared_key)

logger.info("Entity created successfully")

return vault_pb2.CreateEntityResponse(
Expand Down Expand Up @@ -226,11 +246,11 @@ def initiate_creation(request):
)

try:
missing_fields_response = check_missing_fields(
invalid_fields_response = validate_request_fields(
context, request, ["phone_number"]
)
if missing_fields_response:
return missing_fields_response
if invalid_fields_response:
return invalid_fields_response

phone_number_hash = generate_hmac(HASHING_KEY, request.phone_number)
entity_obj = find_entity(phone_number_hash=phone_number_hash)
Expand All @@ -249,11 +269,11 @@ def initiate_creation(request):
"client_publish_pub_key",
"client_device_id_pub_key",
]
missing_fields_response = check_missing_fields(
invalid_fields_response = validate_request_fields(
context, request, required_fields
)
if missing_fields_response:
return missing_fields_response
if invalid_fields_response:
return invalid_fields_response

return complete_creation(request)

Expand Down Expand Up @@ -282,11 +302,11 @@ def initiate_authentication(request, entity_obj):
Returns:
vault_pb2.AuthenticateEntityResponse: Authentication response.
"""
missing_fields_response = check_missing_fields(
invalid_fields_response = validate_request_fields(
context, request, ["password"]
)
if missing_fields_response:
return missing_fields_response
if invalid_fields_response:
return invalid_fields_response

if not verify_hmac(HASHING_KEY, request.password, entity_obj.password_hash):
return error_response(
Expand Down Expand Up @@ -322,16 +342,16 @@ def complete_authentication(request, entity_obj):
Returns:
vault_pb2.AuthenticateEntityResponse: Authentication response.
"""
missing_fields_response = check_missing_fields(
invalid_fields_response = validate_request_fields(
context,
request,
[
"client_publish_pub_key",
"client_device_id_pub_key",
],
)
if missing_fields_response:
return missing_fields_response
if invalid_fields_response:
return invalid_fields_response

success, response = handle_pow_verification(context, request)
if not success:
Expand All @@ -355,11 +375,6 @@ def complete_authentication(request, entity_obj):
)
)

entity_obj.client_publish_pub_key = request.client_publish_pub_key
entity_obj.client_device_id_pub_key = request.client_device_id_pub_key
entity_obj.server_crypto_metadata = crypto_metadata_ciphertext_b64
entity_obj.save()

shared_key = get_shared_key(
os.path.join(KEYSTORE_PATH, f"{eid}_device_id.db"),
server_device_id_keypair.pnt_keystore,
Expand All @@ -369,6 +384,11 @@ def complete_authentication(request, entity_obj):

long_lived_token = generate_llt(eid, shared_key)

entity_obj.client_publish_pub_key = request.client_publish_pub_key
entity_obj.client_device_id_pub_key = request.client_device_id_pub_key
entity_obj.server_crypto_metadata = crypto_metadata_ciphertext_b64
entity_obj.save()

return vault_pb2.AuthenticateEntityResponse(
long_lived_token=long_lived_token,
message="Entity authenticated successfully!",
Expand All @@ -381,11 +401,11 @@ def complete_authentication(request, entity_obj):
)

try:
missing_fields_response = check_missing_fields(
invalid_fields_response = validate_request_fields(
context, request, ["phone_number"]
)
if missing_fields_response:
return missing_fields_response
if invalid_fields_response:
return invalid_fields_response

phone_number_hash = generate_hmac(HASHING_KEY, request.phone_number)
entity_obj = find_entity(phone_number_hash=phone_number_hash)
Expand Down
28 changes: 28 additions & 0 deletions src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import base64
from functools import wraps

from cryptography.hazmat.primitives.asymmetric import x25519 as x25519_core
import mysql.connector
from peewee import DatabaseError
from smswithoutborders_libsig.keypairs import x25519
Expand Down Expand Up @@ -360,3 +361,30 @@ def convert_to_fernet_key(secret_key):
raise ValueError("Secret key must be 32 bytes long")

return base64.urlsafe_b64encode(secret_key)


def is_valid_x25519_public_key(encoded_key):
"""
Validates an X25519 public key encoded in base64.
Args:
encoded_key (bytes): The base64-encoded public key to validate.
Returns:
tuple[bool, str]: A tuple where the first element is a boolean i
ndicating whether the key is valid, and the second element is an
error message if the key is invalid, or None if the key
is valid.
"""
try:
decoded_key = base64.b64decode(encoded_key)
except (TypeError, ValueError) as err:
logger.exception("Base64 decoding error: %s", err)
return False, "Invalid base64 encoding"

try:
x25519_core.X25519PublicKey.from_public_bytes(decoded_key)
return True, None
except ValueError as err:
logger.exception("X25519 public key validation error: %s", err)
return False, f"Invalid X25519 public key: {err}"
89 changes: 79 additions & 10 deletions tests/test_grpc_entity_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,8 +205,8 @@ def test_entity_complete_creation_invalid_proof(grpc_server_mock):
"ownership_proof_response": "12346",
"country_code": "CM",
"password": "password",
"client_publish_pub_key": "client_publish_pub_key",
"client_device_id_pub_key": "client_device_id_pub_key",
"client_publish_pub_key": base64.b64encode(b"\x82" * 32).decode("utf-8"),
"client_device_id_pub_key": base64.b64encode(b"\x82" * 32).decode("utf-8"),
}

request = vault_pb2.CreateEntityRequest(**request_data)
Expand Down Expand Up @@ -333,6 +333,36 @@ def test_entity_complete_creation_success(grpc_server_mock):
assert client_device_id_shared_key == server_device_id_shared_key


def test_entity_complete_creation_invalid_public_keys(grpc_server_mock):
"""Test case for entity creation with invalid public keys."""
request_data = {
"phone_number": "+237123456789",
"ownership_proof_response": "123456",
"country_code": "CM",
"password": "Password@1234",
"client_publish_pub_key": "invalid_key",
"client_device_id_pub_key": "invalid_key",
}

request = vault_pb2.CreateEntityRequest(**request_data)

create_entity_method = grpc_server_mock.invoke_unary_unary(
method_descriptor=(
vault_pb2.DESCRIPTOR.services_by_name["Entity"].methods_by_name[
"CreateEntity"
]
),
invocation_metadata={},
request=request,
timeout=1,
)

_, _, code, details = create_entity_method.termination()

assert code == grpc.StatusCode.INVALID_ARGUMENT
assert "Invalid fields:" in details


def test_entity_initiate_authentication_success(grpc_server_mock):
"""Test case for successful initiation of entity authentication."""
from src.db_models import Entity
Expand Down Expand Up @@ -373,7 +403,13 @@ def test_entity_complete_authentication_success(grpc_server_mock):
"""Test case for successful completion of entity authentication."""
from src.db_models import Entity

request_data = {"phone_number": "+237123456789", "password": "Password@1234"}
request_data = {
"phone_number": "+237123456789",
"password": "Password@1234",
"ownership_proof_response": "123456",
"client_publish_pub_key": base64.b64encode(b"\x82" * 32).decode("utf-8"),
"client_device_id_pub_key": base64.b64encode(b"\x82" * 32).decode("utf-8"),
}
hash_key = load_key(get_configs("HASHING_SALT"), 32)
phone_number_hash = generate_hmac(hash_key, request_data["phone_number"])

Expand All @@ -384,13 +420,6 @@ def test_entity_complete_authentication_success(grpc_server_mock):
password_hash=generate_hmac(hash_key, request_data["password"]),
)

request_data = {
"phone_number": request_data["phone_number"],
"password": request_data["password"],
"ownership_proof_response": "123456",
"client_publish_pub_key": "Kqprob8WuflOMpcR6SGg8yQumerTvm1MQeAtcgFxWFY",
"client_device_id_pub_key": "UD6gLBg0RJ/olGhJItmDxHOdv0550BDpGGnMIcvbCkc=",
}
request = vault_pb2.AuthenticateEntityRequest(**request_data)

authenticate_entity_method = grpc_server_mock.invoke_unary_unary(
Expand Down Expand Up @@ -467,3 +496,43 @@ def test_entity_authenticate_incorrect_password(grpc_server_mock):

assert code == grpc.StatusCode.UNAUTHENTICATED
assert "Incorrect credentials." in details


def test_entity_complete_authentication_invalid_public_keys(grpc_server_mock):
"""Test case for entity authentication with invalid public keys."""
from src.db_models import Entity

request_data = {
"phone_number": "+237123456789",
"password": "Password@1234",
"ownership_proof_response": "123456",
"client_publish_pub_key": "invalid_key",
"client_device_id_pub_key": "invalid_key",
}
hash_key = load_key(get_configs("HASHING_SALT"), 32)
phone_number_hash = generate_hmac(hash_key, request_data["phone_number"])

Entity.create(
phone_number_hash=phone_number_hash,
eid=generate_eid(phone_number_hash),
country_code="CM",
password_hash=generate_hmac(hash_key, request_data["password"]),
)

request = vault_pb2.AuthenticateEntityRequest(**request_data)

authenticate_entity_method = grpc_server_mock.invoke_unary_unary(
method_descriptor=(
vault_pb2.DESCRIPTOR.services_by_name["Entity"].methods_by_name[
"AuthenticateEntity"
]
),
invocation_metadata={},
request=request,
timeout=1,
)

_, _, code, details = authenticate_entity_method.termination()

assert code == grpc.StatusCode.INVALID_ARGUMENT
assert "Invalid fields:" in details

0 comments on commit 6cadc0e

Please sign in to comment.