Skip to content

Commit

Permalink
Merge pull request #330 from pepkit/313_mint_tokens
Browse files Browse the repository at this point in the history
Add the ability to mint new API tokens through the UI
  • Loading branch information
nleroy917 authored Jul 10, 2024
2 parents 72f22a9 + 922ed0a commit 83e7fe5
Show file tree
Hide file tree
Showing 26 changed files with 927 additions and 127 deletions.
1 change: 1 addition & 0 deletions pephub/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@
JWT_SECRET = token_hex(32)
JWT_EXPIRATION = 4320 # 3 days in minutes
JWT_EXPIRATION_SECONDS = JWT_EXPIRATION * 60 # seconds
MAX_NEW_KEYS = 5

AUTH_CODE_EXPIRATION = 5 * 60 # seconds

Expand Down
42 changes: 20 additions & 22 deletions pephub/dependencies.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import json
import logging
import os
from datetime import datetime, timedelta
from datetime import datetime
from secrets import token_hex
from typing import Any, Dict, Generator, List, Optional, Union
from typing import Any, Dict, List, Optional, Union
from cachetools import cached, TTLCache

import jwt
Expand Down Expand Up @@ -31,10 +31,11 @@
DEFAULT_POSTGRES_USER,
DEFAULT_QDRANT_HOST,
DEFAULT_QDRANT_PORT,
JWT_EXPIRATION,
JWT_SECRET,
)
from .helpers import jwt_encode_user_data
from .routers.models import ForkRequest
from .developer_keys import dev_key_handler

_LOGGER_PEPHUB = logging.getLogger("uvicorn.access")

Expand Down Expand Up @@ -83,14 +84,8 @@ def _request_user_data_from_github(access_token: str) -> UserData:
)

@staticmethod
def jwt_encode_user_data(user_data: dict) -> str:
exp = datetime.utcnow() + timedelta(minutes=JWT_EXPIRATION)
encoded_user_data = jwt.encode(
{**user_data, "exp": exp}, JWT_SECRET, algorithm="HS256"
)
if isinstance(encoded_user_data, bytes):
encoded_user_data = encoded_user_data.decode("utf-8")
return encoded_user_data
def jwt_encode_user_data(user_data: dict, exp: datetime = None) -> str:
return jwt_encode_user_data(user_data, exp=exp)


# database connection
Expand Down Expand Up @@ -132,19 +127,22 @@ def get_db() -> PEPDatabaseAgent:
return agent


def read_authorization_header(Authorization: str = Header(None)) -> Union[dict, None]:
def read_authorization_header(authorization: str = Header(None)) -> Union[dict, None]:
"""
Reads and decodes a JWT, returning the decoded variables.
:param Authorization: JWT provided via FastAPI injection from the API cookie.
"""
if Authorization is None:
if authorization is None:
return None
else:
Authorization = Authorization.replace("Bearer ", "")
authorization = authorization.replace("Bearer ", "")
try:
# Python jwt.decode verifies content as well so this is safe.
session_info = jwt.decode(Authorization, JWT_SECRET, algorithms=["HS256"])
# check last 5 chars
if dev_key_handler.is_key_bad(authorization[-5:]):
raise HTTPException(401, "JWT has been revoked")
session_info = jwt.decode(authorization, JWT_SECRET, algorithms=["HS256"])
except jwt.exceptions.InvalidSignatureError as e:
_LOGGER_PEPHUB.error(e)
return None
Expand Down Expand Up @@ -201,7 +199,7 @@ def get_project(
description="Return the project with the samples pephub_id",
include_in_schema=False,
),
) -> Dict[str, Any]:
) -> Dict[str, Any]: # type: ignore
try:
proj = agent.project.get(namespace, project, tag, raw=True, with_id=with_id)
yield proj
Expand All @@ -217,7 +215,7 @@ def get_config(
project: str,
tag: Optional[str] = DEFAULT_TAG,
agent: PEPDatabaseAgent = Depends(get_db),
) -> Dict[str, Any]:
) -> Dict[str, Any]: # type: ignore
try:
config = agent.project.get_config(namespace, project, tag)
yield config
Expand All @@ -233,7 +231,7 @@ def get_subsamples(
project: str,
tag: Optional[str] = DEFAULT_TAG,
agent: PEPDatabaseAgent = Depends(get_db),
) -> Dict[str, Any]:
) -> Dict[str, Any]: # type: ignore # type: ignore
try:
subsamples = agent.project.get_subsamples(namespace, project, tag)
yield subsamples
Expand All @@ -250,7 +248,7 @@ def get_project_annotation(
tag: Optional[str] = DEFAULT_TAG,
agent: PEPDatabaseAgent = Depends(get_db),
namespace_access_list: List[str] = Depends(get_namespace_access_list),
) -> AnnotationModel:
) -> AnnotationModel: # type: ignore
try:
anno = agent.annotation.get(
namespace, project, tag, admin=namespace_access_list
Expand Down Expand Up @@ -320,7 +318,7 @@ def verify_user_can_read_project(
def verify_user_can_fork(
fork_request: ForkRequest,
namespace_access_list: List[str] = Depends(get_namespace_access_list),
) -> bool:
) -> bool: # type: ignore
fork_namespace = fork_request.fork_to
if fork_namespace in (namespace_access_list or []):
yield
Expand All @@ -344,7 +342,7 @@ def get_qdrant_enabled() -> bool:

def get_qdrant(
qdrant_enabled: bool = Depends(get_qdrant_enabled),
) -> Union[QdrantClient, None]:
) -> Union[QdrantClient, None]: # type: ignore
"""
Return connection to qdrant client
"""
Expand Down Expand Up @@ -383,7 +381,7 @@ def get_namespace_info(
namespace: str,
agent: PEPDatabaseAgent = Depends(get_db),
user: str = Depends(get_user_from_session_info),
) -> Namespace:
) -> Namespace: # type: ignore
"""
Get the information on a namespace, if it exists.
"""
Expand Down
81 changes: 81 additions & 0 deletions pephub/developer_keys.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from datetime import datetime, timedelta
from typing import Dict, List

import secrets

from fastapi import HTTPException

from .routers.models import DeveloperKey

from .helpers import jwt_encode_user_data
from .const import MAX_NEW_KEYS


class DeveloperKeyHandler:
def __init__(self, default_exp: int = 30 * 24 * 60 * 60):
self._keys: Dict[str, List[DeveloperKey]] = {}
self._default_exp = default_exp
self._bad_jwts = []

def add_key(self, namespace: str, key: DeveloperKey):
"""
Add a key to the handler for a given namespace/user
:param namespace: namespace for the key
:param key: DeveloperKey object
"""
if namespace not in self._keys:
self._keys[namespace] = []
if len(self._keys[namespace]) >= MAX_NEW_KEYS:
raise HTTPException(
status_code=400,
detail="You have reached the maximum number of keys allowed",
)
self._keys[namespace].append(key)

def get_keys_for_namespace(self, namespace: str) -> List[DeveloperKey]:
"""
Get all the keys for a given namespace
:param namespace: namespace for the key
"""
return self._keys.get(namespace) or []

def remove_key(self, namespace: str, last_five_chars: str):
"""
Remove a key from the handler for a given namespace/user
:param namespace: namespace for the key
:param key: key to remove
"""
if namespace in self._keys:
self._keys[namespace] = [
key for key in self._keys[namespace] if key.key[-5:] != last_five_chars
]
self._bad_jwts.append(last_five_chars)

def mint_key_for_namespace(
self, namespace: str, session_info: dict
) -> DeveloperKey:
"""
Mint a new key for a given namespace
:param namespace: namespace for the key
"""
salt = secrets.token_hex(32)
session_info["salt"] = salt
expiry = datetime.utcnow() + timedelta(seconds=self._default_exp)
new_key = jwt_encode_user_data(session_info, exp=expiry)
key = DeveloperKey(
key=new_key,
created_at=datetime.utcnow().isoformat(),
expires=expiry.isoformat(),
)
self.add_key(namespace, key)
return key

def is_key_bad(self, last_five_chars: str) -> bool:
return last_five_chars in self._bad_jwts


dev_key_handler = DeveloperKeyHandler()
20 changes: 19 additions & 1 deletion pephub/helpers.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import io
import zipfile
from datetime import date
from datetime import date, datetime, timedelta
from typing import Any, Dict, List, Tuple, Union

import jwt
import pandas as pd
import yaml
from fastapi import Response, UploadFile
Expand All @@ -15,6 +16,23 @@
SAMPLE_RAW_DICT_KEY,
SUBSAMPLE_RAW_LIST_KEY,
)
from .const import JWT_EXPIRATION, JWT_SECRET


def jwt_encode_user_data(user_data: dict, exp: datetime = None) -> str:
"""
Encode user data into a JWT token.
:param user_data: user data to encode
:param exp: expiration time for the token
"""
exp = exp or datetime.utcnow() + timedelta(minutes=JWT_EXPIRATION)
encoded_user_data = jwt.encode(
{**user_data, "exp": exp}, JWT_SECRET, algorithm="HS256"
)
if isinstance(encoded_user_data, bytes):
encoded_user_data = encoded_user_data.decode("utf-8")
return encoded_user_data


def zip_pep(project: Dict[str, Any]) -> Response:
Expand Down
23 changes: 23 additions & 0 deletions pephub/limiter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from fastapi import Request, HTTPException
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.util import get_remote_address
from slowapi.errors import RateLimitExceeded


limiter = Limiter(key_func=get_remote_address)


def _custom_rate_limit_exceeded_handler(request: Request, exc: RateLimitExceeded):
"""
Custom rate limit exceeded handler. Simple wrapper around slowapi's handler to ensure that
we properly raise an HTTPException with status code 429.
:param request: request object
:param exc: RateLimitExceeded exception
"""
_ = _rate_limit_exceeded_handler(request, exc)
raise HTTPException(
status_code=429,
detail="You are requesting too many new keys. Please try again later.",
)
8 changes: 8 additions & 0 deletions pephub/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@
import coloredlogs
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from slowapi.errors import RateLimitExceeded


from ._version import __version__ as server_v
from .const import ALL_VERSIONS, PKG_NAME, TAGS_METADATA
from .limiter import limiter, _custom_rate_limit_exceeded_handler
from .routers.api.v1.base import api as api_base
from .routers.api.v1.namespace import namespace as api_namespace
from .routers.api.v1.namespace import namespaces as api_namespaces
Expand Down Expand Up @@ -58,6 +61,11 @@
#
# # logfire.instrument_fastapi(app)

# rate limiting

app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _custom_rate_limit_exceeded_handler)

# CORS is required for the validation HTML SPA to work externally
origins = ["*"]
app.add_middleware(
Expand Down
2 changes: 1 addition & 1 deletion pephub/routers/api/v1/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@
get_db,
get_namespace_access_list,
get_namespace_info,
get_user_from_session_info,
read_authorization_header,
verify_user_can_write_namespace,
get_pepdb_namespace_info,
get_user_from_session_info,
)
from ....helpers import parse_user_file_upload, split_upload_files_on_init_file
from ...models import FavoriteRequest, ProjectJsonRequest, ProjectRawModel
Expand Down
50 changes: 49 additions & 1 deletion pephub/routers/auth/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@
from dotenv import load_dotenv
from fastapi import APIRouter, BackgroundTasks, Depends, Header, Request
from fastapi.exceptions import HTTPException
from fastapi.responses import RedirectResponse
from fastapi.responses import RedirectResponse, JSONResponse
from fastapi.templating import Jinja2Templates

from ...limiter import limiter

from ...const import (
AUTH_CODE_EXPIRATION,
BASE_TEMPLATES_PATH,
Expand All @@ -29,7 +31,9 @@
InitializeDeviceCodeResponse,
JWTDeviceTokenResponse,
TokenExchange,
RevokeRequest,
)
from ...developer_keys import dev_key_handler

load_dotenv()

Expand Down Expand Up @@ -66,6 +70,50 @@ def delete_device_code_after(code: str, expiration: int = AUTH_CODE_EXPIRATION):
DEVICE_CODES.pop(code, None)


@auth.get("/user/keys")
def get_user_keys(session_info: Union[dict, None] = Depends(read_authorization_header)):
if session_info:
keys = dev_key_handler.get_keys_for_namespace(session_info["login"])

# obfuscate the keys -- we never want to show the full key
for key in keys:
key.key = key.key[:5] + "*" * 10 + key.key[-5:]

return {"keys": keys}

else:
raise HTTPException(status_code=401, detail="Invalid token")


@auth.post("/user/keys")
@limiter.limit("5/minute")
def mint_user_key(
request: Request,
session_info: Union[dict, None] = Depends(read_authorization_header),
):
if session_info:
key = dev_key_handler.mint_key_for_namespace(
session_info["login"], session_info=session_info
)
return {"key": key}
else:
raise HTTPException(status_code=401, detail="Invalid token")


@auth.delete("/user/keys")
def delete_user_key(
revoke_request: RevokeRequest,
session_info: Union[dict, None] = Depends(read_authorization_header),
):
if session_info:
dev_key_handler.remove_key(
session_info["login"], revoke_request.last_five_chars
)
return JSONResponse({"message": "Key deleted successfully."}, status_code=202)
else:
raise HTTPException(status_code=401, detail="Invalid token")


@auth.get("/login", response_class=RedirectResponse)
def login(
client_redirect_uri: Union[str, None] = None,
Expand Down
Loading

0 comments on commit 83e7fe5

Please sign in to comment.