diff --git a/arxiv/auth/README.md b/arxiv/auth/README.md new file mode 100644 index 00000000..56698f3b --- /dev/null +++ b/arxiv/auth/README.md @@ -0,0 +1,57 @@ +# ``arxiv-auth`` Library + +This package provides a Flask add on and other code for working with arxiv +authenticated users in arXiv services. + +Housing these components in arxiv-base ensures that users +and sessions are represented and manipulated consistently. The login+logout, user +accounts(TBD), API client registry(TBD), and authenticator(TBD) services all +rely on this package. + +# Quick start +For use-cases to check if a request is from an authenticated arxiv user, do the +following: + +1. Add arxiv-base to your dependencies +2. Install :class:`arxiv.auth.auth.Auth` onto your application. This adds a + function called for each request to Flask that adds an instance of + :class:`arxiv.auth.domain.Session` at ``flask.request.auth`` if the client is + authenticated. +3. Add to the ``flask.config`` to setup :class:`arxiv_auth.auth.Auth` and + related classes + +Here's an example of how you might do #2 and #3: +``` + from flask import Flask + from arxiv.base import Base + from arxiv.auth.auth import auth + + app = Flask(__name__) + Base(app) + + # config settings required to use legacy auth + app.config['CLASSIC_SESSION_HASH'] = '{hash_private_secret}' + app.config['CLASSIC_DB_URI'] = '{your_sqlalchemy_db_uri_to_legacy_db'} + app.config['SESSION_DURATION'] = 36000 + app.config['CLASSIC_COOKIE_NAME'] = 'tapir_session' + + auth.Auth(app) # <- Install the Auth to get auth checks and request.auth + + @app.route("/") + def are_you_logged_in(): + if request.auth is not None: + return "

Hello, You are logged in.

" + else: + return "

Hello unknown client.

" +``` + +# Middleware + +In during NG there was middleware for arxiv-auth that could be used in NGINX to +do the authentication there. As of 2023 it is not in use. + +See :class:`arxiv.auth.auth.middleware.AuthMiddleware` + +If you are not deploying this application in the cloud behind NGINX (and +therefore will not support sessions from the distributed store), you do not +need the auth middleware. diff --git a/arxiv/auth/__init__.py b/arxiv/auth/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/arxiv/auth/app.py b/arxiv/auth/app.py new file mode 100644 index 00000000..6046b2e4 --- /dev/null +++ b/arxiv/auth/app.py @@ -0,0 +1,10 @@ +"""test app.""" + +import sys +sys.path.append('./arxiv') + +from flask import Flask +from arxiv_users import auth, legacy + +app = Flask('test') +legacy.create_all() diff --git a/arxiv/auth/auth/__init__.py b/arxiv/auth/auth/__init__.py new file mode 100644 index 00000000..8d91c78c --- /dev/null +++ b/arxiv/auth/auth/__init__.py @@ -0,0 +1,199 @@ +"""Provides tools for working with authenticated user/client sessions.""" + +from typing import Optional, Union, Any, List +import os + +from werkzeug.datastructures.structures import MultiDict + +from flask import Flask, request, Response +from retry import retry + +from ...db import transaction +from ..legacy import util +from ..legacy.cookies import parse_cookie +from .. import domain, legacy + +import logging + +logger = logging.getLogger(__name__) + + +class Auth(object): + """ + Attaches session and authentication information to the request. + + Set env var or `Flask.config` `ARXIV_AUTH_DEBUG` to True to get + additional debugging in the logs. Only use this for short term debugging of + configs. This may be used in produciton but should not be left on in production. + + Intended for use in a Flask application factory, for example: + + .. code-block:: python + + from flask import Flask + from arxiv.users.auth import Auth + from someapp import routes + + + def create_web_app() -> Flask: + app = Flask('someapp') + app.config.from_pyfile('config.py') + Auth(app) # Registers the before_reques auth check + + @app.route("/hello") + def hello(): + if request.auth: + return f"Hello {request.auth.user.name}!" + else: + return f"Hello world! (not authenticated)" + + return app + + + """ + + def __init__(self, app: Optional[Flask] = None) -> None: + """ + Initialize ``app`` with `Auth`. + + Parameters + ---------- + app : :class:`Flask` + + """ + if app is not None: + self.init_app(app) + if self.app.config.get('AUTH_UPDATED_SESSION_REF'): + self.auth_session_name = "auth" + else: + self.auth_session_name = "session" + + @retry(legacy.exceptions.Unavailable, tries=3, delay=0.5, backoff=2) + def _get_legacy_session(self, + cookie_value: str) -> Optional[domain.Session]: + """ + Attempt to load a legacy auth session. + + Returns + ------- + :class:`domain.Session` or None + + """ + if cookie_value is None: + return None + try: + with transaction(): + return legacy.sessions.load(cookie_value) + except legacy.exceptions.UnknownSession as e: + logger.debug('No legacy session available: %s', e) + except legacy.exceptions.InvalidCookie as e: + logger.debug('Invalid legacy cookie: %s', e) + except legacy.exceptions.SessionExpired as e: + logger.debug('Legacy session is expired: %s', e) + return None + + def init_app(self, app: Flask) -> None: + """ + Attach :meth:`.load_session` to the Flask app. + + Parameters + ---------- + app : :class:`Flask` + + """ + self.app = app + app.config['arxiv_auth.Auth'] = self + + if app.config.get('ARXIV_AUTH_DEBUG') or os.getenv('ARXIV_AUTH_DEBUG'): + self.auth_debug() + logger.debug("ARXIV_AUTH_DEBUG is set and auth debug messages to logging are turned on") + + self.app.before_request(self.load_session) + self.app.config.setdefault('DEFAULT_LOGOUT_REDIRECT_URL', + 'https://arxiv.org') + self.app.config.setdefault('DEFAULT_LOGIN_REDIRECT_URL', + 'https://arxiv.org') + + if app.config.get('ARXIV_AUTH_DEBUG') or os.getenv('ARXIV_AUTH_DEBUG'): + self.auth_debug() + logger.debug("ARXIV_AUTH_DEBUG is set and auth debug messages to logging is turned on") + + + def load_session(self) -> Optional[Response]: + """Look for an active session, and attach it to the request. + + The typical scenario will involve the + :class:`.middleware.AuthMiddleware` unpacking a session token and + adding it to the WSGI request environ. + + As a fallback, if the legacy database is available, this method will + also attempt to load an active legacy session. + + """ + # Check the WSGI request environ for the key, which is where the auth + # middleware puts any unpacked auth information from the request OR any + # exceptions that need to be raised withing the request context. + req_auth: Optional[Union[domain.Session, Exception]] = \ + request.environ.get(self.auth_session_name) + + # Middlware may raise exception, needs to be raised in to be handled correctly. + if isinstance(req_auth, Exception): + logger.debug('Middleware passed an exception: %s', req_auth) + raise req_auth + + if not req_auth: + if util.is_configured(): + req_auth = self.first_valid(self.legacy_cookies()) + else: + logger.warning('No legacy DB, will not check tapir auth.') + + # Attach auth to the request so other can access easily. request.auth + setattr(request, self.auth_session_name, req_auth) + return None + + def first_valid(self, cookies: List[str]) -> Optional[domain.Session]: + """First valid legacy session or None if there are none.""" + first = next(filter(bool, + map(self._get_legacy_session, + cookies)), None) + + if first is None: + logger.debug("Out of %d cookies, no legacy cookie found", len(cookies)) + else: + logger.debug("Out of %d cookies, found a good legacy cookie", len(cookies)) + + return first + + def legacy_cookies(self) -> List[str]: + """Gets list of legacy cookies. + + Duplicate cookies occur due to the browser sending both the + cookies for both arxiv.org and sub.arxiv.org. If this is being + served at sub.arxiv.org, there is no response that will cause + the browser to alter its cookie store for arxiv.org. Duplicate + cookies must be handled gracefully to for the domain and + subdomain to coexist. + + The standard way to avoid this problem is to append part of + the domain's name to the cookie key but this needs to work + even if the configuration is not ideal. + + """ + # By default, werkzeug uses a dict-based struct that supports only a + # single value per key. This isn't really up to speed with RFC 6265. + # Luckily we can just pass in an alternate struct to parse_cookie() + # that can cope with multiple values. + raw_cookie = request.environ.get('HTTP_COOKIE', None) + if raw_cookie is None: + return [] + cookies = parse_cookie(raw_cookie, cls=MultiDict) + return cookies.getlist(self.app.config['CLASSIC_COOKIE_NAME']) + + def auth_debug(self) -> None: + """Sets several auth loggers to DEBUG. + + This is useful to get an idea of what is going on with auth. + """ + logger.setLevel(logging.DEBUG) + legacy.sessions.logger.setLevel(logging.DEBUG) + legacy.authenticate.logger.setLevel(logging.DEBUG) diff --git a/arxiv/auth/auth/decorators.py b/arxiv/auth/auth/decorators.py new file mode 100644 index 00000000..4d4ddb9c --- /dev/null +++ b/arxiv/auth/auth/decorators.py @@ -0,0 +1,248 @@ +""" +Scope-based authorization of user/client requests. + +This module provides :func:`scoped`, a decorator factory used to protect Flask +routes for which authorization is required. This is done by specifying a +required authorization scope (see :mod:`arxiv.users.auth.scopes`) and/or by +providing a custom authorizer function. + +For routes that involve specific resources, a ``resource`` callback should also +be provided. That callback function should accept the same arguments as the +route function, and return the identifier for the resource as a string. + +Using :func:`scoped` with an authorizer function allows you to define +application-specific authorization logic on a per-request basis without adding +complexity to request controllers. The call signature of the authorizer +function should be: ``(session: domain.Session, *args, **kwargs) -> bool``, +where `*args` and `**kwargs` are the positional and keyword arguments, +respectively, passed by Flask to the decorated route function (e.g. the +URL parameters). + +.. note:: The authorizer function is only called if the session does not have + a global or resource-specific instance of the required scope, or if a + required scope is not specified. + +Here's an example of how you might use this in a Flask application: + +.. code-block:: python + + from arxiv.users.auth.decorators import scoped + from arxiv.users.auth import scopes + from arxiv.users import domain + + + def is_owner(session: domain.Session, user_id: str, **kwargs) -> bool: + '''Check whether the authenticated user matches the requested user.''' + return session.user.user_id == user_id + + + def get_resource_id(user_id: str) -> str: + '''Get the user ID from the request.''' + return user_id + + + def redirect_to_login(user_id: str) -> Response: + '''Send the unauthorized user to the log in page.''' + return url_for('login') + + + @blueprint.route('//profile', methods=['GET']) + @scoped(scopes.EDIT_PROFILE, resource=get_resource_id, + authorizer=user_is_owner, unauthorized=redirect_to_login) + def edit_profile(user_id: str): + '''User can update their account information.''' + data, code, headers = profile.get_profile(user_id) + return render_template('accounts/profile.html', **data) + + +When the decorated route function is called... + +- If no session is available from either the middleware or the legacy database, + the ``unauthorized`` callback is called, and/or :class:`Unauthorized` + exception is raised. +- If a required scope was provided, the session is checked for the presence of + that scope in this order: + + - Global scope (`:*`), e.g. for administrators. + - Resource-specific scope (`:[resource_id]`), i.e. explicitly granted for a + particular resource. + - Generic scope (no resource part). + +- If an authorization function was provided, the function is called only if + a required scope was not provided, or if only the generic scope was found. +- Session data is added directly to the Flask request object as + ``request.auth``, for ease of access elsewhere in the application. +- Finally, if no exceptions have been raised, the route is called with the + original parameters. + +""" + +from typing import Optional, Callable, Any, List +from functools import wraps +from flask import request +from werkzeug.exceptions import Unauthorized, Forbidden +import logging + +from .. import domain + +INVALID_TOKEN = {'reason': 'Invalid authorization token'} +INVALID_SCOPE = {'reason': 'Token not authorized for this action'} + + +logger = logging.getLogger(__name__) +logger.propagate = False + + +def scoped(required: Optional[domain.Scope] = None, + resource: Optional[Callable] = None, + authorizer: Optional[Callable] = None, + unauthorized: Optional[Callable] = None) -> Callable: + """ + Generate a decorator to enforce authorization requirements. + + Parameters + ---------- + required : str + The scope required on a user or client session in order use the + decorated route. See :mod:`arxiv.users.auth.scopes`. If not provided, + no scope will be enforced. + resource : function + If a route provides actions for a specific resource, a callable should + be provided that accepts the route arguments and returns the resource + identifier (str). + authorizer : function + In addition, an authorizer function may be passed to provide more + specific authorization checks. For example, this function may check + that the requesting user is the owner of a resource. Should have the + signature: ``(session: domain.Session, *args, **kwargs) -> bool``. + ``*args`` and ``**kwargs`` are the parameters passed to the decorated + function. If the authorizer returns ``False``, an :class:`.` + exception is raised. + unauthorized : function + DEPRECATED: Do not use this parameter. It is likely it does not do + what you expect. + A callback may be passed to handle cases in which the request is + unauthorized. This function will be passed the same arguments as the + original route function. If the callback returns anything other than + ``None``, the return value will be treated as a response and execution + will stop. Otherwise, an ``Unauthorized`` exception will be raised. + If a callback is not provided (default) an ``Unauthorized`` exception + will be raised. + + Returns + ------- + function + A decorator that enforces the required scope and calls the (optionally) + provided authorizer. + + """ + if required and not isinstance(required, domain.Scope): + required = domain.Scope(required) + + def protector(func: Callable) -> Callable: + """Decorator that provides scope enforcement.""" + @wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> Any: + """ + Check the authorization token before executing the method. + + Will also raise exceptions passed by the auth middleware. + + Raises + ------ + :class:`.Unauthorized` + Raised when session data is not available. + :class:`.` + Raised when the session has insufficient auth scope, or the + provided authorizer returns ``False``. + + """ + if hasattr(request, 'auth'): + session = request.auth + elif hasattr(request, 'session'): + session = request.session + else: + raise Unauthorized('No active session on request') + scopes: List[domain.Scope] = [] + authorized: bool = False + logger.debug('Required: %s, authorizer: %s, unauthorized: %s', + required, authorizer, unauthorized) + # Use of the decorator implies that an auth session ought to be + # present. So we'll complain here if it's not. + if not session or not (session.user or session.client): + logger.debug('No valid session; aborting') + if unauthorized is not None: + response = unauthorized(*args, **kwargs) + if response is not None: + return response + raise Unauthorized('Not a valid session') + + if session.authorizations is not None: + scopes = session.authorizations.scopes + logger.debug('session has scopes: %s', scopes) + + # If a required scope is provided, we first check to see whether + # the session globally or explicitly authorizes the request. We + # then fall back to the locally-defined authorizer function if it + # is provided. + if required and scopes: + # A global scope is usually granted to administrators, or + # perhaps moderators (e.g. view submission content). + # For example: `submission:read:*`. + if required.as_global() in scopes: + logger.debug('Authorized with global scope') + authorized = True + + # A resource-specific scope may be granted at the auth layer. + # For example, an admin may provide provisional access to a + # specific resource for a specific role. This kind of + # authorization is only supported if the service provides a + # ``resource()`` callback to get the resource identifier. + elif (resource is not None + and ( + required.for_resource(str(resource(*args, **kwargs))) + in scopes)): + logger.debug('Authorized by specific resource') + authorized = True + + # If both the global and resource-specific scope authorization + # fail, then we look for the general scope in the session. + elif required in scopes: + logger.debug('Required scope is present') + # If an authorizer callback is provided by the service, + # then we will enforce whatever it returns. + if authorizer: + authorized = authorizer(session, *args, **kwargs) + logger.debug('Authorizer func returned %s', authorized) + # If no authorizer callback is provided, it is implied that + # the general scope is sufficient to authorize the request. + elif authorizer is None: + logger.debug('No authorizer func provided') + authorized = True + # The required scope is not present. There is nothing left to + # check. + else: + logger.debug('Required scope is not present') + authorized = False + + elif required is None and authorizer is None: + logger.debug('No scope required, no authorizer function;' + ' request is authorized.') + authorized = True + + # If a specific scope is not required, we rely entirely on the + # authorizer callback. + elif authorizer is not None: + logger.debug('Calling authorizer callback') + authorized = authorizer(session, *args, **kwargs) + else: + logger.debug('No authorization path available') + + if not authorized: + logger.debug('Session is not authorized') + raise Forbidden('Access denied') + + logger.debug('Request is authorized, proceeding') + return func(*args, **kwargs) + return wrapper + return protector diff --git a/arxiv/auth/auth/exceptions.py b/arxiv/auth/auth/exceptions.py new file mode 100644 index 00000000..e15deccd --- /dev/null +++ b/arxiv/auth/auth/exceptions.py @@ -0,0 +1,29 @@ +"""Authn/z-related exceptions raised by components in this module.""" + + +class InvalidToken(ValueError): + """Token in request is not valid.""" + + +class MissingToken(ValueError): + """No token found in request.""" + + +class ConfigurationError(RuntimeError): + """The application is not configured correctly.""" + + +class SessionCreationFailed(RuntimeError): + """Failed to create a session in the session store.""" + + +class SessionDeletionFailed(RuntimeError): + """Failed to delete a session in the session store.""" + + +class UnknownSession(RuntimeError): + """Failed to locate a session in the session store.""" + + +class ExpiredToken(RuntimeError): + """An expired token was passed.""" diff --git a/arxiv/auth/auth/middleware.py b/arxiv/auth/auth/middleware.py new file mode 100644 index 00000000..be4c63db --- /dev/null +++ b/arxiv/auth/auth/middleware.py @@ -0,0 +1,105 @@ +""" +Middleware for interpreting authn/z information on requestsself. + +This module provides :class:`AuthMiddleware`, which unpacks encrypted JSON +Web Tokens provided via the ``Authorization`` header. This is intended to +support requests that have been pre-authorized by the web server using the +authenticator service (see :mod:`authenticator`). + +The configuration parameter ``JWT_SECRET`` must be set in the WSGI request +environ (e.g. Apache's SetEnv) or in the runtime environment. This must be +the same secret that was used by the authenticator service to mint the token. + +To install the middleware, use the pattern described in +:mod:`arxiv.base.middleware`. For example: + +.. code-block:: python + + from arxiv.base import Base + from arxiv.base.middleware import wrap + from arxiv.users import auth + + + def create_web_app() -> Flask: + app = Flask('foo') + Base(app) + auth.Auth(app) + wrap(app, [auth.middleware.AuthMiddleware]) + return app + + +For convenience, this is intended to be used with +:mod:`arxiv.users.auth.decorators`. + +""" + +import os +from typing import Callable, Iterable, Tuple +import jwt +import logging + +from werkzeug.exceptions import Unauthorized, InternalServerError + +from arxiv.base.middleware import BaseMiddleware + +from . import tokens +from .exceptions import InvalidToken, ConfigurationError, MissingToken +from .. import domain + +logger = logging.getLogger(__name__) + +WSGIRequest = Tuple[dict, Callable] + + +class AuthMiddleware(BaseMiddleware): + """ + Middleware to handle auth information on requests. + + Before the request is handled by the application, the ``Authorization`` + header is parsed for an encrypted JWT. If successfully decrypted, + information about the user and their authorization scope is attached + to the request. + + This can be accessed in the application via + ``flask.request.environ['session']``. If Authorization header was not + included, then that value will be ``None``. + + If the JWT could not be decrypted, the value will be an + :class:`Unauthorized` exception instance. We cannot raise the exception + here, because the middleware is executed outside of the Flask application. + It's up to something running inside the application (e.g. + :func:`arxiv.users.auth.decorators.scoped`) to raise the exception. + + """ + + def before(self, environ: dict, start_response: Callable) -> WSGIRequest: + """Decode and unpack the auth token on the request.""" + environ['auth'] = None # Create the session key, at a minimum. + environ['token'] = None + token = environ.get('HTTP_AUTHORIZATION') # We may not have a token. + if token is None: + logger.debug('No auth token') + return environ, start_response + + # The token secret should be set in the WSGI environ, or in os.environ. + secret = environ.get('JWT_SECRET', os.environ.get('JWT_SECRET')) + if secret is None: + raise ConfigurationError('Missing decryption token') + + try: + # Try to verify the token in the Authorization header, and attach + # the decoded session data to the request. + session: domain.Session = tokens.decode(token, secret) + environ['auth'] = session + + # Attach the encrypted token so that we can use it in subrequests. + environ['token'] = token + except InvalidToken as e: # Let the application decide what to do. + logger.debug(f'Auth token not valid: {token}') + exception = Unauthorized('Invalid auth token') + environ['auth'] = exception + except Exception as e: + logger.error(f'Unhandled exception: {e}') + exception = InternalServerError(f'Unhandled: {e}') # type: ignore + environ['auth'] = exception + return environ, start_response diff --git a/arxiv/auth/auth/scopes.py b/arxiv/auth/auth/scopes.py new file mode 100644 index 00000000..2c56858a --- /dev/null +++ b/arxiv/auth/auth/scopes.py @@ -0,0 +1,241 @@ +""" +Authorization scopes for arXiv users and clients. + +The concept of authorization scope comes from OAuth 2.0 (`RFC 6749 ยง3.3 +`_). For a nice primer, see +`this blog post `_. The basic idea is that +the authorization associated with an access token can be limited, e.g. to +limit what actions an API client can take on behalf of a user. + +In this package, the scope concept is applied to both API client and end-user +sessions. When the session is created, we consult the relevant bits of data in +our system (e.g. what roles the user has, what privileges are associated with +those roles) to determine what the user is authorized to do. Those privileges +are attached to the user's session as authorization scopes. + +This module simply defines a set of constants (str) that can be used to refer +to specific authorization scopes. Rather than refer to scopes by writing new +str objects, these constants should be imported and used. This improves the +semantics of code, and reduces the risk of programming errors. For an example, +see :mod:`arxiv.users.auth.decorators`. + +""" +from typing import Optional, Dict +from ..domain import Scope + +def _scope_str(domain:str, action:str, resource:str = None): + return str(Scope(domain=domain, action=action, resource=resource)) + + +class domains: + """Known authorization domains.""" + + PUBLIC = 'public' + """The public arXiv site, including APIs.""" + PROFILE = 'profile' + """arXiv user profile.""" + SUBMISSION = 'submission' + """Submission interfaces and actions.""" + UPLOAD = 'upload' + """File uploads, including those for submissions.""" + COMPILE = 'compile' + """PDF compilation.""" + FULLTEXT = 'fulltext' + """Fulltext extraction.""" + PREVIEW = 'preview' + """Submission previews.""" + +class actions: + """Known authorization actions.""" + + UPDATE = 'update' + CREATE = 'create' + DELETE = 'delete' + RELEASE = 'release' + READ = 'read' + PROXY = 'proxy' + +READ_PUBLIC = _scope_str(domains.PUBLIC, actions.READ) +""" +Authorizes access to public endpoints. + +This is implicitly granted to all anonymous users. For endpoints requiring +authentication (e.g. APIs) this scope can be used to denote baseline read +access for clients. +""" + +EDIT_PROFILE = _scope_str(domains.PROFILE, actions.UPDATE) +""" +Authorizes editing user profile. + +This includes things like affiliation, full name, etc.. +""" + +VIEW_PROFILE = _scope_str(domains.PROFILE, actions.READ) +""" +Authorizes viewing the content of a user profile. + +This includes things like affiliation, full name, and e-mail address. +""" + +CREATE_SUBMISSION = _scope_str(domains.SUBMISSION, actions.CREATE) +"""Authorizes creating a new submission.""" + +EDIT_SUBMISSION = _scope_str(domains.SUBMISSION, actions.UPDATE) +"""Authorizes updating a submission that has not yet been announced.""" + +VIEW_SUBMISSION = _scope_str(domains.SUBMISSION, actions.READ) +"""Authorizes viewing a submission.""" + +DELETE_SUBMISSION = _scope_str(domains.SUBMISSION, actions.DELETE) +"""Authorizes deleting a submission.""" + +PROXY_SUBMISSION = _scope_str(domains.SUBMISSION, actions.PROXY) +""" +Authorizes creating a submission on behalf of another user. + +This authorization is specifically for human users submitting on behalf of +other human users. For client authorization to submit on behalf of a user, +submission:create should be used instead. +""" + +READ_UPLOAD = _scope_str(domains.UPLOAD, actions.READ) +"""Authorizes viewing the content of an upload workspace.""" + +WRITE_UPLOAD = _scope_str(domains.UPLOAD, actions.UPDATE) +"""Authorizes uploading files to to a workspace.""" + +RELEASE_UPLOAD = _scope_str(domains.UPLOAD, actions.RELEASE) +"""Authorizes releasing an upload workspace.""" + +DELETE_UPLOAD_WORKSPACE = _scope_str(domains.UPLOAD, 'delete_workspace') +"""Can delete an entire workspace in the file management service.""" + +DELETE_UPLOAD_FILE = _scope_str(domains.UPLOAD, actions.DELETE) +"""Can delete files from a file management upload workspace.""" + +READ_UPLOAD_LOGS = _scope_str(domains.UPLOAD, 'read_logs') +"""Can read logs for a file management upload workspace.""" + +READ_UPLOAD_SERVICE_LOGS = _scope_str(domains.UPLOAD, 'read_service_logs') +"""Can read service logs in the file management service.""" + +CREATE_UPLOAD_CHECKPOINT = _scope_str(domains.UPLOAD, 'create_checkpoint') +"""Create an upload workspace checkpoint.""" + +DELETE_UPLOAD_CHECKPOINT = _scope_str(domains.UPLOAD, 'delete_checkpoint') +"""Delete an upload workspace checkpoint.""" + +READ_UPLOAD_CHECKPOINT = _scope_str(domains.UPLOAD, 'read_checkpoints') +"""Read from an upload workspace checkpoint.""" + +RESTORE_UPLOAD_CHECKPOINT = _scope_str(domains.UPLOAD, 'restore_checkpoint') +"""Restore an upload workspace to a checkpoint.""" + +READ_COMPILE = _scope_str(domains.COMPILE, actions.READ) +"""Can read documents generated by the compilation service.""" + +CREATE_COMPILE = _scope_str(domains.COMPILE, actions.CREATE) +"""Can create new documents via the compilation service.""" + +READ_FULLTEXT = _scope_str(domains.FULLTEXT, actions.READ) +"""Can access plain text extracted from compiled documents.""" + +CREATE_FULLTEXT = _scope_str(domains.FULLTEXT, actions.CREATE) +"""Can trigger new plain text extractions from compiled documents.""" + +READ_PREVIEW = _scope_str(domains.PREVIEW, actions.READ) +"""Can view a submission preview.""" + +CREATE_PREVIEW = _scope_str(domains.PREVIEW, actions.CREATE) +"""Can create a new submission preview.""" + +GENERAL_USER = [ + READ_PUBLIC, # Access to public APIs. + + # Profile management. + EDIT_PROFILE, + VIEW_PROFILE, + + # Ability to use the submission system. + CREATE_SUBMISSION, + EDIT_SUBMISSION, + VIEW_SUBMISSION, + DELETE_SUBMISSION, + + # Allows usage of the compilation service during submission. + READ_COMPILE, + CREATE_COMPILE, + + # Allows usage of the file management service during submission. + READ_UPLOAD, + WRITE_UPLOAD, + DELETE_UPLOAD_FILE, + + # Ability to create and view submission previews. + READ_PREVIEW, + CREATE_PREVIEW, +] +""" +The default scopes afforded to an authenticated user. + +This static list will be deprecated by role-based access control (RBAC) at a +later milestone of arXiv. +""" + +_ADMIN_USER = GENERAL_USER + [ + CREATE_UPLOAD_CHECKPOINT, + DELETE_UPLOAD_CHECKPOINT, + READ_UPLOAD_CHECKPOINT, + RESTORE_UPLOAD_CHECKPOINT, + READ_FULLTEXT, + CREATE_FULLTEXT, + DELETE_UPLOAD_WORKSPACE, + READ_UPLOAD_LOGS, + READ_UPLOAD_SERVICE_LOGS +] +ADMIN_USER = [str(Scope.from_str(scope).as_global()) for scope in _ADMIN_USER] +""" +Scopes afforded to an administrator. + +This static list will be deprecated by role-based access control (RBAC) at a +later milestone of arXiv. +""" + +_HUMAN_LABELS: Dict[Scope, str] = { + EDIT_PROFILE: "Grants authorization to change the contents of your user" + " profile. This includes your affiliation, preferred name," + " default submission category, etc.", + VIEW_PROFILE: "Grants authorization to view the contents of your user" + " profile. This includes your affiliation, preferred name," + " default submission category, etc.", + CREATE_SUBMISSION: "Grants authorization to submit papers on your behalf.", + EDIT_SUBMISSION: "Grants authorization to make changes to your submissions" + " that have not yet been announced. For example, to" + " update the DOI or journal reference field on your" + " behalf. Note that this affects only the metadata of" + " your submission, and not the content.", + VIEW_SUBMISSION: "Grants authorization to view your submissions, including" + " those that have not yet been announced. Note that this" + " only applies to the submission metadata, and not to the" + " uploaded submission content.", + READ_UPLOAD: "Grants authorization to view the contents of your uploads.", + WRITE_UPLOAD: "Grants authorization to add and delete files on your" + " behalf.", + READ_UPLOAD_LOGS: "Grants authorization to read logs of upload activity" + " related to your submissions.", + READ_COMPILE: "Grants authorization to read a compilation task, product," + " and any log output, related to your submissions.", + CREATE_COMPILE: "Grants authorization to compile your submission source" + " files, e.g. to produce a PDF.", + READ_PREVIEW: "Grants authorization to retrieve the preview (e.g. PDF) of" + " your submission.", + CREATE_PREVIEW: "Grants authorization to update the preview (e.g. PDF) of" + " your submission.", +} + + +def get_human_label(scope: str) -> Optional[str]: + """Get a human-readable label for a scope, for display to end users.""" + label: Optional[str] = _HUMAN_LABELS.get(scope, None) + return label diff --git a/arxiv/auth/auth/sessions/__init__.py b/arxiv/auth/auth/sessions/__init__.py new file mode 100644 index 00000000..0d2b82ab --- /dev/null +++ b/arxiv/auth/auth/sessions/__init__.py @@ -0,0 +1,12 @@ +""" +Integration with the distributed session store. + +In this implementation, we use a key-value store to hold session data +in JSON format. When a session is created, a JWT cookie value is +created that contains information sufficient to retrieve the session. + +See :mod:`.store`. + +""" + +from .store import SessionStore diff --git a/arxiv/auth/auth/sessions/store.py b/arxiv/auth/auth/sessions/store.py new file mode 100644 index 00000000..cd70eb36 --- /dev/null +++ b/arxiv/auth/auth/sessions/store.py @@ -0,0 +1,271 @@ +""" +Internal service API for the distributed session store. + +Used to create, delete, and verify user and client session. +""" + +import uuid +import random +from datetime import datetime, timedelta +import dateutil.parser +from pytz import timezone, UTC +import logging + +from typing import Optional, Union + +import redis +import rediscluster + +import jwt + +from .. import domain +from ..exceptions import SessionCreationFailed, InvalidToken, \ + SessionDeletionFailed, UnknownSession, ExpiredToken + +from arxiv.base.globals import get_application_config, get_application_global + +logger = logging.getLogger(__name__) +EASTERN = timezone('US/Eastern') + + +def _generate_nonce(length: int = 8) -> str: + return ''.join([str(random.randint(0, 9)) for i in range(length)]) + + +class SessionStore(object): + """ + Manages a connection to Redis. + + In fact, the StrictRedis instance is thread safe and connections are + attached at the time a command is executed. This class simply provides a + container for configuration. + + Pass fake=True to use FakeRedis for testing of development. + """ + + def __init__(self, host: str, port: int, db: int, secret: str, + duration: int = 7200, token: Optional[str] = None, + cluster: bool = True, fake: bool = False) -> None: + """Open the connection to Redis.""" + self._secret = secret + self._duration = duration + if fake: + logger.warning('Using FakeRedis') + import fakeredis # this is a dev dependency needed during testing + self.r = fakeredis.FakeStrictRedis() + else: + logger.debug('New Redis connection at %s, port %s', host, port) + if cluster: + self.r = rediscluster.StrictRedisCluster( + startup_nodes=[{'host': host, 'port': str(port)}], + skip_full_coverage_check=True + ) + else: + self.r = redis.StrictRedis(host=host, port=port) + + def create(self, authorizations: domain.Authorizations, + ip_address: str, remote_host: str, tracking_cookie: str = '', + user: Optional[domain.User] = None, + client: Optional[domain.Client] = None, + session_id: Optional[str] = None) -> domain.Session: + """ + Create a new session. + + Parameters + ---------- + authorizations : :class:`domain.Authorizations` + ip_address : str + remote_host : str + tracking_cookie : str + user : :class:`domain.User` + client : :class:`domain.Client` + + Returns + ------- + :class:`.Session` + + """ + if session_id is None: + session_id = str(uuid.uuid4()) + start_time = datetime.now(tz=UTC) + end_time = start_time + timedelta(seconds=self._duration) + session = domain.Session( + session_id=session_id, + user=user, + client=client, + start_time=start_time, + end_time=end_time, + authorizations=authorizations, + nonce=_generate_nonce() + ) + logger.debug('storing session %s', session) + try: + self.r.set(session_id, + jwt.encode(session.json_safe_dict(), self._secret), + ex=self._duration) + except redis.exceptions.ConnectionError as e: + raise SessionCreationFailed(f'Connection failed: {e}') from e + except Exception as e: + raise SessionCreationFailed(f'Failed to create: {e}') from e + + return session + + def generate_cookie(self, session: domain.Session) -> str: + """Generate a cookie from a :class:`domain.Session`.""" + if session.end_time is None: + raise RuntimeError('Session has no expiry') + if session.user is None: + raise RuntimeError('Session user is not set') + return self._pack_cookie({ + 'user_id': session.user.user_id, + 'session_id': session.session_id, + 'nonce': session.nonce, + 'expires': session.end_time.isoformat() + }) + + def delete(self, cookie: str) -> None: + """ + Delete a session. + + Parameters + ---------- + cookie : str + + """ + cookie_data = self._unpack_cookie(cookie) + self.delete_by_id(cookie_data['session_id']) + + def delete_by_id(self, session_id: str) -> None: + """ + Delete a session in the key-value store by ID. + + Parameters + ---------- + session_id : str + + """ + try: + self.r.delete(session_id) + except redis.exceptions.ConnectionError as e: + raise SessionDeletionFailed(f'Connection failed: {e}') from e + except Exception as e: + raise SessionDeletionFailed(f'Failed to delete: {e}') from e + + def validate_session_against_cookie(self, session: domain.Session, + cookie: str) -> None: + """ + Validate session data against a cookie. + + Parameters + ---------- + session : :class:`Session` + cookie : str + + Raises + ------ + :class:`InvalidToken` + Raised if the data in the cookie does not match the session data. + + """ + cookie_data = self._unpack_cookie(cookie) + if cookie_data['nonce'] != session.nonce \ + or session.user is None \ + or session.user.user_id != cookie_data['user_id']: + raise InvalidToken('Invalid token; likely a forgery') + + def load(self, cookie: str, decode: bool = True) \ + -> Union[domain.Session, str, bytes]: + """Load a session using a session cookie.""" + try: + cookie_data = self._unpack_cookie(cookie) + expires = dateutil.parser.parse(cookie_data['expires']) + except (KeyError, jwt.exceptions.DecodeError) as e: + raise InvalidToken('Token payload malformed') from e + + if expires <= datetime.now(tz=UTC): + raise InvalidToken('Session has expired') + + session = self.load_by_id(cookie_data['session_id'], decode=decode) + + if not decode: + assert isinstance(session, str) or isinstance(session, bytes) + return session + assert isinstance(session, domain.Session) + if session.expired: + raise ExpiredToken('Session has expired') + if session.user is None and session.client is None: + raise InvalidToken('Neither user nor client data are present') + + self.validate_session_against_cookie(session, cookie) + return session + + def load_by_id(self, session_id: str, decode: bool = True) \ + -> Union[domain.Session, str, bytes]: + """Get session data by session ID.""" + session_jwt: str = self.r.get(session_id) + if not session_jwt: + logger.debug(f'No such session: {session_id}') + raise UnknownSession(f'Failed to find session {session_id}') + if decode: + return self._decode(session_jwt) + return session_jwt + + def _encode(self, session_data: dict) -> bytes: + return jwt.encode(session_data, self._secret) + + def _decode(self, session_jwt: str) -> domain.Session: + try: + return domain.Session.parse_obj( + jwt.decode(session_jwt, self._secret, algorithms=['HS256'])) + except jwt.exceptions.InvalidSignatureError: + raise InvalidToken('Invalid or corrupted session token') + + def _unpack_cookie(self, cookie: str) -> dict: + secret = self._secret + try: + data = dict(jwt.decode(cookie, secret, algorithms=['HS256'])) + except jwt.exceptions.DecodeError as e: + raise InvalidToken('Session cookie is malformed') from e + return data + + def _pack_cookie(self, cookie_data: dict) -> str: + secret = self._secret + return jwt.encode(cookie_data, secret) + + @classmethod + def init_app(cls, app: object = None) -> None: + """Set default configuration parameters for an application instance.""" + config = get_application_config(app) + config.setdefault('REDIS_HOST', 'localhost') + config.setdefault('REDIS_PORT', '7000') + config.setdefault('REDIS_DATABASE', '0') + config.setdefault('REDIS_TOKEN', None) + config.setdefault('REDIS_CLUSTER', '1') + config.setdefault('JWT_SECRET', 'foosecret') + config.setdefault('SESSION_DURATION', '7200') + config.setdefault('REDIS_FAKE', False) + + @classmethod + def get_session(cls, app: object = None) -> 'SessionStore': + """Get a new session with the search index.""" + config = get_application_config(app) + host = config.get('REDIS_HOST', 'localhost') + port = int(config.get('REDIS_PORT', '7000')) + db = int(config.get('REDIS_DATABASE', '0')) + token = config.get('REDIS_TOKEN', None) + cluster = config.get('REDIS_CLUSTER', '1') == '1' + secret = config['JWT_SECRET'] + duration = int(config.get('SESSION_DURATION', '7200')) + fake = config.get('REDIS_FAKE', False) + return cls(host, port, db, secret, duration, token=token, + cluster=cluster, fake=fake) + + @classmethod + def current_session(cls) -> 'SessionStore': + """Get/create :class:`.SearchSession` for this context.""" + g = get_application_global() + if not g: + return cls.get_session() + if 'redis' not in g: + g.redis = cls.get_session() + return g.redis # type: ignore diff --git a/arxiv/auth/auth/sessions/tests/__init__.py b/arxiv/auth/auth/sessions/tests/__init__.py new file mode 100644 index 00000000..d4540d44 --- /dev/null +++ b/arxiv/auth/auth/sessions/tests/__init__.py @@ -0,0 +1 @@ +"""Tests for :mod:`accounts.services.session_store`.""" diff --git a/arxiv/auth/auth/sessions/tests/test_integration.py b/arxiv/auth/auth/sessions/tests/test_integration.py new file mode 100644 index 00000000..2d001e4c --- /dev/null +++ b/arxiv/auth/auth/sessions/tests/test_integration.py @@ -0,0 +1,82 @@ +"""Integration tests for the session_store session store with Redis.""" + +from unittest import TestCase, mock +import jwt + +from .... import domain +from .. import store + + +class TestDistributedSessionServiceIntegration(TestCase): + """Test integration with Redis.""" + + @classmethod + def setUpClass(self): + self.secret = 'bazsecret' + + @mock.patch(f'{store.__name__}.get_application_config') + def test_store_create(self, mock_get_config): + """An entry should be created in Redis.""" + mock_get_config.return_value = { + 'JWT_SECRET': self.secret, + 'REDIS_FAKE': True + } + ip = '127.0.0.1' + remote_host = 'foo-host.foo.com' + user = domain.User( + user_id='1', + username='theuser', + email='the@user.com', + ) + authorizations = domain.Authorizations( + classic=2, + scopes=['foo:write'], + endorsements=[] + ) + s = store.SessionStore.current_session() + session = s.create(authorizations, ip, remote_host, user=user) + cookie = s.generate_cookie(session) + + # API still works as expected. + self.assertIsInstance(session, domain.Session) + self.assertTrue(bool(session.session_id)) + self.assertIsNotNone(cookie) + + r = s.r + raw = r.get(session.session_id) + stored_data = jwt.decode(raw, self.secret, algorithms=['HS256']) + cookie_data = jwt.decode(cookie, self.secret, algorithms=['HS256']) + self.assertEqual(stored_data['nonce'], cookie_data['nonce']) + + # def test_invalidate_session(self): + # """Invalidate a session from the datastore.""" + # r = rediscluster.StrictRedisCluster(startup_nodes=[dict(host='localhost', port='7000')]) + # data_in = {'end_time': time.time() + 30 * 60, 'user_id': 1, + # 'nonce': '123'} + # r.set('fookey', json.dumps(data_in)) + # data0 = json.loads(r.get('fookey')) + # now = time.time() + # self.assertGreaterEqual(data0['end_time'], now) + # store.invalidate( + # store.current_session()._pack_cookie({ + # 'session_id': 'fookey', + # 'nonce': '123', + # 'user_id': 1 + # }) + # ) + # data1 = json.loads(r.get('fookey')) + # now = time.time() + # self.assertGreaterEqual(now, data1['end_time']) + + @mock.patch(f'{store.__name__}.get_application_config') + def test_delete_session(self, mock_get_config): + """Delete a session from the datastore.""" + mock_get_config.return_value = { + 'JWT_SECRET': self.secret, + 'REDIS_FAKE': True + } + s = store.SessionStore.current_session() + r = s.r + r.set('fookey', b'foovalue') + s.delete_by_id('fookey') + self.assertIsNone(r.get('fookey')) diff --git a/arxiv/auth/auth/sessions/tests/test_unit.py b/arxiv/auth/auth/sessions/tests/test_unit.py new file mode 100644 index 00000000..d5b82490 --- /dev/null +++ b/arxiv/auth/auth/sessions/tests/test_unit.py @@ -0,0 +1,319 @@ +"""Tests for :mod:`arxiv.users.auth.sessions.store`.""" + +from unittest import TestCase, mock +import time +import jwt +import json +from datetime import datetime, timedelta +from pytz import timezone, UTC +from redis.exceptions import ConnectionError + +from .... import domain +from .. import store + +EASTERN = timezone('US/Eastern') + + +class TestDistributedSessionService(TestCase): + """The store session service puts sessions in a key-value store.""" + + @mock.patch(f'{store.__name__}.get_application_config') + @mock.patch(f'{store.__name__}.rediscluster') + def test_create(self, mock_redis, mock_get_config): + """Accept a :class:`.User` and returns a :class:`.Session`.""" + mock_get_config.return_value = {'JWT_SECRET': 'foosecret'} + mock_redis.exceptions.ConnectionError = ConnectionError + mock_redis_connection = mock.MagicMock() + mock_redis.StrictRedisCluster.return_value = mock_redis_connection + ip = '127.0.0.1' + remote_host = 'foo-host.foo.com' + user = domain.User( + user_id='1', + username='theuser', + email='the@user.com' + ) + auths = domain.Authorizations( + classic=2, + scopes=['foo:write'], + endorsements=[] + ) + r = store.SessionStore('localhost', 7000, 0, 'foosecret') + session = r.create(auths, ip, remote_host, user=user) + cookie = r.generate_cookie(session) + self.assertIsInstance(session, domain.Session) + self.assertTrue(bool(session.session_id)) + self.assertIsNotNone(cookie) + self.assertEqual(mock_redis_connection.set.call_count, 1) + + @mock.patch(f'{store.__name__}.get_application_config') + @mock.patch(f'{store.__name__}.rediscluster') + def test_delete(self, mock_redis, mock_get_config): + """Delete a session from the datastore.""" + mock_get_config.return_value = {'JWT_SECRET': 'foosecret'} + mock_redis.exceptions.ConnectionError = ConnectionError + mock_redis_connection = mock.MagicMock() + mock_redis.StrictRedisCluster.return_value = mock_redis_connection + r = store.SessionStore('localhost', 7000, 0, 'foosecret') + r.delete_by_id('fookey') + self.assertEqual(mock_redis_connection.delete.call_count, 1) + + @mock.patch(f'{store.__name__}.get_application_config') + @mock.patch(f'{store.__name__}.rediscluster') + def test_connection_failed(self, mock_redis, mock_get_config): + """:class:`.SessionCreationFailed` is raised when creation fails.""" + mock_get_config.return_value = {'JWT_SECRET': 'foosecret'} + mock_redis.exceptions.ConnectionError = ConnectionError + mock_redis_connection = mock.MagicMock() + mock_redis_connection.set.side_effect = ConnectionError + mock_redis.StrictRedisCluster.return_value = mock_redis_connection + ip = '127.0.0.1' + remote_host = 'foo-host.foo.com' + user = domain.User( + user_id='1', + username='theuser', + email='the@user.com' + ) + auths = domain.Authorizations( + classic=2, + scopes=['foo:write'], + endorsements=[] + ) + r = store.SessionStore('localhost', 7000, 0, 'foosecret') + with self.assertRaises(store.SessionCreationFailed): + r.create(auths, ip, remote_host, user=user) + + +class TestGetSession(TestCase): + """Tests for :func:`store.SessionStore.current_session().load`.""" + + @mock.patch(f'{store.__name__}.get_application_config') + @mock.patch(f'{store.__name__}.rediscluster.StrictRedisCluster') + def test_not_a_token(self, mock_get_redis, mock_get_config): + """Something other than a JWT is passed.""" + mock_get_config.return_value = { + 'JWT_SECRET': 'barsecret', + 'REDIS_HOST': 'redis', + 'REDIS_PORT': '1234', + 'REDIS_DATABASE': 4 + } + mock_redis = mock.MagicMock() + mock_get_redis.return_value = mock_redis + with self.assertRaises(store.InvalidToken): + store.SessionStore.current_session().load('notatoken') + + @mock.patch(f'{store.__name__}.get_application_config') + @mock.patch(f'{store.__name__}.rediscluster.StrictRedisCluster') + def test_malformed_token(self, mock_get_redis, mock_get_config): + """A JWT with missing claims is passed.""" + secret = 'barsecret' + mock_get_config.return_value = { + 'JWT_SECRET': secret, + 'REDIS_HOST': 'redis', + 'REDIS_PORT': '1234', + 'REDIS_DATABASE': 4 + } + mock_redis = mock.MagicMock() + mock_get_redis.return_value = mock_redis + required_claims = ['session_id', 'nonce'] + for exc in required_claims: + claims = {claim: '' for claim in required_claims if claim != exc} + malformed_token = jwt.encode(claims, secret) + with self.assertRaises(store.InvalidToken): + store.SessionStore.current_session().load(malformed_token) + + @mock.patch(f'{store.__name__}.get_application_config') + @mock.patch(f'{store.__name__}.rediscluster.StrictRedisCluster') + def test_token_with_bad_encryption(self, mock_get_redis, mock_get_config): + """A JWT produced with a different secret is passed.""" + secret = 'barsecret' + mock_get_config.return_value = { + 'JWT_SECRET': secret, + 'REDIS_HOST': 'redis', + 'REDIS_PORT': '1234', + 'REDIS_DATABASE': 4 + } + mock_redis = mock.MagicMock() + mock_get_redis.return_value = mock_redis + start_time = datetime.now(tz=UTC) + end_time = start_time + timedelta(seconds=7200) + claims = { + 'user_id': '1234', + 'session_id': 'ajx9043jjx00s', + 'nonce': '0039299290099', + 'expires': end_time.isoformat() + } + bad_token = jwt.encode(claims, 'nottherightsecret') + with self.assertRaises(store.InvalidToken): + store.SessionStore.current_session().load(bad_token) + + @mock.patch(f'{store.__name__}.get_application_config') + @mock.patch(f'{store.__name__}.rediscluster.StrictRedisCluster') + def test_expired_token(self, mock_get_redis, mock_get_config): + """A JWT produced with a different secret is passed.""" + secret = 'barsecret' + mock_get_config.return_value = { + 'JWT_SECRET': secret, + 'REDIS_HOST': 'redis', + 'REDIS_PORT': '1234', + 'REDIS_DATABASE': 4 + } + mock_redis = mock.MagicMock() + start_time = datetime.now(tz=UTC) + mock_redis.get.return_value = json.dumps({ + 'user_id': '1234', + 'session_id': 'ajx9043jjx00s', + 'nonce': '0039299290099', + 'expires': start_time.isoformat(), + }) + mock_get_redis.return_value = mock_redis + + claims = { + 'user_id': '1234', + 'session_id': 'ajx9043jjx00s', + 'nonce': '0039299290099', + 'expires': start_time.isoformat(), + } + expired_token = jwt.encode(claims, secret) + with self.assertRaises(store.InvalidToken): + store.SessionStore.current_session().load(expired_token) + + @mock.patch(f'{store.__name__}.get_application_config') + @mock.patch(f'{store.__name__}.rediscluster.StrictRedisCluster') + def test_forged_token(self, mock_get_redis, mock_get_config): + """A JWT with the wrong nonce is passed.""" + start_time = datetime.now(tz=UTC) + end_time = start_time + timedelta(seconds=7200) + + secret = 'barsecret' + mock_get_config.return_value = { + 'JWT_SECRET': secret, + 'REDIS_HOST': 'redis', + 'REDIS_PORT': '1234', + 'REDIS_DATABASE': 4 + } + mock_redis = mock.MagicMock() + mock_redis.get.return_value = jwt.encode({ + 'session_id': 'ajx9043jjx00s', + 'nonce': '0039299290098', + 'start_time': start_time.isoformat(), + 'end_time': end_time.isoformat(), + 'user': { + 'user_id': '1235', + 'username': 'foouser', + 'email': 'foo@foo.com' + } + }, secret) + mock_get_redis.return_value = mock_redis + + claims = { + 'user_id': '1234', + 'session_id': 'ajx9043jjx00s', + 'nonce': '0039299290099', # <- Doesn't match! + 'expires': end_time.isoformat(), + } + expired_token = jwt.encode(claims, secret) + with self.assertRaises(store.InvalidToken): + # loaded token is getting non Datetime end_time + store.SessionStore.current_session().load(expired_token) + + @mock.patch(f'{store.__name__}.get_application_config') + @mock.patch(f'{store.__name__}.rediscluster.StrictRedisCluster') + def test_other_forged_token(self, mock_get_redis, mock_get_config): + """A JWT with the wrong user_id is passed.""" + start_time = datetime.now(tz=UTC) + end_time = start_time + timedelta(seconds=7200) + + secret = 'barsecret' + mock_get_config.return_value = { + 'JWT_SECRET': secret, + 'REDIS_HOST': 'redis', + 'REDIS_PORT': '1234', + 'REDIS_DATABASE': 4 + } + mock_redis = mock.MagicMock() + mock_redis.get.return_value = jwt.encode({ + 'session_id': 'ajx9043jjx00s', + 'nonce': '0039299290099', + 'start_time': start_time.isoformat(), + 'user': { + 'user_id': '1235', + 'username': 'foouser', + 'email': 'foo@foo.com' + } + }, secret) + mock_get_redis.return_value = mock_redis + claims = { + 'user_id': '1234', # <- Doesn't match! + 'session_id': 'ajx9043jjx00s', + 'nonce': '0039299290099', + 'expires': end_time.isoformat(), + } + expired_token = jwt.encode(claims, secret) + with self.assertRaises(store.InvalidToken): + store.SessionStore.current_session().load(expired_token) + + @mock.patch(f'{store.__name__}.get_application_config') + @mock.patch(f'{store.__name__}.rediscluster.StrictRedisCluster') + def test_empty_session(self, mock_get_redis, mock_get_config): + """Session has been removed, or may never have existed.""" + start_time = datetime.now(tz=UTC) + end_time = start_time + timedelta(seconds=7200) + + secret = 'barsecret' + mock_get_config.return_value = { + 'JWT_SECRET': secret, + 'REDIS_HOST': 'redis', + 'REDIS_PORT': '1234', + 'REDIS_DATABASE': 4 + } + mock_redis = mock.MagicMock() + mock_redis.get.return_value = '' # <- Empty record! + mock_get_redis.return_value = mock_redis + + claims = { + 'user_id': '1234', + 'session_id': 'ajx9043jjx00s', + 'nonce': '0039299290099', + 'expires': end_time.isoformat(), + } + expired_token = jwt.encode(claims, secret) + with self.assertRaises(store.UnknownSession): + store.SessionStore.current_session().load(expired_token) + + @mock.patch(f'{store.__name__}.get_application_config') + @mock.patch(f'{store.__name__}.rediscluster.StrictRedisCluster') + def test_valid_token(self, mock_get_redis, mock_get_config): + """A valid token is passed.""" + start_time = datetime.now(tz=UTC) + end_time = start_time + timedelta(seconds=7200) + + secret = 'barsecret' + mock_get_config.return_value = { + 'JWT_SECRET': secret, + 'REDIS_HOST': 'redis', + 'REDIS_PORT': '1234', + 'REDIS_DATABASE': 4 + } + mock_redis = mock.MagicMock() + mock_redis.get.return_value = jwt.encode({ + 'session_id': 'ajx9043jjx00s', + 'start_time': datetime.now(tz=UTC).isoformat(), + 'nonce': '0039299290098', + 'user': { + 'user_id': '1234', + 'username': 'foouser', + 'email': 'foo@foo.com' + } + }, secret) + mock_get_redis.return_value = mock_redis + + claims = { + 'user_id': '1234', + 'session_id': 'ajx9043jjx00s', + 'nonce': '0039299290098', + 'expires': end_time.isoformat(), + } + valid_token = jwt.encode(claims, secret) + + session = store.SessionStore.current_session().load(valid_token) + self.assertIsInstance(session, domain.Session, "Returns a session") diff --git a/arxiv/auth/auth/tests/__init__.py b/arxiv/auth/auth/tests/__init__.py new file mode 100644 index 00000000..ed72440d --- /dev/null +++ b/arxiv/auth/auth/tests/__init__.py @@ -0,0 +1 @@ +"""Tests for :mod:`arxiv.users.auth`.""" diff --git a/arxiv/auth/auth/tests/test_decorators.py b/arxiv/auth/auth/tests/test_decorators.py new file mode 100644 index 00000000..584860e5 --- /dev/null +++ b/arxiv/auth/auth/tests/test_decorators.py @@ -0,0 +1,240 @@ +"""Tests for :mod:`arxiv.users.auth.decorators`.""" + +import pytest +from datetime import datetime +from pytz import timezone, UTC + +from flask import request, current_app +from werkzeug.exceptions import Unauthorized, Forbidden + +from .. import scopes, decorators +from ... import domain + +EASTERN = timezone('US/Eastern') +""" @mock.patch(f'{decorators.__name__}.request')""" + +def test_no_session(mocker, request_context): + """No session is present on the request.""" + with request_context: + mock_req = mocker.patch(f'{decorators.__name__}.request') + mock_req.auth = None + + assert not hasattr(request, 'called') + + @decorators.scoped(scopes.CREATE_SUBMISSION) + def _protected(): + request.called = True + + with pytest.raises(Unauthorized): + _protected() + + assert not hasattr(request, 'called'), "The protected function should not have its body called" + + +def test_scope_is_missing(mocker, request_context): + """Session does not have required scope.""" + with request_context: + mock_req = mocker.patch(f'{decorators.__name__}.request') + mock_req.auth = domain.Session( + session_id='fooid', + start_time=datetime.now(tz=UTC), + user=domain.User( + user_id='235678', + email='foo@foo.com', + username='foouser' + ), + authorizations=domain.Authorizations( + scopes=[scopes.VIEW_SUBMISSION] + ) + ) + + assert not hasattr(request, 'called') + + @decorators.scoped(scopes.CREATE_SUBMISSION) + def protected(): + """A protected function.""" + request.called = True + + with pytest.raises(Forbidden): + protected() + + assert not hasattr(request, 'called'), "The protected function should not have its body called" + +def test_scope_is_present(mocker, request_context): + """Session has required scope.""" + with request_context: + request.auth = domain.Session( + session_id='fooid', + start_time=datetime.now(tz=UTC), + user=domain.User( + user_id='235678', + email='foo@foo.com', + username='foouser' + ), + authorizations=domain.Authorizations( + scopes=[scopes.VIEW_SUBMISSION, scopes.CREATE_SUBMISSION] + ) + ) + + assert not hasattr(request, 'called') + + @decorators.scoped(scopes.CREATE_SUBMISSION) + def protected(): + """A protected function.""" + print("HERE IN PROTECTED") + request.called = True + + protected() + assert request.called + + +def test_user_and_client_are_missing(mocker, request_context): + """Session does not user nor client information.""" + with request_context: + mock_req = mocker.patch(f'{decorators.__name__}.request') + mock_req.auth = domain.Session( + session_id='fooid', + start_time=datetime.now(tz=UTC), + authorizations=domain.Authorizations( + scopes=[scopes.CREATE_SUBMISSION] + ) + ) + assert not hasattr(request, 'called') + @decorators.scoped(scopes.CREATE_SUBMISSION) + def protected(): + """A protected function.""" + request.called = True + + with pytest.raises(Unauthorized): + protected() + + assert not hasattr(request, 'called'), "The protected function should not have its body called" + +def test_authorizer_returns_false(mocker, request_context): + """Session has required scope, but authorizer func returns false.""" + with request_context: + mock_req = mocker.patch(f'{decorators.__name__}.request') + mock_req.auth = domain.Session( + session_id='fooid', + start_time=datetime.now(tz=UTC), + user=domain.User( + user_id='235678', + email='foo@foo.com', + username='foouser' + ), + authorizations=domain.Authorizations( + scopes=[scopes.CREATE_SUBMISSION] + ) + ) + assert not hasattr(request, 'called') + assert not hasattr(request, 'authorizer_called') + + def return_false(session: domain.Session) -> bool: + request.authorizer_called = True + return False + + @decorators.scoped(scopes.CREATE_SUBMISSION, authorizer=return_false) + def protected(): + """A protected function.""" + request.called = True + + with pytest.raises(Forbidden): + protected() + + assert not hasattr(request, 'called') + assert request.authorizer_called + + +def test_authorizer_returns_true(mocker, request_context): + """Session has required scope, authorizer func returns true.""" + with request_context: + mock_req = mocker.patch(f'{decorators.__name__}.request') + mock_req.auth = domain.Session( + session_id='fooid', + start_time=datetime.now(tz=UTC), + user=domain.User( + user_id='235678', + email='foo@foo.com', + username='foouser' + ), + authorizations=domain.Authorizations( + scopes=[scopes.CREATE_SUBMISSION] + ) + ) + assert not hasattr(request, 'called') + assert not hasattr(request, 'authorizer_called') + + def return_true(session: domain.Session) -> bool: + request.authorizer_called = True + return True + + @decorators.scoped(scopes.CREATE_SUBMISSION, authorizer=return_true) + def protected(): + """A protected function.""" + request.called = True + + protected() + + assert request.called + assert request.authorizer_called + +def test_session_has_global(mocker, request_context): + """Session has global scope, and authorizer func returns false.""" + with request_context: + mock_req = mocker.patch(f'{decorators.__name__}.request') + mock_req.auth = domain.Session( + session_id='fooid', + start_time=datetime.now(tz=UTC), + user=domain.User( + user_id='235678', + email='foo@foo.com', + username='foouser' + ), + authorizations=domain.Authorizations( + scopes=[domain.Scope(scopes.CREATE_SUBMISSION).as_global()] + ) + ) + + def return_false(session: domain.Session) -> bool: + return False + + @decorators.scoped(scopes.CREATE_SUBMISSION, authorizer=return_false) + def protected(): + """A protected function.""" + request.called = True + + protected() + assert request.called + + +def test_session_has_resource_scope(mocker, request_context): + """Session has resource scope, and authorizer func returns false.""" + with request_context: + mock_req = mocker.patch(f'{decorators.__name__}.request') + mock_req.auth = domain.Session( + session_id='fooid', + start_time=datetime.now(tz=UTC), + user=domain.User( + user_id='235678', + email='foo@foo.com', + username='foouser' + ), + authorizations=domain.Authorizations( + scopes=[domain.Scope(scopes.EDIT_SUBMISSION).for_resource('1')] + ) + ) + + def return_false(session: domain.Session) -> bool: + return False + + def get_resource(*args, **kwargs) -> bool: + return '1' + + @decorators.scoped(scopes.EDIT_SUBMISSION, resource=get_resource, + authorizer=return_false) + def protected(): + """A protected function.""" + request.called = True + + protected() + assert request.called diff --git a/arxiv/auth/auth/tests/test_extension.py b/arxiv/auth/auth/tests/test_extension.py new file mode 100644 index 00000000..32668382 --- /dev/null +++ b/arxiv/auth/auth/tests/test_extension.py @@ -0,0 +1,106 @@ +"""Tests for :class:`arxiv.users.auth.Auth`.""" +from logging import DEBUG +import pytest + +from datetime import datetime +from pytz import timezone, UTC +from ... import auth, domain + +EASTERN = timezone('US/Eastern') + +@pytest.fixture +def app_with_cookie(app): + app.config['CLASSIC_COOKIE_NAME'] = 'foo_cookie' + return app + +def test_no_session_legacy_available(mocker, app_with_cookie): + """No session is present on the request, but database is present.""" + inst = app_with_cookie.config['arxiv_auth.Auth'] + auth.logger.setLevel(DEBUG) + with app_with_cookie.test_request_context(): + mock_legacy = mocker.patch(f'{auth.__name__}.legacy') + mock_request = mocker.patch(f'{auth.__name__}.request') + mock_request.environ = {'auth': None, + 'HTTP_COOKIE': 'foo_cookie=sessioncookie123'} + + mock_legacy.is_configured.return_value = True + mock_legacy.sessions.load.return_value = None + + inst.load_session() + assert mock_request.auth is None + + assert mock_legacy.sessions.load.call_count == 1, "An attempt is made to load a legacy session" + +def test_legacy_is_valid(mocker, app_with_cookie): + """A valid legacy session is available.""" + inst = app_with_cookie.config['arxiv_auth.Auth'] + with app_with_cookie.test_request_context(): + mock_legacy = mocker.patch(f'{auth.__name__}.legacy') + mock_request = mocker.patch(f'{auth.__name__}.request') + mock_request.environ = {'auth': None, + 'HTTP_COOKIE': 'foo_cookie=sessioncookie123'} + + mock_request.auth = None + mock_legacy.is_configured.return_value = True + session = domain.Session( + session_id='fooid', + start_time=datetime.now(tz=UTC), + user=domain.User( + user_id='235678', + email='foo@foo.com', + username='foouser' + ), + authorizations=domain.Authorizations( + scopes=[auth.scopes.VIEW_SUBMISSION] + ) + ) + mock_legacy.sessions.load.return_value = session + + inst.load_session() + assert mock_request.auth == session, "Session is attached to the request at auth" + +def test_auth_session_rename(mocker, app_with_cookie): + """ + The auth session is accessed via ``request.auth``. + + Per ARXIVNG-1920 using ``request.auth`` is deprecated. + """ + inst = app_with_cookie.config['arxiv_auth.Auth'] + with app_with_cookie.test_request_context(): + mock_legacy = mocker.patch(f'{auth.__name__}.legacy') + mock_request = mocker.patch(f'{auth.__name__}.request') + mock_request.environ = {'auth': None, + 'HTTP_COOKIE': 'foo_cookie=sessioncookie123'} + + mock_request.auth = None + mock_legacy.is_configured.return_value = True + + authex = domain.Session( + session_id='fooid', + start_time=datetime.now(tz=UTC), + user=domain.User( + user_id='235678', + email='foo@foo.com', + username='foouser' + ), + authorizations=domain.Authorizations( + scopes=[auth.scopes.VIEW_SUBMISSION] + ) + ) + mock_legacy.sessions.load.return_value = authex + + inst.load_session() + assert mock_request.auth == authex, "Auth is attached to the request" + + + +def test_middleware_exception(mocker, app_with_cookie): + """Middleware has passed an exception.""" + inst = app_with_cookie.config['arxiv_auth.Auth'] + with app_with_cookie.test_request_context(): + mock_request = mocker.patch(f'{auth.__name__}.request') + mock_request.environ = {'auth': RuntimeError('Nope!')} + + + with pytest.raises(RuntimeError): + inst.load_session() diff --git a/arxiv/auth/auth/tests/test_middleware.py b/arxiv/auth/auth/tests/test_middleware.py new file mode 100644 index 00000000..0da509fb --- /dev/null +++ b/arxiv/auth/auth/tests/test_middleware.py @@ -0,0 +1,103 @@ +"""Test :mod:`arxiv.users.auth.middleware`.""" + +import os +from unittest import TestCase, mock +from datetime import datetime +from pytz import timezone, UTC +import json + +from flask import Flask, Blueprint +from flask import request, current_app + +from arxiv.base.middleware import wrap +from arxiv import status + +from ..middleware import AuthMiddleware +from .. import tokens, scopes +from ... import domain + +EASTERN = timezone('US/Eastern') + + +blueprint = Blueprint('fooprint', __name__, url_prefix='') + + +@blueprint.route('/public', methods=['GET']) +def public(): # type: ignore + """Return the request auth as a JSON document, or raise exceptions.""" + data = request.environ.get('auth') + if data: + if isinstance(data, Exception): + raise data + return domain.Session.parse_obj(data).json_safe_dict() + return json.dumps({}) + + +class TestAuthMiddleware(TestCase): + """Test :class:`.AuthMiddleware` on a Flask app.""" + + def setUp(self): + """Instantiate an app and attach middlware.""" + self.secret = 'foosecret' + self.app = Flask('foo') + os.environ['JWT_SECRET'] = self.secret + self.app.register_blueprint(blueprint) + wrap(self.app, [AuthMiddleware]) + self.client = self.app.test_client() + + def test_no_token(self): + """No token is passed in the request.""" + response = self.client.get('/public') + data = json.loads(response.data) + self.assertEqual(data, {}, "No session data is set") + + def test_user_token(self): + """A valid user token is passed in the request.""" + session = domain.Session( + session_id='foo1234', + start_time=datetime.now(tz=UTC), + user=domain.User( + user_id='43219', + email='foo@foo.com', + username='foouser', + name=domain.UserFullName(forename='Foo', surname='User') + ), + # client: Optional[Client] = None + # end_time: Optional[datetime] = None + authorizations=domain.Authorizations( + scopes=scopes.GENERAL_USER + ), + ip_address='10.10.10.10', + remote_host='foo-host.something.com', + nonce='asdfkjl3kmalkml;xml;mla' + ) + token = tokens.encode(session, self.secret) + response = self.client.get('/public', headers={'Authorization': token}) + data = json.loads(response.data) + self.assertEqual(data, session.json_safe_dict(), + "Session data are added to the request") + + def test_forged_user_token(self): + """A forged user token is passed in the request.""" + session = domain.Session( + session_id='foo1234', + start_time=datetime.now(tz=UTC), + user=domain.User( + user_id='43219', + email='foo@foo.com', + username='foouser', + name=domain.UserFullName(forename='Foo', surname='User') + ), + # client: Optional[Client] = None + # end_time: Optional[datetime] = None + authorizations=domain.Authorizations( + scopes=scopes.GENERAL_USER + ), + ip_address='10.10.10.10', + remote_host='foo-host.something.com', + nonce='asdfkjl3kmalkml;xml;mla' + ) + token = tokens.encode(session, 'notthesecret') + response = self.client.get('/public', headers={'Authorization': token}) + self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED, + "A 401 exception is passed by the middleware") diff --git a/arxiv/auth/auth/tests/test_tokens.py b/arxiv/auth/auth/tests/test_tokens.py new file mode 100644 index 00000000..2524a946 --- /dev/null +++ b/arxiv/auth/auth/tests/test_tokens.py @@ -0,0 +1,65 @@ +"""Tests for :mod:`arxiv.users.auth.tokens`.""" + +from unittest import TestCase +from datetime import datetime + +from arxiv.taxonomy.definitions import CATEGORIES +from .. import tokens +from ... import domain +from ...auth import scopes + + +class TestEncodeDecode(TestCase): + """Tests for :func:`tokens.encode` and :func:`tokens.decode`.""" + + def test_encode_session(self): + """Encode a typical user session.""" + session = domain.Session( + session_id='asdf1234', + start_time=datetime.now(), end_time=datetime.now(), + user=domain.User( + user_id='12345', + email='foo@bar.com', + username='emanresu', + name=domain.UserFullName(forename='First', surname='Last', suffix='Lastest'), + profile=domain.UserProfile( + affiliation='FSU', + rank=3, + country='us', + default_category=CATEGORIES['astro-ph.CO'], + submission_groups=['grp_physics'] + ) + ), + authorizations=domain.Authorizations(scopes=[scopes.VIEW_SUBMISSION, scopes.CREATE_SUBMISSION]) + ) + secret = 'foosecret' + token = tokens.encode(session, secret) + + data = tokens.decode(token, secret) + self.assertEqual(session, data) + + def test_mismatched_secrets(self): + """Secret used to encode is not the same as the one used to decode.""" + session = domain.Session( + session_id='asdf1234', + start_time=datetime.now(), end_time=datetime.now(), + user=domain.User( + user_id='12345', + email='foo@bar.com', + username='emanresu', + name=domain.UserFullName(forename='First', surname='Last', suffix='Lastest'), + profile=domain.UserProfile( + affiliation='FSU', + rank=3, + country='us', + default_category=CATEGORIES['astro-ph.CO'], + submission_groups=['grp_physics'] + ) + ), + authorizations=domain.Authorizations(scopes=[scopes.VIEW_SUBMISSION, scopes.CREATE_SUBMISSION]) + ) + secret = 'foosecret' + token = tokens.encode(session, secret) + + with self.assertRaises(tokens.exceptions.InvalidToken): + tokens.decode(token, 'not the secret') diff --git a/arxiv/auth/auth/tokens.py b/arxiv/auth/auth/tokens.py new file mode 100644 index 00000000..ac2ccc5a --- /dev/null +++ b/arxiv/auth/auth/tokens.py @@ -0,0 +1,49 @@ +""" +Functions for working with authn/z tokens on user/client requests. + +Encrypted JSON Web Tokens can be used inside the arXiv system to securely +convey authn/z information about each request. These tokens will usually be +generated by the :mod:`authorizer` in response to an +authorization subrequest from the web server, and contain information about +the identity of the user and/or client as well as authorization information +(e.g. :mod:`arxiv.users.auth.scopes`). + +It is essential that these JWTs are encrypted and decrypted precisely the same +way in all arXiv services, so we include these routines here for convenience. + +""" + +import jwt +from . import exceptions +from .. import domain + + +def encode(session: domain.Session, secret: str) -> str: + """ + Encode session information as an encrypted JWT. + + Parameters + ---------- + session : :class:`.domain.Session` + User or client session data, including authorization information. + secret : str + A secret key used to encrypt the token. This secret is required to + decode the token later on (e.g. in the application handling the + request). + + Returns + ------- + str + An encrypted JWT. + + """ + return jwt.encode(session.json_safe_dict(), secret) + + +def decode(token: str, secret: str) -> domain.Session: + """Decode an auth token to access session information.""" + try: + data = dict(jwt.decode(token, secret, algorithms=['HS256'])) + except jwt.exceptions.DecodeError as e: + raise exceptions.InvalidToken('Not a valid token') from e + return domain.session_from_dict(data) diff --git a/arxiv/auth/conftest.py b/arxiv/auth/conftest.py new file mode 100644 index 00000000..9d7b2483 --- /dev/null +++ b/arxiv/auth/conftest.py @@ -0,0 +1,44 @@ +import pytest +import os +# os.environ['CLASSIC_DB_URI'] = 'sqlite:///:memory:' + +from flask import Flask + +from ..base import Base +from ..config import Settings +from ..base.middleware import wrap +from ..db.models import configure_db + +from ..auth.auth import Auth +from ..auth.auth.middleware import AuthMiddleware + + +@pytest.fixture() +def app(): + + app = Flask('test_auth_app') + app.config['CLASSIC_DATABASE_URI'] = 'sqlite:///test.db' + app.config['CLASSIC_SESSION_HASH'] = f'fake set in {__file__}' + app.config['SESSION_DURATION'] = f'fake set in {__file__}' + app.config['CLASSIC_COOKIE_NAME'] = f'fake set in {__file__}' + app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///test.db' + app.config['AUTH_UPDATED_SESSION_REF'] = True + settings = Settings ( + CLASSIC_DB_URI = 'sqlite:///test.db', + LATEXML_DB_URI = None + ) + engine, _ = configure_db(settings) + app.config['DB_ENGINE'] = engine + + Base(app) + + Auth(app) + wrap(app, [AuthMiddleware]) + + + return app + + +@pytest.fixture() +def request_context(app): + yield app.test_request_context() diff --git a/arxiv/auth/domain.py b/arxiv/auth/domain.py new file mode 100644 index 00000000..a655a7d2 --- /dev/null +++ b/arxiv/auth/domain.py @@ -0,0 +1,357 @@ +"""Defines user concepts for use in arXiv services.""" + + +from typing import Any, Optional, List, NamedTuple +from collections.abc import Iterable + +from datetime import datetime +from pytz import timezone, UTC + +from pydantic import BaseModel, ConfigDict, ValidationError, validator +from arxiv.taxonomy.category import Category +from arxiv.taxonomy import definitions +from arxiv.db.models import Demographic + +EASTERN = timezone('US/Eastern') + +STAFF = ('1', 'Staff') +PROFESSOR = ('2', 'Professor') +POST_DOC = ('3', 'Post doc') +GRAD_STUDENT = ('4', 'Grad student') +OTHER = ('5', 'Other') +RANKS = [STAFF, PROFESSOR, POST_DOC, GRAD_STUDENT, OTHER] + + +def _check_category(data: Any) -> Category: + if isinstance(data, Category): + return data + if not isinstance(data, str): + raise ValidationError(f"object of type {type(data)} cannnot be used as a Category", Category) + cat = Category(data) + cat.name # possible rasie value error on non-existance + return cat + + +class UserProfile(BaseModel): + """User profile data.""" + + model_config = ConfigDict(arbitrary_types_allowed=True) + + affiliation: str + """Institutional affiliation.""" + + country: str + """Should be an ISO 3166-1 alpha-2 country code.""" + + rank: int + """Academic rank. Must be one of :const:`.RANKS`.""" + + submission_groups: List[str] + """ + Groups to which the user prefers to submit. + + Items should be one of :ref:`arxiv.taxonomy.definitions.GROUPS`. + """ + + default_category: Optional[Category] + """ + Default submission category. + + Should be one of :ref:`arxiv.taxonomy.CATEGORIES`. + """ + + # @validator('default_category') + # @classmethod + # def check_category(cls, data: Any) -> Category: + # """Checks if `data` is a category.""" + # return _check_category(data) + + homepage_url: str = '' + """User's homepage or external profile URL.""" + + remember_me: bool = True + """Indicates whether the user prefers permanent session cookies.""" + + @property + def rank_display(self) -> str: + """The display name of the user's rank.""" + _rank: str = dict(RANKS)[str(self.rank)] + return _rank + + @property + def default_archive(self) -> Optional[str]: + """The archive of the default category.""" + return self.default_category.in_archive if self.default_category else None + + @property + def default_subject(self) -> Optional[str]: + """The subject of the default category.""" + if self.default_category is not None: + subject: str + if '.' in self.default_category.id: + subject = self.default_category.id.split('.', 1)[1] + else: + subject = self.default_category.id + return subject + return None + + @property + def groups_display(self) -> str: + """Display-ready representation of active groups for this profile.""" + return ", ".join([ + definitions.GROUPS[group]['name'] + for group in self.submission_groups + ]) + + @staticmethod + def from_orm (model: Demographic) -> 'UserProfile': + if model.subject_class: + category = definitions.CATEGORIES[f'{model.archive}.{model.subject_class}'] + elif model.archive: + category = definitions.CATEGORIES[f'{model.archive}'] + else: + category = None + + return UserProfile( + affiliation=model.affiliation, + country=model.country, + rank=model.type, + submission_groups=model.groups, + default_category=category, + homepage_url=model.url, + ) + + +class Scope(str): + """Represents an authorization policy.""" + + def __new__(cls, domain, action=None, resource=None): + """Handle __new__.""" + return str.__new__(cls, cls.from_parts(domain, action, resource)) + + @property + def domain(self) -> str: + """ + The domain to which the scope applies. + + This will generally refer to a specific service. + """ + return self.parts[0] + + @property + def action(self) -> str: + """An action within ``domain`.""" + return self.parts[1] + + @property + def resource(self) -> Optional[str]: + """The specific resource to which this policy applies.""" + return self.parts[2] + + @property + def parts(self) -> str: + """Get parts of the Scope.""" + parts = self.split(':') + parts = parts + [None] * (3 - len(parts)) + return parts + + def for_resource(self, resource_id: str) -> 'Scope': + """Create a copy of this scope with a specific resource.""" + return Scope(domain=self.domain, action=self.action, resource=resource_id) + + def as_global(self) -> 'Scope': + """Create a copy of this scope with a global resource.""" + return self.for_resource('*') + + @classmethod + def from_parts(cls, domain, action=None, resource=None): + """Create a scope string from parts.""" + return ":".join([o for o in [domain,action,resource] if o is not None]) + + @classmethod + def to_parts(cls, scopestr): + """Split a scop string to parts.""" + parts = scopestr.split(':') + return parts + [None] * (3 - len(parts)) + + @classmethod + def from_str(cls, scopestr:str) -> "Scope": + """Make a Scope from a string.""" + parts = cls.to_parts(scopestr) + return cls(domain=parts[0], action=parts[1], resource=parts[2]) + + +class Authorizations(BaseModel): + """Authorization information, e.g. associated with a :class:`.Session`.""" + + model_config = ConfigDict(arbitrary_types_allowed=True) + + classic: int = 0 + """Capability code associated with a user's session.""" + + scopes: List[str] = [] + """Authorized :class:`.scope`s. See also :mod:`arxiv.users.auth.scopes`.""" + + + @classmethod + def before_init(cls, data: dict) -> None: + if 'scopes' in data: + if type(data['scopes']) is str: + data['scopes'] = [ + Scope(*scope.split(':')) for scope + in data['scopes'].split() + ] + elif type(data['scopes']) is list: + data['scopes'] = [ + Scope(**scope) if type(scope) is dict + else Scope(*scope.split(':')) + for scope in data['scopes'] + ] + + +class UserFullName(BaseModel): + """Represents a user's full name.""" + + forename: str + """First name or given name.""" + + surname: str + """Last name or family name.""" + + suffix: Optional[str] = '' + """Any title or qualifier used as a suffix/postfix.""" + + +class User(BaseModel): + """Represents an arXiv user and their authorizations.""" + + username: str + """Slug-like username.""" + + email: str + """The user's primary e-mail address.""" + + user_id: Optional[str] = None + """Unique identifier for the user. If ``None``, the user does not exist.""" + + name: Optional[UserFullName] = None + """The user's full name (if available).""" + + profile: Optional[UserProfile] = None + """The user's account profile (if available).""" + + verified: bool = False + """Whether or not the users' e-mail address has been verified.""" + + # def asdict(self) -> dict: + # """Generate a dict representation of this :class:`.User`.""" + # data = super(User, self)._asdict() + # if self.name is not None: + # data['name'] = self.name._asdict() + # if self.profile is not None: + # data['profile'] = self.profile._asdict() + # return data + + # TODO: consider whether this information is relevant beyond the + # ``arxiv.users.legacy.authenticate`` module. + # + # approved: bool = True + # """Whether or not the users' account is approved.""" + # + # banned: bool = False + # """Whether or not the user has been banned.""" + # + # deleted: bool = False + # """Whether or not the user has been deleted.""" + + +class Client(BaseModel): + """API client.""" + + owner_id: str + """The arXiv user responsible for the client.""" + + client_id: Optional[str] = None + """Unique identifier for a :class:`.Client`.""" + + name: Optional[str] = None + """Human-friendly name of the API client.""" + + url: Optional[str] = None + """Homepage or other resource describing the API client.""" + + description: Optional[str] = None + """Brief description of the API client.""" + + redirect_uri: Optional[str] = None + """The authorized redirect URI for the client.""" + + +class Session(BaseModel): + """Represents an authenticated session in the arXiv system.""" + + model_config = ConfigDict(arbitrary_types_allowed=True) + + session_id: str + """Unique identifier for the session.""" + + start_time: datetime + """The ISO-8601 datetime when the session was created.""" + + user: Optional[User] = None + """The user for which the session was created.""" + + client: Optional[Client] = None + """The client for which the session was created.""" + + end_time: Optional[datetime] = None + """The ISO-8601 datetime when the session ended.""" + + authorizations: Optional[Authorizations] = None + """Authorizations for the current session.""" + + ip_address: Optional[str] = None + """The IP address of the client for which the session was created.""" + + remote_host: Optional[str] = None + """The hostname of the client for which the session was created.""" + + nonce: Optional[str] = None + """A pseudo-random nonce generated when the session was created.""" + + def is_authorized(self, scope: Scope, resource: str) -> bool: + """Check whether this session is authorized for a specific resource.""" + return (self.authorizations is not None and ( + scope.as_global() in self.authorizations.scopes + or scope.for_resource(resource) in self.authorizations.scopes)) + + @property + def expired(self) -> bool: + """Expired if the current time is later than :attr:`.end_time`.""" + return bool(self.end_time is not None + and datetime.now(tz=UTC) >= self.end_time) + + @property + def expires(self) -> Optional[int]: + """ + Number of seconds until the session expires. + + If the session is already expired, returns 0. + """ + if self.end_time is None: + return None + duration = (self.end_time - datetime.now(tz=UTC)).total_seconds() + return int(max(duration, 0)) + + def json_safe_dict(self) -> dict: + """Creates a json dict with the datetimes converted to ISO datetime strs.""" + out = self.dict() + if self.start_time: + out['start_time'] = self.start_time.isoformat() + if self.end_time: + out['end_time'] = self.end_time.isoformat() + return out + +def session_from_dict(data: dict) -> Session: + """Create a Session from a dict.""" + return Session.parse_obj(data) diff --git a/arxiv/auth/helpers.py b/arxiv/auth/helpers.py new file mode 100644 index 00000000..47f75576 --- /dev/null +++ b/arxiv/auth/helpers.py @@ -0,0 +1,49 @@ +"""Helpers and utilities for :mod:`arxiv.users`.""" + +import os +from typing import List +from pytz import timezone, UTC +import uuid +from datetime import timedelta, datetime +from . import auth, domain +from arxiv.base.globals import get_application_config +from arxiv.taxonomy.definitions import CATEGORIES + + +def generate_token(user_id: str, email: str, username: str, + first_name: str = 'Jane', last_name: str = 'Doe', + suffix_name: str = 'IV', + affiliation: str = 'Cornell University', + rank: int = 3, + country: str = 'us', + default_category: domain.Category = CATEGORIES['astro-ph.GA'], + submission_groups: str = 'grp_physics', + scope: List[domain.Scope] = [], + verified: bool = False) -> str: + """Generate an auth token for dev/testing purposes.""" + # Specify the validity period for the session. + start = datetime.now(tz=timezone('US/Eastern')) + end = start + timedelta(seconds=36000) # Make this as long as you want. + + # Create a user + session = domain.Session( + session_id=str(uuid.uuid4()), + start_time=start, end_time=end, + user=domain.User( + user_id=user_id, + email=email, + username=username, + name=domain.UserFullName(forename=first_name, surname=last_name, suffix=suffix_name), + profile=domain.UserProfile( + affiliation=affiliation, + rank=int(rank), + country=country, + default_category=default_category, + submission_groups=submission_groups.split(',') + ), + verified=verified + ), + authorizations=domain.Authorizations(scopes=scope) + ) + token = auth.tokens.encode(session, get_application_config()['JWT_SECRET']) + return token diff --git a/arxiv/auth/legacy/__init__.py b/arxiv/auth/legacy/__init__.py new file mode 100644 index 00000000..820be948 --- /dev/null +++ b/arxiv/auth/legacy/__init__.py @@ -0,0 +1,8 @@ +""" +Integrations with the legacy arXiv database for users and sessions. + +This package provides integrations with legacy user and sessions data in the +classic DB. These components were pulled out as a separate package because +they are required by both the accounts service and the authn/z middlware, +and maintaining them in both places would create too much duplication. +""" diff --git a/arxiv/auth/legacy/accounts.py b/arxiv/auth/legacy/accounts.py new file mode 100644 index 00000000..0f418b31 --- /dev/null +++ b/arxiv/auth/legacy/accounts.py @@ -0,0 +1,304 @@ +"""Provide methods for working with user accounts.""" + +from typing import Optional, Tuple, Any +import logging + +from sqlalchemy.exc import OperationalError + +from .. import domain +from . import util, exceptions +from .passwords import hash_password +from .exceptions import Unavailable +from ...db import session +from ...db.models import TapirUser, TapirUsersPassword, \ + TapirNickname, Demographic + + +logger = logging.getLogger(__name__) + + +def does_username_exist(username: str) -> bool: + """ + Determine whether a user with a particular username already exists. + + Parameters + ---------- + username : str + + Returns + ------- + bool + + """ + try: + data = session.query(TapirNickname) \ + .filter(TapirNickname.nickname == username) \ + .first() + except OperationalError as e: + raise Unavailable('Database is temporarily unavailable') from e + if data: + return True + return False + + +def does_email_exist(email: str) -> bool: + """ + Determine whether a user with a particular address already exists. + + Parameters + ---------- + email : str + + Returns + ------- + bool + + """ + try: + data = session.query(TapirUser).filter(TapirUser.email == email).first() + print (session.query(TapirUser).all()) + except OperationalError as e: + raise Unavailable('Database is temporarily unavailable') from e + if data: + return True + return False + + +def register(user: domain.User, password: str, ip: str, + remote_host: str) -> Tuple[domain.User, domain.Authorizations]: + """ + Create a new user. + + Parameters + ---------- + user : :class:`.domain.User` + User data for the new account. + password : str + Password for the account. + ip : str + The IP address of the client requesting the registration. + remote_host : str + The remote hostname of the client requesting the registration. + + Returns + ------- + :class:`.domain.User` + Data about the created user. + :class:`.domain.Authorizations` + Privileges attached to the created user. + + """ + try: + db_user, db_nick, db_profile = _create(user, password, ip, remote_host) + session.commit() + except OperationalError as e: + raise Unavailable('Database is temporarily unavailable') from e + except Exception as e: + logger.debug(e) + raise exceptions.RegistrationFailed('Could not create user')# from e + + user = domain.User( + user_id=str(db_user.user_id), + username=db_nick.nickname, + email=db_user.email, + name=domain.UserFullName( + forename=db_user.first_name, + surname=db_user.last_name, + suffix=db_user.suffix_name + ), + profile=domain.UserProfile.from_orm(db_profile) if db_profile is not None else None + ) + auths = domain.Authorizations( + classic=util.compute_capabilities(db_user), + scopes=util.get_scopes(db_user), + ) + return user, auths + + +def get_user_by_id(user_id: str) -> domain.User: + """Load user data from the database.""" + try: + db_user, db_nick, db_profile = _get_user_data(user_id) + except OperationalError as e: + raise Unavailable('Database is temporarily unavailable') from e + user = domain.User( + user_id=str(db_user.user_id), + username=db_nick.nickname, + email=db_user.email, + name=domain.UserFullName( + forename=db_user.first_name, + surname=db_user.last_name, + suffix=db_user.suffix_name + ), + profile=domain.UserProfile.from_orm(db_profile) if db_profile is not None else None + ) + return user + + +def update(user: domain.User) -> Tuple[domain.User, domain.Authorizations]: + """Update a user in the database.""" + if user.user_id is None: + raise ValueError('User ID must be set') + + db_user, db_nick, db_profile = _get_user_data(user.user_id) + # TODO: we probably want to think a bit more about changing usernames + # and e-mail addresses. + # + # _update_field_if_changed(db_nick, 'nickname', user.username) + # _update_field_if_changed(db_user, 'email', user.email) + if user.name is not None: + _update_field_if_changed(db_user, 'first_name', user.name.forename) + _update_field_if_changed(db_user, 'last_name', user.name.surname) + _update_field_if_changed(db_user, 'suffix_name', user.name.suffix) + if user.profile is not None: + if db_profile is not None: + def _has_group(group: str) -> int: + if user.profile is None: + return 0 + return int(group in user.profile.submission_groups) + + _update_field_if_changed(db_profile, 'affiliation', + user.profile.affiliation) + _update_field_if_changed(db_profile, 'country', + user.profile.country) + _update_field_if_changed(db_profile, 'type', user.profile.rank) + _update_field_if_changed(db_profile, 'url', + user.profile.homepage_url) + _update_field_if_changed(db_profile, 'archive', + user.profile.default_archive) + _update_field_if_changed(db_profile, 'subject_class', + user.profile.default_subject) + for grp, field in Demographic.GROUP_FLAGS: + _update_field_if_changed(db_profile, field, + _has_group(grp)) + session.add(db_profile) + else: + db_profile = _create_profile(user, db_user) + + session.add(db_nick) + session.add(db_user) + session.commit() + + user = domain.User( + user_id=str(db_user.user_id), + username=db_nick.nickname, + email=db_user.email, + name=domain.UserFullName( + forename=db_user.first_name, + surname=db_user.last_name, + suffix=db_user.suffix_name + ), + profile=domain.UserProfile.from_orm(db_profile) if db_profile is not None else None + ) + auths = domain.Authorizations( + classic=util.compute_capabilities(db_user), + scopes=util.get_scopes(db_user), + ) + return user, auths + + +def _update_field_if_changed(obj: Any, field: Any, update_with: Any) -> None: + if getattr(obj, field) != update_with: + setattr(obj, field, update_with) + + +def _get_user_data(user_id: str) -> Tuple[TapirUser, TapirNickname, Demographic]: + + try: + db_user, db_nick = session.query(TapirUser, TapirNickname) \ + .filter(TapirUser.user_id == user_id) \ + .filter(TapirUser.flag_approved == 1) \ + .filter(TapirUser.flag_deleted == 0) \ + .filter(TapirUser.flag_banned == 0) \ + .filter(TapirNickname.flag_primary == 1) \ + .filter(TapirNickname.flag_valid == 1) \ + .filter(TapirNickname.user_id == TapirUser.user_id) \ + .first() + except TypeError: # first() returns a single None if no match. + raise exceptions.NoSuchUser('User does not exist') + # Profile may not exist. + db_profile = session.query(Demographic) \ + .filter(Demographic.user_id == user_id) \ + .first() + if not db_user: + raise exceptions.NoSuchUser('User does not exist') + return db_user, db_nick, db_profile + + +def _create_profile(user: domain.User, db_user: TapirUser) -> Demographic: + def _has_group(group: str) -> int: + if user.profile is None: + return 0 + return int(group in user.profile.submission_groups) + + db_profile = Demographic( + user=db_user, + country=user.profile.country if user.profile else None, + affiliation=user.profile.affiliation if user.profile else None, + url=user.profile.homepage_url if user.profile else None, + type=user.profile.rank if user.profile else None, + archive=user.profile.default_archive if user.profile else None, + subject_class=user.profile.default_subject if user.profile else None, + original_subject_classes='', + flag_group_physics=_has_group('grp_physics'), + flag_group_math=_has_group('grp_math'), + flag_group_cs=_has_group('grp_cs'), + flag_group_q_bio=_has_group('grp_q-bio'), + flag_group_q_fin=_has_group('grp_q-fin'), + flag_group_stat=_has_group('grp_stat'), + flag_group_eess=_has_group('grp_eess'), + flag_group_econ=_has_group('grp_econ'), + ) + session.add(db_profile) + return db_profile + + +def _create(user: domain.User, password: str, ip: str, remote_host: str) \ + -> Tuple[TapirUser, TapirNickname, Optional[Demographic]]: + if not user.name.forename: + raise ValueError("Must have forename to create user") + if not user.name.surname: + raise ValueError("Must have surname to create user") + data = dict( + email=user.email, + policy_class=2, + joined_ip_num=ip, + joined_remote_host=remote_host, + joined_date=util.now(), + tracking_cookie='' # TODO: set this. + ) + if user.name is not None: + data.update(dict( + first_name=user.name.forename, + last_name=user.name.surname, + suffix_name=user.name.suffix + )) + print (data) + + # Main user entry. + db_user = TapirUser(**data) + session.add(db_user) + # Nickname is where we keep the username. + db_nick = TapirNickname( + user=db_user, + nickname=user.username, + flag_valid=1, + flag_primary=1 + ) + session.add(db_nick) + + db_profile: Optional[Demographic] + if user.profile is not None: + db_profile = _create_profile(user, db_user) + else: + db_profile = None + + db_pass = TapirUsersPassword( + user=db_user, + password_storage=2, + password_enc=hash_password(password) + ) + session.add(db_pass) + from sqlalchemy import select + print(session.execute(select(TapirUser.email)).all()) + return db_user, db_nick, db_profile diff --git a/arxiv/auth/legacy/authenticate.py b/arxiv/auth/legacy/authenticate.py new file mode 100644 index 00000000..8992e81f --- /dev/null +++ b/arxiv/auth/legacy/authenticate.py @@ -0,0 +1,328 @@ +"""Provide an API for user authentication using the legacy database.""" + +from typing import Optional, Tuple +import logging + +from sqlalchemy.exc import OperationalError + +from . import util +from .. import domain + +from . passwords import check_password, is_ascii +from ...db import session +from ...db.models import TapirUser, TapirUsersPassword, TapirPermanentToken, \ + TapirNickname, Demographic +from .exceptions import NoSuchUser, AuthenticationFailed, \ + PasswordAuthenticationFailed, Unavailable + +logger = logging.getLogger(__name__) + +PassData = Tuple[TapirUser, TapirUsersPassword, TapirNickname, Demographic] + + +def authenticate(username_or_email: Optional[str] = None, + password: Optional[str] = None, token: Optional[str] = None) \ + -> Tuple[domain.User, domain.Authorizations]: + """ + Validate username/password. If successful, retrieve user details. + + Parameters + ---------- + username_or_email : str + Users may log in with either their username or their email address. + password : str + Password (as entered). Danger, Will Robinson! + token : str + Alternatively, the user may provide a bearer token. This is currently + used to support "permanent" sessions, in which the token is used to + "automatically" log the user in (i.e. without entering credentials). + + Returns + ------- + :class:`domain.User` + :class:`domain.Authorizations` + + Raises + ------ + :class:`AuthenticationFailed` + Failed to authenticate user with provided credentials. + :class:`Unavailable` + Unable to connect to DB. + """ + try: + if username_or_email and password: + passdata = _authenticate_password(username_or_email, password) + # The "tapir permanent token" is effectively a bearer token. If passed, + # a new session will be "automatically" created (from the user's + # perspective). + elif token: + db_token = _authenticate_token(token) + passdata = _get_user_by_user_id(db_token.user_id) + else: + logger.debug('Neither username/password nor token provided') + raise AuthenticationFailed('Username+password or token required') + except OperationalError as e: + # Note OperationalError can be a lot of different things and not just + # the DB being unavailable. So this message can be deceptive. + raise Unavailable('Database is temporarily unavailable') from e + except Exception as ex: + raise AuthenticationFailed() from ex + + db_user, _, db_nick, db_profile = passdata + user = domain.User( + user_id=str(db_user.user_id), + username=db_nick.nickname, + email=db_user.email, + name=domain.UserFullName( + forename=db_user.first_name, + surname=db_user.last_name, + suffix=db_user.suffix_name + ), + profile=domain.UserProfile.from_orm(db_profile) if db_profile else None, + verified=bool(db_user.flag_email_verified) + ) + auths = domain.Authorizations( + classic=util.compute_capabilities(db_user), + scopes=util.get_scopes(db_user), + ) + return user, auths + + +def _authenticate_token(token: str) -> TapirPermanentToken: + """ + Authenticate using a permanent token. + + Parameters + ---------- + token : str + + Returns + ------- + :class:`.TapirUser` + :class:`.TapirPermanentToken` + :class:`.TapirNickname` + :class:`.Demographic` + + Raises + ------ + :class:`AuthenticationFailed` + Raised if the token is malformed, or there is no corresponding token + in the database. + + """ + try: + user_id, secret = token.split('-') + except ValueError as e: + raise AuthenticationFailed('Token is malformed') from e + try: + return _get_token(user_id, secret) + except NoSuchUser as e: + logger.debug('Not a valid permanent token') + raise AuthenticationFailed('Invalid token') from e + + +def _authenticate_password(username_or_email: str, password: str) -> PassData: + """ + Authenticate using username/email and password. + + Parameters + ---------- + username_or_email : str + Either the email address or username of the authenticating user. + password : str + + Returns + ------- + :class:`.TapirUser` + :class:`.TapirUsersPassword` + :class:`.TapirNickname` + + Raises + ------ + :class:`AuthenticationFailed` + Raised if the user does not exist or the password is incorrect. + :class:`RuntimeError` + Raised when other problems arise. + + """ + logger.debug(f'Authenticate with password, user: {username_or_email}') + + if not password: + raise ValueError('Passed empty password') + if not isinstance(password, str): + raise ValueError('Passed non-str password: {type(password)}') + if not is_ascii(password): + raise ValueError('Password non-ascii password') + + if not username_or_email: + raise ValueError('Passed empty username_or_email') + if not isinstance(password, str): + raise ValueError('Passed non-str username_or_email: {type(username_or_email)}') + if len(username_or_email) > 255: + raise ValueError(f'Passed username_or_email too long: len {len(username_or_email)}') + if not is_ascii(username_or_email): + raise ValueError('Passed non-ascii username_or_email') + + if '@' in username_or_email: + passdata = _get_user_by_email(username_or_email) + else: + passdata = _get_user_by_username(username_or_email) + + db_user, db_pass, db_nick, db_profile = passdata + logger.debug(f'Got user with user_id: {db_user.user_id}') + try: + if check_password(password, db_pass.password_enc): + return passdata + except PasswordAuthenticationFailed as e: + raise AuthenticationFailed('Invalid username or password') from e + + +def _get_user_by_user_id(user_id: int) -> PassData: + tapir_user: TapirUser = session.query(TapirUser) \ + .filter(TapirUser.user_id == int(user_id)) \ + .filter(TapirUser.flag_approved == 1) \ + .filter(TapirUser.flag_deleted == 0) \ + .filter(TapirUser.flag_banned == 0) \ + .first() + return _get_passdata(tapir_user) + + +def _get_user_by_email(email: str) -> PassData: + if not email or '@' not in email: + raise ValueError("must be an email address") + tapir_user: TapirUser = session.query(TapirUser) \ + .filter(TapirUser.email == email) \ + .filter(TapirUser.flag_approved == 1) \ + .filter(TapirUser.flag_deleted == 0) \ + .filter(TapirUser.flag_banned == 0) \ + .first() + return _get_passdata(tapir_user) + + +def _get_user_by_username(username: str) -> PassData: + """Username is the tapir nickname.""" + if not username or '@' in username: + raise ValueError("username must not contain a @") + tapir_nick = session.query(TapirNickname) \ + .filter(TapirNickname.nickname == username) \ + .filter(TapirNickname.flag_valid == 1) \ + .first() + if not tapir_nick: + raise NoSuchUser('User lacks a nickname') + + tapir_user = session.query(TapirUser) \ + .filter(TapirUser.user_id == tapir_nick.user_id) \ + .filter(TapirUser.flag_approved == 1) \ + .filter(TapirUser.flag_deleted == 0) \ + .filter(TapirUser.flag_banned == 0) \ + .first() + return _get_passdata(tapir_user) + + +def _get_passdata(tapir_user: TapirUser) -> PassData: + """ + Retrieve password, nick name and profile data. + + Parameters + ---------- + username_or_email : str + + Returns + ------- + :class:`.TapirUser` + :class:`.TapirUsersPassword` + :class:`.TapirNickname` + :class:`.Demographic` + + Raises + ------ + :class:`NoSuchUser` + Raised when the user cannot be found. + :class:`RuntimeError` + Raised when other problems arise. + + """ + if not tapir_user: + raise NoSuchUser('User does not exist') + + tapir_nick = session.query(TapirNickname) \ + .filter(TapirNickname.user_id ==tapir_user.user_id) \ + .filter(TapirNickname.flag_valid == 1) \ + .first() + if not tapir_nick: + raise NoSuchUser('User lacks a nickname') + + tapir_password: TapirUsersPassword = session.query(TapirUsersPassword) \ + .filter(TapirUsersPassword.user_id == tapir_user.user_id) \ + .first() + if not tapir_password: + raise RuntimeError(f'Missing password') + + tapir_profile: Demographic = session.query(Demographic) \ + .filter(Demographic.user_id == tapir_user.user_id) \ + .first() + return tapir_user, tapir_password, tapir_nick, tapir_profile + + +def _invalidate_token(user_id: str, secret: str) -> None: + """ + Invalidate a user's permanent login token. + + Parameters + ---------- + user_id : str + secret : str + + Raises + ------ + :class:`NoSuchUser` + Raised when the token or user cannot be found. + + """ + db_token = _get_token(user_id, secret) + db_token.valid = 0 + session.add(db_token) + session.commit() + + +def _get_token(user_id: str, secret: str) -> TapirPermanentToken: + """ + Retrieve a user's permanent token. + + User ID and token are used together as the primary key for the token. + + Parameters + ---------- + user_id : str + secret : str + valid : int + (default: 1) + + Returns + ------- + :class:`.TapirPermanentToken` + + Raises + ------ + :class:`NoSuchUser` + Raised when the token or user cannot be found. + + """ + if not user_id.isdigit(): + raise ValueError("user_id must be digits") + if not user_id: + raise ValueError("user_id must not be empty") + if len(user_id) > 50: + raise ValueError("user_id too long") + if len(secret) > 40: + raise ValueError("secret too long") + + db_token: TapirPermanentToken = session.query(TapirPermanentToken) \ + .filter(TapirPermanentToken.user_id == user_id) \ + .filter(TapirPermanentToken.secret == secret) \ + .filter(TapirPermanentToken.valid == 1) \ + .first() # The token must still be valid. + if not db_token: + raise NoSuchUser('No such token') + else: + return db_token diff --git a/arxiv/auth/legacy/cookies.py b/arxiv/auth/legacy/cookies.py new file mode 100644 index 00000000..88e66a21 --- /dev/null +++ b/arxiv/auth/legacy/cookies.py @@ -0,0 +1,137 @@ +"""Provides functions for working with legacy session cookies. + +The legacy cookie is 6 parts seperated with ':'. + +The parts are: +1. session id +2. tapir_users.user_id +3. ip the session was started at +4. time issued at as unix epoch +5. capabilities +6. b64 encoced sha1 hash of parts 1-5 + +In a way it is similar to a JWT where parts 1-5 are similar to the JWT +payload, though fixed in strucutre, and part 6 forms the signature. + +Parts 1-5 are not b64 encoded. +""" + +from typing import Tuple, List +from base64 import b64encode +import hashlib +from datetime import datetime, timedelta + +from werkzeug.http import parse_cookie +from werkzeug.datastructures import MultiDict + +from .exceptions import InvalidCookie +from . import util + + +def unpack(cookie: str) -> Tuple[str, str, str, datetime, datetime, str]: + """ + Unpack the legacy session cookie. + + Parameters + ---------- + cookie : str + The value of session cookie. + + Returns + ------- + str + The session ID associated with the cookie. + str + The user ID of the authenticated account. + str + The IP address of the client when the session was created. + datetime + The datetime when the session was created. + datetime + The datetime when the session expires. + str + Legacy user privilege level. + + Raises + ------ + :class:`InvalidCookie` + Raised if the cookie is detectably malformed or tampered with. + + """ + parts = cookie.split(':') + if len(parts) < 5: + raise InvalidCookie('Malformed cookie') + + session_id = parts[0] + user_id = parts[1] + ip = parts[2] + issued_at = util.from_epoch(int(parts[3])) + expires_at = issued_at + timedelta(seconds=util.get_session_duration()) + capabilities = parts[4] + try: + expected = pack(session_id, user_id, ip, issued_at, capabilities) + except Exception as e: + raise InvalidCookie('Invalid session cookie; problem while repacking') from e + + if expected == cookie: + return session_id, user_id, ip, issued_at, expires_at, capabilities + else: + raise InvalidCookie('Invalid session cookie; not as expected') + + +def pack(session_id: str, user_id: str, ip: str, issued_at: datetime, + capabilities: str) -> str: + """ + Generate a value for the classic session cookie. + + Parameters + ---------- + session_id : str + The session ID associated with the cookie. + user_id : str + The user ID of the authenticated account. + ip : str + Client IP address. + issued_at : datetime + The UNIX time at which the session was initiated. + capabilities : str + This is essentially a user privilege level. + + Returns + ------- + str + Signed session cookie value. + + """ + session_hash = util.get_session_hash() + value = ':'.join(map(str, [session_id, user_id, ip, util.epoch(issued_at), + capabilities])) + to_sign = f'{value}-{session_hash}'.encode('utf-8') + cookie_hash = b64encode(hashlib.sha1(to_sign).digest()) + return value + ':' + cookie_hash.decode('utf-8')[:-1] + + +def get_cookies(request, cookie_name:str) -> List[str]: + """Gets list of legacy cookies. + + Duplicate cookies occur due to the browser sending both the + cookies for both arxiv.org and sub.arxiv.org. If this is being + served at sub.arxiv.org, there is no response that will cause + the browser to alter its cookie store for arxiv.org. Duplicate + cookies must be handled gracefully to for the domain and + subdomain to coexist. + + The standard way to avoid this problem is to append part of + the domain's name to the cookie key but this needs to work + even if the configuration is not ideal. + + """ + # By default, werkzeug uses a dict-based struct that supports only a + # single value per key. This isn't really up to speed with RFC 6265. + # Luckily we can just pass in an alternate struct to parse_cookie() + # that can cope with multiple values. + raw_cookie = request.environ.get('HTTP_COOKIE', None) + if raw_cookie is None: + return [] + cookies = parse_cookie(raw_cookie, cls=MultiDict) + return cookies.getlist(cookie_name) diff --git a/arxiv/auth/legacy/endorsements.py b/arxiv/auth/legacy/endorsements.py new file mode 100644 index 00000000..05daa774 --- /dev/null +++ b/arxiv/auth/legacy/endorsements.py @@ -0,0 +1,420 @@ +""" +Provide endorsement authorizations for users. + +Endorsements are authorization scopes tied to specific classificatory +categories, and are used primarily to determine whether or not a user may +submit a paper with a particular primary or secondary classification. + +This module preserves the behavior of the legacy system with respect to +interpreting endorsements and evaluating potential autoendorsement. The +relevant policies can be found on the `arXiv help pages +`_. +""" + +from typing import List, Dict, Optional, Set, Union +from collections import Counter +from datetime import datetime +from functools import lru_cache as memoize +from itertools import groupby + +from sqlalchemy.sql.expression import literal + +from . import util +from .. import domain +from ...taxonomy import definitions +from ...db import session +from ...db.models import Endorsement, PaperOwner, Document, \ + t_arXiv_in_category, Category, EndorsementDomain, t_arXiv_white_email, \ + t_arXiv_black_email + + +GENERAL_CATEGORIES = [ + definitions.CATEGORIES['math.GM'], + definitions.CATEGORIES['physics.gen-ph'], +] + +WINDOW_START = util.from_epoch(157783680) + +Endorsements = List[Union[domain.Category,str]] + + +def get_endorsements(user: domain.User) -> Endorsements: + """ + Get all endorsements (explicit and implicit) for a user. + + Parameters + ---------- + user : :class:`.domain.User` + + Returns + ------- + list + Each item is a :class:`.domain.Category` for which the user is + either explicitly or implicitly endorsed. + + """ + endorsements = list(set(explicit_endorsements(user)) + | set(implicit_endorsements(user))) + + return endorsements + + +@memoize() +def _categories_in_archive(archive: str) -> Set[str]: + return set(category for category, definition + in definitions.CATEGORIES_ACTIVE.items() + if definition.in_archive == archive) + + +@memoize() +def _category(archive: str, subject_class: str) -> domain.Category: + if subject_class: + return definitions.CATEGORIES[f'{archive}.{subject_class}'] + return definitions.CATEGORIES[archive] + + +@memoize() +def _get_archive(category: domain.Category) -> str: + return category.in_archive + + +def _all_archives(endorsements: Endorsements) -> bool: + archives = set(_get_archive(category) for category in endorsements + if category.id.endswith(".*")) + missing = set(definitions.ARCHIVES_ACTIVE.keys()) - archives + return len(missing) == 0 or (len(missing) == 1 and 'test' in missing) + + +def _all_subjects_in_archive(archive: str, endorsements: Endorsements) -> bool: + return len(_categories_in_archive(archive) - set(endorsements)) == 0 + + +def compress_endorsements(endorsements: Endorsements) -> Endorsements: + """ + Compress endorsed categories using wildcard notation if possible. + + We want to avoid simply enumerating all of the categories that exist. If + all subjects in an archive are present, we represent that as "{archive}.*". + If all subjects in all archives are present, we represent that as "*.*". + + Parameters + ---------- + endorsements : list + A list of endorsed categories. + + Returns + ------- + list + The same endorsed categories, compressed with wildcards where possible. + + """ + compressed: Endorsements = [] + grouped = groupby(sorted(endorsements, key=_get_archive), key=_get_archive) + for archive, archive_endorsements in grouped: + archive_endorsements_list = list(archive_endorsements) + if _all_subjects_in_archive(archive, archive_endorsements_list): + compressed.append(domain.Category(id=f"{archive}.*", + full_name=f"all of {archive}", + is_active=True, + in_archive=archive, + is_general=False, + )) + + else: + for endorsement in archive_endorsements_list: + compressed.append(endorsement) + if _all_archives(compressed): + return list(definitions.CATEGORIES.values()) + return compressed + + +def explicit_endorsements(user: domain.User) -> Endorsements: + """ + Load endorsed categories for a user. + + These are endorsements (including auto-endorsements) that have been + explicitly commemorated. + + Parameters + ---------- + user : :class:`.domain.User` + + Returns + ------- + list + Each item is a :class:`.domain.Category` for which the user is + explicitly endorsed. + + """ + data: List[Endorsement] = ( + session.query( + Endorsement.archive, + Endorsement.subject_class, + Endorsement.point_value, + ) + .filter(Endorsement.endorsee_id == user.user_id) + .filter(Endorsement.flag_valid == 1) + .all() + ) + pooled: Counter = Counter() + for archive, subject, points in data: + pooled[_category(archive, subject)] += points + return [category for category, points in pooled.items() if points] + + +def implicit_endorsements(user: domain.User) -> Endorsements: + """ + Determine categories for which a user may be autoendorsed. + + In the classic system, this was determined upon request, when the user + attempted to submit to a particular category. Because we are separating + authorization concerns (which includes endorsement) from the submission + system itself, we want to calculate possible autoendorsement categories + ahead of time. + + New development of autoendorsement-related functionality should not happen + here. This function and related code are intended only to preserve the + business logic already implemented in the classic system. + + Parameters + ---------- + :class:`.User` + + Returns + ------- + list + Each item is a :class:`.domain.Category` for which the user may be + auto-endorsed. + + """ + candidates = [definitions.CATEGORIES[category] + for category, data in definitions.CATEGORIES_ACTIVE.items()] + policies = category_policies() + invalidated = invalidated_autoendorsements(user) + papers = domain_papers(user) + user_is_academic = is_academic(user) + return [ + category for category in candidates + if category in policies + and not _disqualifying_invalidations(category, invalidated) + and (policies[category]['endorse_all'] + or _endorse_by_email(category, policies, user_is_academic) + or _endorse_by_papers(category, policies, papers)) + ] + + +def is_academic(user: domain.User) -> bool: + """ + Determine whether a user is academic, based on their email address. + + Uses whitelist and blacklist patterns in the database. + + Parameters + ---------- + user : :class:`.domain.User` + + Returns + ------- + bool + + """ + in_whitelist = ( + session.query(t_arXiv_white_email.c.pattern) + .filter(literal(user.email).like(t_arXiv_white_email.c.pattern)) + .first() + ) + if in_whitelist: + return True + in_blacklist = ( + session.query(t_arXiv_black_email.c.pattern) + .filter(literal(user.email).like(t_arXiv_black_email.c.pattern)) + .first() + ) + if in_blacklist: + return False + return True + + +def _disqualifying_invalidations(category: domain.Category, + invalidated: Endorsements) -> bool: + """ + Evaluate whether endorsement invalidations are disqualifying. + + This enforces the policy that invalidated (revoked) auto-endorsements can + prevent future auto-endorsement. + + Parameters + ---------- + category : :class:`.Category` + The category for which an auto-endorsement is being considered. + invalidated : list + Categories for which the user has had auto-endorsements invalidated + (revoked). + + Returns + ------- + bool + + """ + return bool((category in GENERAL_CATEGORIES and category in invalidated) + or (category not in GENERAL_CATEGORIES and invalidated)) + + +def _endorse_by_email(category: domain.Category, + policies: Dict[domain.Category, Dict], + user_is_academic: bool) -> bool: + """ + Evaluate whether an auto-endorsement can be issued based on email address. + + This enforces the policy that some categories allow auto-endorsement for + academic users. + + Parameters + ---------- + category : :class:`.Category` + The category for which an auto-endorsement is being considered. + policies : dict + Describes auto-endorsement policies for each category (inherited from + their endorsement domains). + user_is_academic : bool + Whether or not the user has been determined to be academic. + + Returns + ------- + bool + + """ + policy = policies.get(category) + if policy is None or 'endorse_email' not in policy: + return False + return policy['endorse_email'] and user_is_academic + + +def _endorse_by_papers(category: domain.Category, + policies: Dict[domain.Category, Dict], + papers: Dict[str, int]) -> bool: + """ + Evaluate whether an auto-endorsement can be issued based on prior papers. + + This enforces the policy that some categories allow auto-endorsements for + users who have published a minimum number of papers in categories that + share an endoresement domain. + + Parameters + ---------- + category : :class:`.Category` + The category for which an auto-endorsement is being considered. + policies : dict + Describes auto-endorsement policies for each category (inherited from + their endorsement domains). + papers : dict + The number of papers that the user has published in each endorsement + domain. Keys are str names of endorsement domains, values are int. + + Returns + ------- + bool + + """ + N_papers = papers.get(policies[category]['domain'], 0) + min_papers = policies[category]['min_papers'] + return bool(N_papers >= min_papers) + + +def domain_papers(user: domain.User, + start_date: Optional[datetime] = None) -> Dict[str, int]: + """ + Calculate the number of papers that a user owns in each endorsement domain. + + This includes both submitted and claimed papers. + + Parameters + ---------- + user : :class:`.domain.User` + start_date : :class:`.datetime` or None + If provided, will only count papers published after this date. + + Returns + ------- + dict + Keys are classification domains (str), values are the number of papers + in each respective domain (int). + + """ + query = session.query(PaperOwner.document_id, + Document.document_id, + t_arXiv_in_category.c.document_id, + Category.endorsement_domain) \ + .filter(PaperOwner.user_id == user.user_id) \ + .filter(Document.document_id == PaperOwner.document_id) \ + .filter(t_arXiv_in_category.c.document_id == Document.document_id) \ + .filter(Category.archive == t_arXiv_in_category.c.archive) \ + .filter(Category.subject_class == t_arXiv_in_category.c.subject_class) + + if start_date: + query = query.filter(Document.dated > util.epoch(start_date)) + data = query.all() + return dict(Counter(domain for _, _, _, domain in data).items()) + + +def category_policies() -> Dict[domain.Category, Dict]: + """ + Load auto-endorsement policies for each category from the database. + + Each category belongs to an endorsement domain, which defines the + auto-endorsement policies. We retrieve those policies from the perspective + of the individueal category for ease of lookup. + + Returns + ------- + dict + Keys are :class:`.domain.Category` instances. Values are dicts with + policiy details. + + """ + data = session.query(Category.archive, + Category.subject_class, + EndorsementDomain.endorse_all, + EndorsementDomain.endorse_email, + EndorsementDomain.papers_to_endorse, + EndorsementDomain.endorsement_domain) \ + .filter(Category.definitive == 1) \ + .filter(Category.active == 1) \ + .filter(Category.endorsement_domain + == EndorsementDomain.endorsement_domain) \ + .all() + + policies = {} + for arch, subj, endorse_all, endorse_email, min_papers, e_domain in data: + policies[_category(arch, subj)] = { + 'domain': e_domain, + 'endorse_all': endorse_all == 'y', + 'endorse_email': endorse_email == 'y', + 'min_papers': min_papers + } + + return policies + + +def invalidated_autoendorsements(user: domain.User) -> Endorsements: + """ + Load any invalidated (revoked) auto-endorsements for a user. + + Parameters + ---------- + user : :class:`.domain.User` + + Returns + ------- + list + Items are :class:`.domain.Category` for which the user has had past + auto-endorsements revoked. + + """ + data: List[Endorsement] = session.query(Endorsement.archive, + Endorsement.subject_class) \ + .filter(Endorsement.endorsee_id == user.user_id) \ + .filter(Endorsement.flag_valid == 0) \ + .filter(Endorsement.type == 'auto') \ + .all() + return [_category(archive, subject) for archive, subject in data] diff --git a/arxiv/auth/legacy/exceptions.py b/arxiv/auth/legacy/exceptions.py new file mode 100644 index 00000000..0f4ee8d6 --- /dev/null +++ b/arxiv/auth/legacy/exceptions.py @@ -0,0 +1,45 @@ +"""Exceptions for legacy user/session integration.""" + + +class AuthenticationFailed(RuntimeError): + """Failed to authenticate user with provided credentials.""" + + +class NoSuchUser(RuntimeError): + """A reference to a non-existant user was passed.""" + + +class PasswordAuthenticationFailed(RuntimeError): + """An invalid username/password combination were provided.""" + + +class SessionCreationFailed(RuntimeError): + """Failed to create a session in the legacy database.""" + + +class SessionDeletionFailed(RuntimeError): + """Failed to delete a session in the legacy database.""" + + +class UnknownSession(RuntimeError): + """Failed to locate a session in the legacy database.""" + + +class SessionExpired(RuntimeError): + """A reference was made to an expired session.""" + + +class InvalidCookie(ValueError): + """The value of a passed legacy cookie is not valid.""" + + +class RegistrationFailed(RuntimeError): + """Could not create a new user.""" + + +class UpdateUserFailed(RuntimeError): + """Could not update a user.""" + + +class Unavailable(RuntimeError): + """The database is temporarily unavailable.""" diff --git a/arxiv/auth/legacy/passwords.py b/arxiv/auth/legacy/passwords.py new file mode 100644 index 00000000..915a16c8 --- /dev/null +++ b/arxiv/auth/legacy/passwords.py @@ -0,0 +1,47 @@ +"""Passwords from legacy.""" + +import secrets +from base64 import b64encode, b64decode +import hashlib + +from .exceptions import PasswordAuthenticationFailed + + +def _hash_salt_and_password(salt: bytes, password: str) -> bytes: + return hashlib.sha1(salt + b'-' + password.encode('ascii')).digest() + + +def hash_password(password: str) -> str: + """Generate a secure hash of a password. + + The password must be ascii. + """ + salt = secrets.token_bytes(4) + hashed = _hash_salt_and_password(salt, password) + return b64encode(salt + hashed).decode('ascii') + + +def check_password(password: str, encrypted: bytes): + """Check a password against an encrypted hash.""" + try: + password.encode('ascii') + except UnicodeEncodeError: + raise PasswordAuthenticationFailed('Password not ascii') + + decoded = b64decode(encrypted) + salt = decoded[:4] + enc_hashed = decoded[4:] + pass_hashed = _hash_salt_and_password(salt, password) + if pass_hashed != enc_hashed: + raise PasswordAuthenticationFailed('Incorrect password') + else: + return True + + +def is_ascii(string): + """Returns true if the string is only ascii chars.""" + try: + string.encode('ascii') + return True + except UnicodeEncodeError: + return False diff --git a/arxiv/auth/legacy/sessions.py b/arxiv/auth/legacy/sessions.py new file mode 100644 index 00000000..bb11fe4f --- /dev/null +++ b/arxiv/auth/legacy/sessions.py @@ -0,0 +1,242 @@ +"""Provides API for legacy user sessions.""" +from datetime import datetime, timedelta +from pytz import timezone, UTC + +import logging + +from typing import Optional, Tuple + +from sqlalchemy.exc import SQLAlchemyError, OperationalError +from sqlalchemy.orm.exc import NoResultFound + +from .. import domain +from ...db import session +from . import cookies, util + +from ...db.models import TapirSession, TapirSessionsAudit, TapirUser, \ + TapirNickname, Demographic +from .exceptions import UnknownSession, SessionCreationFailed, \ + SessionExpired, InvalidCookie, Unavailable + +logger = logging.getLogger(__name__) +EASTERN = timezone('US/Eastern') + + +def _load(session_id: str) -> TapirSession: + """Get TapirSession from session id.""" + db_session: TapirSession = session.query(TapirSession) \ + .filter(TapirSession.session_id == session_id) \ + .first() + if not db_session: + logger.debug(f'No session found with id {session_id}') + raise UnknownSession('No such session') + return db_session + + +def _load_audit(session_id: str) -> TapirSessionsAudit: + """Get TapirSessionsAudit from session id.""" + db_sessions_audit: TapirSessionsAudit = session.query(TapirSessionsAudit) \ + .filter(TapirSessionsAudit.session_id == session_id) \ + .first() + if not db_sessions_audit: + logger.debug(f'No session audit found with id {session_id}') + raise UnknownSession('No such session audit') + return db_sessions_audit + + +def load(cookie: str) -> domain.Session: + """ + Given a session cookie (from request), load the logged-in user. + + Parameters + ---------- + cookie : str + Legacy cookie value passed with the request. + + Returns + ------- + :class:`.domain.Session` + + Raises + ------ + :class:`.legacy.exceptions.SessionExpired` + :class:`.legacy.exceptions.UnknownSession` + + """ + session_id, user_id, ip, issued_at, expires_at, _ = cookies.unpack(cookie) + logger.debug('Load session %s for user %s at %s', + session_id, user_id, ip) + + if expires_at <= datetime.now(tz=UTC): + raise SessionExpired(f'Session {session_id} has expired in cookie') + + data: Tuple[TapirUser, TapirSession, TapirNickname, Demographic] + try: + data = session.query(TapirUser, TapirSession, TapirNickname, Demographic) \ + .join(TapirSession).join(TapirNickname).join(Demographic) \ + .filter(TapirUser.user_id == user_id) \ + .filter(TapirSession.session_id == session_id ) \ + .first() + except OperationalError as e: + raise Unavailable('Database is temporarily unavailable') from e + + if not data: + raise UnknownSession('No such user or session') + + db_user, db_session, db_nick, db_profile = data + + if db_session.end_time != 0 and db_session.end_time < util.now(): + raise SessionExpired(f'Session {session_id} has expired in the DB') + + user = domain.User( + user_id=str(user_id), + username=db_nick.nickname, + email=db_user.email, + name=domain.UserFullName( + forename=db_user.first_name, + surname=db_user.last_name, + suffix=db_user.suffix_name + ), + profile=domain.UserProfile.from_orm(db_profile) if db_profile else None, + verified=bool(db_user.flag_email_verified) + ) + + authorizations = domain.Authorizations( + classic=util.compute_capabilities(db_user), + scopes=util.get_scopes(db_user) + ) + user_session = domain.Session(session_id=str(db_session.session_id), + start_time=issued_at, end_time=expires_at, + user=user, authorizations=authorizations) + logger.debug('loaded session %s', user_session.session_id) + return user_session + + +def create(authorizations: domain.Authorizations, + ip: str, remote_host: str, tracking_cookie: str = '', + user: Optional[domain.User] = None) -> domain.Session: + """ + Create a new legacy session for an authenticated user. + + Parameters + ---------- + user : :class:`.User` + ip : str + Client IP address. + remote_host : str + Client hostname. + tracking_cookie : str + Tracking cookie payload from client request. + + Returns + ------- + :class:`.Session` + + """ + if user is None: + raise SessionCreationFailed('Legacy sessions require a user') + + logger.debug('create session for user %s', user.user_id) + start = datetime.now(tz=UTC) + end = start + timedelta(seconds=util.get_session_duration()) + try: + tapir_session = TapirSession( + user_id=user.user_id, + last_reissue=util.epoch(start), + start_time=util.epoch(start), + end_time=0 + ) + tapir_sessions_audit = TapirSessionsAudit( + session=tapir_session, + ip_addr=ip, + remote_host=remote_host, + tracking_cookie=tracking_cookie + ) + session.add(tapir_sessions_audit) + session.commit() + except Exception as e: + raise SessionCreationFailed(f'Failed to create: {e}') from e + + session_id = str(tapir_session.session_id) + user_session = domain.Session(session_id=session_id, + start_time=start, end_time=end, + user=user, + authorizations=authorizations, + ip_address=ip, remote_host=remote_host) + logger.debug('created session %s', user_session.session_id) + return user_session + + +def generate_cookie(session: domain.Session) -> str: + """ + Generate a cookie from a :class:`domain.Session`. + + Parameters + ---------- + session : :class:`domain.Session` + + Returns + ------- + str + + """ + if session.user is None \ + or session.user.user_id is None \ + or session.ip_address is None \ + or session.authorizations is None: + raise RuntimeError('Cannot generate cookie without an authorized user') + + return cookies.pack(str(session.session_id), session.user.user_id, + session.ip_address, session.start_time, + str(session.authorizations.classic)) + + +def invalidate(cookie: str) -> None: + """ + Invalidate a legacy user session. + + Parameters + ---------- + cookie : str + Session cookie generated when the session was created. + + Raises + ------ + :class:`UnknownSession` + The session could not be found, or the cookie was not valid. + + """ + try: + session_id, user_id, ip, _, _, _ = cookies.unpack(cookie) + except InvalidCookie as e: + raise UnknownSession('No such session') from e + + invalidate_by_id(session_id) + + +def invalidate_by_id(session_id: str) -> None: + """ + Invalidate a legacy user session by ID. + + Parameters + ---------- + session_id : str + Unique identifier for the session. + + Raises + ------ + :class:`UnknownSession` + The session could not be found, or the cookie was not valid. + + """ + delta = datetime.now(tz=UTC) - datetime.fromtimestamp(0, tz=EASTERN) + end = (delta).total_seconds() + try: + tapir_session = _load(session_id) + tapir_session.end_time = end - 1 + session.merge(tapir_session) + session.commit() + except NoResultFound as e: + raise UnknownSession(f'No such session {session_id}') from e + except SQLAlchemyError as e: + raise IOError(f'Database error') from e diff --git a/arxiv/auth/legacy/tests/__init__.py b/arxiv/auth/legacy/tests/__init__.py new file mode 100644 index 00000000..32781b97 --- /dev/null +++ b/arxiv/auth/legacy/tests/__init__.py @@ -0,0 +1,4 @@ +"""Tests for :mod:`arxiv.users.legacy`.""" +import os + +os.environ['CLASSIC_DB_URI'] = 'sqlite:///:memory:' \ No newline at end of file diff --git a/arxiv/auth/legacy/tests/test_accounts.py b/arxiv/auth/legacy/tests/test_accounts.py new file mode 100644 index 00000000..7f4192a1 --- /dev/null +++ b/arxiv/auth/legacy/tests/test_accounts.py @@ -0,0 +1,386 @@ +"""Tests for :mod:`arxiv.users.legacy.accounts`.""" + +import tempfile +from datetime import datetime +import shutil +import hashlib +from pytz import UTC +from unittest import TestCase +from sqlalchemy import select +from flask import Flask + +from arxiv.config import Settings +from arxiv.taxonomy import definitions +from arxiv.db import transaction +from arxiv.db import models + +from .util import temporary_db +from .. import util, authenticate, exceptions +from .. import accounts +from ... import domain + + +def get_user(session, user_id): + """Helper to get user database objects by user id.""" + db_user, db_nick = ( + session.query(models.TapirUser, models.TapirNickname) + .filter(models.TapirUser.user_id == user_id) + .filter(models.TapirNickname.flag_primary == 1) + .filter(models.TapirNickname.user_id == models.TapirUser.user_id) + .first() + ) + + db_profile = session.query(models.Demographic) \ + .filter(models.Demographic.user_id == user_id) \ + .first() + + return db_user, db_nick, db_profile + + +class SetUpUserMixin(TestCase): + """Mixin for creating a test user and other database goodies.""" + + def setUp(self): + """Set up the database.""" + self.db_path = tempfile.mkdtemp() + self.db_uri = f'sqlite:///{self.db_path}/test.db' + self.user_id = '15830' + self.app = Flask('test') + self.app.config['CLASSIC_SESSION_HASH'] = 'foohash' + self.app.config['CLASSIC_COOKIE_NAME'] = 'tapir_session_cookie' + self.app.config['SESSION_DURATION'] = '36000' + settings = Settings( + CLASSIC_DB_URI=self.db_uri, + LATEXML_DB_URI=None) + + self.engine, _ = models.configure_db(settings) # Insert tapir policy classes + + with temporary_db("sqlite:///:memory:", drop=False) as session: + self.user_class = session.scalar( + select(models.TapirPolicyClass).where(models.TapirPolicyClass.class_id==2)) + self.email = 'first@last.iv' + self.db_user = models.TapirUser( + user_id=self.user_id, + first_name='first', + last_name='last', + suffix_name='iv', + email=self.email, + policy_class=self.user_class.class_id, + flag_edit_users=1, + flag_email_verified=1, + flag_edit_system=0, + flag_approved=1, + flag_deleted=0, + flag_banned=0, + tracking_cookie='foocookie', + ) + self.username = 'foouser' + self.db_nick = models.TapirNickname( + nickname=self.username, + user_id=self.user_id, + user_seq=1, + flag_valid=1, + role=0, + policy=0, + flag_primary=1 + ) + self.salt = b'foo' + self.password = b'thepassword' + hashed = hashlib.sha1(self.salt + b'-' + self.password).digest() + self.db_password = models.TapirUsersPassword( + user_id=self.user_id, + password_storage=2, + password_enc=hashed + ) + n = util.epoch(datetime.now(tz=UTC)) + self.secret = 'foosecret' + self.db_token = models.TapirPermanentToken( + user_id=self.user_id, + secret=self.secret, + valid=1, + issued_when=n, + issued_to='127.0.0.1', + remote_host='foohost.foo.com', + session_id=0 + ) + session.add(self.user_class) + session.add(self.db_user) + session.add(self.db_password) + session.add(self.db_nick) + session.add(self.db_token) + session.commit() + + def tearDown(self): + """Drop tables from the in-memory db file.""" + util.drop_all(self.engine) + + +class TestUsernameExists(SetUpUserMixin): + """Tests for :mod:`accounts.does_username_exist`.""" + + def test_with_nonexistant_user(self): + """There is no user with the passed username.""" + with self.app.app_context(): + self.assertFalse(accounts.does_username_exist('baruser')) + + def test_with_existant_user(self): + """There is a user with the passed username.""" + # with temporary_db(self.db_uri, create=False, drop=False): + # with transaction() as session: + # print (f'NICKS: {session.query(models.TapirNickname).all()}') + # self.setUp() + with self.app.app_context(): + self.assertTrue(accounts.does_username_exist('foouser')) + + +class TestEmailExists(SetUpUserMixin): + """Tests for :mod:`accounts.does_email_exist`.""" + + def test_with_nonexistant_email(self): + """There is no user with the passed email.""" + with self.app.app_context(): + self.assertFalse(accounts.does_email_exist('foo@bar.com')) + + def test_with_existant_email(self): + """There is a user with the passed email.""" + self.setUp() + with self.app.app_context(): + self.assertTrue(accounts.does_email_exist('first@last.iv')) + + +class TestRegister(SetUpUserMixin, TestCase): + """Tests for :mod:`accounts.register`.""" + + def test_register_with_duplicate_username(self): + """The username is already in the system.""" + user = domain.User(username='foouser', email='foo@bar.com') + ip = '1.2.3.4' + with self.app.app_context(): + with self.assertRaises(exceptions.RegistrationFailed): + accounts.register(user, 'apassword1', ip=ip, remote_host=ip) + + def test_register_with_duplicate_email(self): + """The email address is already in the system.""" + user = domain.User(username='bazuser', email='first@last.iv') + ip = '1.2.3.4' + with self.app.app_context(): + with self.assertRaises(exceptions.RegistrationFailed): + accounts.register(user, 'apassword1', ip=ip, remote_host=ip) + + def test_register_with_name_details(self): + """Registration includes the user's name.""" + name = domain.UserFullName(forename='foo', surname='user', suffix='iv') + user = domain.User(username='bazuser', email='new@account.edu', + name=name) + ip = '1.2.3.4' + + with self.app.app_context(): + with transaction() as session: + u, _ = accounts.register(user, 'apassword1', ip=ip, remote_host=ip) + db_user, db_nick, db_profile = get_user(session, u.user_id) + + self.assertEqual(db_user.first_name, name.forename) + self.assertEqual(db_user.last_name, name.surname) + self.assertEqual(db_user.suffix_name, name.suffix) + + def test_register_with_bare_minimum(self): + """Registration includes only a username, name, email address, password.""" + user = domain.User(username='bazuser', email='new@account.edu', + name = domain.UserFullName(forename='foo', surname='user', suffix='iv')) + ip = '1.2.3.4' + + with self.app.app_context(): + with transaction() as session: + u, _ = accounts.register(user, 'apassword1', ip=ip, remote_host=ip) + db_user, db_nick, db_profile = get_user(session, u.user_id) + + self.assertEqual(db_user.flag_email_verified, 0) + self.assertEqual(db_nick.nickname, user.username) + self.assertEqual(db_user.email, user.email) + + def test_register_with_profile(self): + """Registration includes profile information.""" + profile = domain.UserProfile( + affiliation='School of Hard Knocks', + country='de', + rank=1, + submission_groups=['grp_cs', 'grp_q-bio'], + default_category=definitions.CATEGORIES['cs.DL'], + homepage_url='https://google.com' + ) + name = domain.UserFullName(forename='foo', surname='user', suffix='iv') + user = domain.User(username='bazuser', email='new@account.edu', + name=name, profile=profile) + ip = '1.2.3.4' + + with self.app.app_context(): + with transaction() as session: + u, _ = accounts.register(user, 'apassword1', ip=ip, remote_host=ip) + db_user, db_nick, db_profile = get_user(session, u.user_id) + + self.assertEqual(db_profile.affiliation, profile.affiliation) + self.assertEqual(db_profile.country, profile.country), + self.assertEqual(db_profile.type, profile.rank), + self.assertEqual(db_profile.flag_group_cs, 1) + self.assertEqual(db_profile.flag_group_q_bio, 1) + self.assertEqual(db_profile.flag_group_physics, 0) + self.assertEqual(db_profile.archive, 'cs') + self.assertEqual(db_profile.subject_class, 'DL') + + def test_can_authenticate_after_registration(self): + """A may authenticate a bare-minimum user after registration.""" + user = domain.User(username='bazuser', email='new@account.edu', + name=domain.UserFullName(forename='foo', surname='user')) + ip = '1.2.3.4' + + with self.app.app_context(): + with transaction() as session: + u, _ = accounts.register(user, 'apassword1', ip=ip, remote_host=ip) + db_user, db_nick, db_profile = get_user(session, u.user_id) + auth_user, auths = authenticate.authenticate( + username_or_email=user.username, + password='apassword1' + ) + self.assertEqual(str(db_user.user_id), auth_user.user_id) + + +class TestGetUserById(SetUpUserMixin): + """Tests for :func:`accounts.get_user_by_id`.""" + + def test_user_exists(self): + """A well-rounded user exists with the requested user id.""" + profile = domain.UserProfile( + affiliation='School of Hard Knocks', + country='de', + rank=1, + submission_groups=['grp_cs', 'grp_q-bio'], + default_category=definitions.CATEGORIES['cs.DL'], + homepage_url='https://google.com' + ) + name = domain.UserFullName(forename='foo', surname='user', suffix='iv') + user = domain.User(username='bazuser', email='new@account.edu', + name=name, profile=profile) + ip = '1.2.3.4' + + with self.app.app_context(): + u, _ = accounts.register(user, 'apassword1', ip=ip, remote_host=ip) + loaded_user = accounts.get_user_by_id(u.user_id) + + self.assertEqual(loaded_user.username, user.username) + self.assertEqual(loaded_user.email, user.email) + self.assertEqual(loaded_user.profile.affiliation, profile.affiliation) + + def test_user_does_not_exist(self): + """No user with the specified username.""" + with self.app.app_context(): + with self.assertRaises(exceptions.NoSuchUser): + accounts.get_user_by_id('1234') + + def test_with_no_profile(self): + """The user exists, but there is no profile.""" + name = domain.UserFullName(forename='foo', surname='user', suffix='iv') + user = domain.User(username='bazuser', email='new@account.edu', + name=name) + ip = '1.2.3.4' + + with self.app.app_context(): + u, _ = accounts.register(user, 'apassword1', ip=ip, remote_host=ip) + loaded_user = accounts.get_user_by_id(u.user_id) + + self.assertEqual(loaded_user.username, user.username) + self.assertEqual(loaded_user.email, user.email) + self.assertIsNone(loaded_user.profile) + + +class TestUpdate(SetUpUserMixin, TestCase): + """Tests for :func:`accounts.update`.""" + + def test_user_without_id(self): + """A :class:`domain.User` is passed without an ID.""" + user = domain.User(username='bazuser', email='new@account.edu') + with self.app.app_context(): + with self.assertRaises(ValueError): + accounts.update(user) + + def test_update_nonexistant_user(self): + """A :class:`domain.User` is passed that is not in the database.""" + user = domain.User(username='bazuser', email='new@account.edu', + user_id='12345') + with self.app.app_context(): + with self.assertRaises(exceptions.NoSuchUser): + accounts.update(user) + + def test_update_name(self): + """The user's name is changed.""" + name = domain.UserFullName(forename='foo', surname='user', suffix='iv') + user = domain.User(username='bazuser', email='new@account.edu', + name=name) + ip = '1.2.3.4' + + with self.app.app_context(): + user, _ = accounts.register(user, 'apassword1', ip=ip, + remote_host=ip) + + with self.app.app_context(): + with transaction() as session: + updated_name = domain.UserFullName(forename='Foo', + surname=name.surname, + suffix=name.suffix) + updated_user = domain.User(user_id=user.user_id, + username=user.username, + email=user.email, + name=updated_name) + + updated_user, _ = accounts.update(updated_user) + self.assertEqual(user.user_id, updated_user.user_id) + self.assertEqual(updated_user.name.forename, 'Foo') + db_user, db_nick, db_profile = get_user(session, user.user_id) + self.assertEqual(db_user.first_name, 'Foo') + + def test_update_profile(self): + """Changes are made to profile information.""" + profile = domain.UserProfile( + affiliation='School of Hard Knocks', + country='de', + rank=1, + submission_groups=['grp_cs', 'grp_q-bio'], + default_category=definitions.CATEGORIES['cs.DL'], + homepage_url='https://google.com' + ) + name = domain.UserFullName(forename='foo', surname='user', suffix='iv') + user = domain.User(username='bazuser', email='new@account.edu', + name=name, profile=profile) + ip = '1.2.3.4' + + with self.app.app_context(): + user, _ = accounts.register(user, 'apassword1', ip=ip, + remote_host=ip) + + updated_profile = domain.UserProfile( + affiliation='School of Hard Knocks', + country='us', + rank=2, + submission_groups=['grp_cs', 'grp_physics'], + default_category=definitions.CATEGORIES['cs.IR'], + homepage_url='https://google.com' + ) + updated_user = domain.User(user_id=user.user_id, + username=user.username, + email=user.email, + name=name, + profile=updated_profile) + + with self.app.app_context(): + with transaction() as session: + u, _ = accounts.update(updated_user) + db_user, db_nick, db_profile = get_user(session, u.user_id) + + self.assertEqual(db_profile.affiliation, + updated_profile.affiliation) + self.assertEqual(db_profile.country, updated_profile.country), + self.assertEqual(db_profile.type, updated_profile.rank), + self.assertEqual(db_profile.flag_group_cs, 1) + self.assertEqual(db_profile.flag_group_q_bio, 0) + self.assertEqual(db_profile.flag_group_physics, 1) + self.assertEqual(db_profile.archive, 'cs') + self.assertEqual(db_profile.subject_class, 'IR') diff --git a/arxiv/auth/legacy/tests/test_authenticate.py b/arxiv/auth/legacy/tests/test_authenticate.py new file mode 100644 index 00000000..e6329135 --- /dev/null +++ b/arxiv/auth/legacy/tests/test_authenticate.py @@ -0,0 +1,299 @@ +"""Tests for :mod:`accounts.services.user_data`.""" + +from unittest import TestCase, mock +from datetime import datetime +from pytz import timezone, UTC +import tempfile +import shutil +import hashlib + +from flask import Flask + +from sqlalchemy import select +from sqlalchemy.exc import SQLAlchemyError, OperationalError + +from arxiv.config import Settings +from arxiv.db import models + +from .. import authenticate, exceptions, util + +from .util import temporary_db + +from ..passwords import hash_password + +EASTERN = timezone('US/Eastern') + + +class TestAuthenticateWithPermanentToken(TestCase): + """User has a permanent token.""" + + def setUp(self): + """Instantiate a user and its credentials in the DB.""" + self.db = f'sqlite:///:memory:' + self.user_id = '1' + + self.app = Flask('test') + settings = Settings( + CLASSIC_DB_URI=self.db, + LATEXML_DB_URI=None) + + self.engine, _ = models.configure_db(settings) + + with temporary_db(self.db, drop=False) as session: + self.user_class = session.scalar( + select(models.TapirPolicyClass).where(models.TapirPolicyClass.class_id==2)) + self.email = 'first@last.iv' + self.db_user = models.TapirUser( + user_id=self.user_id, + first_name='first', + last_name='last', + suffix_name='iv', + email=self.email, + policy_class=self.user_class.class_id, + flag_edit_users=1, + flag_email_verified=1, + flag_edit_system=0, + flag_approved=1, + flag_deleted=0, + flag_banned=0, + tracking_cookie='foocookie', + ) + self.username = 'foouser' + self.db_nick = models.TapirNickname( + nickname=self.username, + user_id=self.user_id, + user_seq=1, + flag_valid=1, + role=0, + policy=0, + flag_primary=1 + ) + self.salt = b'foo' + self.password = b'thepassword' + hashed = hashlib.sha1(self.salt + b'-' + self.password).digest() + self.db_password = models.TapirUsersPassword( + user_id=self.user_id, + password_storage=2, + password_enc=hashed + ) + n = util.epoch(datetime.now(tz=UTC)) + self.secret = 'foosecret' + self.db_token = models.TapirPermanentToken( + user_id=self.user_id, + secret=self.secret, + valid=1, + issued_when=n, + issued_to='127.0.0.1', + remote_host='foohost.foo.com', + session_id=0 + ) + session.add(self.user_class) + session.add(self.db_user) + session.add(self.db_password) + session.add(self.db_nick) + session.add(self.db_token) + session.commit() + + def tearDown(self): + """Drop tables from the in-memory db file.""" + util.drop_all(self.engine) + + def test_token_is_malformed(self): + """Token is present, but it has the wrong format.""" + bad_token = 'footokenhasnohyphen' + with self.app.app_context(): + with self.assertRaises(exceptions.AuthenticationFailed): + authenticate.authenticate(token=bad_token) + + def test_token_is_incorrect(self): + """Token is present, but there is no such token in the database.""" + bad_token = '1234-nosuchtoken' + # with temporary_db(self.db, create=False): + with self.app.app_context(): + with self.assertRaises(exceptions.AuthenticationFailed): + authenticate.authenticate(token=bad_token) + + def test_token_is_invalid(self): + """The token is present, but it is not valid.""" + with self.app.app_context(): + authenticate._invalidate_token(self.user_id, self.secret) + + with self.assertRaises(exceptions.AuthenticationFailed): + token = f'{self.user_id}-{self.secret}' + authenticate.authenticate(token=token) + + def test_token_is_valid(self): + """The token is valid!.""" + with self.app.app_context(): + token = f'{self.user_id}-{self.secret}' + user, auths = authenticate.authenticate(token=token) + self.assertIsInstance(user, authenticate.domain.User, + "Returns data about the user") + self.assertIsInstance(auths, authenticate.domain.Authorizations, + "Returns authorization data") + self.assertEqual(user.user_id, self.user_id, + "User ID is set correctly") + self.assertEqual(user.username, self.username, + "Username is set correctly") + self.assertEqual(user.email, self.email, + "User email is set correctly") + self.assertEqual(auths.classic, 6, + "authorizations are set") + + +class TestAuthenticateWithPassword(TestCase): + """User is attempting login with username+password.""" + + def setUp(self): + """Instantiate a user.""" + self.db = f'sqlite:///:memory:' + + self.app = Flask('test') + settings = Settings( + CLASSIC_DB_URI=self.db, + LATEXML_DB_URI=None) + + self.engine, _ = models.configure_db(settings) + self.user_id = '5' + with temporary_db(self.db, drop=False) as session: + # We have a good old-fashioned user. + self.user_class = session.scalar( + select(models.TapirPolicyClass).where(models.TapirPolicyClass.class_id==2)) + self.email = 'first@last.iv' + self.db_user = models.TapirUser( + user_id=self.user_id, + first_name='first', + last_name='last', + suffix_name='iv', + email=self.email, + policy_class=self.user_class.class_id, + flag_edit_users=1, + flag_email_verified=1, + flag_edit_system=0, + flag_approved=1, + flag_deleted=0, + flag_banned=0, + tracking_cookie='foocookie', + ) + self.username = 'foouser' + self.db_nick = models.TapirNickname( + nickname=self.username, + user_id=self.user_id, + user_seq=1, + flag_valid=1, + role=0, + policy=0, + flag_primary=1 + ) + self.salt = b'foo' + self.password = 'thepassword' + # hashed = hashlib.sha1(self.salt + b'-' + self.password).digest() + + hashed = hash_password(self.password) + self.db_password = models.TapirUsersPassword( + user_id=self.user_id, + password_storage=2, + password_enc=hashed + ) + session.add(self.user_class) + session.add(self.db_user) + session.add(self.db_password) + session.add(self.db_nick) + session.commit() + + def tearDown(self): + """Drop tables from the in-memory db file.""" + util.drop_all(self.engine) + + def test_no_username(self): + """Username is not entered.""" + username = '' + password = 'foopass' + with self.assertRaises(exceptions.AuthenticationFailed): + with self.app.app_context(): + authenticate.authenticate(username, password) + + def test_no_password(self): + """Password is not entered.""" + username = 'foouser' + password = '' + with self.assertRaises(exceptions.AuthenticationFailed): + with self.app.app_context(): + authenticate.authenticate(username, password) + + def test_password_is_incorrect(self): + """Password is incorrect.""" + with self.app.app_context(): + with self.assertRaises(exceptions.AuthenticationFailed): + authenticate.authenticate('foouser', 'notthepassword') + + def test_password_is_correct(self): + """Password is correct.""" + with self.app.app_context(): + + user, auths = authenticate.authenticate('foouser', 'thepassword') + self.assertIsInstance(user, authenticate.domain.User, + "Returns data about the user") + self.assertIsInstance(auths, authenticate.domain.Authorizations, + "Returns authorization data") + self.assertEqual(user.user_id, self.user_id, + "User ID is set correctly") + self.assertEqual(user.username, self.username, + "Username is set correctly") + self.assertEqual(user.email, self.email, + "User email is set correctly") + self.assertEqual(auths.classic, 6, + "Authorizations are set") + + def test_login_with_email_and_correct_password(self): + """User attempts to log in with e-mail address.""" + with self.app.app_context(): + user, auths = authenticate.authenticate('first@last.iv', + 'thepassword') + self.assertIsInstance(user, authenticate.domain.User, + "Returns data about the user") + self.assertIsInstance(auths, authenticate.domain.Authorizations, + "Returns authorization data") + self.assertEqual(user.user_id, self.user_id, + "User ID is set correctly") + self.assertEqual(user.username, self.username, + "Username is set correctly") + self.assertEqual(user.email, self.email, + "User email is set correctly") + self.assertEqual(auths.classic, 6, + "authorizations are set") + + def test_no_such_user(self): + """Username does not exist.""" + with self.app.app_context(): + with self.assertRaises(exceptions.AuthenticationFailed): + authenticate.authenticate('nobody', 'thepassword') + + + def test_bad_data(self): + """Test with bad data.""" + with self.app.app_context(): + with self.assertRaises(exceptions.AuthenticationFailed): + authenticate.authenticate('abc', '') + with self.assertRaises(exceptions.AuthenticationFailed): + authenticate.authenticate('abc', 234) + with self.assertRaises(exceptions.AuthenticationFailed): + authenticate.authenticate('abc', 'ฮฒ') + with self.assertRaises(exceptions.AuthenticationFailed): + authenticate.authenticate('', 'password') + with self.assertRaises(exceptions.AuthenticationFailed): + authenticate.authenticate('ฮฒ', 'password') + with self.assertRaises(exceptions.AuthenticationFailed): + authenticate.authenticate('long'*100, 'password') + with self.assertRaises(exceptions.AuthenticationFailed): + authenticate.authenticate(1234, 'password') + + + with self.assertRaises(exceptions.AuthenticationFailed): + authenticate.authenticate(None, None, 'abcc-something') + with self.assertRaises(exceptions.AuthenticationFailed): + authenticate.authenticate(None, None, '-something') + with self.assertRaises(exceptions.AuthenticationFailed): + authenticate.authenticate(None, None, ('long'*20) + '-something') + with self.assertRaises(exceptions.AuthenticationFailed): + authenticate.authenticate(None, None, '1234-' + 40 * 'long') diff --git a/arxiv/auth/legacy/tests/test_bootstrap.py b/arxiv/auth/legacy/tests/test_bootstrap.py new file mode 100644 index 00000000..e7528d99 --- /dev/null +++ b/arxiv/auth/legacy/tests/test_bootstrap.py @@ -0,0 +1,301 @@ +"""Test the legacy integration with synthetic data.""" + +import os +import sys +from typing import Tuple +from unittest import TestCase +from flask import Flask +import locale + +from typing import List +import random +from datetime import datetime +from pytz import timezone, UTC +from mimesis import Person, Internet, Datetime, locales + +from sqlalchemy import select, func +from arxiv.db import transaction +from arxiv.db import models +from arxiv.config import Settings +from arxiv.taxonomy import definitions +from .. import util, sessions, authenticate, exceptions +from ..passwords import hash_password +from ... import domain + +LOCALES = locales.Locale.values() +EASTERN = timezone('US/Eastern') + + +def _random_category() -> Tuple[str, str]: + category = random.choice(list(definitions.CATEGORIES_ACTIVE.items())) + archive = category[1].in_archive + subject_class = category[0].split('.')[-1] if '.' in category[0] else '' + return archive, subject_class + + +def _get_locale() -> str: + return LOCALES[random.randint(0, len(LOCALES) - 1)] + + +def _prob(P: int) -> bool: + return random.randint(0, 100) < P + + +class TestBootstrap(TestCase): + """Tests against legacy user integrations with fake data.""" + + @classmethod + def setUpClass(cls): + """Generate some fake data.""" + cls.app = Flask('test') + cls.app.config['CLASSIC_SESSION_HASH'] = 'foohash' + cls.app.config['CLASSIC_COOKIE_NAME'] = 'tapir_session_cookie' + cls.app.config['SESSION_DURATION'] = '36000' + settings = Settings( + CLASSIC_DB_URI='sqlite:///:memory:', + LATEXML_DB_URI=None) + + engine, _ = models.configure_db(settings) + + with cls.app.app_context(): + util.create_all(engine) + with transaction() as session: + edc = session.execute(select(models.Endorsement)).all() + for row in edc: + print(row) + assert len(edc) == 0, "Expect the table to be empty at the start" + + session.add(models.EndorsementDomain( + endorsement_domain='test_domain_bootstrap', + endorse_all='n', + mods_endorse_all='n', + endorse_email='y', + papers_to_endorse=3 + )) + + for category in definitions.CATEGORIES_ACTIVE.keys(): + if '.' in category: + archive, subject_class = category.split('.', 1) + else: + archive, subject_class = category, '' + + with transaction() as session: + #print(f"arch: {archive} sc: {subject_class}") + session.add(models.Category( + archive=archive, + subject_class=subject_class, + definitive=1, + active=1, + endorsement_domain='test_domain_bootstrap' + )) + + COUNT = 100 + + cls.users = [] + + _users = [] + _domain_users = [] + for i in range(COUNT): + with transaction() as session: + locale = _get_locale() + person = Person(locale) + net = Internet() + ip_addr = net.ip_v4() + email = person.email() + approved = 1 if _prob(90) else 0 + deleted = 1 if _prob(2) else 0 + banned = 1 if random.randint(0, 100) <= 1 else 0 + first_name = person.name() + last_name = person.surname() + suffix_name = person.title() + name = (first_name, last_name, suffix_name) + joined_date = util.epoch( + Datetime(locale).datetime().replace(tzinfo=EASTERN) + ) + db_user = models.TapirUser( + first_name=first_name, + last_name=last_name, + suffix_name=suffix_name, + share_first_name=1, + share_last_name=1, + email=email, + flag_approved=approved, + flag_deleted=deleted, + flag_banned=banned, + flag_edit_users=0, + flag_edit_system=0, + flag_email_verified=1, + share_email=8, + email_bouncing=0, + policy_class=2, # Public user. TODO: consider admin. + joined_date=joined_date, + joined_ip_num=ip_addr, + joined_remote_host=ip_addr + ) + session.add(db_user) + + # Create a username. + username_is_valid = 1 if _prob(90) else 0 + username = person.username() + db_nick = models.TapirNickname( + user=db_user, + nickname=username, + flag_valid=username_is_valid, + flag_primary=1 + ) + + # Create the user's profile. + archive, subject_class = _random_category() + db_profile = models.Demographic( + user=db_user, + country=locale, + affiliation=person.university(), + url=net.url(), + type=random.randint(1, 5), + archive=archive, + subject_class=subject_class, + original_subject_classes='', + flag_group_math=1 if _prob(5) else 0, + flag_group_cs=1 if _prob(5) else 0, + flag_group_nlin=1 if _prob(5) else 0, + flag_group_q_bio=1 if _prob(5) else 0, + flag_group_q_fin=1 if _prob(5) else 0, + flag_group_stat=1 if _prob(5) else 0 + ) + + # Set the user's password. + password = person.password() + db_password = models.TapirUsersPassword( + user=db_user, + password_storage=2, + password_enc=hash_password(password) + ) + + # Create some endorsements. + archive, subject_class = _random_category() + net_points = 0 + for _ in range(0, random.randint(1, 4)): + etype = random.choice(['auto', 'user', 'admin']) + point_value = random.randint(-10, 10) + net_points += point_value + if len(_users) > 0 and etype == 'auto': + endorser_id = random.choice(_users).user_id + else: + endorser_id = None + issued_when = util.epoch( + Datetime(locale).datetime().replace(tzinfo=EASTERN) + ) + session.add(models.Endorsement( + endorsee=db_user, + endorser_id=endorser_id, + archive=archive, + subject_class=subject_class, + flag_valid=1, + type=etype, + point_value=point_value, + issued_when=issued_when + )) + + session.add(db_password) + session.add(db_nick) + session.add(db_profile) + _users.append(db_user) + _domain_users.append(( + domain.User( + user_id=str(db_user.user_id), + username=db_nick.nickname, + email=db_user.email, + name=domain.UserFullName( + forename=db_user.first_name, + surname=db_user.last_name, + suffix=db_user.suffix_name + ), + profile=domain.UserProfile.from_orm(db_profile), + verified=bool(db_user.flag_email_verified) + ), + domain.Authorizations( + classic=util.compute_capabilities(db_user), + ) + )) + session.commit() + # We'll use these data to run tests. + cls.users.append(( + email, username, password, name, + (archive, subject_class, net_points), + (approved, deleted, banned, username_is_valid), + )) + + + def test_authenticate_and_use_session(self): + """Attempt to authenticate users and create/load auth sessions.""" + with self.app.app_context(): + for datum in self.users: + email, username, password, name, endorsement, status = datum + approved, deleted, banned, username_is_valid = status + + + # Banned or deleted users may not log in. + if deleted or banned: + with self.assertRaises(exceptions.AuthenticationFailed): + authenticate.authenticate(email, password) + continue + + # Users who are not approved may not log in. + elif not approved: + with self.assertRaises(exceptions.AuthenticationFailed): + authenticate.authenticate(email, password) + continue + + # username not valid may not log in + elif not username_is_valid: + print( f"USERNAME_IS_VALID: {datum}") + with self.assertRaises(exceptions.AuthenticationFailed): + authenticate.authenticate(email, password) + continue + + # Approved users may log in. + assert approved and not deleted and not banned and username_is_valid + user, auths = authenticate.authenticate(email, password) + self.assertIsInstance(user, domain.User, + "User data is returned") + self.assertEqual(user.email, email, + "Email is set correctly") + self.assertEqual(user.username, username, + "Username is set correctly") + + first_name, last_name, suffix_name = name + self.assertEqual(user.name.forename, first_name, + "Forename is set correctly") + self.assertEqual(user.name.surname, last_name, + "Surname is set correctly") + self.assertEqual(user.name.suffix, suffix_name, + "Suffix is set correctly") + self.assertIsInstance(auths, domain.Authorizations, + "Authorizations data are returned") + # if endorsement[2] > 0: + # self.assertTrue(auths.endorsed_for( + # domain.Category( + # id=f'{endorsement[0]}.{endorsement[1]}', + # full_name="fake", + # is_active=False, + # in_archive=endorsement[0], + # is_general=False + # ) + # ), "Endorsements are included in authorizations") + + net = Internet() + ip = net.ip_v4() + session = sessions.create(auths, ip, ip, user=user) + cookie = sessions.generate_cookie(session) + + session_loaded = sessions.load(cookie) + self.assertEqual(session.user, session_loaded.user, + "Loaded the correct user") + self.assertEqual(session.session_id, session_loaded.session_id, + "Loaded the correct session") + + # Invalidate 10% of the sessions, and try again. + if _prob(10): + sessions.invalidate(cookie) + with self.assertRaises(exceptions.SessionExpired): + sessions.load(cookie) diff --git a/arxiv/auth/legacy/tests/test_endorsement_auto.py b/arxiv/auth/legacy/tests/test_endorsement_auto.py new file mode 100644 index 00000000..8aa57d69 --- /dev/null +++ b/arxiv/auth/legacy/tests/test_endorsement_auto.py @@ -0,0 +1,325 @@ +"""Tests for :mod:`arxiv.users.legacy.endorsements` using a live test DB.""" + +import os +from unittest import TestCase, mock +from datetime import datetime +from pytz import timezone, UTC + +from flask import Flask +from sqlalchemy import insert +from mimesis import Person, Internet, Datetime + +from arxiv.db import transaction +from arxiv.db import models +from arxiv.config import Settings +from arxiv.taxonomy import definitions +from .. import endorsements, util +from ... import domain + +EASTERN = timezone('US/Eastern') + + +class TestAutoEndorsement(TestCase): + """Tests for :func:`get_autoendorsements`.""" + + def setUp(self): + """Generate some fake data.""" + + self.app = Flask('test') + self.app.config['CLASSIC_SESSION_HASH'] = 'foohash' + self.app.config['CLASSIC_COOKIE_NAME'] = 'tapir_session_cookie' + self.app.config['SESSION_DURATION'] = '36000' + self.app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite://' #in memory + settings = Settings( + CLASSIC_DB_URI='sqlite:///:memory:', + LATEXML_DB_URI=None) + + engine, _ = models.configure_db(settings) + self.default_tracking_data = { + 'remote_addr': '0.0.0.0', + 'remote_host': 'foo-host.foo.com', + 'tracking_cookie': '0' + } + + with self.app.app_context(): + util.create_all(engine) + with transaction() as session: + person = Person('en') + net = Internet() + ip_addr = net.ip_v4() + email = person.email() + approved = 1 + deleted = 0 + banned = 0 + first_name = person.name() + last_name = person.surname() + suffix_name = person.title() + joined_date = util.epoch( + Datetime('en').datetime().replace(tzinfo=EASTERN) + ) + db_user = models.TapirUser( + first_name=first_name, + last_name=last_name, + suffix_name=suffix_name, + share_first_name=1, + share_last_name=1, + email=email, + flag_approved=approved, + flag_deleted=deleted, + flag_banned=banned, + flag_edit_users=0, + flag_edit_system=0, + flag_email_verified=1, + share_email=8, + email_bouncing=0, + policy_class=2, # Public user. TODO: consider admin. + joined_date=joined_date, + joined_ip_num=ip_addr, + joined_remote_host=ip_addr + ) + session.add(db_user) + + self.user = domain.User( + user_id=str(db_user.user_id), + username='foouser', + email=db_user.email, + name=domain.UserFullName( + forename=db_user.first_name, + surname=db_user.last_name, + suffix=db_user.suffix_name + ) + ) + + + def test_invalidated_autoendorsements(self): + """The user has two autoendorsements that have been invalidated.""" + with self.app.app_context(): + with transaction() as session: + issued_when = util.epoch( + Datetime('en').datetime().replace(tzinfo=EASTERN) + ) + session.add(models.Endorsement( + endorsee_id=self.user.user_id, + archive='astro-ph', + subject_class='CO', + flag_valid=0, + type='auto', + point_value=10, + issued_when=issued_when + )) + session.add(models.Endorsement( + endorsee_id=self.user.user_id, + archive='astro-ph', + subject_class='CO', + flag_valid=0, + type='auto', + point_value=10, + issued_when=issued_when + )) + session.add(models.Endorsement( + endorsee_id=self.user.user_id, + archive='astro-ph', + subject_class='CO', + flag_valid=1, + type='auto', + point_value=10, + issued_when=issued_when + )) + session.add(models.Endorsement( + endorsee_id=self.user.user_id, + archive='astro-ph', + subject_class='CO', + flag_valid=1, + type='user', + point_value=10, + issued_when=issued_when + )) + + result = endorsements.invalidated_autoendorsements(self.user) + self.assertEqual(len(result), 2, "Two revoked endorsements are loaded") + + def test_category_policies(self): + """Load category endorsement policies from the database.""" + with self.app.app_context(): + with transaction() as session: + session.add(models.Category( + archive='astro-ph', + subject_class='CO', + definitive=1, + active=1, + endorsement_domain='astro-ph' + )) + session.add(models.EndorsementDomain( + endorsement_domain='astro-ph', + endorse_all='n', + mods_endorse_all='n', + endorse_email='y', + papers_to_endorse=3 + )) + + policies = endorsements.category_policies() + category = definitions.CATEGORIES['astro-ph.CO'] + self.assertIn(category, policies, "Data are loaded for categories") + self.assertEqual(policies[category]['domain'], 'astro-ph') + self.assertFalse(policies[category]['endorse_all']) + self.assertTrue(policies[category]['endorse_email']) + self.assertEqual(policies[category]['min_papers'], 3) + + def test_domain_papers(self): + """Get the number of papers published in each domain.""" + with self.app.app_context(): + with transaction() as session: + # User owns three papers. + document1 = models.Document( + document_id=1, + title='Foo Title', + submitter_email='foo@bar.baz', + paper_id='2101.00123', + dated=util.epoch(datetime.now(tz=UTC)) + ) + session.add(document1) + session.add(models.PaperOwner( + document=document1, + user_id=self.user.user_id, + flag_author=0, # <- User is _not_ an author. + valid=1, + **self.default_tracking_data + )) + session.execute( + insert(models.t_arXiv_in_category) + .values( + document_id=1, + archive='cs', + subject_class='DL', + is_primary=1 + ) + ) + session.add(models.Category( + archive='cs', + subject_class='DL', + definitive=1, + active=1, + endorsement_domain='firstdomain' + )) + # Here's another paper. + document2 = models.Document( + document_id=2, + title='Foo Title', + submitter_email='foo@bar.baz', + paper_id='2101.00124', + dated=util.epoch(datetime.now(tz=UTC)) + ) + session.add(document2) + session.add (models.PaperOwner( + document=document2, + user_id=self.user.user_id, + flag_author=1, # <- User is an author. + valid=1, + **self.default_tracking_data + )) + session.execute( + insert(models.t_arXiv_in_category) + .values( + document_id=2, + archive='cs', + subject_class='IR', + is_primary=1 + ) + ) + session.add(models.Category( + archive='cs', + subject_class='IR', + definitive=1, + active=1, + endorsement_domain='firstdomain' + )) + # Here's a paper for which the user is an author. + document3 = models.Document( + document_id=3, + title='Foo Title', + submitter_email='foo@bar.baz', + paper_id='2101.00125', + dated=util.epoch(datetime.now(tz=UTC)) + ) + session.add(document3) + session.add (models.PaperOwner( + document=document3, + user_id=self.user.user_id, + flag_author=1, + valid=1, + **self.default_tracking_data + )) + # It has both a primary and a secondary classification. + session.execute( + insert(models.t_arXiv_in_category) + .values( + document_id=3, + archive='astro-ph', + subject_class='EP', + is_primary=1 + ) + ) + session.execute( + insert(models.t_arXiv_in_category) + .values( + document_id=3, + archive='astro-ph', + subject_class='CO', + is_primary=0 # <- secondary! + ) + ) + session.add(models.Category( + archive='astro-ph', + subject_class='EP', + definitive=1, + active=1, + endorsement_domain='seconddomain' + )) + session.add(models.Category( + archive='astro-ph', + subject_class='CO', + definitive=1, + active=1, + endorsement_domain='seconddomain' + )) + papers = endorsements.domain_papers(self.user) + self.assertEqual(papers['firstdomain'], 2) + self.assertEqual(papers['seconddomain'], 2) + + def test_is_academic(self): + """Determine whether a user is academic based on email.""" + ok_patterns = ['%w3.org', '%aaas.org', '%agu.org', '%ams.org'] + bad_patterns = ['%.com', '%.net', '%.biz.%'] + with self.app.app_context(): + with transaction() as session: + for pattern in ok_patterns: + session.execute(insert( + models.t_arXiv_white_email) + .values(pattern=str(pattern)) + ) + for pattern in bad_patterns: + session.execute(insert( + models.t_arXiv_black_email) + .values(pattern=str(pattern)) + ) + + self.assertTrue(endorsements.is_academic(domain.User( + user_id='2', + email='someone@fsu.edu', + username='someone' + ))) + self.assertFalse(endorsements.is_academic(domain.User( + user_id='2', + email='someone@fsu.biz.edu', + username='someone' + ))) + self.assertTrue(endorsements.is_academic(domain.User( + user_id='2', + email='someone@aaas.org', + username='someone' + ))) + self.assertFalse(endorsements.is_academic(domain.User( + user_id='2', + email='someone@foo.com', + username='someone' + ))) diff --git a/arxiv/auth/legacy/tests/test_endorsements.py b/arxiv/auth/legacy/tests/test_endorsements.py new file mode 100644 index 00000000..96abe3e3 --- /dev/null +++ b/arxiv/auth/legacy/tests/test_endorsements.py @@ -0,0 +1,463 @@ +"""Tests for :mod:`arxiv.users.legacy.endorsements` using a live test DB.""" + +import os +from unittest import TestCase, mock +from datetime import datetime +from pytz import timezone, UTC + +from flask import Flask +from mimesis import Person, Internet, Datetime +from sqlalchemy import insert + +from arxiv.taxonomy import definitions +from arxiv.config import Settings +from arxiv.db import models, transaction +from .. import endorsements, util +from ... import domain + +EASTERN = timezone('US/Eastern') + + +class TestEndorsement(TestCase): + """Tests for :func:`get_endorsements`.""" + + def setUp(self): + """Generate some fake data.""" + self.app = Flask('test') + self.app.config['CLASSIC_DATABASE_URI'] = 'sqlite:///test.db' + self.app.config['CLASSIC_SESSION_HASH'] = 'foohash' + settings = Settings( + CLASSIC_DB_URI='sqlite:///test.db', + LATEXML_DB_URI=None) + + engine, _ = models.configure_db(settings) + + self.default_tracking_data = { + 'remote_addr': '0.0.0.0', + 'remote_host': 'foo-host.foo.com', + 'tracking_cookie': '0' + } + + with self.app.app_context(): + util.create_all(engine) + with transaction() as session: + person = Person('en') + net = Internet() + ip_addr = net.ip_v4() + email = "foouser@agu.org" + approved = 1 + deleted = 0 + banned = 0 + first_name = person.name() + last_name = person.surname() + suffix_name = person.title() + joined_date = util.epoch( + Datetime('en').datetime().replace(tzinfo=EASTERN) + ) + db_user = models.TapirUser( + first_name=first_name, + last_name=last_name, + suffix_name=suffix_name, + share_first_name=1, + share_last_name=1, + email=email, + flag_approved=approved, + flag_deleted=deleted, + flag_banned=banned, + flag_edit_users=0, + flag_edit_system=0, + flag_email_verified=1, + share_email=8, + email_bouncing=0, + policy_class=2, # Public user. TODO: consider admin. + joined_date=joined_date, + joined_ip_num=ip_addr, + joined_remote_host=ip_addr + ) + session.add(db_user) + + self.user = domain.User( + user_id=str(db_user.user_id), + username='foouser', + email=db_user.email, + name=domain.UserFullName( + forename=db_user.first_name, + surname=db_user.last_name, + suffix=db_user.suffix_name + ) + ) + + ok_patterns = ['%w3.org', '%aaas.org', '%agu.org', '%ams.org'] + bad_patterns = ['%.com', '%.net', '%.biz.%'] + + with transaction() as session: + for pattern in ok_patterns: + session.execute( + insert(models.t_arXiv_white_email) + .values(pattern=str(pattern)) + ) + for pattern in bad_patterns: + session.execute( + insert(models.t_arXiv_black_email) + .values(pattern=str(pattern)) + ) + + session.add(models.EndorsementDomain( + endorsement_domain='test_domain', + endorse_all='n', + mods_endorse_all='n', + endorse_email='y', + papers_to_endorse=3 + )) + + # for category, definition in definitions.CATEGORIES_ACTIVE.items(): + # if '.' in category: + # archive, subject_class = category.split('.', 1) + # else: + # archive, subject_class = category, '' + # session.add(models.Category( + # archive=archive, + # subject_class=subject_class, + # definitive=1, + # active=1, + # endorsement_domain='test_domain' + # )) + + def test_get_endorsements(self): + """Test :func:`endoresement.get_endorsements`.""" + with self.app.app_context(): + with transaction() as session: + for category, definition in definitions.CATEGORIES_ACTIVE.items(): + if '.' in category: + archive, subject_class = category.split('.', 1) + else: + archive, subject_class = category, '' + session.add(models.Category( + archive=archive, + subject_class=subject_class, + definitive=1, + active=1, + endorsement_domain='test_domain' + )) + all_endorsements = set( + endorsements.get_endorsements(self.user) + ) + all_possible = set(definitions.CATEGORIES_ACTIVE.values()) + self.assertEqual(all_endorsements, all_possible) + + def tearDown(self): + """Remove the test DB.""" + try: + os.remove('./test.db') + except FileNotFoundError: + pass + + +class TestAutoEndorsement(TestCase): + """Tests for :func:`get_autoendorsements`.""" + + def setUp(self): + """Generate some fake data.""" + self.app = Flask('test') + self.app.config['CLASSIC_DATABASE_URI'] = 'sqlite:///test.db' + self.app.config['CLASSIC_SESSION_HASH'] = 'foohash' + settings = Settings( + CLASSIC_DB_URI='sqlite:///test.db', + LATEXML_DB_URI=None) + + engine, _ = models.configure_db(settings) + + self.default_tracking_data = { + 'remote_addr': '0.0.0.0', + 'remote_host': 'foo-host.foo.com', + 'tracking_cookie': '0' + } + + with self.app.app_context(): + util.create_all(engine) + with transaction() as session: + person = Person('en') + net = Internet() + ip_addr = net.ip_v4() + email = person.email() + approved = 1 + deleted = 0 + banned = 0 + first_name = person.name() + last_name = person.surname() + suffix_name = person.title() + joined_date = util.epoch( + Datetime('en').datetime().replace(tzinfo=EASTERN) + ) + db_user = models.TapirUser( + first_name=first_name, + last_name=last_name, + suffix_name=suffix_name, + share_first_name=1, + share_last_name=1, + email=email, + flag_approved=approved, + flag_deleted=deleted, + flag_banned=banned, + flag_edit_users=0, + flag_edit_system=0, + flag_email_verified=1, + share_email=8, + email_bouncing=0, + policy_class=2, # Public user. TODO: consider admin. + joined_date=joined_date, + joined_ip_num=ip_addr, + joined_remote_host=ip_addr + ) + session.add(db_user) + + self.user = domain.User( + user_id=str(db_user.user_id), + username='foouser', + email=db_user.email, + name=domain.UserFullName( + forename=db_user.first_name, + surname=db_user.last_name, + suffix=db_user.suffix_name + ) + ) + + def tearDown(self): + """Remove the test DB.""" + try: + os.remove('./test.db') + except FileNotFoundError: + pass + + def test_invalidated_autoendorsements(self): + """The user has two autoendorsements that have been invalidated.""" + with self.app.app_context(): + with transaction() as session: + issued_when = util.epoch( + Datetime('en').datetime().replace(tzinfo=EASTERN) + ) + session.add(models.Endorsement( + endorsee_id=self.user.user_id, + archive='astro-ph', + subject_class='CO', + flag_valid=0, + type='auto', + point_value=10, + issued_when=issued_when + )) + session.add(models.Endorsement( + endorsee_id=self.user.user_id, + archive='astro-ph', + subject_class='CO', + flag_valid=0, + type='auto', + point_value=10, + issued_when=issued_when + )) + session.add(models.Endorsement( + endorsee_id=self.user.user_id, + archive='astro-ph', + subject_class='CO', + flag_valid=1, + type='auto', + point_value=10, + issued_when=issued_when + )) + session.add(models.Endorsement( + endorsee_id=self.user.user_id, + archive='astro-ph', + subject_class='CO', + flag_valid=1, + type='user', + point_value=10, + issued_when=issued_when + )) + + result = endorsements.invalidated_autoendorsements(self.user) + self.assertEqual(len(result), 2, "Two revoked endorsements are loaded") + + def test_category_policies(self): + """Load category endorsement policies from the database.""" + with self.app.app_context(): + with transaction() as session: + session.add(models.Category( + archive='astro-ph', + subject_class='CO', + definitive=1, + active=1, + endorsement_domain='astro-ph' + )) + session.add(models.EndorsementDomain( + endorsement_domain='astro-ph', + endorse_all='n', + mods_endorse_all='n', + endorse_email='y', + papers_to_endorse=3 + )) + + policies = endorsements.category_policies() + category = definitions.CATEGORIES['astro-ph.CO'] + self.assertIn(category, policies, "Data are loaded for categories") + self.assertEqual(policies[category]['domain'], 'astro-ph') + self.assertFalse(policies[category]['endorse_all']) + self.assertTrue(policies[category]['endorse_email']) + self.assertEqual(policies[category]['min_papers'], 3) + + def test_domain_papers(self): + """Get the number of papers published in each domain.""" + with self.app.app_context(): + with transaction() as session: + # User owns three papers. + document1 = models.Document( + document_id=1, + title='Foo Title', + submitter_email='foo@bar.baz', + paper_id='2101.00123', + dated=util.epoch(datetime.now(tz=UTC)) + ) + session.add(document1) + session.add(models.PaperOwner( + document=document1, + user_id=self.user.user_id, + flag_author=0, + valid=1, + **self.default_tracking_data + )) + session.execute( + insert(models.t_arXiv_in_category) + .values( + document_id=1, + archive='cs', + subject_class='DL', + is_primary=1 + ) + ) + session.add(models.Category( + archive='cs', + subject_class='DL', + definitive=1, + active=1, + endorsement_domain='firstdomain' + )) + # Here's another paper. + document2 = models.Document( + document_id=2, + title='Foo Title', + submitter_email='foo@bar.baz', + paper_id='2101.00124', + dated=util.epoch(datetime.now(tz=UTC)) + ) + session.add(document2) + session.add(models.PaperOwner( + document=document2, + user_id=self.user.user_id, + flag_author=1, + valid=1, + **self.default_tracking_data + )) + session.execute( + insert(models.t_arXiv_in_category) + .values( + document_id=2, + archive='cs', + subject_class='IR', + is_primary=1 + ) + ) + session.add(models.Category( + archive='cs', + subject_class='IR', + definitive=1, + active=1, + endorsement_domain='firstdomain' + )) + # Here's a paper for which the user is an author. + document3 = models.Document( + document_id=3, + title='Foo Title', + submitter_email='foo@bar.baz', + paper_id='2101.00125', + dated=util.epoch(datetime.now(tz=UTC)) + ) + session.add(document3) + session.add(models.PaperOwner( + document=document3, + user_id=self.user.user_id, + flag_author=1, + valid=1, + **self.default_tracking_data + )) + # It has both a primary and a secondary classification. + session.execute( + insert(models.t_arXiv_in_category) + .values( + document_id=3, + archive='astro-ph', + subject_class='EP', + is_primary=1 + ) + ) + session.execute( + insert(models.t_arXiv_in_category) + .values( + document_id=3, + archive='astro-ph', + subject_class='CO', + is_primary=0 # <- secondary! + ) + ) + session.add(models.Category( + archive='astro-ph', + subject_class='EP', + definitive=1, + active=1, + endorsement_domain='seconddomain' + )) + session.add(models.Category( + archive='astro-ph', + subject_class='CO', + definitive=1, + active=1, + endorsement_domain='seconddomain' + )) + papers = endorsements.domain_papers(self.user) + self.assertEqual(papers['firstdomain'], 2) + self.assertEqual(papers['seconddomain'], 2) + + def test_is_academic(self): + """Determine whether a user is academic based on email.""" + ok_patterns = ['%w3.org', '%aaas.org', '%agu.org', '%ams.org'] + bad_patterns = ['%.com', '%.net', '%.biz.%'] + with self.app.app_context(): + with transaction() as session: + for pattern in ok_patterns: + session.execute( + insert(models.t_arXiv_white_email) + .values(pattern=str(pattern)) + ) + for pattern in bad_patterns: + session.execute( + insert(models.t_arXiv_black_email) + .values(pattern=str(pattern)) + ) + + self.assertTrue(endorsements.is_academic(domain.User( + user_id='2', + email='someone@fsu.edu', + username='someone' + ))) + self.assertFalse(endorsements.is_academic(domain.User( + user_id='2', + email='someone@fsu.biz.edu', + username='someone' + ))) + self.assertTrue(endorsements.is_academic(domain.User( + user_id='2', + email='someone@aaas.org', + username='someone' + ))) + self.assertFalse(endorsements.is_academic(domain.User( + user_id='2', + email='someone@foo.com', + username='someone' + ))) diff --git a/arxiv/auth/legacy/tests/test_passwords.py b/arxiv/auth/legacy/tests/test_passwords.py new file mode 100644 index 00000000..9c8a97fe --- /dev/null +++ b/arxiv/auth/legacy/tests/test_passwords.py @@ -0,0 +1,30 @@ +"""Tests for :mod:`legacy_users.passwords`.""" + +from unittest import TestCase +from ..exceptions import PasswordAuthenticationFailed +from ..passwords import hash_password, check_password + +from hypothesis import given, settings +from hypothesis import strategies as st +import string + + +class TestCheckPassword(TestCase): + """Tests passwords.""" + @given(st.text(alphabet=string.printable)) + @settings(max_examples=500) + def test_check_passwords_successful(self, passw): + encrypted = hash_password(passw) + self.assertTrue( check_password(passw, encrypted.encode('ascii')), + f"should work for password '{passw}'") + + @given(st.text(alphabet=string.printable), st.text(alphabet=st.characters())) + @settings(max_examples=5000) + def test_check_passwords_fuzz(self, passw, fuzzpw): + if passw == fuzzpw: + self.assertTrue(check_password(fuzzpw, + hash_password(passw).encode('ascii'))) + else: + with self.assertRaises(PasswordAuthenticationFailed): + check_password(fuzzpw, + hash_password(passw).encode('ascii')) diff --git a/arxiv/auth/legacy/tests/test_sessions.py b/arxiv/auth/legacy/tests/test_sessions.py new file mode 100644 index 00000000..90111145 --- /dev/null +++ b/arxiv/auth/legacy/tests/test_sessions.py @@ -0,0 +1,104 @@ +"""Tests for legacy_users service.""" +import time +from unittest import mock, TestCase +from datetime import datetime +from pytz import timezone, UTC + +from arxiv.db import models, transaction +from .. import exceptions, sessions, util, cookies + +from .util import temporary_db + +EASTERN = timezone('US/Eastern') + + +class TestCreateSession(TestCase): + """Tests for public function :func:`.`.""" + + @mock.patch(f'{sessions.__name__}.util.get_session_duration') + def test_create(self, mock_get_session_duration): + """Accept a :class:`.User` and returns a :class:`.Session`.""" + mock_get_session_duration.return_value = 36000 + user = sessions.domain.User( + user_id="1", + username='theuser', + email='the@user.com', + ) + auths = sessions.domain.Authorizations(classic=6) + ip_address = '127.0.0.1' + remote_host = 'foo-host.foo.com' + tracking = "1.foo" + with temporary_db('sqlite:///:memory:', create=True): + user_session = sessions.create(auths, ip_address, remote_host, + tracking, user=user) + self.assertIsInstance(user_session, sessions.domain.Session) + tapir_session = sessions._load(user_session.session_id) + self.assertIsNotNone(user_session, 'verifying we have a session') + if tapir_session is not None: + self.assertEqual( + tapir_session.session_id, + int(user_session.session_id), + "Returned session has correct session id." + ) + self.assertEqual(tapir_session.user_id, int(user.user_id), + "Returned session has correct user id.") + self.assertEqual(tapir_session.end_time, 0, + "End time is 0 (no end time)") + + tapir_session_audit = sessions._load_audit(user_session.session_id) + self.assertIsNotNone(tapir_session_audit) + if tapir_session_audit is not None: + self.assertEqual( + tapir_session_audit.session_id, + int(user_session.session_id), + "Returned session audit has correct session id." + ) + self.assertEqual( + tapir_session_audit.ip_addr, + user_session.ip_address, + "Returned session audit has correct ip address" + ) + self.assertEqual( + tapir_session_audit.remote_host, + user_session.remote_host, + "Returned session audit has correct remote host" + ) + + +class TestInvalidateSession(TestCase): + """Tests for public function :func:`.invalidate`.""" + + @mock.patch(f'{cookies.__name__}.util.get_session_duration') + def test_invalidate(self, mock_get_duration): + """The session is invalidated by setting `end_time`.""" + mock_get_duration.return_value = 36000 + session_id = "424242424" + user_id = "12345" + ip = "127.0.0.1" + capabilities = 6 + start = datetime.now(tz=UTC) + + with temporary_db('sqlite:///:memory:') as db_session: + cookie = cookies.pack(session_id, user_id, ip, start, capabilities) + with transaction() as db_session: + tapir_session = models.TapirSession( + session_id=session_id, + user_id=12345, + last_reissue=util.epoch(start), + start_time=util.epoch(start), + end_time=0 + ) + db_session.add(tapir_session) + + sessions.invalidate(cookie) + tapir_session = sessions._load(session_id) + time.sleep(1) + self.assertGreaterEqual(util.now(), tapir_session.end_time) + + @mock.patch(f'{cookies.__name__}.util.get_session_duration') + def test_invalidate_nonexistant_session(self, mock_get_duration): + """An exception is raised if the session doesn't exist.""" + mock_get_duration.return_value = 36000 + with temporary_db('sqlite:///:memory:'): + with self.assertRaises(exceptions.UnknownSession): + sessions.invalidate('1:1:10.10.10.10:1531145500:4') diff --git a/arxiv/auth/legacy/tests/test_util.py b/arxiv/auth/legacy/tests/test_util.py new file mode 100644 index 00000000..1ff69488 --- /dev/null +++ b/arxiv/auth/legacy/tests/test_util.py @@ -0,0 +1,32 @@ +"""Tests for :mod:`legacy_users.util`.""" + +from unittest import TestCase +from arxiv.db import models +from .util import temporary_db +from .. import util, sessions + +class TestGetSession(TestCase): + """ + Tests for private function :func:`._load`. + + Gets a :class:`.TapirSession` given a session ID. + """ + + def test_load_returns_a_session(self) -> None: + """If ID matches a known session, returns a :class:`.TapirSession`.""" + session_id = "424242424" + with temporary_db('sqlite:///:memory:') as db_session: + start = util.now() + db_session.add(models.TapirSession( + session_id=session_id, + user_id=12345, + last_reissue=start, + start_time=start, + end_time=0 + )) + db_session.commit() + tapir_session = sessions._load(session_id) + self.assertIsNotNone(tapir_session, 'verifying we have a session') + self.assertEqual(tapir_session.session_id, int(session_id), + "Returned session has correct session id.") + diff --git a/arxiv/auth/legacy/tests/util.py b/arxiv/auth/legacy/tests/util.py new file mode 100644 index 00000000..eab847fd --- /dev/null +++ b/arxiv/auth/legacy/tests/util.py @@ -0,0 +1,33 @@ +"""Testing helpers.""" +from contextlib import contextmanager + +from flask import Flask + +from arxiv.config import Settings +from arxiv.db import transaction +from arxiv.db.models import configure_db +from .. import util + + +@contextmanager +def temporary_db(db_uri: str, create: bool = True, drop: bool = True): + """Provide an in-memory sqlite database for testing purposes.""" + app = Flask('foo') + app.config['CLASSIC_SESSION_HASH'] = 'foohash' + app.config['SESSION_DURATION'] = 3600 + app.config['CLASSIC_COOKIE_NAME'] = 'tapir_session' + settings = Settings( + CLASSIC_DB_URI=db_uri, + LATEXML_DB_URI=None) + + engine, _ = configure_db (settings) + + with app.app_context(): + if create: + util.create_all(engine) + try: + with transaction() as session: + yield session + finally: + if drop: + util.drop_all(engine) diff --git a/arxiv/auth/legacy/util.py b/arxiv/auth/legacy/util.py new file mode 100644 index 00000000..c718927f --- /dev/null +++ b/arxiv/auth/legacy/util.py @@ -0,0 +1,103 @@ +"""Helpers and Flask application integration.""" + +from typing import List, Any +from datetime import datetime +from pytz import timezone, UTC +import logging + +from sqlalchemy import text, Engine + +from ...base.globals import get_application_config + +from ..auth import scopes +from .. import domain +from ...db import session, Base, SessionLocal +from ...db.models import TapirUser, TapirPolicyClass + +EASTERN = timezone('US/Eastern') +logger = logging.getLogger(__name__) +logger.propagate = False + + +def now() -> int: + """Get the current epoch/unix time.""" + return epoch(datetime.now(tz=UTC)) + + +def epoch(t: datetime) -> int: + """Convert a :class:`.datetime` to UNIX time.""" + delta = t - datetime.fromtimestamp(0, tz=EASTERN) + return int(round((delta).total_seconds())) + + +def from_epoch(t: int) -> datetime: + """Get a :class:`datetime` from an UNIX timestamp.""" + return datetime.fromtimestamp(t, tz=EASTERN) + + +def create_all(engine: Engine) -> None: + """Create all tables in the database.""" + Base.metadata.create_all(engine) + with SessionLocal() as session: + data = session.query(TapirPolicyClass).all() + if data: + return + + for datum in TapirPolicyClass.POLICY_CLASSES: + session.add(TapirPolicyClass(**datum)) + session.commit() + +def drop_all(engine: Engine) -> None: + """Drop all tables in the database.""" + Base.metadata.drop_all(engine) + + +def compute_capabilities(tapir_user: TapirUser) -> int: + """Calculate the privilege level code for a user.""" + return int(sum([2 * tapir_user.flag_edit_users, + 4 * tapir_user.flag_email_verified, + 8 * tapir_user.flag_edit_system])) + + +def get_scopes(db_user: TapirUser) -> List[domain.Scope]: + """Generate a list of authz scopes for a legacy user based on class.""" + if db_user.policy_class == TapirPolicyClass.PUBLIC_USER: + return scopes.GENERAL_USER + if db_user.policy_class == TapirPolicyClass.ADMIN: + return scopes.ADMIN_USER + return [] + + +def is_configured() -> bool: + """Determine whether or not the legacy auth is configured of the `Flask` app.""" + config = get_application_config() + return not bool(missing_configs(config)) + +def missing_configs(config) -> List[str]: + """Returns missing keys for configs needed in `Flask.config` for legacy auth to work.""" + missing = [key for key in ['CLASSIC_SESSION_HASH', 'SESSION_DURATION', 'CLASSIC_COOKIE_NAME'] + if key not in config] + return missing + +def get_session_hash() -> str: + """Get the legacy hash secret.""" + config = get_application_config() + session_hash: str = config['CLASSIC_SESSION_HASH'] + return session_hash + + +def get_session_duration() -> int: + """Get the session duration from the config.""" + config = get_application_config() + timeout: str = config['SESSION_DURATION'] + return int(timeout) + + +def is_available(**kwargs: Any) -> bool: + """Check our connection to the database.""" + try: + session.query("1").from_statement(text("SELECT 1")).all() + except Exception as e: + logger.error('Encountered an error talking to database: %s', e) + return False + return True diff --git a/arxiv/auth/tests/__init__.py b/arxiv/auth/tests/__init__.py new file mode 100644 index 00000000..1f835616 --- /dev/null +++ b/arxiv/auth/tests/__init__.py @@ -0,0 +1 @@ +"""Package-level tests for :mod:`arxiv.users`.""" diff --git a/arxiv/auth/tests/test_domain.py b/arxiv/auth/tests/test_domain.py new file mode 100644 index 00000000..f34dd0a1 --- /dev/null +++ b/arxiv/auth/tests/test_domain.py @@ -0,0 +1,49 @@ +"""Tests for :mod:`arxiv.users.domain`.""" + +from unittest import TestCase +from datetime import datetime +from arxiv.taxonomy import definitions +from pytz import timezone +from ..auth import scopes +from ..auth import domain + +EASTERN = timezone('US/Eastern') + + +class TestSession(TestCase): + def test_with_session(self): + session = domain.Session( + session_id='asdf1234', + start_time=datetime.now(), end_time=datetime.now(), + user=domain.User( + user_id='12345', + email='foo@bar.com', + username='emanresu', + name=domain.UserFullName(forename='First', surname='Last', suffix='Lastest'), + profile=domain.UserProfile( + affiliation='FSU', + rank=3, + country='us', + default_category=definitions.CATEGORIES['astro-ph.CO'], + submission_groups=['grp_physics'] + ) + ), + authorizations=domain.Authorizations( + scopes=[scopes.VIEW_SUBMISSION, scopes.CREATE_SUBMISSION], + ) + ) + session_data = session.dict() + self.assertEqual(session_data['authorizations']['scopes'], + ['submission:read','submission:create']) + + self.assertEqual(session_data['user']['profile']['affiliation'], 'FSU') + self.assertEqual(session_data['user']['profile']['country'], 'us') + self.assertEqual(session_data['user']['profile']['submission_groups'], ['grp_physics']) + self.assertEqual(session_data['user']['profile']['default_category']['id'], 'astro-ph.CO') + self.assertEqual( + session_data['user']['name'], + {'forename': 'First', 'surname': 'Last', 'suffix': 'Lastest'} + ) + + as_session = domain.session_from_dict(session_data) + self.assertEqual(session, as_session) diff --git a/arxiv/auth/tests/test_helpers.py b/arxiv/auth/tests/test_helpers.py new file mode 100644 index 00000000..8e58ccaa --- /dev/null +++ b/arxiv/auth/tests/test_helpers.py @@ -0,0 +1,63 @@ +"""Tests for :mod:`.helpers`.""" + +from unittest import TestCase, mock +import os +import logging + +from flask import Flask + +from arxiv import status +from arxiv.config import Settings +from arxiv.db.models import configure_db +from arxiv.base import Base +from arxiv.base.middleware import wrap + +from .. import auth, helpers, legacy +from ..auth import decorators +from ..auth.middleware import AuthMiddleware +from ..auth.scopes import VIEW_SUBMISSION, CREATE_SUBMISSION, EDIT_SUBMISSION + + +class TestGenerateToken(TestCase): + """Tests for :func:`.helpers.generate_token`.""" + + def test_token_is_usable(self): + """Verify that :func:`.helpers.generate_token` makes usable tokens.""" + os.environ['JWT_SECRET'] = 'thesecret' + scope = [VIEW_SUBMISSION, EDIT_SUBMISSION, CREATE_SUBMISSION] + token = helpers.generate_token("1234", "user@foo.com", "theuser", + scope=scope) + + app = Flask('test') + app.config['CLASSIC_SESSION_HASH'] = 'foohash' + app.config['CLASSIC_COOKIE_NAME'] = 'tapir_session_cookie' + app.config['SESSION_DURATION'] = '36000' + app.config['AUTH_UPDATED_SESSION_REF'] = True + + app.config.update({ + 'JWT_SECRET': 'thesecret', + }) + + settings = Settings( + CLASSIC_DB_URI='sqlite:///:memory:', + LATEXML_DB_URI=None) + + configure_db (settings) + + Base(app) + auth.Auth(app) + wrap(app, [AuthMiddleware]) + + @app.route('/') + @decorators.scoped(EDIT_SUBMISSION) + def protected(): + return "this is protected" + + client = app.test_client() + with app.app_context(): + response = client.get('/') + self.assertEqual(response.status_code, + status.HTTP_401_UNAUTHORIZED) + + response = client.get('/', headers={'Authorization': token}) + self.assertEqual(response.status_code, status.HTTP_200_OK) diff --git a/arxiv/base/__init__.py b/arxiv/base/__init__.py index 64da60c4..0e99b73e 100644 --- a/arxiv/base/__init__.py +++ b/arxiv/base/__init__.py @@ -129,4 +129,6 @@ def register_blueprint(self: Flask, blueprint: Blueprint, # See: https://github.com/pallets-eco/flask-sqlalchemy/blob/42a36a3cb604fd39d81d00b54ab3988bbd0ad184/src/flask_sqlalchemy/session.py#L109 @app.teardown_appcontext def remove_scoped_session (response_or_exc: BaseException | None) -> None: + if response_or_exc: + session.rollback() session.remove() diff --git a/arxiv/config/__init__.py b/arxiv/config/__init__.py index e5a0fc01..ee038ea0 100644 --- a/arxiv/config/__init__.py +++ b/arxiv/config/__init__.py @@ -170,7 +170,7 @@ class Settings(BaseSettings): TRACKBACK_SECRET: SecretStr = SecretStr(token_hex(10)) CLASSIC_DB_URI: str = DEFAULT_DB - LATEXML_DB_URI: str = DEFAULT_LATEXML_DB + LATEXML_DB_URI: Optional[str] = DEFAULT_LATEXML_DB ECHO_SQL: bool = False CLASSIC_DB_TRANSACTION_ISOLATION_LEVEL: Optional[IsolationLevel] = None LATEXML_DB_TRANSACTION_ISOLATION_LEVEL: Optional[IsolationLevel] = None diff --git a/arxiv/db/__init__.py b/arxiv/db/__init__.py index 05763471..7bbc0072 100644 --- a/arxiv/db/__init__.py +++ b/arxiv/db/__init__.py @@ -40,7 +40,7 @@ from flask.globals import app_ctx from flask import has_app_context -from sqlalchemy import create_engine, MetaData +from sqlalchemy import Engine, MetaData from sqlalchemy.event import listens_for from sqlalchemy.orm import sessionmaker, scoped_session, DeclarativeBase @@ -56,18 +56,7 @@ class LaTeXMLBase(DeclarativeBase): logger = logging.getLogger(__name__) -engine = create_engine(settings.CLASSIC_DB_URI, - echo=settings.ECHO_SQL, - isolation_level=settings.CLASSIC_DB_TRANSACTION_ISOLATION_LEVEL, - pool_recycle=600, - max_overflow=(settings.REQUEST_CONCURRENCY - 5), # max overflow is how many + base pool size, which is 5 by default - pool_pre_ping=settings.POOL_PRE_PING) -latexml_engine = create_engine(settings.LATEXML_DB_URI, - echo=settings.ECHO_SQL, - isolation_level=settings.LATEXML_DB_TRANSACTION_ISOLATION_LEVEL, - pool_recycle=600, - max_overflow=(settings.REQUEST_CONCURRENCY - 5), - pool_pre_ping=settings.POOL_PRE_PING) + SessionLocal = sessionmaker(autocommit=False, autoflush=False) def _app_ctx_id () -> int: @@ -75,6 +64,9 @@ def _app_ctx_id () -> int: session = scoped_session(SessionLocal, scopefunc=_app_ctx_id) +def get_engine () -> Engine: + return SessionLocal().get_bind(Base) + @contextmanager def get_db (): db = SessionLocal() @@ -93,19 +85,20 @@ def transaction (): if db.new or db.dirty or db.deleted: db.commit() except Exception as e: - logger.warn(f'Commit failed, rolling back', exc_info=1) + logger.warning(f'Commit failed, rolling back', exc_info=1) db.rollback() + raise finally: if not in_flask: db.close() -def config_query_timing( slightly_long_sec: float, long_sec: float): - @listens_for(engine, "before_cursor_execute") +def config_query_timing(slightly_long_sec: float, long_sec: float): + @listens_for(get_engine(), "before_cursor_execute") def _record_query_start (conn, cursor, statement, parameters, context, executemany): conn.info['query_start'] = datetime.now() - @listens_for(engine, "after_cursor_execute") + @listens_for(get_engine(), "after_cursor_execute") def _calculate_query_run_time (conn, cursor, statement, parameters, context, executemany): if conn.info.get('query_start'): delta: timedelta = (datetime.now() - conn.info['query_start']) diff --git a/arxiv/db/models.py b/arxiv/db/models.py index 8259daa6..a8b8ac09 100644 --- a/arxiv/db/models.py +++ b/arxiv/db/models.py @@ -3,7 +3,7 @@ This file was generated by using sqlacodgen """ -from typing import Optional, Literal, Any +from typing import Optional, Literal, Any, Tuple, List import re import os import hashlib @@ -30,7 +30,9 @@ Text, Table, Enum, - text + text, + create_engine, + Engine ) from sqlalchemy.schema import FetchedValue from sqlalchemy.orm import ( @@ -39,15 +41,67 @@ relationship ) -from ..config import settings +from ..config import Settings, settings from . import Base, LaTeXMLBase, metadata, \ - SessionLocal, engine, latexml_engine + SessionLocal from .types import intpk tb_secret = settings.TRACKBACK_SECRET tz = gettz(settings.ARXIV_BUSINESS_TZ) +def configure_db (base_settings: Settings) -> Tuple[Engine, Optional[Engine]]: + if 'sqlite' in base_settings.CLASSIC_DB_URI: + engine = create_engine(base_settings.CLASSIC_DB_URI) + if base_settings.LATEXML_DB_URI: + latexml_engine = create_engine(base_settings.LATEXML_DB_URI) + else: + latexml_engine = None + else: + engine = create_engine(base_settings.CLASSIC_DB_URI, + echo=base_settings.ECHO_SQL, + isolation_level=base_settings.CLASSIC_DB_TRANSACTION_ISOLATION_LEVEL, + pool_recycle=600, + max_overflow=(base_settings.REQUEST_CONCURRENCY - 5), # max overflow is how many + base pool size, which is 5 by default + pool_pre_ping=base_settings.POOL_PRE_PING) + if base_settings.LATEXML_DB_URI: + latexml_engine = create_engine(base_settings.LATEXML_DB_URI, + echo=base_settings.ECHO_SQL, + isolation_level=base_settings.LATEXML_DB_TRANSACTION_ISOLATION_LEVEL, + pool_recycle=600, + max_overflow=(base_settings.REQUEST_CONCURRENCY - 5), + pool_pre_ping=base_settings.POOL_PRE_PING) + else: + latexml_engine = None + SessionLocal.configure(binds={ + Base: engine, + LaTeXMLBase: (latexml_engine if latexml_engine else engine), + t_arXiv_stats_hourly: engine, + t_arXiv_admin_state: engine, + t_arXiv_bad_pw: engine, + t_arXiv_black_email: engine, + t_arXiv_block_email: engine, + t_arXiv_bogus_subject_class: engine, + t_arXiv_duplicates: engine, + t_arXiv_in_category: engine, + t_arXiv_moderators: engine, + t_arXiv_ownership_requests_papers: engine, + t_arXiv_refresh_list: engine, + t_arXiv_updates_tmp: engine, + t_arXiv_white_email: engine, + t_arXiv_xml_notifications: engine, + t_demographics_backup: engine, + t_tapir_email_change_tokens_used: engine, + t_tapir_email_tokens_used: engine, + t_tapir_error_log: engine, + t_tapir_no_cookies: engine, + t_tapir_periodic_tasks_log: engine, + t_tapir_periodic_tasks_log: engine, + t_tapir_permanent_tokens_used: engine, + t_tapir_save_post_variables: engine + }) + return engine, latexml_engine + class MemberInstitution(Base): __tablename__ = 'Subscription_UniversalInstitution' @@ -97,7 +151,7 @@ class AdminLog(Base): id: Mapped[intpk] logtime: Mapped[Optional[str]] = mapped_column(String(24)) - created: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=FetchedValue()) + created: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=text('CURRENT_TIMESTAMP'), server_onupdate=text('CURRENT_TIMESTAMP')) paper_id: Mapped[Optional[str]] = mapped_column(String(20), index=True) username: Mapped[Optional[str]] = mapped_column(String(20), index=True) host: Mapped[Optional[str]] = mapped_column(String(64)) @@ -294,17 +348,18 @@ class Category(Base): archive: Mapped[str] = mapped_column(ForeignKey('arXiv_archives.archive_id'), primary_key=True, nullable=False, server_default=FetchedValue()) subject_class: Mapped[str] = mapped_column(String(16), primary_key=True, nullable=False, server_default=FetchedValue()) - definitive: Mapped[int] = mapped_column(Integer, nullable=False, server_default=FetchedValue()) - active: Mapped[int] = mapped_column(Integer, nullable=False, server_default=FetchedValue()) + definitive: Mapped[int] = mapped_column(Integer, nullable=False, server_default=text("'0'")) + active: Mapped[int] = mapped_column(Integer, nullable=False, server_default=text("'0'")) category_name: Mapped[Optional[str]] - endorse_all: Mapped[Literal['y', 'n', 'd']] = mapped_column(Enum('y', 'n', 'd'), nullable=False, server_default=FetchedValue()) - endorse_email: Mapped[Literal['y', 'n', 'd']] = mapped_column(Enum('y', 'n', 'd'), nullable=False, server_default=FetchedValue()) - papers_to_endorse: Mapped[int] = mapped_column(SmallInteger, nullable=False, server_default=FetchedValue()) + endorse_all: Mapped[Literal['y', 'n', 'd']] = mapped_column(Enum('y', 'n', 'd'), nullable=False, server_default=text("'d'")) + endorse_email: Mapped[Literal['y', 'n', 'd']] = mapped_column(Enum('y', 'n', 'd'), nullable=False, server_default=text("'d'")) + papers_to_endorse: Mapped[int] = mapped_column(SmallInteger, nullable=False, server_default=text("'0'")) endorsement_domain: Mapped[Optional[str]] = mapped_column(ForeignKey('arXiv_endorsement_domains.endorsement_domain'), index=True) arXiv_archive = relationship('Archive', primaryjoin='Category.archive == Archive.archive_id', backref='arXiv_categories') + arXiv_endorsements = relationship('Endorsement', back_populates='arXiv_categories') arXiv_endorsement_domain = relationship('EndorsementDomain', primaryjoin='Category.endorsement_domain == EndorsementDomain.endorsement_domain', backref='arXiv_categories') - + arXiv_endorsement_requests = relationship('EndorsementRequest', back_populates='arXiv_categories') class QuestionableCategory(Category): __tablename__ = 'arXiv_questionable_categories' @@ -342,8 +397,8 @@ class ControlHold(Base): placed_by: Mapped[Optional[int]] = mapped_column(ForeignKey('tapir_users.user_id'), index=True) last_changed_by: Mapped[Optional[int]] = mapped_column(ForeignKey('tapir_users.user_id'), index=True) - tapir_user = relationship('TapirUser', primaryjoin='ControlHold.last_changed_by == TapirUser.user_id', backref='tapiruser_arXiv_control_holds') - tapir_user1 = relationship('TapirUser', primaryjoin='ControlHold.placed_by == TapirUser.user_id', backref='tapiruser_arXiv_control_holds_0') + tapir_users = relationship('TapirUser', foreign_keys=[last_changed_by], back_populates='arXiv_control_holds') + tapir_users_ = relationship('TapirUser', foreign_keys=[placed_by], back_populates='arXiv_control_holds_') @@ -369,8 +424,8 @@ class CrossControl(Base): publish_date: Mapped[int] = mapped_column(Integer, nullable=False, server_default=FetchedValue()) arXiv_category = relationship('Category', primaryjoin='and_(CrossControl.archive == Category.archive, CrossControl.subject_class == Category.subject_class)', backref='arXiv_cross_controls') - document = relationship('Document', primaryjoin='CrossControl.document_id == Document.document_id', backref='arXiv_cross_controls') - user = relationship('TapirUser', primaryjoin='CrossControl.user_id == TapirUser.user_id', backref='arXiv_cross_controls') + document = relationship('Document', primaryjoin='CrossControl.document_id == Document.document_id', back_populates='arXiv_cross_controls') + user = relationship('TapirUser', primaryjoin='CrossControl.user_id == TapirUser.user_id', back_populates='arXiv_cross_controls') @@ -432,12 +487,13 @@ class Document(Base): authors: Mapped[Optional[str]] = mapped_column(Text) submitter_email: Mapped[str] = mapped_column(String(64), nullable=False, index=True, server_default=FetchedValue()) submitter_id: Mapped[Optional[int]] = mapped_column(ForeignKey('tapir_users.user_id'), index=True) - dated: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=FetchedValue()) + dated: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=text("'0'")) primary_subject_class: Mapped[Optional[str]] = mapped_column(String(16)) created: Mapped[Optional[datetime]] - submitter = relationship('TapirUser', primaryjoin='Document.submitter_id == TapirUser.user_id', backref='arXiv_documents') - + owners = relationship('PaperOwner', back_populates='document') + submitter = relationship('TapirUser', primaryjoin='Document.submitter_id == TapirUser.user_id', back_populates='arXiv_documents') + arXiv_cross_controls = relationship('CrossControl', back_populates='document') class DBLP(Document): __tablename__ = 'arXiv_dblp' @@ -490,8 +546,10 @@ class EndorsementRequest(Base): issued_when: Mapped[int] = mapped_column(Integer, nullable=False, server_default=FetchedValue()) point_value: Mapped[int] = mapped_column(Integer, nullable=False, server_default=FetchedValue()) - arXiv_category = relationship('Category', primaryjoin='and_(EndorsementRequest.archive == Category.archive, EndorsementRequest.subject_class == Category.subject_class)', backref='arXiv_endorsement_requests') - endorsee = relationship('TapirUser', primaryjoin='EndorsementRequest.endorsee_id == TapirUser.user_id', backref='arXiv_endorsement_requests') + arXiv_categories = relationship('Category', primaryjoin='and_(EndorsementRequest.archive == Category.archive, EndorsementRequest.subject_class == Category.subject_class)', back_populates='arXiv_endorsement_requests') + endorsee = relationship('TapirUser', primaryjoin='EndorsementRequest.endorsee_id == TapirUser.user_id', back_populates='arXiv_endorsement_requests', uselist=False) + endorsement = relationship('Endorsement', back_populates='request', uselist=False) + audit = relationship('EndorsementRequestsAudit', uselist=False) class EndorsementRequestsAudit(EndorsementRequest): @@ -524,10 +582,10 @@ class Endorsement(Base): issued_when: Mapped[int] = mapped_column(Integer, nullable=False, server_default=FetchedValue()) request_id: Mapped[Optional[int]] = mapped_column(ForeignKey('arXiv_endorsement_requests.request_id'), index=True) - arXiv_category = relationship('Category', primaryjoin='and_(Endorsement.archive == Category.archive, Endorsement.subject_class == Category.subject_class)', backref='arXiv_endorsements') - endorsee = relationship('TapirUser', primaryjoin='Endorsement.endorsee_id == TapirUser.user_id', backref='tapiruser_arXiv_endorsements') - endorser = relationship('TapirUser', primaryjoin='Endorsement.endorser_id == TapirUser.user_id', backref='tapiruser_arXiv_endorsements_0') - request = relationship('EndorsementRequest', primaryjoin='Endorsement.request_id == EndorsementRequest.request_id', backref='arXiv_endorsements') + arXiv_categories = relationship('Category', primaryjoin='and_(Endorsement.archive == Category.archive, Endorsement.subject_class == Category.subject_class)', back_populates='arXiv_endorsements') + endorsee = relationship('TapirUser', primaryjoin='Endorsement.endorsee_id == TapirUser.user_id', back_populates='endorsee_of') + endorser = relationship('TapirUser', primaryjoin='Endorsement.endorser_id == TapirUser.user_id', back_populates='endorses') + request = relationship('EndorsementRequest', primaryjoin='Endorsement.request_id == EndorsementRequest.request_id', back_populates='endorsement') class EndorsementsAudit(Endorsement): @@ -598,7 +656,7 @@ class JrefControl(Base): publish_date: Mapped[int] = mapped_column(Integer, nullable=False, server_default=FetchedValue()) document = relationship('Document', primaryjoin='JrefControl.document_id == Document.document_id', backref='arXiv_jref_controls') - user = relationship('TapirUser', primaryjoin='JrefControl.user_id == TapirUser.user_id', backref='arXiv_jref_controls') + user = relationship('TapirUser', primaryjoin='JrefControl.user_id == TapirUser.user_id', back_populates='arXiv_jref_controls') @@ -658,7 +716,7 @@ class Metadata(Base): document = relationship('Document', primaryjoin='Metadata.document_id == Document.document_id', backref='arXiv_metadata') arXiv_license = relationship('License', primaryjoin='Metadata.license == License.name', backref='arXiv_metadata') - submitter = relationship('TapirUser', primaryjoin='Metadata.submitter_id == TapirUser.user_id', backref='arXiv_metadata') + submitter = relationship('TapirUser', primaryjoin='Metadata.submitter_id == TapirUser.user_id', back_populates='arXiv_metadata') @@ -688,7 +746,7 @@ class ModeratorApiKey(Base): issued_to: Mapped[str] = mapped_column(String(16), nullable=False, server_default=FetchedValue()) remote_host: Mapped[str] = mapped_column(String(255), nullable=False, server_default=FetchedValue()) - user = relationship('TapirUser', primaryjoin='ModeratorApiKey.user_id == TapirUser.user_id', backref='arXiv_moderator_api_keys') + user = relationship('TapirUser', primaryjoin='ModeratorApiKey.user_id == TapirUser.user_id', back_populates='arXiv_moderator_api_keys') @@ -762,6 +820,12 @@ class OrcidConfig(Base): value: Mapped[Optional[str]] = mapped_column(String(150)) +t_arXiv_ownership_requests_papers = Table( + 'arXiv_ownership_requests_papers', metadata, + Column('request_id', ForeignKey('arXiv_ownership_requests.request_id'), nullable=False, server_default=FetchedValue()), + Column('document_id', ForeignKey('arXiv_documents.document_id'), nullable=False, index=True, server_default=FetchedValue()), + Index('request_id', 'request_id', 'document_id', unique=True) +) class OwnershipRequest(Base): __tablename__ = 'arXiv_ownership_requests' @@ -771,11 +835,12 @@ class OwnershipRequest(Base): endorsement_request_id: Mapped[Optional[int]] = mapped_column(ForeignKey('arXiv_endorsement_requests.request_id'), index=True) workflow_status: Mapped[Literal['pending', 'accepted', 'rejected']] = mapped_column(Enum('pending', 'accepted', 'rejected'), nullable=False, server_default=FetchedValue()) + request_audit = relationship('OwnershipRequestsAudit', back_populates='ownership_request', uselist=False) endorsement_request = relationship('EndorsementRequest', primaryjoin='OwnershipRequest.endorsement_request_id == EndorsementRequest.request_id', backref='arXiv_ownership_requests') - user = relationship('TapirUser', primaryjoin='OwnershipRequest.user_id == TapirUser.user_id', backref='arXiv_ownership_requests') + user = relationship('TapirUser', primaryjoin='OwnershipRequest.user_id == TapirUser.user_id', back_populates='arXiv_ownership_requests') + documents = relationship("Document", secondary=t_arXiv_ownership_requests_papers) - -class OwnershipRequestsAudit(OwnershipRequest): +class OwnershipRequestsAudit(Base): __tablename__ = 'arXiv_ownership_requests_audit' request_id: Mapped[int] = mapped_column(ForeignKey('arXiv_ownership_requests.request_id'), primary_key=True, server_default=FetchedValue()) @@ -785,31 +850,33 @@ class OwnershipRequestsAudit(OwnershipRequest): tracking_cookie: Mapped[str] = mapped_column(String(255), nullable=False, server_default=FetchedValue()) date: Mapped[int] = mapped_column(Integer, nullable=False, server_default=FetchedValue()) + ownership_request = relationship('OwnershipRequest', primaryjoin='OwnershipRequestsAudit.request_id == OwnershipRequest.request_id', back_populates='request_audit', uselist=False) -t_arXiv_ownership_requests_papers = Table( - 'arXiv_ownership_requests_papers', metadata, - Column('request_id', Integer, nullable=False, server_default=FetchedValue()), - Column('document_id', Integer, nullable=False, index=True, server_default=FetchedValue()), - Index('request_id', 'request_id', 'document_id') -) - +class PaperOwner(Base): + __tablename__ = 'arXiv_paper_owners' + __table_args__ = ( + ForeignKeyConstraint(['added_by'], ['tapir_users.user_id'], name='0_595'), + ForeignKeyConstraint(['document_id'], ['arXiv_documents.document_id'], name='0_593'), + ForeignKeyConstraint(['user_id'], ['tapir_users.user_id'], name='0_594'), + Index('added_by', 'added_by'), + # Index('document_id', 'document_id', 'user_id', unique=True), + # Index('user_id', 'user_id'), + ) + document_id: Mapped[int] = mapped_column(ForeignKey('arXiv_documents.document_id'), primary_key=True) + user_id: Mapped[int] = mapped_column(ForeignKey('tapir_users.user_id'), primary_key=True) + date: Mapped[int] = mapped_column(Integer, nullable=False, server_default=text("'0'")) + added_by: Mapped[int] = mapped_column(Integer, nullable=False, server_default=text("'0'")) + remote_addr: Mapped[str] = mapped_column(String(16), nullable=False, server_default=text("''")) + remote_host: Mapped[str] = mapped_column(String(255), nullable=False, server_default=text("''")) + tracking_cookie: Mapped[str] = mapped_column(String(32), nullable=False, server_default=text("''")) + valid: Mapped[int] = mapped_column(SmallInteger, nullable=False, server_default=text("'0'")) + flag_author: Mapped[int] = mapped_column(SmallInteger, nullable=False, server_default=text("'0'")) + flag_auto: Mapped[int] = mapped_column(SmallInteger, nullable=False, server_default=text("'1'")) -t_arXiv_paper_owners = Table( - 'arXiv_paper_owners', metadata, - Column('document_id', ForeignKey('arXiv_documents.document_id'), nullable=False, server_default=FetchedValue()), - Column('user_id', ForeignKey('tapir_users.user_id'), nullable=False, index=True, server_default=FetchedValue()), - Column('date', Integer, nullable=False, server_default=FetchedValue()), - Column('added_by', ForeignKey('tapir_users.user_id'), nullable=False, index=True, server_default=FetchedValue()), - Column('remote_addr', String(16), nullable=False, server_default=FetchedValue()), - Column('remote_host', String(255), nullable=False, server_default=FetchedValue()), - Column('tracking_cookie', String(32), nullable=False, server_default=FetchedValue()), - Column('valid', Integer, nullable=False, server_default=FetchedValue()), - Column('flag_author', Integer, nullable=False, server_default=FetchedValue()), - Column('flag_auto', Integer, nullable=False, server_default=FetchedValue()), - Index('owners_document_id', 'document_id', 'user_id') -) + document = relationship('Document', back_populates='owners') + owner = relationship('TapirUser', foreign_keys="[PaperOwner.user_id]", back_populates='owned_papers') class PaperSession(Base): @@ -883,7 +950,7 @@ class ShowEmailRequest(Base): request_id: Mapped[intpk] document = relationship('Document', primaryjoin='ShowEmailRequest.document_id == Document.document_id', backref='arXiv_show_email_requests') - user = relationship('TapirUser', primaryjoin='ShowEmailRequest.user_id == TapirUser.user_id', backref='arXiv_show_email_requests') + user = relationship('TapirUser', primaryjoin='ShowEmailRequest.user_id == TapirUser.user_id', back_populates='arXiv_show_email_requests') @@ -963,7 +1030,7 @@ class SubmissionCategoryProposal(Base): proposal_comment = relationship('AdminLog', primaryjoin='SubmissionCategoryProposal.proposal_comment_id == AdminLog.id', backref='arxivadminlog_arXiv_submission_category_proposals') response_comment = relationship('AdminLog', primaryjoin='SubmissionCategoryProposal.response_comment_id == AdminLog.id', backref='arxivadminlog_arXiv_submission_category_proposals_0') submission = relationship('Submission', primaryjoin='SubmissionCategoryProposal.submission_id == Submission.submission_id', backref='arXiv_submission_category_proposals') - user = relationship('TapirUser', primaryjoin='SubmissionCategoryProposal.user_id == TapirUser.user_id', backref='arXiv_submission_category_proposals') + user = relationship('TapirUser', primaryjoin='SubmissionCategoryProposal.user_id == TapirUser.user_id', back_populates='arXiv_submission_category_proposal') @@ -985,7 +1052,7 @@ class SubmissionControl(Base): publish_date: Mapped[int] = mapped_column(Integer, nullable=False, server_default=FetchedValue()) document = relationship('Document', primaryjoin='SubmissionControl.document_id == Document.document_id', backref='arXiv_submission_controls') - user = relationship('TapirUser', primaryjoin='SubmissionControl.user_id == TapirUser.user_id', backref='arXiv_submission_controls') + user = relationship('TapirUser', primaryjoin='SubmissionControl.user_id == TapirUser.user_id', back_populates='arXiv_submission_control') @@ -1002,7 +1069,7 @@ class SubmissionFlag(Base): updated: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=FetchedValue()) submission = relationship('Submission', primaryjoin='SubmissionFlag.submission_id == Submission.submission_id', backref='arXiv_submission_flags') - user = relationship('TapirUser', primaryjoin='SubmissionFlag.user_id == TapirUser.user_id', backref='arXiv_submission_flags') + user = relationship('TapirUser', primaryjoin='SubmissionFlag.user_id == TapirUser.user_id', back_populates='arXiv_submission_flag') @@ -1018,7 +1085,7 @@ class SubmissionHoldReason(Base): comment = relationship('AdminLog', primaryjoin='SubmissionHoldReason.comment_id == AdminLog.id', backref='arXiv_submission_hold_reasons') submission = relationship('Submission', primaryjoin='SubmissionHoldReason.submission_id == Submission.submission_id', backref='arXiv_submission_hold_reasons') - user = relationship('TapirUser', primaryjoin='SubmissionHoldReason.user_id == TapirUser.user_id', backref='arXiv_submission_hold_reasons') + user = relationship('TapirUser', primaryjoin='SubmissionHoldReason.user_id == TapirUser.user_id', back_populates='arXiv_submission_hold_reason') @@ -1061,7 +1128,7 @@ class SubmissionViewFlag(Base): updated: Mapped[Optional[datetime]] submission = relationship('Submission', primaryjoin='SubmissionViewFlag.submission_id == Submission.submission_id', backref='arXiv_submission_view_flags') - user = relationship('TapirUser', primaryjoin='SubmissionViewFlag.user_id == TapirUser.user_id', backref='arXiv_submission_view_flags') + user = relationship('TapirUser', primaryjoin='SubmissionViewFlag.user_id == TapirUser.user_id', back_populates='arXiv_submission_view_flag') @@ -1073,7 +1140,7 @@ class Submission(Base): doc_paper_id: Mapped[Optional[str]] = mapped_column(String(20), index=True) sword_id: Mapped[Optional[int]] = mapped_column(ForeignKey('arXiv_tracking.sword_id'), index=True) userinfo: Mapped[Optional[int]] = mapped_column(Integer, server_default=FetchedValue()) - is_author: Mapped[int] = mapped_column(Integer, nullable=False, server_default=FetchedValue()) + is_author: Mapped[int] = mapped_column(Integer, nullable=False, server_default=text("'0'")) agree_policy: Mapped[Optional[int]] = mapped_column(Integer, server_default=FetchedValue()) viewed: Mapped[Optional[int]] = mapped_column(Integer, server_default=FetchedValue()) stage: Mapped[Optional[int]] = mapped_column(Integer, server_default=FetchedValue()) @@ -1082,7 +1149,7 @@ class Submission(Base): submitter_email: Mapped[Optional[str]] = mapped_column(String(64)) created: Mapped[Optional[datetime]] updated: Mapped[Optional[datetime]] - status: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=FetchedValue()) + status: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=text("'0'")) sticky_status: Mapped[Optional[int]] = mapped_column(Integer) must_process: Mapped[Optional[int]] = mapped_column(Integer, server_default=FetchedValue()) submit_time: Mapped[Optional[datetime]] @@ -1091,7 +1158,7 @@ class Submission(Base): source_format: Mapped[Optional[str]] = mapped_column(String(12)) source_flags: Mapped[Optional[str]] = mapped_column(String(12)) has_pilot_data: Mapped[Optional[int]] = mapped_column(Integer) - is_withdrawn: Mapped[int] = mapped_column(Integer, nullable=False, server_default=FetchedValue()) + is_withdrawn: Mapped[int] = mapped_column(Integer, nullable=False, server_default=text("'0'")) title: Mapped[Optional[str]] = mapped_column(Text) authors: Mapped[Optional[str]] = mapped_column(Text) comments: Mapped[Optional[str]] = mapped_column(Text) @@ -1103,7 +1170,7 @@ class Submission(Base): doi: Mapped[Optional[str]] abstract: Mapped[Optional[str]] = mapped_column(Text) license: Mapped[Optional[str]] = mapped_column(ForeignKey('arXiv_licenses.name', onupdate='CASCADE'), index=True) - version: Mapped[int] = mapped_column(Integer, nullable=False, server_default=FetchedValue()) + version: Mapped[int] = mapped_column(Integer, nullable=False, server_default=text("'1'")) type: Mapped[Optional[str]] = mapped_column(String(8), index=True) is_ok: Mapped[Optional[int]] = mapped_column(Integer, index=True) admin_ok: Mapped[Optional[int]] = mapped_column(Integer) @@ -1114,13 +1181,13 @@ class Submission(Base): package: Mapped[str] = mapped_column(String(255), nullable=False, server_default=FetchedValue()) rt_ticket_id: Mapped[Optional[int]] = mapped_column(Integer, index=True) auto_hold: Mapped[Optional[int]] = mapped_column(Integer, server_default=FetchedValue()) - is_locked: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=FetchedValue()) + is_locked: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=text("'0'")) agreement_id = mapped_column(ForeignKey('arXiv_submission_agreements.agreement_id'), index=True) agreement = relationship('SubmissionAgreement', primaryjoin='Submission.agreement_id == SubmissionAgreement.agreement_id', backref='arXiv_submissions') document = relationship('Document', primaryjoin='Submission.document_id == Document.document_id', backref='arXiv_submissions') arXiv_license = relationship('License', primaryjoin='Submission.license == License.name', backref='arXiv_submissions') - submitter = relationship('TapirUser', primaryjoin='Submission.submitter_id == TapirUser.user_id', backref='arXiv_submissions') + submitter = relationship('TapirUser', primaryjoin='Submission.submitter_id == TapirUser.user_id', back_populates='arXiv_submissions') sword = relationship('Tracking', primaryjoin='Submission.sword_id == Tracking.sword_id', backref='arXiv_submissions') @@ -1137,17 +1204,15 @@ class PilotDataset(Submission): last_checked: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=FetchedValue()) -class SubmissionAbsClassifierDatum(Submission): +class SubmissionAbsClassifierDatum(Base): __tablename__ = 'arXiv_submission_abs_classifier_data' submission_id: Mapped[int] = mapped_column(ForeignKey('arXiv_submissions.submission_id', ondelete='CASCADE'), primary_key=True, server_default=FetchedValue()) json: Mapped[Optional[str]] = mapped_column(Text) last_update: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=FetchedValue()) - # status: Mapped[Optional[Literal['processing', 'success', 'failed', 'no connection']]] = mapped_column(Enum('processing', 'success', 'failed', 'no connection')) - # ^This column is inherited + status: Mapped[Optional[Literal['processing', 'success', 'failed', 'no connection']]] = mapped_column(Enum('processing', 'success', 'failed', 'no connection')) message: Mapped[Optional[str]] = mapped_column(Text) - # is_oversize: Mapped[Optional[int]] = mapped_column(Integer) - # ^This column is inherited + is_oversize: Mapped[Optional[int]] = mapped_column(Integer) suggested_primary: Mapped[Optional[str]] = mapped_column(Text) suggested_reason: Mapped[Optional[str]] = mapped_column(Text) autoproposal_primary: Mapped[Optional[str]] = mapped_column(Text) @@ -1156,17 +1221,15 @@ class SubmissionAbsClassifierDatum(Submission): classifier_model_version: Mapped[Optional[str]] = mapped_column(Text) -class SubmissionClassifierDatum(Submission): +class SubmissionClassifierDatum(Base): __tablename__ = 'arXiv_submission_classifier_data' submission_id: Mapped[int] = mapped_column(ForeignKey('arXiv_submissions.submission_id', ondelete='CASCADE'), primary_key=True, server_default=FetchedValue()) json: Mapped[Optional[str]] = mapped_column(Text) last_update: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=FetchedValue()) - # status: Mapped[Optional[Literal['processing', 'success', 'failed', 'no connection']]] = mapped_column(Enum('processing', 'success', 'failed', 'no connection')) - # ^This column is inherited + status: Mapped[Optional[Literal['processing', 'success', 'failed', 'no connection']]] = mapped_column(Enum('processing', 'success', 'failed', 'no connection')) message: Mapped[Optional[str]] = mapped_column(Text) - # is_oversize: Mapped[Optional[int]] = mapped_column(Integer) - # ^This column is inherited + is_oversize: Mapped[Optional[int]] = mapped_column(Integer) class SubmitterFlag(Base): @@ -1412,7 +1475,7 @@ class TapirAddress(Base): share_addr: Mapped[int] = mapped_column(Integer, nullable=False, server_default=FetchedValue()) tapir_country = relationship('TapirCountry', primaryjoin='TapirAddress.country == TapirCountry.digraph', backref='tapir_address') - user = relationship('TapirUser', primaryjoin='TapirAddress.user_id == TapirUser.user_id', backref='tapir_address') + user = relationship('TapirUser', primaryjoin='TapirAddress.user_id == TapirUser.user_id', back_populates='tapir_address') @@ -1431,8 +1494,8 @@ class TapirAdminAudit(Base): comment: Mapped[str] = mapped_column(Text, nullable=False) entry_id: Mapped[intpk] - tapir_user = relationship('TapirUser', primaryjoin='TapirAdminAudit.admin_user == TapirUser.user_id', backref='tapiruser_tapir_admin_audits') - tapir_user1 = relationship('TapirUser', primaryjoin='TapirAdminAudit.affected_user == TapirUser.user_id', backref='tapiruser_tapir_admin_audits_0') + tapir_users = relationship('TapirUser', foreign_keys=[admin_user], back_populates='tapir_admin_audit') + tapir_users_ = relationship('TapirUser', foreign_keys=[affected_user], back_populates='tapir_admin_audit_') session = relationship('TapirSession', primaryjoin='TapirAdminAudit.session_id == TapirSession.session_id', backref='tapir_admin_audits') @@ -1462,7 +1525,7 @@ class TapirEmailChangeToken(Base): consumed_when: Mapped[Optional[int]] = mapped_column(Integer) consumed_from: Mapped[Optional[str]] = mapped_column(String(16)) - user = relationship('TapirUser', primaryjoin='TapirEmailChangeToken.user_id == TapirUser.user_id', backref='tapir_email_change_tokens') + user = relationship('TapirUser', primaryjoin='TapirEmailChangeToken.user_id == TapirUser.user_id', back_populates='tapir_email_change_tokens') t_tapir_email_change_tokens_used = Table( @@ -1514,8 +1577,8 @@ class TapirEmailMailing(Base): mailing_name: Mapped[Optional[str]] comment: Mapped[Optional[str]] = mapped_column(Text) - tapir_user = relationship('TapirUser', primaryjoin='TapirEmailMailing.created_by == TapirUser.user_id', backref='tapiruser_tapir_email_mailings') - tapir_user1 = relationship('TapirUser', primaryjoin='TapirEmailMailing.sent_by == TapirUser.user_id', backref='tapiruser_tapir_email_mailings_0') + tapir_users = relationship('TapirUser', foreign_keys=[created_by], back_populates='tapir_email_mailings') + tapir_users_ = relationship('TapirUser', foreign_keys=[sent_by], back_populates='tapir_email_mailings_') template = relationship('TapirEmailTemplate', primaryjoin='TapirEmailMailing.template_id == TapirEmailTemplate.template_id', backref='tapir_email_mailings') @@ -1538,8 +1601,8 @@ class TapirEmailTemplate(Base): workflow_status: Mapped[int] = mapped_column(Integer, nullable=False, server_default=FetchedValue()) flag_system: Mapped[int] = mapped_column(Integer, nullable=False, server_default=FetchedValue()) - tapir_user = relationship('TapirUser', primaryjoin='TapirEmailTemplate.created_by == TapirUser.user_id', backref='tapiruser_tapir_email_templates') - tapir_user1 = relationship('TapirUser', primaryjoin='TapirEmailTemplate.updated_by == TapirUser.user_id', backref='tapiruser_tapir_email_templates_0') + tapir_users = relationship('TapirUser', foreign_keys=[created_by], back_populates='tapir_email_templates') + tapir_users_ = relationship('TapirUser', foreign_keys=[updated_by], back_populates='tapir_email_templates_') @@ -1555,7 +1618,7 @@ class TapirEmailToken(Base): tracking_cookie: Mapped[str] = mapped_column(String(255), nullable=False, server_default=FetchedValue()) wants_perm_token: Mapped[int] = mapped_column(Integer, nullable=False, server_default=FetchedValue()) - user = relationship('TapirUser', primaryjoin='TapirEmailToken.user_id == TapirUser.user_id', backref='tapir_email_tokens') + user = relationship('TapirUser', primaryjoin='TapirEmailToken.user_id == TapirUser.user_id', back_populates='tapir_email_tokens') @@ -1603,13 +1666,13 @@ class TapirNickname(Base): nick_id: Mapped[intpk] nickname: Mapped[str] = mapped_column(String(20), nullable=False, unique=True, server_default=FetchedValue()) user_id: Mapped[int] = mapped_column(ForeignKey('tapir_users.user_id'), nullable=False, server_default=FetchedValue()) - user_seq: Mapped[int] = mapped_column(Integer, nullable=False, server_default=FetchedValue()) - flag_valid: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=FetchedValue()) - role: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=FetchedValue()) - policy: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=FetchedValue()) - flag_primary: Mapped[int] = mapped_column(Integer, nullable=False, server_default=FetchedValue()) + user_seq: Mapped[int] = mapped_column(Integer, nullable=False, server_default=text("'0'")) + flag_valid: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=text("'0'")) + role: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=text("'0'")) + policy: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=text("'0'")) + flag_primary: Mapped[int] = mapped_column(Integer, nullable=False, server_default=text("'0'")) - user = relationship('TapirUser', primaryjoin='TapirNickname.user_id == TapirUser.user_id', backref='tapir_nicknames') + user = relationship('TapirUser', primaryjoin='TapirNickname.user_id == TapirUser.user_id', back_populates='tapir_nicknames') @@ -1656,7 +1719,7 @@ class TapirPermanentToken(Base): session_id: Mapped[int] = mapped_column(ForeignKey('tapir_sessions.session_id'), nullable=False, index=True, server_default=FetchedValue()) session = relationship('TapirSession', primaryjoin='TapirPermanentToken.session_id == TapirSession.session_id', backref='tapir_permanent_tokens') - user = relationship('TapirUser', primaryjoin='TapirPermanentToken.user_id == TapirUser.user_id', backref='tapir_permanent_tokens') + user = relationship('TapirUser', primaryjoin='TapirPermanentToken.user_id == TapirUser.user_id', back_populates='tapir_permanent_tokens') @@ -1680,13 +1743,22 @@ class TapirPhone(Base): phone_number: Mapped[Optional[str]] = mapped_column(String(32), index=True) share_phone: Mapped[int] = mapped_column(Integer, nullable=False, server_default=FetchedValue()) - user = relationship('TapirUser', primaryjoin='TapirPhone.user_id == TapirUser.user_id', backref='tapir_phones') + user = relationship('TapirUser', primaryjoin='TapirPhone.user_id == TapirUser.user_id', back_populates='tapir_phone') class TapirPolicyClass(Base): __tablename__ = 'tapir_policy_classes' + ADMIN = 1 + PUBLIC_USER = 2 + LEGACY_USER = 3 + POLICY_CLASSES = [ + {"name": "Administrator", "class_id": ADMIN, "description": "", "password_storage": 2, "recovery_policy": 3, "permanent_login": 1}, + {"name": "Public user", "class_id": PUBLIC_USER, "description": "", "password_storage": 2, "recovery_policy": 3, "permanent_login": 1}, + {"name": "Legacy user", "class_id": LEGACY_USER, "description": "", "password_storage": 2, "recovery_policy": 3, "permanent_login": 1}, + ] + class_id: Mapped[int] = mapped_column(SmallInteger, primary_key=True) name: Mapped[str] = mapped_column(String(64), nullable=False, server_default=FetchedValue()) description: Mapped[str] = mapped_column(Text, nullable=False) @@ -1694,7 +1766,7 @@ class TapirPolicyClass(Base): recovery_policy: Mapped[int] = mapped_column(Integer, nullable=False, server_default=FetchedValue()) permanent_login: Mapped[int] = mapped_column(Integer, nullable=False, server_default=FetchedValue()) - + tapir_users = relationship('TapirUser', back_populates='tapir_policy_classes') class TapirPresession(Base): __tablename__ = 'tapir_presessions' @@ -1719,7 +1791,7 @@ class TapirRecoveryToken(Base): remote_host: Mapped[str] = mapped_column(String(255), nullable=False, server_default=FetchedValue()) tracking_cookie: Mapped[str] = mapped_column(String(255), nullable=False, server_default=FetchedValue()) - user = relationship('TapirUser', primaryjoin='TapirRecoveryToken.user_id == TapirUser.user_id', backref='tapir_recovery_tokens') + user = relationship('TapirUser', primaryjoin='TapirRecoveryToken.user_id == TapirUser.user_id', back_populates='tapir_recovery_tokens') @@ -1734,7 +1806,7 @@ class TapirRecoveryTokensUsed(Base): session_id: Mapped[Optional[int]] = mapped_column(ForeignKey('tapir_sessions.session_id'), index=True) session = relationship('TapirSession', primaryjoin='TapirRecoveryTokensUsed.session_id == TapirSession.session_id', backref='tapir_recovery_tokens_useds') - user = relationship('TapirUser', primaryjoin='TapirRecoveryTokensUsed.user_id == TapirUser.user_id', backref='tapir_recovery_tokens_useds') + user = relationship('TapirUser', primaryjoin='TapirRecoveryTokensUsed.user_id == TapirUser.user_id', back_populates='tapir_recovery_tokens_used') @@ -1750,23 +1822,25 @@ class TapirRecoveryTokensUsed(Base): class TapirSession(Base): __tablename__ = 'tapir_sessions' - session_id: Mapped[intpk] - user_id: Mapped[int] = mapped_column(ForeignKey('tapir_users.user_id'), nullable=False, index=True, server_default=FetchedValue()) - last_reissue: Mapped[int] = mapped_column(Integer, nullable=False, server_default=FetchedValue()) - start_time: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=FetchedValue()) - end_time: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=FetchedValue()) + session_id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + user_id: Mapped[int] = mapped_column(ForeignKey('tapir_users.user_id'), nullable=False, index=True, server_default=text("'0'")) + last_reissue: Mapped[int] = mapped_column(Integer, nullable=False, server_default=text("'0'")) + start_time: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=text("'0'")) + end_time: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=text("'0'")) - user = relationship('TapirUser', primaryjoin='TapirSession.user_id == TapirUser.user_id', backref='tapir_sessions') + user = relationship('TapirUser', primaryjoin='TapirSession.user_id == TapirUser.user_id', back_populates='tapir_sessions') -class TapirSessionsAudit(TapirSession): +class TapirSessionsAudit(Base): __tablename__ = 'tapir_sessions_audit' - session_id: Mapped[int] = mapped_column(ForeignKey('tapir_sessions.session_id'), primary_key=True, server_default=FetchedValue()) + session_id: Mapped[int] = mapped_column(ForeignKey('tapir_sessions.session_id'), primary_key=True, server_default=text("'0'"), autoincrement="false") ip_addr: Mapped[str] = mapped_column(String(16), nullable=False, index=True, server_default=FetchedValue()) remote_host: Mapped[str] = mapped_column(String(255), nullable=False, server_default=FetchedValue()) tracking_cookie: Mapped[str] = mapped_column(String(255), nullable=False, index=True, server_default=FetchedValue()) + session = relationship('TapirSession') + class TapirStringVariable(Base): @@ -1789,37 +1863,92 @@ class TapirString(Base): class TapirUser(Base): __tablename__ = 'tapir_users' + __table_args__ = ( + ForeignKeyConstraint(['policy_class'], ['tapir_policy_classes.class_id'], name='0_510'), + Index('email', 'email', unique=True), + Index('first_name', 'first_name'), + Index('flag_approved', 'flag_approved'), + Index('flag_banned', 'flag_banned'), + Index('flag_can_lock', 'flag_can_lock'), + Index('flag_deleted', 'flag_deleted'), + Index('flag_edit_users', 'flag_edit_users'), + Index('flag_internal', 'flag_internal'), + Index('joined_date', 'joined_date'), + Index('joined_ip_num', 'joined_ip_num'), + Index('last_name', 'last_name'), + Index('policy_class', 'policy_class'), + Index('tracking_cookie', 'tracking_cookie') + ) user_id: Mapped[intpk] first_name: Mapped[Optional[str]] = mapped_column(String(50), index=True) last_name: Mapped[Optional[str]] = mapped_column(String(50), index=True) suffix_name: Mapped[Optional[str]] = mapped_column(String(50)) - share_first_name: Mapped[int] = mapped_column(Integer, nullable=False, server_default=FetchedValue()) - share_last_name: Mapped[int] = mapped_column(Integer, nullable=False, server_default=FetchedValue()) - email: Mapped[str] = mapped_column(String(255), nullable=False, unique=True, server_default=FetchedValue()) - share_email: Mapped[int] = mapped_column(Integer, nullable=False, server_default=FetchedValue()) - email_bouncing: Mapped[int] = mapped_column(Integer, nullable=False, server_default=FetchedValue()) - policy_class: Mapped[int] = mapped_column(ForeignKey('tapir_policy_classes.class_id'), nullable=False, index=True, server_default=FetchedValue()) - joined_date: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=FetchedValue()) + share_first_name: Mapped[int] = mapped_column(Integer, nullable=False, server_default=text("'1'")) + share_last_name: Mapped[int] = mapped_column(Integer, nullable=False, server_default=text("'1'")) + email: Mapped[str] = mapped_column(String(255), nullable=False, unique=True, server_default=text("''")) + share_email: Mapped[int] = mapped_column(Integer, nullable=False, server_default=text("'8'")) + email_bouncing: Mapped[int] = mapped_column(Integer, nullable=False, server_default=text("'0'")) + policy_class: Mapped[int] = mapped_column(ForeignKey('tapir_policy_classes.class_id'), nullable=False, index=True, server_default=text("'0'")) + joined_date: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=text("'0'")) joined_ip_num: Mapped[Optional[str]] = mapped_column(String(16), index=True) - joined_remote_host: Mapped[str] = mapped_column(String(255), nullable=False, server_default=FetchedValue()) - flag_internal: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=FetchedValue()) - flag_edit_users: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=FetchedValue()) - flag_edit_system: Mapped[int] = mapped_column(Integer, nullable=False, server_default=FetchedValue()) - flag_email_verified: Mapped[int] = mapped_column(Integer, nullable=False, server_default=FetchedValue()) - flag_approved: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=FetchedValue()) - flag_deleted: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=FetchedValue()) - flag_banned: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=FetchedValue()) - flag_wants_email: Mapped[int] = mapped_column(Integer, nullable=False, server_default=FetchedValue()) - flag_html_email: Mapped[int] = mapped_column(Integer, nullable=False, server_default=FetchedValue()) - tracking_cookie: Mapped[str] = mapped_column(String(255), nullable=False, index=True, server_default=FetchedValue()) - flag_allow_tex_produced: Mapped[int] = mapped_column(Integer, nullable=False, server_default=FetchedValue()) - flag_can_lock: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=FetchedValue()) - - tapir_policy_class = relationship('TapirPolicyClass', primaryjoin='TapirUser.policy_class == TapirPolicyClass.class_id', backref='tapir_users') - - -class AuthorIds(TapirUser): + joined_remote_host: Mapped[str] = mapped_column(String(255), nullable=False, server_default=text("''")) + flag_internal: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=text("'0'")) + flag_edit_users: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=text("'0'")) + flag_edit_system: Mapped[int] = mapped_column(Integer, nullable=False, server_default=text("'0'")) + flag_email_verified: Mapped[int] = mapped_column(Integer, nullable=False, server_default=text("'0'")) + flag_approved: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=text("'1'")) + flag_deleted: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=text("'0'")) + flag_banned: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=text("'0'")) + flag_wants_email: Mapped[int] = mapped_column(Integer, nullable=False, server_default=text("'0'")) + flag_html_email: Mapped[int] = mapped_column(Integer, nullable=False, server_default=text("'0'")) + tracking_cookie: Mapped[str] = mapped_column(String(255), nullable=False, index=True, server_default=text("''")) + flag_allow_tex_produced: Mapped[int] = mapped_column(Integer, nullable=False, server_default=text("'0'")) + flag_can_lock: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=text("'0'")) + + tapir_policy_classes = relationship('TapirPolicyClass', back_populates='tapir_users') + arXiv_control_holds = relationship('ControlHold', foreign_keys='[ControlHold.last_changed_by]', back_populates='tapir_users') + arXiv_control_holds_ = relationship('ControlHold', foreign_keys='[ControlHold.placed_by]', back_populates='tapir_users_') + arXiv_documents = relationship('Document', back_populates='submitter') + arXiv_moderator_api_keys = relationship('ModeratorApiKey', back_populates='user') + tapir_address = relationship('TapirAddress', back_populates='user') + tapir_email_change_tokens = relationship('TapirEmailChangeToken', back_populates='user') + tapir_email_templates = relationship('TapirEmailTemplate', foreign_keys='[TapirEmailTemplate.created_by]', back_populates='tapir_users') + tapir_email_templates_ = relationship('TapirEmailTemplate', foreign_keys='[TapirEmailTemplate.updated_by]', back_populates='tapir_users_') + tapir_email_tokens = relationship('TapirEmailToken', back_populates='user') + tapir_nicknames = relationship('TapirNickname', back_populates='user', uselist=False) + tapir_phone = relationship('TapirPhone', back_populates='user') + tapir_recovery_tokens = relationship('TapirRecoveryToken', back_populates='user') + tapir_sessions = relationship('TapirSession', back_populates='user') + arXiv_cross_controls = relationship('CrossControl', back_populates='user') + arXiv_endorsement_requests = relationship('EndorsementRequest', back_populates='endorsee') + arXiv_jref_controls = relationship('JrefControl', back_populates='user') + arXiv_metadata = relationship('Metadata', back_populates='submitter') + arXiv_show_email_requests = relationship('ShowEmailRequest', back_populates='user') + arXiv_submission_control = relationship('SubmissionControl', back_populates='user') + arXiv_submissions = relationship('Submission', back_populates='submitter') + tapir_admin_audit = relationship('TapirAdminAudit', foreign_keys='[TapirAdminAudit.admin_user]', back_populates='tapir_users') + tapir_admin_audit_ = relationship('TapirAdminAudit', foreign_keys='[TapirAdminAudit.affected_user]', back_populates='tapir_users_') + tapir_email_mailings = relationship('TapirEmailMailing', foreign_keys='[TapirEmailMailing.created_by]', back_populates='tapir_users') + tapir_email_mailings_ = relationship('TapirEmailMailing', foreign_keys='[TapirEmailMailing.sent_by]', back_populates='tapir_users_') + tapir_permanent_tokens = relationship('TapirPermanentToken', back_populates='user') + tapir_recovery_tokens_used = relationship('TapirRecoveryTokensUsed', back_populates='user') + + endorsee_of = relationship('Endorsement', foreign_keys='[Endorsement.endorsee_id]', back_populates='endorsee') + endorses = relationship('Endorsement', foreign_keys='[Endorsement.endorser_id]', back_populates='endorser') + + arXiv_ownership_requests = relationship('OwnershipRequest', back_populates='user') + arXiv_submission_category_proposal = relationship('SubmissionCategoryProposal', back_populates='user') + arXiv_submission_flag = relationship('SubmissionFlag', back_populates='user') + arXiv_submission_hold_reason = relationship('SubmissionHoldReason', back_populates='user') + arXiv_submission_view_flag = relationship('SubmissionViewFlag', back_populates='user') + + owned_papers = relationship("PaperOwner", foreign_keys="[PaperOwner.user_id]", back_populates="owner") + + demographics = relationship('Demographic', foreign_keys="[Demographic.user_id]", uselist=False, back_populates='user') + + +class AuthorIds(Base): __tablename__ = 'arXiv_author_ids' user_id: Mapped[int] = mapped_column(ForeignKey('tapir_users.user_id'), primary_key=True) @@ -1827,7 +1956,7 @@ class AuthorIds(TapirUser): updated: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=FetchedValue()) -class Demographic(TapirUser): +class Demographic(Base): __tablename__ = 'arXiv_demographics' __table_args__ = ( ForeignKeyConstraint(['archive', 'subject_class'], ['arXiv_categories.archive', 'arXiv_categories.subject_class']), @@ -1843,26 +1972,44 @@ class Demographic(TapirUser): subject_class: Mapped[Optional[str]] = mapped_column(String(16)) original_subject_classes: Mapped[str] = mapped_column(String(255), nullable=False, server_default=FetchedValue()) flag_group_physics: Mapped[Optional[int]] = mapped_column(Integer, index=True) - flag_group_math: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=FetchedValue()) - flag_group_cs: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=FetchedValue()) - flag_group_nlin: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=FetchedValue()) - flag_proxy: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=FetchedValue()) - flag_journal: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=FetchedValue()) - flag_xml: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=FetchedValue()) - dirty: Mapped[int] = mapped_column(Integer, nullable=False, server_default=FetchedValue()) - flag_group_test: Mapped[int] = mapped_column(Integer, nullable=False, server_default=FetchedValue()) - flag_suspect: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=FetchedValue()) - flag_group_q_bio: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=FetchedValue()) - flag_group_q_fin: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=FetchedValue()) - flag_group_stat: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=FetchedValue()) - flag_group_eess: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=FetchedValue()) - flag_group_econ: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=FetchedValue()) - veto_status: Mapped[Literal['ok', 'no-endorse', 'no-upload', 'no-replace']] = mapped_column(Enum('ok', 'no-endorse', 'no-upload', 'no-replace'), nullable=False, server_default=FetchedValue()) - + flag_group_math: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=text("'0'")) + flag_group_cs: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=text("'0'")) + flag_group_nlin: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=text("'0'")) + flag_proxy: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=text("'0'")) + flag_journal: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=text("'0'")) + flag_xml: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=text("'0'")) + dirty: Mapped[int] = mapped_column(Integer, nullable=False, server_default=text("'0'")) + flag_group_test: Mapped[int] = mapped_column(Integer, nullable=False, server_default=text("'0'")) + flag_suspect: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=text("'0'")) + flag_group_q_bio: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=text("'0'")) + flag_group_q_fin: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=text("'0'")) + flag_group_stat: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=text("'0'")) + flag_group_eess: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=text("'0'")) + flag_group_econ: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=text("'0'")) + veto_status: Mapped[Literal['ok', 'no-endorse', 'no-upload', 'no-replace']] = mapped_column(Enum('ok', 'no-endorse', 'no-upload', 'no-replace'), nullable=False, server_default=text("'ok'")) + + user = relationship('TapirUser', back_populates='demographics') arXiv_category = relationship('Category', primaryjoin='and_(Demographic.archive == Category.archive, Demographic.subject_class == Category.subject_class)', backref='arXiv_demographics') + GROUP_FLAGS = [ + ('grp_physics', 'flag_group_physics'), + ('grp_math', 'flag_group_math'), + ('grp_cs', 'flag_group_cs'), + ('grp_q-bio', 'flag_group_q_bio'), + ('grp_q-fin', 'flag_group_q_fin'), + ('grp_q-stat', 'flag_group_stat'), + ('grp_q-econ', 'flag_group_econ'), + ('grp_eess', 'flag_group_eess'), + ] -class OrcidIds(TapirUser): + @property + def groups(self) -> List[str]: + """Active groups for this user profile.""" + return [group for group, column in self.GROUP_FLAGS + if getattr(self, column) == 1] + + +class OrcidIds(Base): __tablename__ = 'arXiv_orcid_ids' user_id: Mapped[int] = mapped_column(ForeignKey('tapir_users.user_id'), primary_key=True) @@ -1871,7 +2018,7 @@ class OrcidIds(TapirUser): updated: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=FetchedValue()) -class QueueView(TapirUser): +class QueueView(Base): __tablename__ = 'arXiv_queue_view' user_id: Mapped[int] = mapped_column(ForeignKey('tapir_users.user_id', ondelete='CASCADE'), primary_key=True, server_default=FetchedValue()) @@ -1880,14 +2027,14 @@ class QueueView(TapirUser): total_views: Mapped[int] = mapped_column(Integer, nullable=False, server_default=FetchedValue()) -class SuspiciousName(TapirUser): +class SuspiciousName(Base): __tablename__ = 'arXiv_suspicious_names' user_id: Mapped[int] = mapped_column(ForeignKey('tapir_users.user_id'), primary_key=True, server_default=FetchedValue()) full_name: Mapped[str] = mapped_column(String(255), nullable=False, server_default=FetchedValue()) -class SwordLicense(TapirUser): +class SwordLicense(Base): __tablename__ = 'arXiv_sword_licenses' user_id: Mapped[int] = mapped_column(ForeignKey('tapir_users.user_id'), primary_key=True) @@ -1895,7 +2042,7 @@ class SwordLicense(TapirUser): updated: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=FetchedValue()) -class TapirDemographic(TapirUser): +class TapirDemographic(Base): __tablename__ = 'tapir_demographics' user_id: Mapped[int] = mapped_column(ForeignKey('tapir_users.user_id'), primary_key=True, server_default=FetchedValue()) @@ -1910,7 +2057,7 @@ class TapirDemographic(TapirUser): tapir_country = relationship('TapirCountry', primaryjoin='TapirDemographic.country == TapirCountry.digraph', backref='tapir_demographics') -class TapirUsersHot(TapirUser): +class TapirUsersHot(Base): __tablename__ = 'tapir_users_hot' user_id: Mapped[int] = mapped_column(ForeignKey('tapir_users.user_id'), primary_key=True, server_default=FetchedValue()) @@ -1919,13 +2066,16 @@ class TapirUsersHot(TapirUser): number_sessions: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=FetchedValue()) -class TapirUsersPassword(TapirUser): +class TapirUsersPassword(Base): __tablename__ = 'tapir_users_password' user_id: Mapped[int] = mapped_column(ForeignKey('tapir_users.user_id'), primary_key=True, server_default=FetchedValue()) password_storage: Mapped[int] = mapped_column(Integer, nullable=False, server_default=FetchedValue()) password_enc: Mapped[str] = mapped_column(String(50), nullable=False, server_default=FetchedValue()) + user = relationship('TapirUser') + + class DBLaTeXMLDocuments(LaTeXMLBase): __tablename__ = 'arXiv_latexml_doc' @@ -1970,31 +2120,3 @@ class DBLaTeXMLFeedback (LaTeXMLBase): selected_html: Mapped[Optional[str]] initiation_mode: Mapped[Optional[str]] -SessionLocal.configure(binds={ - Base: engine, - LaTeXMLBase: latexml_engine, - t_arXiv_stats_hourly: engine, - t_arXiv_admin_state: engine, - t_arXiv_bad_pw: engine, - t_arXiv_black_email: engine, - t_arXiv_block_email: engine, - t_arXiv_bogus_subject_class: engine, - t_arXiv_duplicates: engine, - t_arXiv_in_category: engine, - t_arXiv_moderators: engine, - t_arXiv_ownership_requests_papers: engine, - t_arXiv_refresh_list: engine, - t_arXiv_paper_owners: engine, - t_arXiv_updates_tmp: engine, - t_arXiv_white_email: engine, - t_arXiv_xml_notifications: engine, - t_demographics_backup: engine, - t_tapir_email_change_tokens_used: engine, - t_tapir_email_tokens_used: engine, - t_tapir_error_log: engine, - t_tapir_no_cookies: engine, - t_tapir_periodic_tasks_log: engine, - t_tapir_periodic_tasks_log: engine, - t_tapir_permanent_tokens_used: engine, - t_tapir_save_post_variables: engine -}) \ No newline at end of file diff --git a/arxiv/dev_environment/build.sh b/arxiv/dev_environment/build.sh deleted file mode 100755 index 6b0f3397..00000000 --- a/arxiv/dev_environment/build.sh +++ /dev/null @@ -1,3 +0,0 @@ -docker build -t classic_db_img -f Dockerfile.classic_db -terraform init -terraform plan \ No newline at end of file diff --git a/arxiv/dev_environment/classic_db/Dockerfile b/arxiv/dev_environment/classic_db/Dockerfile deleted file mode 100644 index 63b5d17a..00000000 --- a/arxiv/dev_environment/classic_db/Dockerfile +++ /dev/null @@ -1,17 +0,0 @@ -FROM mysql:8.0.36-debian - -######## Install Python 3.11 ######## - -RUN apt install software-properties-common -y -RUN add-apt-repository ppa:deadsnakes/ppa -RUN apt-get update -y -RUN apt install python3.11 python3-pip -y - -###### Install Cloud SQL Proxy ###### - -WORKDIR /cloudsql - -RUN apt-get install -y curl - -RUN curl -o cloud-sql-proxy https://storage.googleapis.com/cloud-sql-connectors/cloud-sql-proxy/v2.10.1/cloud-sql-proxy.linux.amd64 -RUN chmod +x cloud-sql-proxy diff --git a/arxiv/dev_environment/classic_db/wrapper.sh b/arxiv/dev_environment/classic_db/wrapper.sh deleted file mode 100644 index d6660106..00000000 --- a/arxiv/dev_environment/classic_db/wrapper.sh +++ /dev/null @@ -1,5 +0,0 @@ -#!/bin/sh -/cloudsql/cloud-sql-proxy arxiv-development:us-central1:latexml-db -u /cloudsql & -/cloudsql/cloud-sql-proxy arxiv-development:us-east4:arxiv-db-dev -u /cloudsql & - -gunicorn --bind 0.0.0.0:8000 -t 600 -w 12 --threads 2 entry_point:app diff --git a/arxiv/dev_environment/main.tf b/arxiv/dev_environment/main.tf deleted file mode 100644 index bf98a41e..00000000 --- a/arxiv/dev_environment/main.tf +++ /dev/null @@ -1,29 +0,0 @@ -terraform { - required_providers { - google = { - source = "hashicorp/google" - version = "4.51.0" - } - } -} - -provider "google" { - project = "arxiv-development" -} - -provider "docker" { - -} - -data "google_secret_manager_secret_version" "CLASSIC_DB_URI" { - secret = "browse-sqlalchemy-db-uri" - version = "latest" -} - -resource "docker_container" "classic_db" { - image = "classic_db_img" - name = "classic_db" - env = [ - "CLASSIC_DB_URI=${data.google_secret_manager_secret_version.CLASSIC_DB_URI.secret_data}" - ] -} \ No newline at end of file diff --git a/arxiv/taxonomy/__init__.py b/arxiv/taxonomy/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/poetry.lock b/poetry.lock index e827c138..5cafd073 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] name = "alabaster" @@ -321,6 +321,17 @@ files = [ [package.extras] toml = ["tomli"] +[[package]] +name = "decorator" +version = "5.1.1" +description = "Decorators for Humans" +optional = false +python-versions = ">=3.5" +files = [ + {file = "decorator-5.1.1-py3-none-any.whl", hash = "sha256:b8c3f85900b9dc423225913c5aace94729fe1fa9763b38939a95226f02d37186"}, + {file = "decorator-5.1.1.tar.gz", hash = "sha256:637996211036b6385ef91435e4fae22989472f9d571faba8927ba8253acbc330"}, +] + [[package]] name = "docutils" version = "0.21.2" @@ -332,6 +343,25 @@ files = [ {file = "docutils-0.21.2.tar.gz", hash = "sha256:3a6b18732edf182daa3cd12775bbb338cf5691468f91eeeb109deff6ebfa986f"}, ] +[[package]] +name = "fakeredis" +version = "2.7.1" +description = "Fake implementation of redis API for testing purposes." +optional = false +python-versions = ">=3.8.1,<4.0" +files = [ + {file = "fakeredis-2.7.1-py3-none-any.whl", hash = "sha256:e2f7a88dad23be1191ad6212008e170d75c2d63dde979c2694be8cbfd917428e"}, + {file = "fakeredis-2.7.1.tar.gz", hash = "sha256:854d758794dab9953be16b9a0f7fbd4bbd6b6964db7a9684e163291c1342ece6"}, +] + +[package.dependencies] +redis = "<4.5" +sortedcontainers = ">=2.4.0,<3.0.0" + +[package.extras] +json = ["jsonpath-ng (>=1.5,<2.0)"] +lua = ["lupa (>=1.14,<2.0)"] + [[package]] name = "fastly" version = "5.2.0" @@ -1029,6 +1059,21 @@ files = [ {file = "MarkupSafe-2.1.5.tar.gz", hash = "sha256:d283d37a890ba4c1ae73ffadf8046435c76e7bc2247bbb63c00bd1a709c6544b"}, ] +[[package]] +name = "mimesis" +version = "17.0.0" +description = "Mimesis: Fake Data Generator." +optional = false +python-versions = "<4.0,>=3.10" +files = [ + {file = "mimesis-17.0.0-py3-none-any.whl", hash = "sha256:a088a14075c6d0356fea15e7687afb8f900a412daa80f2f3fe20f1771532402f"}, + {file = "mimesis-17.0.0.tar.gz", hash = "sha256:57fd7c2762c668054f2ebb85f71d484fa2fd55647d2a02bac4f2e97b81f22d8d"}, +] + +[package.extras] +factory = ["factory-boy (>=3.3.0,<4.0.0)"] +pytest = ["pytest (>=7.2,<8.0)"] + [[package]] name = "mypy" version = "1.10.0" @@ -1167,6 +1212,17 @@ files = [ {file = "protobuf-4.25.3.tar.gz", hash = "sha256:25b5d0b42fd000320bd7830b349e3b696435f3b329810427a6bcce6a5492cc5c"}, ] +[[package]] +name = "py" +version = "1.11.0" +description = "library with cross-python path, ini-parsing, io, code, log facilities" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" +files = [ + {file = "py-1.11.0-py2.py3-none-any.whl", hash = "sha256:607c53218732647dff4acdfcd50cb62615cedf612e72d1724fb1a0cc6405b378"}, + {file = "py-1.11.0.tar.gz", hash = "sha256:51c75c4126074b472f746a24399ad32f6053d1b34b68d2fa41e558e6f4a98719"}, +] + [[package]] name = "pyasn1" version = "0.6.0" @@ -1275,6 +1331,23 @@ files = [ [package.extras] windows-terminal = ["colorama (>=0.4.6)"] +[[package]] +name = "pyjwt" +version = "2.8.0" +description = "JSON Web Token implementation in Python" +optional = false +python-versions = ">=3.7" +files = [ + {file = "PyJWT-2.8.0-py3-none-any.whl", hash = "sha256:59127c392cc44c2da5bb3192169a91f429924e17aff6534d70fdc02ab3e04320"}, + {file = "PyJWT-2.8.0.tar.gz", hash = "sha256:57e28d156e3d5c10088e0c68abb90bfac3df82b40a71bd0daa20c65ccd5c23de"}, +] + +[package.extras] +crypto = ["cryptography (>=3.4.0)"] +dev = ["coverage[toml] (==5.0.4)", "cryptography (>=3.4.0)", "pre-commit", "pytest (>=6.0.0,<7.0.0)", "sphinx (>=4.5.0,<5.0.0)", "sphinx-rtd-theme", "zope.interface"] +docs = ["sphinx (>=4.5.0,<5.0.0)", "sphinx-rtd-theme", "zope.interface"] +tests = ["coverage[toml] (==5.0.4)", "pytest (>=6.0.0,<7.0.0)"] + [[package]] name = "pytest" version = "8.2.1" @@ -1313,6 +1386,23 @@ pytest = ">=4.6" [package.extras] testing = ["fields", "hunter", "process-tests", "pytest-xdist", "virtualenv"] +[[package]] +name = "pytest-mock" +version = "3.14.0" +description = "Thin-wrapper around the mock package for easier use with pytest" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pytest-mock-3.14.0.tar.gz", hash = "sha256:2719255a1efeceadbc056d6bf3df3d1c5015530fb40cf347c0f9afac88410bd0"}, + {file = "pytest_mock-3.14.0-py3-none-any.whl", hash = "sha256:0b72c38033392a5f4621342fe11e9219ac11ec9d375f8e2a0c164539e0d70f6f"}, +] + +[package.dependencies] +pytest = ">=6.2.5" + +[package.extras] +dev = ["pre-commit", "pytest-asyncio", "tox"] + [[package]] name = "python-dateutil" version = "2.9.0.post0" @@ -1338,6 +1428,30 @@ files = [ {file = "pytz-2024.1.tar.gz", hash = "sha256:2a29735ea9c18baf14b448846bde5a48030ed267578472d8955cd0e7443a9812"}, ] +[[package]] +name = "redis" +version = "2.10.6" +description = "Python client for Redis key-value store" +optional = false +python-versions = "*" +files = [ + {file = "redis-2.10.6-py2.py3-none-any.whl", hash = "sha256:8a1900a9f2a0a44ecf6e8b5eb3e967a9909dfed219ad66df094f27f7d6f330fb"}, + {file = "redis-2.10.6.tar.gz", hash = "sha256:a22ca993cea2962dbb588f9f30d0015ac4afcc45bee27d3978c0dbe9e97c6c0f"}, +] + +[[package]] +name = "redis-py-cluster" +version = "1.3.6" +description = "Library for communicating with Redis Clusters. Built on top of redis-py lib" +optional = false +python-versions = "*" +files = [ + {file = "redis-py-cluster-1.3.6.tar.gz", hash = "sha256:7db54b1de60bd34da3806676b112f07fc9afae556d8260ac02c3335d574ee42c"}, +] + +[package.dependencies] +redis = "2.10.6" + [[package]] name = "requests" version = "2.32.0" @@ -1359,6 +1473,21 @@ urllib3 = ">=1.21.1,<3" socks = ["PySocks (>=1.5.6,!=1.5.7)"] use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] +[[package]] +name = "retry" +version = "0.9.2" +description = "Easy to use retry decorator." +optional = false +python-versions = "*" +files = [ + {file = "retry-0.9.2-py2.py3-none-any.whl", hash = "sha256:ccddf89761fa2c726ab29391837d4327f819ea14d244c232a1d24c67a2f98606"}, + {file = "retry-0.9.2.tar.gz", hash = "sha256:f8bfa8b99b69c4506d6f5bd3b0aabf77f98cdb17f3c9fc3f5ca820033336fba4"}, +] + +[package.dependencies] +decorator = ">=3.4.2" +py = ">=1.4.26,<2.0.0" + [[package]] name = "rsa" version = "4.9" @@ -1390,6 +1519,21 @@ botocore = ">=1.33.2,<2.0a.0" [package.extras] crt = ["botocore[crt] (>=1.33.2,<2.0a.0)"] +[[package]] +name = "setuptools" +version = "70.1.0" +description = "Easily download, build, install, upgrade, and uninstall Python packages" +optional = false +python-versions = ">=3.8" +files = [ + {file = "setuptools-70.1.0-py3-none-any.whl", hash = "sha256:d9b8b771455a97c8a9f3ab3448ebe0b29b5e105f1228bba41028be116985a267"}, + {file = "setuptools-70.1.0.tar.gz", hash = "sha256:01a1e793faa5bd89abc851fa15d0a0db26f160890c7102cd8dce643e886b47f5"}, +] + +[package.extras] +docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "pyproject-hooks (!=1.1)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier"] +testing = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "importlib-metadata", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "jaraco.test", "mypy (==1.10.0)", "packaging (>=23.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.1)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-home (>=0.5)", "pytest-mypy", "pytest-perf", "pytest-ruff (>=0.3.2)", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"] + [[package]] name = "six" version = "1.16.0" @@ -1782,4 +1926,4 @@ sphinx = ["sphinx", "sphinx-autodoc-typehints", "sphinxcontrib-websupport"] [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "7c3918d253eb10e738c8cf47af4235054785932dbeaa4c4447bc374ad10b5962" +content-hash = "c355759252de731ba6f40ffceb152166a7873941532338ee6ba6aa4eb66594cc" diff --git a/pyproject.toml b/pyproject.toml index 5fa8f35f..7fc0b850 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,13 +37,21 @@ validators = "*" sphinx = { version = "*", optional = true } sphinxcontrib-websupport = { version = "*", optional = true } sphinx-autodoc-typehints = { version = "*", optional = true } +mimesis = "*" +retry = "^0.9.2" +pyjwt = "*" +redis = "==2.10.6" +redis-py-cluster = "==1.3.6" +setuptools = "^70.0.0" [tool.poetry.dev-dependencies] pydocstyle = "*" mypy = "*" pytest = "*" +pytest-mock = "^3.8.2" pytest-cov = "*" hypothesis = "*" +fakeredis = "*" click = "*" diff --git a/test.db-journal b/test.db-journal new file mode 100644 index 00000000..526187cd Binary files /dev/null and b/test.db-journal differ