Skip to content
This repository has been archived by the owner on Dec 2, 2024. It is now read-only.

Commit

Permalink
feat: sessions and hmac tokens
Browse files Browse the repository at this point in the history
  • Loading branch information
tnix100 committed Aug 25, 2024
1 parent d54ec93 commit 5a1657e
Show file tree
Hide file tree
Showing 14 changed files with 262 additions and 158 deletions.
2 changes: 2 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ API_ROOT=
INTERNAL_API_ENDPOINT="http://127.0.0.1:3001" # used for proxying CL3 commands
INTERNAL_API_TOKEN="" # used for authenticating internal API requests (gives access to any account, meant to be used by CL3)

SENTRY_DSN=

CAPTCHA_SITEKEY=
CAPTCHA_SECRET=

Expand Down
12 changes: 8 additions & 4 deletions cloudlink.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,8 @@ def __init__(
self.server = server
self.websocket = websocket

# Set username, protocol version, IP, and trusted status
# Set account session ID, username, protocol version, IP, and trusted status
self.acc_session_id: Optional[str] = None
self.username: Optional[str] = None
try:
self.proto_version: int = int(self.req_params.get("v")[0])
Expand Down Expand Up @@ -255,7 +256,7 @@ def ip(self):
else:
return self.websocket.remote_address

def authenticate(self, account: dict[str, Any], token: str, listener: Optional[str] = None):
def authenticate(self, acc_session: dict[str, Any], token: str, account: dict[str, Any], listener: Optional[str] = None):
if self.username:
self.logout()

Expand All @@ -265,6 +266,7 @@ def authenticate(self, account: dict[str, Any], token: str, listener: Optional[s
return self.send_statuscode("Banned", listener)

# Authenticate
self.acc_session_id = acc_session["_id"]
self.username = account["_id"]
if self.username in self.server.usernames:
self.server.usernames[self.username].append(self)
Expand All @@ -275,6 +277,7 @@ def authenticate(self, account: dict[str, Any], token: str, listener: Optional[s
# Send auth payload
self.send("auth", {
"username": self.username,
"session": acc_session,
"token": token,
"account": account,
"relationships": self.proxy_api_request("/me/relationships", "get")["autoget"],
Expand Down Expand Up @@ -307,6 +310,7 @@ def proxy_api_request(
headers.update({
"X-Internal-Token": os.environ["INTERNAL_API_TOKEN"],
"X-Internal-Ip": self.ip,
"X-Internal-UA": self.websocket.request_headers.get("User-Agent"),
})
if self.username:
headers["X-Internal-Username"] = self.username
Expand Down Expand Up @@ -356,7 +360,7 @@ def send_statuscode(self, statuscode: str, listener: Optional[str] = None):
def kick(self):
async def _kick():
await self.websocket.close()
asyncio.create_task(_kick())
asyncio.run(_kick())

class CloudlinkCommands:
@staticmethod
Expand Down Expand Up @@ -389,7 +393,7 @@ async def authpswd(client: CloudlinkClient, val, listener: Optional[str] = None)
else:
if resp and not resp["error"]:
# Authenticate client
client.authenticate(resp["account"], resp["token"], listener=listener)
client.authenticate(resp["session"], resp["token"], resp["account"], listener=listener)

# Tell the client it is authenticated
client.send_statuscode("OK", listener)
Expand Down
57 changes: 20 additions & 37 deletions database.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
import os
import secrets
from radix import Radix
from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey
from hashlib import sha256
from base64 import urlsafe_b64encode

from utils import log

CURRENT_DB_VERSION = 9
CURRENT_DB_VERSION = 10

# Create Redis connection
log("Connecting to Redis...")
Expand Down Expand Up @@ -43,8 +44,6 @@
# Create usersv0 indexes
try: db.usersv0.create_index([("lower_username", pymongo.ASCENDING)], name="lower_username", unique=True)
except: pass
try: db.usersv0.create_index([("tokens", pymongo.ASCENDING)], name="tokens", unique=True)
except: pass
try: db.usersv0.create_index([("created", pymongo.DESCENDING)], name="recent_users")
except: pass
try:
Expand Down Expand Up @@ -193,7 +192,6 @@
"avatar_color": None,
"quote": None,
"pswd": None,
"tokens": None,
"flags": 1,
"permissions": None,
"ban": None,
Expand All @@ -214,41 +212,17 @@
"registration": True
})
except pymongo.errors.DuplicateKeyError: pass


# Load existing signing keys or create new ones
signing_keys = {}
if db.config.count_documents({"_id": "signing_keys"}, limit=1):
data = db.config.count_documents({"_id": "signing_keys"}, limit=1)

acc_priv = Ed25519PrivateKey.from_private_bytes(data["acc_priv"])
email_priv = Ed25519PrivateKey.from_private_bytes(data["email_priv"])

signing_keys.update({
"acc_priv": acc_priv,
"acc_pub": acc_priv.public_key(),

"email_priv": email_priv,
"email_pub": email_priv.public_key()
try:
db.config.insert_one({
"_id": "signing_keys",
"acc": secrets.token_bytes(64),
"email": secrets.token_bytes(64)
})
else:
acc_priv = Ed25519PrivateKey.generate()
email_priv = Ed25519PrivateKey.generate()
except pymongo.errors.DuplicateKeyError: pass

signing_keys.update({
"acc_priv": acc_priv,
"acc_pub": acc_priv.public_key(),

"email_priv": email_priv,
"email_pub": email_priv.public_key()
})

data = {
"_id": "signing_keys",
"acc_priv": acc_priv.private_bytes_raw(),
"email_priv": email_priv.private_bytes_raw()
}
db.confing.insert_one(signing_keys)
# Load signing keys
signing_keys = db.config.find_one({"_id": "signing_keys"})


# Load netblocks
Expand Down Expand Up @@ -343,6 +317,15 @@ def get_total_pages(collection: str, query: dict, page_size: int = 25) -> int:
"mfa_recovery_code": user["mfa_recovery_code"][:10]
}})

# New sessions
log("[Migrator] Adding new sessions")
from sessions import AccSession
for user in db.usersv0.find({"tokens": {"$exists": True}}, projection={"_id": 1, "tokens": 1}):
if user["tokens"]:
for token in user["tokens"]:
rdb.set(urlsafe_b64encode(sha256(token.encode()).digest()), user["_id"], ex=1209600) # 14 days
db.usersv0.update_one({"_id": user["_id"]}, {"$set": {"tokens": []}})

db.config.update_one({"_id": "migration"}, {"$set": {"database": CURRENT_DB_VERSION}})
log(f"[Migrator] Finished Migrating DB to version {CURRENT_DB_VERSION}")

Expand Down
3 changes: 3 additions & 0 deletions errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
class InvalidTokenSignature(Exception): pass

class SessionNotFound(Exception): pass
23 changes: 15 additions & 8 deletions grpc_auth/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
auth_service_pb2 as pb2
)

from sentry_sdk import capture_exception

from database import db
from sessions import AccSession


class AuthService(pb2_grpc.AuthServicer):
Expand All @@ -22,15 +25,19 @@ def CheckToken(self, request, context):
if not authed:
context.abort(grpc.StatusCode.UNAUTHENTICATED, "Invalid or missing token")

account = db.usersv0.find_one({"tokens": request.token}, projection={
"_id": 1,
"ban.state": 1,
"ban.expires": 1
})
if account:
try:
username = AccSession.get_username_by_token(request.token)
except Exception as e:
capture_exception(e)
else:
account = db.usersv0.find_one({"_id": username}, projection={
"_id": 1,
"ban.state": 1,
"ban.expires": 1
})
if account and \
(account["ban"]["state"] == "perm_ban" or \
(account["ban"]["state"] == "temp_ban" and account["ban"]["expires"] > time.time())):
(account["ban"]["state"] == "perm_ban" or \
(account["ban"]["state"] == "temp_ban" and account["ban"]["expires"] > time.time())):
account = None

return pb2.CheckTokenResp(
Expand Down
4 changes: 4 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import asyncio
import os
import uvicorn
import sentry_sdk

from threading import Thread

Expand All @@ -16,6 +17,9 @@


if __name__ == "__main__":
# Initialise Sentry (uses SENTRY_DSN env var)
sentry_sdk.init()

# Create Cloudlink server
cl = CloudlinkServer()

Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,5 @@ protobuf
pyotp
emoji
websockets
qrcode
qrcode
sentry-sdk
22 changes: 15 additions & 7 deletions rest_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from quart_cors import cors
from quart_schema import QuartSchema, RequestSchemaValidationError, validate_headers, hide
from pydantic import BaseModel
from sentry_sdk import capture_exception
import time, os

from .v0 import v0
Expand All @@ -10,6 +11,7 @@
from .admin import admin_bp

from database import db, blocked_ips, registration_blocked_ips
from sessions import AccSession
import security


Expand Down Expand Up @@ -41,6 +43,7 @@ async def internal_auth():
abort(401)

request.internal_ip = request.headers.get("X-Internal-Ip")
request.headers["User-Agent"] = request.headers.get("X-Internal-UA")
request.internal_username = request.headers.get("X-Internal-Username")
request.bypass_captcha = True

Expand Down Expand Up @@ -74,13 +77,18 @@ async def check_auth(headers: TokenHeader):
"ban.expires": 1
})
elif headers.token: # external auth
account = db.usersv0.find_one({"tokens": headers.token}, projection={
"_id": 1,
"flags": 1,
"permissions": 1,
"ban.state": 1,
"ban.expires": 1
})
try:
username = AccSession.get_username_by_token(headers.token)
except Exception as e:
capture_exception(e)
else:
account = db.usersv0.find_one({"_id": username}, projection={
"_id": 1,
"flags": 1,
"permissions": 1,
"ban.state": 1,
"ban.expires": 1
})

if account:
if account["ban"]["state"] == "perm_ban" or (account["ban"]["state"] == "temp_ban" and account["ban"]["expires"] > time.time()):
Expand Down
32 changes: 15 additions & 17 deletions rest_api/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import security
from database import db, get_total_pages, blocked_ips, registration_blocked_ips
from sessions import AccSession


admin_bp = Blueprint("admin_bp", __name__, url_prefix="/admin")
Expand Down Expand Up @@ -651,17 +652,17 @@ async def delete_user(username, query_args: DeleteUserQueryArgs):
{"_id": username}, {"$set": {"delete_after": None}}
)
elif deletion_mode in ["schedule", "immediate", "purge"]:
db.usersv0.update_one(
{"_id": username},
{
"$set": {
"tokens": [],
"delete_after": int(time.time()) + (604800 if deletion_mode == "schedule" else 0),
}
},
)
for client in app.cl.usernames.get(username, []):
client.kick()
if deletion_mode == "schedule":
db.usersv0.update_one(
{"_id": username},
{
"$set": {
"delete_after": int(time.time()) + (604800 if deletion_mode == "schedule" else 0),
}
},
)
for session in AccSession.get_all(username):
session.revoke()
if deletion_mode in ["immediate", "purge"]:
security.delete_account(username, purge=(deletion_mode == "purge"))
else:
Expand Down Expand Up @@ -828,12 +829,9 @@ async def kick_user(username):
if not security.has_permission(request.permissions, security.AdminPermissions.KICK_USERS):
abort(401)

# Revoke tokens
db.usersv0.update_one({"_id": username}, {"$set": {"tokens": []}})

# Kick clients
for client in app.cl.usernames.get(username, []):
client.kick()
# Revoke sessions
for session in AccSession.get_all(username):
session.revoke()

# Add log
security.add_audit_log(
Expand Down
Loading

0 comments on commit 5a1657e

Please sign in to comment.