Skip to content

Commit

Permalink
Merge pull request #515 from scidsg/captcha-progress
Browse files Browse the repository at this point in the history
Save progress when the wrong CAPTCHA is failed
  • Loading branch information
micahflee committed Sep 6, 2024
2 parents 82677d3 + 2f0319d commit 953bcd1
Show file tree
Hide file tree
Showing 3 changed files with 146 additions and 5 deletions.
80 changes: 76 additions & 4 deletions hushline/crypto.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,85 @@
import os
from base64 import urlsafe_b64decode, urlsafe_b64encode

from cryptography.fernet import Fernet
from cryptography.hazmat.primitives.kdf.scrypt import Scrypt
from flask import current_app
from pysequoia import Cert, encrypt

# https://cryptography.io/en/latest/hazmat/primitives/key-derivation-functions/#scrypt
SCRYPT_LENGTH = 32 # The desired length of the derived key in bytes.
SCRYPT_N = 2**14 # CPU/Memory cost parameter. It must be larger than 1 and be a power of 2.
SCRYPT_R = 8 # Block size parameter.
SCRYPT_P = 1 # Parallelization parameter.


def generate_salt() -> str:
"""
Generate a random salt for use in encryption key derivation.
"""
return urlsafe_b64encode(os.urandom(32)).decode()


def get_encryption_key(scope: bytes | str | None = None, salt: str | None = None) -> Fernet:
"""
Return the default Fernet encryption key. If a scope and salt are provided, a unique encryption
key will be derived based on the scope and salt.
"""
encryption_key = os.environ.get("ENCRYPTION_KEY")
if encryption_key is None:
raise ValueError("Encryption key not found. Please check your .env file.")

# If a scope is provided, we will use it to derive a unique encryption key
if scope is not None and salt is not None:
# Convert the scope to bytes if it is a string
if isinstance(scope, str):
scope_bytes = scope.encode()
elif isinstance(scope, bytes):
scope_bytes = scope

# Convert the encryption key and salt to bytes
encryption_key_bytes = urlsafe_b64decode(encryption_key)
salt_bytes = urlsafe_b64decode(salt)

# Use Scrypt to derive a unique encryption key based on the scope
kdf = Scrypt(
salt=salt_bytes,
length=SCRYPT_LENGTH,
n=SCRYPT_N,
r=SCRYPT_R,
p=SCRYPT_P,
)

# Concatenate the encryption key with the scope
items = (encryption_key_bytes, scope_bytes)
result = len(items).to_bytes(8, "big")
result += b"".join(len(item).to_bytes(8, "big") + item for item in items)

# Derive the new key
new_encryption_key_bytes = kdf.derive(result)
encryption_key = urlsafe_b64encode(new_encryption_key_bytes).decode()

return Fernet(encryption_key)


encryption_key = os.environ.get("ENCRYPTION_KEY")

if encryption_key is None:
raise ValueError("Encryption key not found. Please check your .env file.")

fernet = Fernet(encryption_key)


def encrypt_field(data: bytes | str | None) -> str | None:
def encrypt_field(
data: bytes | str | None, scope: bytes | str | None = None, salt: str | None = None
) -> str | None:
"""
Encrypts the data with the default encryption key. If both scope and salt are provided,
a unique encryption key will be derived based on the scope and salt.
"""
if data is None:
return None

fernet = get_encryption_key(scope, salt)

# Check if data is already a bytes object
if not isinstance(data, bytes):
# If data is a string, encode it to bytes
Expand All @@ -26,9 +90,17 @@ def encrypt_field(data: bytes | str | None) -> str | None:
return fernet.encrypt_at_time(data, current_time=0).decode()


def decrypt_field(data: str | None) -> str | None:
def decrypt_field(
data: str | None, scope: bytes | str | None = None, salt: str | None = None
) -> str | None:
"""
Decrypts the data with the default encryption key. If both scope and salt are provided,
a unique encryption key will be derived based on the scope and salt.
"""
if data is None:
return None

fernet = get_encryption_key(scope, salt)
return fernet.decrypt(data.encode()).decode()


Expand Down
32 changes: 31 additions & 1 deletion hushline/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from wtforms import Field, Form, PasswordField, StringField, TextAreaField
from wtforms.validators import DataRequired, Length, Optional, ValidationError

from .crypto import encrypt_message
from .crypto import decrypt_field, encrypt_field, encrypt_message, generate_salt
from .db import db
from .forms import ComplexPassword
from .model import AuthenticationLog, InviteCode, Message, SMTPEncryption, User
Expand Down Expand Up @@ -118,6 +118,27 @@ def profile(username: str) -> Response | str:
flash("🫥 User not found.")
return redirect(url_for("index"))

# If the encrypted message is stored in the session, use it to populate the form
scope = "submit_message"
if (
f"{scope}:salt" in session
and f"{scope}:contact_method" in session
and f"{scope}:content" in session
):
try:
form.contact_method.data = decrypt_field(
session[f"{scope}:contact_method"], scope, session[f"{scope}:salt"]
)
form.content.data = decrypt_field(
session[f"{scope}:content"], scope, session[f"{scope}:salt"]
)
except Exception:
app.logger.error("Error decrypting content", exc_info=True)

session.pop(f"{scope}:contact_method", None)
session.pop(f"{scope}:content", None)
session.pop(f"{scope}:salt", None)

# Generate a simple math problem using secrets module (e.g., "What is 6 + 7?")
num1 = secrets.randbelow(10) + 1 # To get a number between 1 and 10
num2 = secrets.randbelow(10) + 1 # To get a number between 1 and 10
Expand Down Expand Up @@ -179,6 +200,15 @@ def submit_message(username: str) -> Response | str:

captcha_answer = request.form.get("captcha_answer", "")
if not validate_captcha(captcha_answer):
# Encrypt the message and store it in the session
scope = "submit_message"
salt = generate_salt()
session[f"{scope}:contact_method"] = encrypt_field(
form.contact_method.data, scope, salt
)
session[f"{scope}:content"] = encrypt_field(form.content.data, scope, salt)
session[f"{scope}:salt"] = salt

return redirect(url_for("profile", username=username))

content = form.content.data
Expand Down
39 changes: 39 additions & 0 deletions tests/test_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,3 +173,42 @@ def test_profile_extra_fields(client: FlaskClient, app: Flask) -> None:
soup
) or "<script>alert('xss')</script>" in str(soup)
assert "<script>alert('xss')</script>" not in str(soup)


def test_profile_submit_message_with_invalid_captcha(client: FlaskClient) -> None:
# Register a user
user = register_user(client, "test_user_concat", "Secure-Test-Pass123")
assert user is not None

# Log in the user
login_success = login_user(client, "test_user_concat", "Secure-Test-Pass123")
assert login_success

# Prepare the message and contact method data
message_content = "This is a test message."
contact_method = "email@example.com"
message_data = {
"content": message_content,
"contact_method": contact_method,
"client_side_encrypted": "false",
"captcha_answer": 0, # the answer is never 0
}

# Send a POST request to submit the message
response = client.post(
f"/to/{user.primary_username}",
data=message_data,
follow_redirects=True,
)

# Make sure there's a CAPTCHA error
assert response.status_code == 200
assert b"Incorrect CAPTCHA." in response.data

# Make sure the contact method and message content are there
assert contact_method.encode() in response.data
assert message_content.encode() in response.data

# Verify that the message is not saved in the database
message = db.session.scalars(db.select(Message).filter_by(user_id=user.id).limit(1)).first()
assert message is None

0 comments on commit 953bcd1

Please sign in to comment.