Skip to content

Commit

Permalink
more refactoring of portal_utils
Browse files Browse the repository at this point in the history
  • Loading branch information
dmichaels-harvard committed Dec 19, 2023
1 parent b63d4b2 commit 14e9f9f
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 71 deletions.
154 changes: 90 additions & 64 deletions dcicutils/portal_utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
from collections import deque
import io
import json
from pyramid.paster import get_app
from pyramid.router import Router
from pyramid.config import Configurator as PyramidConfigurator
from pyramid.paster import get_app as pyramid_get_app
from pyramid.response import Response as PyramidResponse
from pyramid.router import Router as PyramidRouter
import os
import re
import requests
from requests.models import Response as RequestResponse
from typing import Optional, Type, Union
from typing import Callable, Dict, List, Optional, Type, Union
from uuid import uuid4 as uuid
# from waitress import serve
from webtest.app import TestApp, TestResponse
from dcicutils.common import OrchestratedApp, ORCHESTRATED_APPS
from dcicutils.ff_utils import get_metadata, get_schema, patch_metadata, post_metadata
Expand Down Expand Up @@ -37,7 +41,7 @@ class Portal:
KEYS_FILE_DIRECTORY = os.path.expanduser(f"~")

def __init__(self,
arg: Optional[Union[Portal, TestApp, VirtualApp, Router, dict, tuple, str]] = None,
arg: Optional[Union[Portal, TestApp, VirtualApp, PyramidRouter, dict, tuple, str]] = None,
env: Optional[str] = None, server: Optional[str] = None,
app: Optional[OrchestratedApp] = None) -> None:

Expand Down Expand Up @@ -69,14 +73,14 @@ def init_from_portal(portal: Portal, unspecified: Optional[list] = None) -> None
self._app = portal._app
self._vapp = portal._vapp

def init_from_vapp(vapp: Union[TestApp, VirtualApp, Router], unspecified: Optional[list] = []) -> None:
def init_from_vapp(vapp: Union[TestApp, VirtualApp, PyramidRouter], unspecified: Optional[list] = []) -> None:
init(unspecified)
self._vapp = Portal._create_testapp(vapp)
self._vapp = Portal._create_vapp(vapp)

def init_from_ini_file(ini_file: str, unspecified: Optional[list] = []) -> None:
init(unspecified)
self._ini_file = ini_file
self._vapp = Portal._create_testapp(ini_file)
self._vapp = Portal._create_vapp(ini_file)

def init_from_key(key: dict, server: Optional[str], unspecified: Optional[list] = []) -> None:
init(unspecified)
Expand Down Expand Up @@ -138,7 +142,7 @@ def normalize_server(server: str) -> Optional[str]:

if isinstance(arg, Portal):
init_from_portal(arg, unspecified=[env, server, app])
elif isinstance(arg, (TestApp, VirtualApp, Router)):
elif isinstance(arg, (TestApp, VirtualApp, PyramidRouter)):
init_from_vapp(arg, unspecified=[env, server, app])
elif isinstance(arg, str) and arg.endswith(".ini"):
init_from_ini_file(arg, unspecified=[env, server, app])
Expand Down Expand Up @@ -201,35 +205,36 @@ def get_metadata(self, object_id: str) -> Optional[dict]:
def patch_metadata(self, object_id: str, data: str) -> Optional[dict]:
if self._key:
return patch_metadata(obj_id=object_id, patch_item=data, key=self._key)
return self.patch(f"/{object_id}", data)
return self.patch(f"/{object_id}", data).json()

def post_metadata(self, object_type: str, data: str) -> Optional[dict]:
if self._key:
return post_metadata(schema_name=object_type, post_item=data, key=self._key)
return self.post(f"/{object_type}", data)
return self.post(f"/{object_type}", data).json()

def get(self, uri: str, follow: bool = True, **kwargs) -> Optional[Union[RequestResponse, TestResponse]]:
if self._vapp:
response = self._vapp.get(self.url(uri), **self._kwargs(**kwargs))
if response and response.status_code in [301, 302, 303, 307, 308] and follow:
response = response.follow()
return self._response(response)
return requests.get(self.url(uri), allow_redirects=follow, **self._kwargs(**kwargs))
if not self._vapp:
return requests.get(self.url(uri), allow_redirects=follow, **self._kwargs(**kwargs))
response = self._vapp.get(self.url(uri), **self._kwargs(**kwargs))
if response and response.status_code in [301, 302, 303, 307, 308] and follow:
response = response.follow()
return self._response(response)

def patch(self, uri: str, data: Optional[dict] = None,
json: Optional[dict] = None, **kwargs) -> Optional[Union[RequestResponse, TestResponse]]:
if self._vapp:
return self._vapp.patch_json(self.url(uri), json or data, **self._kwargs(**kwargs))
return requests.patch(self.url(uri), data=data, json=json, **self._kwargs(**kwargs))
if not self._vapp:
return requests.patch(self.url(uri), data=data, json=json, **self._kwargs(**kwargs))
return self._response(self._vapp.patch_json(self.url(uri), json or data, **self._kwargs(**kwargs)))

def post(self, uri: str, data: Optional[dict] = None, json: Optional[dict] = None,
files: Optional[dict] = None, **kwargs) -> Optional[Union[RequestResponse, TestResponse]]:
if self._vapp:
if files:
return self._vapp.post(self.url(uri), json or data, upload_files=files, **self._kwargs(**kwargs))
else:
return self._vapp.post_json(self.url(uri), json or data, upload_files=files, **self._kwargs(**kwargs))
return requests.post(self.url(uri), data=data, json=json, files=files, **self._kwargs(**kwargs))
if not self._vapp:
return requests.post(self.url(uri), data=data, json=json, files=files, **self._kwargs(**kwargs))
if files:
response = self._vapp.post(self.url(uri), json or data, upload_files=files, **self._kwargs(**kwargs))
else:
response = self._vapp.post_json(self.url(uri), json or data, upload_files=files, **self._kwargs(**kwargs))
return self._response(response)

def get_schema(self, schema_name: str) -> Optional[dict]:
return get_schema(self.schema_name(schema_name), portal_vapp=self._vapp, key=self._key)
Expand Down Expand Up @@ -311,7 +316,7 @@ def infer_app_from_env(env: str) -> Optional[str]: # noqa
if is_valid_app(app) or (app := infer_app_from_env(env)):
return os.path.join(Portal.KEYS_FILE_DIRECTORY, f".{app.lower()}-keys.json")

def _response(self, response) -> Optional[RequestResponse]:
def _response(self, response: TestResponse) -> Optional[RequestResponse]:
if response and isinstance(getattr(response.__class__, "json"), property):
class RequestResponseWrapper: # For consistency change json property to method.
def __init__(self, response, **kwargs):
Expand All @@ -325,51 +330,72 @@ def json(self): # noqa
return response

@staticmethod
def create_for_testing(ini_file: Optional[str] = None) -> Portal:
if isinstance(ini_file, str):
return Portal(Portal._create_testapp(ini_file))
minimal_ini_for_unit_testing = "[app:app]\nuse = egg:encoded\nsqlalchemy.url = postgresql://dummy\n"
with temporary_file(content=minimal_ini_for_unit_testing, suffix=".ini") as ini_file:
return Portal(Portal._create_testapp(ini_file))

@staticmethod
def create_for_testing_local(ini_file: Optional[str] = None) -> Portal:
if isinstance(ini_file, str) and ini_file:
return Portal(Portal._create_testapp(ini_file))
minimal_ini_for_testing_local = "\n".join([
"[app:app]\nuse = egg:encoded\nfile_upload_bucket = dummy",
"sqlalchemy.url = postgresql://postgres@localhost:5441/postgres?host=/tmp/snovault/pgdata",
"multiauth.groupfinder = encoded.authorization.smaht_groupfinder",
"multiauth.policies = auth0 session remoteuser accesskey",
"multiauth.policy.session.namespace = mailto",
"multiauth.policy.session.use = encoded.authentication.NamespacedAuthenticationPolicy",
"multiauth.policy.session.base = pyramid.authentication.SessionAuthenticationPolicy",
"multiauth.policy.remoteuser.namespace = remoteuser",
"multiauth.policy.remoteuser.use = encoded.authentication.NamespacedAuthenticationPolicy",
"multiauth.policy.remoteuser.base = pyramid.authentication.RemoteUserAuthenticationPolicy",
"multiauth.policy.accesskey.namespace = accesskey",
"multiauth.policy.accesskey.use = encoded.authentication.NamespacedAuthenticationPolicy",
"multiauth.policy.accesskey.base = encoded.authentication.BasicAuthAuthenticationPolicy",
"multiauth.policy.accesskey.check = encoded.authentication.basic_auth_check",
"multiauth.policy.auth0.use = encoded.authentication.NamespacedAuthenticationPolicy",
"multiauth.policy.auth0.namespace = auth0",
"multiauth.policy.auth0.base = encoded.authentication.Auth0AuthenticationPolicy"
])
with temporary_file(content=minimal_ini_for_testing_local, suffix=".ini") as minimal_ini_file:
return Portal(Portal._create_testapp(minimal_ini_file))
def create_for_testing(arg: Optional[Union[str, bool, List[dict], dict, Callable]] = None) -> Portal:
if isinstance(arg, list) or isinstance(arg, dict) or isinstance(arg, Callable):
return Portal(Portal._create_router_for_testing(arg))
if isinstance(arg, str) and arg.endswith(".ini"):
return Portal(Portal._create_vapp(arg))
if arg == "local" or arg is True:
minimal_ini_for_testing = "\n".join([
"[app:app]\nuse = egg:encoded\nfile_upload_bucket = dummy",
"sqlalchemy.url = postgresql://postgres@localhost:5441/postgres?host=/tmp/snovault/pgdata",
"multiauth.groupfinder = encoded.authorization.smaht_groupfinder",
"multiauth.policies = auth0 session remoteuser accesskey",
"multiauth.policy.session.namespace = mailto",
"multiauth.policy.session.use = encoded.authentication.NamespacedAuthenticationPolicy",
"multiauth.policy.session.base = pyramid.authentication.SessionAuthenticationPolicy",
"multiauth.policy.remoteuser.namespace = remoteuser",
"multiauth.policy.remoteuser.use = encoded.authentication.NamespacedAuthenticationPolicy",
"multiauth.policy.remoteuser.base = pyramid.authentication.RemoteUserAuthenticationPolicy",
"multiauth.policy.accesskey.namespace = accesskey",
"multiauth.policy.accesskey.use = encoded.authentication.NamespacedAuthenticationPolicy",
"multiauth.policy.accesskey.base = encoded.authentication.BasicAuthAuthenticationPolicy",
"multiauth.policy.accesskey.check = encoded.authentication.basic_auth_check",
"multiauth.policy.auth0.use = encoded.authentication.NamespacedAuthenticationPolicy",
"multiauth.policy.auth0.namespace = auth0",
"multiauth.policy.auth0.base = encoded.authentication.Auth0AuthenticationPolicy"
])
else:
minimal_ini_for_testing = "[app:app]\nuse = egg:encoded\nsqlalchemy.url = postgresql://dummy\n"
with temporary_file(content=minimal_ini_for_testing, suffix=".ini") as ini_file:
return Portal(Portal._create_vapp(ini_file))

@staticmethod
def _create_testapp(arg: Union[TestApp, VirtualApp, Router, str] = None, app_name: Optional[str] = None) -> TestApp:
def _create_vapp(arg: Union[TestApp, VirtualApp, PyramidRouter, str] = None) -> TestApp:
if isinstance(arg, TestApp):
return arg
elif isinstance(arg, VirtualApp):
if not isinstance(arg.wrapped_app, TestApp):
raise Exception("Portal._create_testapp VirtualApp argument error.")
raise Exception("Portal._create_vapp VirtualApp argument error.")
return arg.wrapped_app
if isinstance(arg, Router):
if isinstance(arg, PyramidRouter):
router = arg
elif isinstance(arg, str) or arg is None:
router = get_app(arg or "development.ini", app_name or "app")
elif isinstance(arg, str) or not arg:
router = pyramid_get_app(arg or "development.ini", "app")
else:
raise Exception("Portal._create_testapp argument error.")
raise Exception("Portal._create_vapp argument error.")
return TestApp(router, {"HTTP_ACCEPT": "application/json", "REMOTE_USER": "TEST"})

@staticmethod
def _create_router_for_testing(endpoints: Optional[List[Dict[str, Union[str, Callable]]]] = None):
if isinstance(endpoints, dict):
endpoints = [endpoints]
elif isinstance(endpoints, Callable):
endpoints = [{"path": "/", "method": "GET", "function": endpoints}]
if not isinstance(endpoints, list) or not endpoints:
endpoints = [{"path": "/", "method": "GET", "function": lambda request: {"status": "OK"}}]
with PyramidConfigurator() as config:
nendpoints = 0
for endpoint in endpoints:
if (endpoint_path := endpoint.get("path")) and (endpoint_function := endpoint.get("function")):
endpoint_method = endpoint.get("method", "GET")
def endpoint_wrapper(request):
response = endpoint_function(request)
return PyramidResponse(json.dumps(response), content_type="application/json; charset=utf-8")
endpoint_id = str(uuid())
config.add_route(endpoint_id, endpoint_path)
config.add_view(endpoint_wrapper, route_name=endpoint_id, request_method=endpoint_method)
nendpoints += 1
if nendpoints == 0:
return Portal._create_router_for_testing([])
return config.make_wsgi_app()
11 changes: 4 additions & 7 deletions dcicutils/structured_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pyramid.router import Router
import re
import sys
from typing import Any, Callable, List, Optional, Tuple, Type, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
from webtest.app import TestApp
from dcicutils.common import OrchestratedApp
from dcicutils.data_readers import CsvReader, Excel, RowReader
Expand Down Expand Up @@ -616,12 +616,9 @@ def _ref_exists_single(self, type_name: str, value: str) -> bool:
return self.get_metadata(f"/{type_name}/{value}") is not None

@staticmethod
def create_for_testing(ini_file: Optional[str] = None, schemas: Optional[List[dict]] = None) -> Portal:
return Portal(PortalBase.create_for_testing(ini_file), schemas=schemas)

@staticmethod
def create_for_testing_local(ini_file: Optional[str] = None, schemas: Optional[List[dict]] = None) -> Portal:
return Portal(PortalBase.create_for_testing_local(ini_file), schemas=schemas)
def create_for_testing(arg: Optional[Union[str, bool, List[dict], dict, Callable]] = None,
schemas: Optional[List[dict]] = None) -> Portal:
return Portal(PortalBase.create_for_testing(arg), schemas=schemas)


def _split_dotted_string(value: str):
Expand Down

0 comments on commit 14e9f9f

Please sign in to comment.