Skip to content

Commit

Permalink
feat: Secure Sourcing related backend APIs (#22)
Browse files Browse the repository at this point in the history
Co-authored-by: mfteloglu <mfteloglu@gmail.com>
  • Loading branch information
egekocabas and mfteloglu authored Jan 13, 2024
1 parent 62b17a3 commit 80e94e9
Show file tree
Hide file tree
Showing 19 changed files with 216 additions and 21 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/deploy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ env:
TF_VAR_GITHUB_TOKEN: ${{ secrets.MINING_GITHUB_TOKEN }}
TF_VAR_ANALYTICS_BASE_URL: ${{ github.event_name == 'release' && vars.PROD_ANALYTICS_BASE_URL || vars.STAGING_ANALYTICS_BASE_URL}}

# Analytics and Sourcing Auth Flow
TF_VAR_PARMA_SHARED_SECRET_KEY: ${{ github.event_name == 'release' && secrets.PROD_PARMA_SHARED_SECRET_KEY || secrets.STAGING_PARMA_SHARED_SECRET_KEY}}

jobs:
deploy:
name: Deploy - ${{ matrix.DEPLOYMENT_ENV }}
Expand Down
1 change: 1 addition & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ COPY --chown=$MAMBA_USER:$MAMBA_USER parma_mining /app/parma_mining

ENV GITHUB_TOKEN=$GITHUB_TOKEN
ENV ANALYTICS_BASE_URL=$ANALYTICS_BASE_URL
ENV PARMA_SHARED_SECRET_KEY=$PARMA_SHARED_SECRET_KEY


EXPOSE 8080
Expand Down
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,4 @@ dependencies:
- pydantic >=2
- pyyaml
- uvicorn >=0.23.2
- python-jose >=3.3.0
20 changes: 13 additions & 7 deletions parma_mining/github/analytics_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,11 @@ class AnalyticsClient:
feed_raw_url = urllib.parse.urljoin(analytics_base, "/feed-raw-data")
crawling_finished_url = urllib.parse.urljoin(analytics_base, "/crawling-finished")

def send_post_request(self, api_endpoint, data):
def send_post_request(self, token: str, api_endpoint, data):
"""Send a POST request to the given API endpoint with the given data."""
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {token}",
}

response = httpx.post(api_endpoint, json=data, headers=headers)
Expand All @@ -47,7 +48,9 @@ def send_post_request(self, api_endpoint, data):
f"response: {response.text}"
)

def register_measurements(self, mapping, parent_id=None, source_module_id=None):
def register_measurements(
self, token: str, mapping, parent_id=None, source_module_id=None
):
"""Register the given mapping as a measurement."""
result = []

Expand All @@ -66,7 +69,9 @@ def register_measurements(self, mapping, parent_id=None, source_module_id=None):
f"measurement {measurement_data['measurement_name']}"
)

response = self.send_post_request(self.measurement_url, measurement_data)
response = self.send_post_request(
token, self.measurement_url, measurement_data
)
measurement_data["source_measurement_id"] = response.get("id")

# add the source measurement id to mapping
Expand All @@ -76,6 +81,7 @@ def register_measurements(self, mapping, parent_id=None, source_module_id=None):

if "NestedMappings" in field_mapping:
nested_measurements = self.register_measurements(
token,
{"Mappings": field_mapping["NestedMappings"]},
parent_id=measurement_data["source_measurement_id"],
source_module_id=source_module_id,
Expand All @@ -84,7 +90,7 @@ def register_measurements(self, mapping, parent_id=None, source_module_id=None):
result.append(measurement_data)
return result, mapping

def feed_raw_data(self, input_data: ResponseModel):
def feed_raw_data(self, token: str, input_data: ResponseModel):
"""Feed the raw data to the analytics service."""
organization_json = json.loads(input_data.raw_data.updated_model_dump())

Expand All @@ -94,8 +100,8 @@ def feed_raw_data(self, input_data: ResponseModel):
"raw_data": organization_json,
}

return self.send_post_request(self.feed_raw_url, data)
return self.send_post_request(token, self.feed_raw_url, data)

def crawling_finished(self, data):
def crawling_finished(self, token, data):
"""Notify crawling is finished to the analytics."""
return self.send_post_request(self.crawling_finished_url, data)
return self.send_post_request(token, self.crawling_finished_url, data)
8 changes: 8 additions & 0 deletions parma_mining/github/api/dependencies/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""This package contains dependency functions used across the endpoints.
These dependencies are designed to provide reusable utility functions, such as
authentication and authorization checks, that can be injected into FastAPI route
handlers. By centralizing these dependencies, the application's code remains clean,
modular, and easy to maintain. Each module in this package is tailored to specific sets
of functionalities.
"""
55 changes: 55 additions & 0 deletions parma_mining/github/api/dependencies/auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
"""This module contains authentication and authorization dependencies.
Specifically designed for sourcing modules in the FastAPI application, it includes
functions to authenticate requests using JWTs and to authorize these requests by
validating the JWTs against defined secret keys. The module ensures that only valid and
authorized sourcing modules can access certain endpoints.
"""
import logging

from fastapi import HTTPException, Header, status

from parma_mining.mining_common.jwt_handler import JWTHandler

logger = logging.getLogger(__name__)


def authenticate(
authorization: str = Header(None),
) -> str:
"""Authenticate the incoming request using the JWT in the Authorization header.
Args:
authorization: The Authorization header containing the JWT.
Returns:
Extracted token from the Authorization header.
(Whenever a request is needed to be made to the Analytics Backend,
This token can be used to authenticate the request.)
Raises:
HTTPException: If the JWT is invalid.
HTTPException: If the JWT is expired.
HTTPException: If the Authorization header is missing.
"""
if authorization is None:
logger.error("Authorization header is required!")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Authorization header is required!",
)

token = (
authorization.split(" ")[1]
if authorization.startswith("Bearer ")
else authorization
)
is_verified: bool = JWTHandler.verify_jwt(token)
if is_verified is False:
logger.error("Invalid shared token or expired token")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid shared token or expired token",
)

return token
20 changes: 13 additions & 7 deletions parma_mining/github/api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
from datetime import datetime, timedelta

from dotenv import load_dotenv
from fastapi import FastAPI, status
from fastapi import Depends, FastAPI, status

from parma_mining.github.analytics_client import AnalyticsClient
from parma_mining.github.api.dependencies.auth import authenticate
from parma_mining.github.client import GitHubClient
from parma_mining.github.helper import collect_errors
from parma_mining.github.model import (
Expand Down Expand Up @@ -56,14 +57,14 @@ def root():


@app.get("/initialize", status_code=status.HTTP_200_OK)
def initialize(source_id: int) -> str:
def initialize(source_id: int, token: str = Depends(authenticate)) -> str:
"""Initialization endpoint for the API."""
# init frequency
time = "weekly"
normalization_map = GithubNormalizationMap().get_normalization_map()
# register the measurements to analytics
analytics_client.register_measurements(
normalization_map, source_module_id=source_id
token=token, mapping=normalization_map, source_module_id=source_id
)

# set and return results
Expand All @@ -77,7 +78,9 @@ def initialize(source_id: int) -> str:
"/companies",
status_code=status.HTTP_200_OK,
)
def get_organization_details(body: CompaniesRequest):
def get_organization_details(
body: CompaniesRequest, token: str = Depends(authenticate)
):
"""Endpoint to get detailed information about a dict of organizations."""
errors: dict[str, ErrorInfoModel] = {}
for company_id, company_data in body.companies.items():
Expand All @@ -100,7 +103,7 @@ def get_organization_details(body: CompaniesRequest):
)
# Write data to db via endpoint in analytics backend
try:
analytics_client.feed_raw_data(data)
analytics_client.feed_raw_data(token, data)
except AnalyticsError as e:
logger.error(
f"Can't send crawling data to the Analytics. Error: {e}"
Expand All @@ -113,11 +116,12 @@ def get_organization_details(body: CompaniesRequest):
collect_errors(company_id, errors, ClientInvalidBodyError(msg))

return analytics_client.crawling_finished(
token,
json.loads(
CrawlingFinishedInputModel(
task_id=body.task_id, errors=errors
).model_dump_json()
)
),
)


Expand All @@ -126,7 +130,9 @@ def get_organization_details(body: CompaniesRequest):
response_model=FinalDiscoveryResponse,
status_code=status.HTTP_200_OK,
)
def discover_companies(request: list[DiscoveryRequest]):
def discover_companies(
request: list[DiscoveryRequest], token: str = Depends(authenticate)
):
"""Endpoint to discover organizations based on provided names."""
if not request:
msg = "Request body cannot be empty for discovery"
Expand Down
1 change: 1 addition & 0 deletions parma_mining/mining_common/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Common functions for mining modules."""
44 changes: 44 additions & 0 deletions parma_mining/mining_common/jwt_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""Module for JWT (JSON Web Token) handling.
This module contains the JWTHandler class which is designed to verify JWTs. The
verification process supports shared secret keys to enable authentication.
"""
import logging
import os

from jose import jwt
from jose.exceptions import ExpiredSignatureError, JWTError

logger = logging.getLogger(__name__)


class JWTHandler:
"""A handler for verifying JWTs."""

SHARED_SECRET_KEY: str = str(
os.getenv("PARMA_SHARED_SECRET_KEY") or "PARMA_SHARED_SECRET_KEY"
)
ALGORITHM: str = "HS256"

@staticmethod
def verify_jwt(token: str) -> bool:
"""Verify a JWT using the shared secret key.
Args:
token: The JWT token to verify.
Returns:
True if the verification is successful.
False otherwise.
"""
try:
jwt.decode(
token, JWTHandler.SHARED_SECRET_KEY, algorithms=[JWTHandler.ALGORITHM]
)
return True
except ExpiredSignatureError:
logger.error("JWT has expired.")
except JWTError:
logger.error("Invalid JWT, unable to decode.")

return False
4 changes: 4 additions & 0 deletions terraform/module/service.tf
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ resource "google_cloud_run_service" "parma_mining_github_cloud_run" {
name = "ANALYTICS_BASE_URL"
value = var.ANALYTICS_BASE_URL
}
env {
name = "PARMA_SHARED_SECRET_KEY"
value = var.PARMA_SHARED_SECRET_KEY
}
}
}
}
Expand Down
8 changes: 8 additions & 0 deletions terraform/module/variables.tf
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,11 @@ variable "ANALYTICS_BASE_URL" {
description = "value"
type = string
}

/* ------------------------ Analytics and Sourcing Auth Flow ------------------------ */

variable "PARMA_SHARED_SECRET_KEY" {
description = "Shared secret key for the analytics and sourcing auth flow"
type = string
sensitive = true
}
2 changes: 2 additions & 0 deletions terraform/prod/main.tf
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,6 @@ module "main" {
region = local.region
GITHUB_TOKEN = var.GITHUB_TOKEN
ANALYTICS_BASE_URL = var.ANALYTICS_BASE_URL

PARMA_SHARED_SECRET_KEY = var.PARMA_SHARED_SECRET_KEY
}
8 changes: 8 additions & 0 deletions terraform/prod/variables.tf
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,11 @@ variable "ANALYTICS_BASE_URL" {
description = "value"
type = string
}

/* ------------------------ Analytics and Sourcing Auth Flow ------------------------ */

variable "PARMA_SHARED_SECRET_KEY" {
description = "Shared secret key for the analytics and sourcing auth flow"
type = string
sensitive = true
}
2 changes: 2 additions & 0 deletions terraform/staging/main.tf
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,6 @@ module "main" {
region = local.region
GITHUB_TOKEN = var.GITHUB_TOKEN
ANALYTICS_BASE_URL = var.ANALYTICS_BASE_URL

PARMA_SHARED_SECRET_KEY = var.PARMA_SHARED_SECRET_KEY
}
8 changes: 8 additions & 0 deletions terraform/staging/variables.tf
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,11 @@ variable "ANALYTICS_BASE_URL" {
description = "value"
type = string
}

/* ------------------------ Analytics and Sourcing Auth Flow ------------------------ */

variable "PARMA_SHARED_SECRET_KEY" {
description = "Shared secret key for the analytics and sourcing auth flow"
type = string
sensitive = true
}
Empty file added tests/__init__.py
Empty file.
Empty file added tests/dependencies/__init__.py
Empty file.
25 changes: 25 additions & 0 deletions tests/dependencies/mock_auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
"""Mock implementations of auth function for testing.
This module provides mock versions of authentication function. This mock function is
designed for use in test environments where actual authentication process is not
required.
"""
import logging

from fastapi import Header

logger = logging.getLogger(__name__)


def mock_authenticate(
authorization: str = Header(None),
) -> str:
"""Authenticate the incoming request using the JWT in the Authorization header.
Args:
authorization: The Authorization header containing the JWT.
Returns:
Dummy token for testing purposes.
"""
return "dummytoken"
Loading

0 comments on commit 80e94e9

Please sign in to comment.