diff --git a/hushline/crypto.py b/hushline/crypto.py index 9d25a603..f72d6173 100644 --- a/hushline/crypto.py +++ b/hushline/crypto.py @@ -1,21 +1,85 @@ import os +from base64 import urlsafe_b64decode, urlsafe_b64encode from cryptography.fernet import Fernet +from cryptography.hazmat.primitives.kdf.scrypt import Scrypt from flask import current_app from pysequoia import Cert, encrypt +# https://cryptography.io/en/latest/hazmat/primitives/key-derivation-functions/#scrypt +SCRYPT_LENGTH = 32 # The desired length of the derived key in bytes. +SCRYPT_N = 2**14 # CPU/Memory cost parameter. It must be larger than 1 and be a power of 2. +SCRYPT_R = 8 # Block size parameter. +SCRYPT_P = 1 # Parallelization parameter. + + +def generate_salt() -> str: + """ + Generate a random salt for use in encryption key derivation. + """ + return urlsafe_b64encode(os.urandom(32)).decode() + + +def get_encryption_key(scope: bytes | str | None = None, salt: str | None = None) -> Fernet: + """ + Return the default Fernet encryption key. If a scope and salt are provided, a unique encryption + key will be derived based on the scope and salt. + """ + encryption_key = os.environ.get("ENCRYPTION_KEY") + if encryption_key is None: + raise ValueError("Encryption key not found. Please check your .env file.") + + # If a scope is provided, we will use it to derive a unique encryption key + if scope is not None and salt is not None: + # Convert the scope to bytes if it is a string + if isinstance(scope, str): + scope_bytes = scope.encode() + elif isinstance(scope, bytes): + scope_bytes = scope + + # Convert the encryption key and salt to bytes + encryption_key_bytes = urlsafe_b64decode(encryption_key) + salt_bytes = urlsafe_b64decode(salt) + + # Use Scrypt to derive a unique encryption key based on the scope + kdf = Scrypt( + salt=salt_bytes, + length=SCRYPT_LENGTH, + n=SCRYPT_N, + r=SCRYPT_R, + p=SCRYPT_P, + ) + + # Concatenate the encryption key with the scope + items = (encryption_key_bytes, scope_bytes) + result = len(items).to_bytes(8, "big") + result += b"".join(len(item).to_bytes(8, "big") + item for item in items) + + # Derive the new key + new_encryption_key_bytes = kdf.derive(result) + encryption_key = urlsafe_b64encode(new_encryption_key_bytes).decode() + + return Fernet(encryption_key) + + encryption_key = os.environ.get("ENCRYPTION_KEY") if encryption_key is None: raise ValueError("Encryption key not found. Please check your .env file.") -fernet = Fernet(encryption_key) - -def encrypt_field(data: bytes | str | None) -> str | None: +def encrypt_field( + data: bytes | str | None, scope: bytes | str | None = None, salt: str | None = None +) -> str | None: + """ + Encrypts the data with the default encryption key. If both scope and salt are provided, + a unique encryption key will be derived based on the scope and salt. + """ if data is None: return None + fernet = get_encryption_key(scope, salt) + # Check if data is already a bytes object if not isinstance(data, bytes): # If data is a string, encode it to bytes @@ -26,9 +90,17 @@ def encrypt_field(data: bytes | str | None) -> str | None: return fernet.encrypt_at_time(data, current_time=0).decode() -def decrypt_field(data: str | None) -> str | None: +def decrypt_field( + data: str | None, scope: bytes | str | None = None, salt: str | None = None +) -> str | None: + """ + Decrypts the data with the default encryption key. If both scope and salt are provided, + a unique encryption key will be derived based on the scope and salt. + """ if data is None: return None + + fernet = get_encryption_key(scope, salt) return fernet.decrypt(data.encode()).decode() diff --git a/hushline/routes.py b/hushline/routes.py index c12abd44..64e1939b 100644 --- a/hushline/routes.py +++ b/hushline/routes.py @@ -21,7 +21,7 @@ from wtforms import Field, Form, PasswordField, StringField, TextAreaField from wtforms.validators import DataRequired, Length, Optional, ValidationError -from .crypto import encrypt_message +from .crypto import decrypt_field, encrypt_field, encrypt_message, generate_salt from .db import db from .forms import ComplexPassword from .model import AuthenticationLog, InviteCode, Message, SMTPEncryption, User @@ -118,6 +118,27 @@ def profile(username: str) -> Response | str: flash("🫥 User not found.") return redirect(url_for("index")) + # If the encrypted message is stored in the session, use it to populate the form + scope = "submit_message" + if ( + f"{scope}:salt" in session + and f"{scope}:contact_method" in session + and f"{scope}:content" in session + ): + try: + form.contact_method.data = decrypt_field( + session[f"{scope}:contact_method"], scope, session[f"{scope}:salt"] + ) + form.content.data = decrypt_field( + session[f"{scope}:content"], scope, session[f"{scope}:salt"] + ) + except Exception: + app.logger.error("Error decrypting content", exc_info=True) + + session.pop(f"{scope}:contact_method", None) + session.pop(f"{scope}:content", None) + session.pop(f"{scope}:salt", None) + # Generate a simple math problem using secrets module (e.g., "What is 6 + 7?") num1 = secrets.randbelow(10) + 1 # To get a number between 1 and 10 num2 = secrets.randbelow(10) + 1 # To get a number between 1 and 10 @@ -179,6 +200,15 @@ def submit_message(username: str) -> Response | str: captcha_answer = request.form.get("captcha_answer", "") if not validate_captcha(captcha_answer): + # Encrypt the message and store it in the session + scope = "submit_message" + salt = generate_salt() + session[f"{scope}:contact_method"] = encrypt_field( + form.contact_method.data, scope, salt + ) + session[f"{scope}:content"] = encrypt_field(form.content.data, scope, salt) + session[f"{scope}:salt"] = salt + return redirect(url_for("profile", username=username)) content = form.content.data diff --git a/tests/test_profile.py b/tests/test_profile.py index 2387bb79..4a2eaa5d 100644 --- a/tests/test_profile.py +++ b/tests/test_profile.py @@ -173,3 +173,42 @@ def test_profile_extra_fields(client: FlaskClient, app: Flask) -> None: soup ) or "<script>alert('xss')</script>" in str(soup) assert "" not in str(soup) + + +def test_profile_submit_message_with_invalid_captcha(client: FlaskClient) -> None: + # Register a user + user = register_user(client, "test_user_concat", "Secure-Test-Pass123") + assert user is not None + + # Log in the user + login_success = login_user(client, "test_user_concat", "Secure-Test-Pass123") + assert login_success + + # Prepare the message and contact method data + message_content = "This is a test message." + contact_method = "email@example.com" + message_data = { + "content": message_content, + "contact_method": contact_method, + "client_side_encrypted": "false", + "captcha_answer": 0, # the answer is never 0 + } + + # Send a POST request to submit the message + response = client.post( + f"/to/{user.primary_username}", + data=message_data, + follow_redirects=True, + ) + + # Make sure there's a CAPTCHA error + assert response.status_code == 200 + assert b"Incorrect CAPTCHA." in response.data + + # Make sure the contact method and message content are there + assert contact_method.encode() in response.data + assert message_content.encode() in response.data + + # Verify that the message is not saved in the database + message = db.session.scalars(db.select(Message).filter_by(user_id=user.id).limit(1)).first() + assert message is None