diff --git a/arxiv/auth/README.md b/arxiv/auth/README.md
new file mode 100644
index 00000000..56698f3b
--- /dev/null
+++ b/arxiv/auth/README.md
@@ -0,0 +1,57 @@
+# ``arxiv-auth`` Library
+
+This package provides a Flask add on and other code for working with arxiv
+authenticated users in arXiv services.
+
+Housing these components in arxiv-base ensures that users
+and sessions are represented and manipulated consistently. The login+logout, user
+accounts(TBD), API client registry(TBD), and authenticator(TBD) services all
+rely on this package.
+
+# Quick start
+For use-cases to check if a request is from an authenticated arxiv user, do the
+following:
+
+1. Add arxiv-base to your dependencies
+2. Install :class:`arxiv.auth.auth.Auth` onto your application. This adds a
+ function called for each request to Flask that adds an instance of
+ :class:`arxiv.auth.domain.Session` at ``flask.request.auth`` if the client is
+ authenticated.
+3. Add to the ``flask.config`` to setup :class:`arxiv_auth.auth.Auth` and
+ related classes
+
+Here's an example of how you might do #2 and #3:
+```
+ from flask import Flask
+ from arxiv.base import Base
+ from arxiv.auth.auth import auth
+
+ app = Flask(__name__)
+ Base(app)
+
+ # config settings required to use legacy auth
+ app.config['CLASSIC_SESSION_HASH'] = '{hash_private_secret}'
+ app.config['CLASSIC_DB_URI'] = '{your_sqlalchemy_db_uri_to_legacy_db'}
+ app.config['SESSION_DURATION'] = 36000
+ app.config['CLASSIC_COOKIE_NAME'] = 'tapir_session'
+
+ auth.Auth(app) # <- Install the Auth to get auth checks and request.auth
+
+ @app.route("/")
+ def are_you_logged_in():
+ if request.auth is not None:
+ return "
Hello, You are logged in.
"
+ else:
+ return "Hello unknown client.
"
+```
+
+# Middleware
+
+In during NG there was middleware for arxiv-auth that could be used in NGINX to
+do the authentication there. As of 2023 it is not in use.
+
+See :class:`arxiv.auth.auth.middleware.AuthMiddleware`
+
+If you are not deploying this application in the cloud behind NGINX (and
+therefore will not support sessions from the distributed store), you do not
+need the auth middleware.
diff --git a/arxiv/auth/__init__.py b/arxiv/auth/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/arxiv/auth/app.py b/arxiv/auth/app.py
new file mode 100644
index 00000000..6046b2e4
--- /dev/null
+++ b/arxiv/auth/app.py
@@ -0,0 +1,10 @@
+"""test app."""
+
+import sys
+sys.path.append('./arxiv')
+
+from flask import Flask
+from arxiv_users import auth, legacy
+
+app = Flask('test')
+legacy.create_all()
diff --git a/arxiv/auth/auth/__init__.py b/arxiv/auth/auth/__init__.py
new file mode 100644
index 00000000..8d91c78c
--- /dev/null
+++ b/arxiv/auth/auth/__init__.py
@@ -0,0 +1,199 @@
+"""Provides tools for working with authenticated user/client sessions."""
+
+from typing import Optional, Union, Any, List
+import os
+
+from werkzeug.datastructures.structures import MultiDict
+
+from flask import Flask, request, Response
+from retry import retry
+
+from ...db import transaction
+from ..legacy import util
+from ..legacy.cookies import parse_cookie
+from .. import domain, legacy
+
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+class Auth(object):
+ """
+ Attaches session and authentication information to the request.
+
+ Set env var or `Flask.config` `ARXIV_AUTH_DEBUG` to True to get
+ additional debugging in the logs. Only use this for short term debugging of
+ configs. This may be used in produciton but should not be left on in production.
+
+ Intended for use in a Flask application factory, for example:
+
+ .. code-block:: python
+
+ from flask import Flask
+ from arxiv.users.auth import Auth
+ from someapp import routes
+
+
+ def create_web_app() -> Flask:
+ app = Flask('someapp')
+ app.config.from_pyfile('config.py')
+ Auth(app) # Registers the before_reques auth check
+
+ @app.route("/hello")
+ def hello():
+ if request.auth:
+ return f"Hello {request.auth.user.name}!"
+ else:
+ return f"Hello world! (not authenticated)"
+
+ return app
+
+
+ """
+
+ def __init__(self, app: Optional[Flask] = None) -> None:
+ """
+ Initialize ``app`` with `Auth`.
+
+ Parameters
+ ----------
+ app : :class:`Flask`
+
+ """
+ if app is not None:
+ self.init_app(app)
+ if self.app.config.get('AUTH_UPDATED_SESSION_REF'):
+ self.auth_session_name = "auth"
+ else:
+ self.auth_session_name = "session"
+
+ @retry(legacy.exceptions.Unavailable, tries=3, delay=0.5, backoff=2)
+ def _get_legacy_session(self,
+ cookie_value: str) -> Optional[domain.Session]:
+ """
+ Attempt to load a legacy auth session.
+
+ Returns
+ -------
+ :class:`domain.Session` or None
+
+ """
+ if cookie_value is None:
+ return None
+ try:
+ with transaction():
+ return legacy.sessions.load(cookie_value)
+ except legacy.exceptions.UnknownSession as e:
+ logger.debug('No legacy session available: %s', e)
+ except legacy.exceptions.InvalidCookie as e:
+ logger.debug('Invalid legacy cookie: %s', e)
+ except legacy.exceptions.SessionExpired as e:
+ logger.debug('Legacy session is expired: %s', e)
+ return None
+
+ def init_app(self, app: Flask) -> None:
+ """
+ Attach :meth:`.load_session` to the Flask app.
+
+ Parameters
+ ----------
+ app : :class:`Flask`
+
+ """
+ self.app = app
+ app.config['arxiv_auth.Auth'] = self
+
+ if app.config.get('ARXIV_AUTH_DEBUG') or os.getenv('ARXIV_AUTH_DEBUG'):
+ self.auth_debug()
+ logger.debug("ARXIV_AUTH_DEBUG is set and auth debug messages to logging are turned on")
+
+ self.app.before_request(self.load_session)
+ self.app.config.setdefault('DEFAULT_LOGOUT_REDIRECT_URL',
+ 'https://arxiv.org')
+ self.app.config.setdefault('DEFAULT_LOGIN_REDIRECT_URL',
+ 'https://arxiv.org')
+
+ if app.config.get('ARXIV_AUTH_DEBUG') or os.getenv('ARXIV_AUTH_DEBUG'):
+ self.auth_debug()
+ logger.debug("ARXIV_AUTH_DEBUG is set and auth debug messages to logging is turned on")
+
+
+ def load_session(self) -> Optional[Response]:
+ """Look for an active session, and attach it to the request.
+
+ The typical scenario will involve the
+ :class:`.middleware.AuthMiddleware` unpacking a session token and
+ adding it to the WSGI request environ.
+
+ As a fallback, if the legacy database is available, this method will
+ also attempt to load an active legacy session.
+
+ """
+ # Check the WSGI request environ for the key, which is where the auth
+ # middleware puts any unpacked auth information from the request OR any
+ # exceptions that need to be raised withing the request context.
+ req_auth: Optional[Union[domain.Session, Exception]] = \
+ request.environ.get(self.auth_session_name)
+
+ # Middlware may raise exception, needs to be raised in to be handled correctly.
+ if isinstance(req_auth, Exception):
+ logger.debug('Middleware passed an exception: %s', req_auth)
+ raise req_auth
+
+ if not req_auth:
+ if util.is_configured():
+ req_auth = self.first_valid(self.legacy_cookies())
+ else:
+ logger.warning('No legacy DB, will not check tapir auth.')
+
+ # Attach auth to the request so other can access easily. request.auth
+ setattr(request, self.auth_session_name, req_auth)
+ return None
+
+ def first_valid(self, cookies: List[str]) -> Optional[domain.Session]:
+ """First valid legacy session or None if there are none."""
+ first = next(filter(bool,
+ map(self._get_legacy_session,
+ cookies)), None)
+
+ if first is None:
+ logger.debug("Out of %d cookies, no legacy cookie found", len(cookies))
+ else:
+ logger.debug("Out of %d cookies, found a good legacy cookie", len(cookies))
+
+ return first
+
+ def legacy_cookies(self) -> List[str]:
+ """Gets list of legacy cookies.
+
+ Duplicate cookies occur due to the browser sending both the
+ cookies for both arxiv.org and sub.arxiv.org. If this is being
+ served at sub.arxiv.org, there is no response that will cause
+ the browser to alter its cookie store for arxiv.org. Duplicate
+ cookies must be handled gracefully to for the domain and
+ subdomain to coexist.
+
+ The standard way to avoid this problem is to append part of
+ the domain's name to the cookie key but this needs to work
+ even if the configuration is not ideal.
+
+ """
+ # By default, werkzeug uses a dict-based struct that supports only a
+ # single value per key. This isn't really up to speed with RFC 6265.
+ # Luckily we can just pass in an alternate struct to parse_cookie()
+ # that can cope with multiple values.
+ raw_cookie = request.environ.get('HTTP_COOKIE', None)
+ if raw_cookie is None:
+ return []
+ cookies = parse_cookie(raw_cookie, cls=MultiDict)
+ return cookies.getlist(self.app.config['CLASSIC_COOKIE_NAME'])
+
+ def auth_debug(self) -> None:
+ """Sets several auth loggers to DEBUG.
+
+ This is useful to get an idea of what is going on with auth.
+ """
+ logger.setLevel(logging.DEBUG)
+ legacy.sessions.logger.setLevel(logging.DEBUG)
+ legacy.authenticate.logger.setLevel(logging.DEBUG)
diff --git a/arxiv/auth/auth/decorators.py b/arxiv/auth/auth/decorators.py
new file mode 100644
index 00000000..4d4ddb9c
--- /dev/null
+++ b/arxiv/auth/auth/decorators.py
@@ -0,0 +1,248 @@
+"""
+Scope-based authorization of user/client requests.
+
+This module provides :func:`scoped`, a decorator factory used to protect Flask
+routes for which authorization is required. This is done by specifying a
+required authorization scope (see :mod:`arxiv.users.auth.scopes`) and/or by
+providing a custom authorizer function.
+
+For routes that involve specific resources, a ``resource`` callback should also
+be provided. That callback function should accept the same arguments as the
+route function, and return the identifier for the resource as a string.
+
+Using :func:`scoped` with an authorizer function allows you to define
+application-specific authorization logic on a per-request basis without adding
+complexity to request controllers. The call signature of the authorizer
+function should be: ``(session: domain.Session, *args, **kwargs) -> bool``,
+where `*args` and `**kwargs` are the positional and keyword arguments,
+respectively, passed by Flask to the decorated route function (e.g. the
+URL parameters).
+
+.. note:: The authorizer function is only called if the session does not have
+ a global or resource-specific instance of the required scope, or if a
+ required scope is not specified.
+
+Here's an example of how you might use this in a Flask application:
+
+.. code-block:: python
+
+ from arxiv.users.auth.decorators import scoped
+ from arxiv.users.auth import scopes
+ from arxiv.users import domain
+
+
+ def is_owner(session: domain.Session, user_id: str, **kwargs) -> bool:
+ '''Check whether the authenticated user matches the requested user.'''
+ return session.user.user_id == user_id
+
+
+ def get_resource_id(user_id: str) -> str:
+ '''Get the user ID from the request.'''
+ return user_id
+
+
+ def redirect_to_login(user_id: str) -> Response:
+ '''Send the unauthorized user to the log in page.'''
+ return url_for('login')
+
+
+ @blueprint.route('//profile', methods=['GET'])
+ @scoped(scopes.EDIT_PROFILE, resource=get_resource_id,
+ authorizer=user_is_owner, unauthorized=redirect_to_login)
+ def edit_profile(user_id: str):
+ '''User can update their account information.'''
+ data, code, headers = profile.get_profile(user_id)
+ return render_template('accounts/profile.html', **data)
+
+
+When the decorated route function is called...
+
+- If no session is available from either the middleware or the legacy database,
+ the ``unauthorized`` callback is called, and/or :class:`Unauthorized`
+ exception is raised.
+- If a required scope was provided, the session is checked for the presence of
+ that scope in this order:
+
+ - Global scope (`:*`), e.g. for administrators.
+ - Resource-specific scope (`:[resource_id]`), i.e. explicitly granted for a
+ particular resource.
+ - Generic scope (no resource part).
+
+- If an authorization function was provided, the function is called only if
+ a required scope was not provided, or if only the generic scope was found.
+- Session data is added directly to the Flask request object as
+ ``request.auth``, for ease of access elsewhere in the application.
+- Finally, if no exceptions have been raised, the route is called with the
+ original parameters.
+
+"""
+
+from typing import Optional, Callable, Any, List
+from functools import wraps
+from flask import request
+from werkzeug.exceptions import Unauthorized, Forbidden
+import logging
+
+from .. import domain
+
+INVALID_TOKEN = {'reason': 'Invalid authorization token'}
+INVALID_SCOPE = {'reason': 'Token not authorized for this action'}
+
+
+logger = logging.getLogger(__name__)
+logger.propagate = False
+
+
+def scoped(required: Optional[domain.Scope] = None,
+ resource: Optional[Callable] = None,
+ authorizer: Optional[Callable] = None,
+ unauthorized: Optional[Callable] = None) -> Callable:
+ """
+ Generate a decorator to enforce authorization requirements.
+
+ Parameters
+ ----------
+ required : str
+ The scope required on a user or client session in order use the
+ decorated route. See :mod:`arxiv.users.auth.scopes`. If not provided,
+ no scope will be enforced.
+ resource : function
+ If a route provides actions for a specific resource, a callable should
+ be provided that accepts the route arguments and returns the resource
+ identifier (str).
+ authorizer : function
+ In addition, an authorizer function may be passed to provide more
+ specific authorization checks. For example, this function may check
+ that the requesting user is the owner of a resource. Should have the
+ signature: ``(session: domain.Session, *args, **kwargs) -> bool``.
+ ``*args`` and ``**kwargs`` are the parameters passed to the decorated
+ function. If the authorizer returns ``False``, an :class:`.`
+ exception is raised.
+ unauthorized : function
+ DEPRECATED: Do not use this parameter. It is likely it does not do
+ what you expect.
+ A callback may be passed to handle cases in which the request is
+ unauthorized. This function will be passed the same arguments as the
+ original route function. If the callback returns anything other than
+ ``None``, the return value will be treated as a response and execution
+ will stop. Otherwise, an ``Unauthorized`` exception will be raised.
+ If a callback is not provided (default) an ``Unauthorized`` exception
+ will be raised.
+
+ Returns
+ -------
+ function
+ A decorator that enforces the required scope and calls the (optionally)
+ provided authorizer.
+
+ """
+ if required and not isinstance(required, domain.Scope):
+ required = domain.Scope(required)
+
+ def protector(func: Callable) -> Callable:
+ """Decorator that provides scope enforcement."""
+ @wraps(func)
+ def wrapper(*args: Any, **kwargs: Any) -> Any:
+ """
+ Check the authorization token before executing the method.
+
+ Will also raise exceptions passed by the auth middleware.
+
+ Raises
+ ------
+ :class:`.Unauthorized`
+ Raised when session data is not available.
+ :class:`.`
+ Raised when the session has insufficient auth scope, or the
+ provided authorizer returns ``False``.
+
+ """
+ if hasattr(request, 'auth'):
+ session = request.auth
+ elif hasattr(request, 'session'):
+ session = request.session
+ else:
+ raise Unauthorized('No active session on request')
+ scopes: List[domain.Scope] = []
+ authorized: bool = False
+ logger.debug('Required: %s, authorizer: %s, unauthorized: %s',
+ required, authorizer, unauthorized)
+ # Use of the decorator implies that an auth session ought to be
+ # present. So we'll complain here if it's not.
+ if not session or not (session.user or session.client):
+ logger.debug('No valid session; aborting')
+ if unauthorized is not None:
+ response = unauthorized(*args, **kwargs)
+ if response is not None:
+ return response
+ raise Unauthorized('Not a valid session')
+
+ if session.authorizations is not None:
+ scopes = session.authorizations.scopes
+ logger.debug('session has scopes: %s', scopes)
+
+ # If a required scope is provided, we first check to see whether
+ # the session globally or explicitly authorizes the request. We
+ # then fall back to the locally-defined authorizer function if it
+ # is provided.
+ if required and scopes:
+ # A global scope is usually granted to administrators, or
+ # perhaps moderators (e.g. view submission content).
+ # For example: `submission:read:*`.
+ if required.as_global() in scopes:
+ logger.debug('Authorized with global scope')
+ authorized = True
+
+ # A resource-specific scope may be granted at the auth layer.
+ # For example, an admin may provide provisional access to a
+ # specific resource for a specific role. This kind of
+ # authorization is only supported if the service provides a
+ # ``resource()`` callback to get the resource identifier.
+ elif (resource is not None
+ and (
+ required.for_resource(str(resource(*args, **kwargs)))
+ in scopes)):
+ logger.debug('Authorized by specific resource')
+ authorized = True
+
+ # If both the global and resource-specific scope authorization
+ # fail, then we look for the general scope in the session.
+ elif required in scopes:
+ logger.debug('Required scope is present')
+ # If an authorizer callback is provided by the service,
+ # then we will enforce whatever it returns.
+ if authorizer:
+ authorized = authorizer(session, *args, **kwargs)
+ logger.debug('Authorizer func returned %s', authorized)
+ # If no authorizer callback is provided, it is implied that
+ # the general scope is sufficient to authorize the request.
+ elif authorizer is None:
+ logger.debug('No authorizer func provided')
+ authorized = True
+ # The required scope is not present. There is nothing left to
+ # check.
+ else:
+ logger.debug('Required scope is not present')
+ authorized = False
+
+ elif required is None and authorizer is None:
+ logger.debug('No scope required, no authorizer function;'
+ ' request is authorized.')
+ authorized = True
+
+ # If a specific scope is not required, we rely entirely on the
+ # authorizer callback.
+ elif authorizer is not None:
+ logger.debug('Calling authorizer callback')
+ authorized = authorizer(session, *args, **kwargs)
+ else:
+ logger.debug('No authorization path available')
+
+ if not authorized:
+ logger.debug('Session is not authorized')
+ raise Forbidden('Access denied')
+
+ logger.debug('Request is authorized, proceeding')
+ return func(*args, **kwargs)
+ return wrapper
+ return protector
diff --git a/arxiv/auth/auth/exceptions.py b/arxiv/auth/auth/exceptions.py
new file mode 100644
index 00000000..e15deccd
--- /dev/null
+++ b/arxiv/auth/auth/exceptions.py
@@ -0,0 +1,29 @@
+"""Authn/z-related exceptions raised by components in this module."""
+
+
+class InvalidToken(ValueError):
+ """Token in request is not valid."""
+
+
+class MissingToken(ValueError):
+ """No token found in request."""
+
+
+class ConfigurationError(RuntimeError):
+ """The application is not configured correctly."""
+
+
+class SessionCreationFailed(RuntimeError):
+ """Failed to create a session in the session store."""
+
+
+class SessionDeletionFailed(RuntimeError):
+ """Failed to delete a session in the session store."""
+
+
+class UnknownSession(RuntimeError):
+ """Failed to locate a session in the session store."""
+
+
+class ExpiredToken(RuntimeError):
+ """An expired token was passed."""
diff --git a/arxiv/auth/auth/middleware.py b/arxiv/auth/auth/middleware.py
new file mode 100644
index 00000000..be4c63db
--- /dev/null
+++ b/arxiv/auth/auth/middleware.py
@@ -0,0 +1,105 @@
+"""
+Middleware for interpreting authn/z information on requestsself.
+
+This module provides :class:`AuthMiddleware`, which unpacks encrypted JSON
+Web Tokens provided via the ``Authorization`` header. This is intended to
+support requests that have been pre-authorized by the web server using the
+authenticator service (see :mod:`authenticator`).
+
+The configuration parameter ``JWT_SECRET`` must be set in the WSGI request
+environ (e.g. Apache's SetEnv) or in the runtime environment. This must be
+the same secret that was used by the authenticator service to mint the token.
+
+To install the middleware, use the pattern described in
+:mod:`arxiv.base.middleware`. For example:
+
+.. code-block:: python
+
+ from arxiv.base import Base
+ from arxiv.base.middleware import wrap
+ from arxiv.users import auth
+
+
+ def create_web_app() -> Flask:
+ app = Flask('foo')
+ Base(app)
+ auth.Auth(app)
+ wrap(app, [auth.middleware.AuthMiddleware])
+ return app
+
+
+For convenience, this is intended to be used with
+:mod:`arxiv.users.auth.decorators`.
+
+"""
+
+import os
+from typing import Callable, Iterable, Tuple
+import jwt
+import logging
+
+from werkzeug.exceptions import Unauthorized, InternalServerError
+
+from arxiv.base.middleware import BaseMiddleware
+
+from . import tokens
+from .exceptions import InvalidToken, ConfigurationError, MissingToken
+from .. import domain
+
+logger = logging.getLogger(__name__)
+
+WSGIRequest = Tuple[dict, Callable]
+
+
+class AuthMiddleware(BaseMiddleware):
+ """
+ Middleware to handle auth information on requests.
+
+ Before the request is handled by the application, the ``Authorization``
+ header is parsed for an encrypted JWT. If successfully decrypted,
+ information about the user and their authorization scope is attached
+ to the request.
+
+ This can be accessed in the application via
+ ``flask.request.environ['session']``. If Authorization header was not
+ included, then that value will be ``None``.
+
+ If the JWT could not be decrypted, the value will be an
+ :class:`Unauthorized` exception instance. We cannot raise the exception
+ here, because the middleware is executed outside of the Flask application.
+ It's up to something running inside the application (e.g.
+ :func:`arxiv.users.auth.decorators.scoped`) to raise the exception.
+
+ """
+
+ def before(self, environ: dict, start_response: Callable) -> WSGIRequest:
+ """Decode and unpack the auth token on the request."""
+ environ['auth'] = None # Create the session key, at a minimum.
+ environ['token'] = None
+ token = environ.get('HTTP_AUTHORIZATION') # We may not have a token.
+ if token is None:
+ logger.debug('No auth token')
+ return environ, start_response
+
+ # The token secret should be set in the WSGI environ, or in os.environ.
+ secret = environ.get('JWT_SECRET', os.environ.get('JWT_SECRET'))
+ if secret is None:
+ raise ConfigurationError('Missing decryption token')
+
+ try:
+ # Try to verify the token in the Authorization header, and attach
+ # the decoded session data to the request.
+ session: domain.Session = tokens.decode(token, secret)
+ environ['auth'] = session
+
+ # Attach the encrypted token so that we can use it in subrequests.
+ environ['token'] = token
+ except InvalidToken as e: # Let the application decide what to do.
+ logger.debug(f'Auth token not valid: {token}')
+ exception = Unauthorized('Invalid auth token')
+ environ['auth'] = exception
+ except Exception as e:
+ logger.error(f'Unhandled exception: {e}')
+ exception = InternalServerError(f'Unhandled: {e}') # type: ignore
+ environ['auth'] = exception
+ return environ, start_response
diff --git a/arxiv/auth/auth/scopes.py b/arxiv/auth/auth/scopes.py
new file mode 100644
index 00000000..2c56858a
--- /dev/null
+++ b/arxiv/auth/auth/scopes.py
@@ -0,0 +1,241 @@
+"""
+Authorization scopes for arXiv users and clients.
+
+The concept of authorization scope comes from OAuth 2.0 (`RFC 6749 ยง3.3
+`_). For a nice primer, see
+`this blog post `_. The basic idea is that
+the authorization associated with an access token can be limited, e.g. to
+limit what actions an API client can take on behalf of a user.
+
+In this package, the scope concept is applied to both API client and end-user
+sessions. When the session is created, we consult the relevant bits of data in
+our system (e.g. what roles the user has, what privileges are associated with
+those roles) to determine what the user is authorized to do. Those privileges
+are attached to the user's session as authorization scopes.
+
+This module simply defines a set of constants (str) that can be used to refer
+to specific authorization scopes. Rather than refer to scopes by writing new
+str objects, these constants should be imported and used. This improves the
+semantics of code, and reduces the risk of programming errors. For an example,
+see :mod:`arxiv.users.auth.decorators`.
+
+"""
+from typing import Optional, Dict
+from ..domain import Scope
+
+def _scope_str(domain:str, action:str, resource:str = None):
+ return str(Scope(domain=domain, action=action, resource=resource))
+
+
+class domains:
+ """Known authorization domains."""
+
+ PUBLIC = 'public'
+ """The public arXiv site, including APIs."""
+ PROFILE = 'profile'
+ """arXiv user profile."""
+ SUBMISSION = 'submission'
+ """Submission interfaces and actions."""
+ UPLOAD = 'upload'
+ """File uploads, including those for submissions."""
+ COMPILE = 'compile'
+ """PDF compilation."""
+ FULLTEXT = 'fulltext'
+ """Fulltext extraction."""
+ PREVIEW = 'preview'
+ """Submission previews."""
+
+class actions:
+ """Known authorization actions."""
+
+ UPDATE = 'update'
+ CREATE = 'create'
+ DELETE = 'delete'
+ RELEASE = 'release'
+ READ = 'read'
+ PROXY = 'proxy'
+
+READ_PUBLIC = _scope_str(domains.PUBLIC, actions.READ)
+"""
+Authorizes access to public endpoints.
+
+This is implicitly granted to all anonymous users. For endpoints requiring
+authentication (e.g. APIs) this scope can be used to denote baseline read
+access for clients.
+"""
+
+EDIT_PROFILE = _scope_str(domains.PROFILE, actions.UPDATE)
+"""
+Authorizes editing user profile.
+
+This includes things like affiliation, full name, etc..
+"""
+
+VIEW_PROFILE = _scope_str(domains.PROFILE, actions.READ)
+"""
+Authorizes viewing the content of a user profile.
+
+This includes things like affiliation, full name, and e-mail address.
+"""
+
+CREATE_SUBMISSION = _scope_str(domains.SUBMISSION, actions.CREATE)
+"""Authorizes creating a new submission."""
+
+EDIT_SUBMISSION = _scope_str(domains.SUBMISSION, actions.UPDATE)
+"""Authorizes updating a submission that has not yet been announced."""
+
+VIEW_SUBMISSION = _scope_str(domains.SUBMISSION, actions.READ)
+"""Authorizes viewing a submission."""
+
+DELETE_SUBMISSION = _scope_str(domains.SUBMISSION, actions.DELETE)
+"""Authorizes deleting a submission."""
+
+PROXY_SUBMISSION = _scope_str(domains.SUBMISSION, actions.PROXY)
+"""
+Authorizes creating a submission on behalf of another user.
+
+This authorization is specifically for human users submitting on behalf of
+other human users. For client authorization to submit on behalf of a user,
+submission:create
should be used instead.
+"""
+
+READ_UPLOAD = _scope_str(domains.UPLOAD, actions.READ)
+"""Authorizes viewing the content of an upload workspace."""
+
+WRITE_UPLOAD = _scope_str(domains.UPLOAD, actions.UPDATE)
+"""Authorizes uploading files to to a workspace."""
+
+RELEASE_UPLOAD = _scope_str(domains.UPLOAD, actions.RELEASE)
+"""Authorizes releasing an upload workspace."""
+
+DELETE_UPLOAD_WORKSPACE = _scope_str(domains.UPLOAD, 'delete_workspace')
+"""Can delete an entire workspace in the file management service."""
+
+DELETE_UPLOAD_FILE = _scope_str(domains.UPLOAD, actions.DELETE)
+"""Can delete files from a file management upload workspace."""
+
+READ_UPLOAD_LOGS = _scope_str(domains.UPLOAD, 'read_logs')
+"""Can read logs for a file management upload workspace."""
+
+READ_UPLOAD_SERVICE_LOGS = _scope_str(domains.UPLOAD, 'read_service_logs')
+"""Can read service logs in the file management service."""
+
+CREATE_UPLOAD_CHECKPOINT = _scope_str(domains.UPLOAD, 'create_checkpoint')
+"""Create an upload workspace checkpoint."""
+
+DELETE_UPLOAD_CHECKPOINT = _scope_str(domains.UPLOAD, 'delete_checkpoint')
+"""Delete an upload workspace checkpoint."""
+
+READ_UPLOAD_CHECKPOINT = _scope_str(domains.UPLOAD, 'read_checkpoints')
+"""Read from an upload workspace checkpoint."""
+
+RESTORE_UPLOAD_CHECKPOINT = _scope_str(domains.UPLOAD, 'restore_checkpoint')
+"""Restore an upload workspace to a checkpoint."""
+
+READ_COMPILE = _scope_str(domains.COMPILE, actions.READ)
+"""Can read documents generated by the compilation service."""
+
+CREATE_COMPILE = _scope_str(domains.COMPILE, actions.CREATE)
+"""Can create new documents via the compilation service."""
+
+READ_FULLTEXT = _scope_str(domains.FULLTEXT, actions.READ)
+"""Can access plain text extracted from compiled documents."""
+
+CREATE_FULLTEXT = _scope_str(domains.FULLTEXT, actions.CREATE)
+"""Can trigger new plain text extractions from compiled documents."""
+
+READ_PREVIEW = _scope_str(domains.PREVIEW, actions.READ)
+"""Can view a submission preview."""
+
+CREATE_PREVIEW = _scope_str(domains.PREVIEW, actions.CREATE)
+"""Can create a new submission preview."""
+
+GENERAL_USER = [
+ READ_PUBLIC, # Access to public APIs.
+
+ # Profile management.
+ EDIT_PROFILE,
+ VIEW_PROFILE,
+
+ # Ability to use the submission system.
+ CREATE_SUBMISSION,
+ EDIT_SUBMISSION,
+ VIEW_SUBMISSION,
+ DELETE_SUBMISSION,
+
+ # Allows usage of the compilation service during submission.
+ READ_COMPILE,
+ CREATE_COMPILE,
+
+ # Allows usage of the file management service during submission.
+ READ_UPLOAD,
+ WRITE_UPLOAD,
+ DELETE_UPLOAD_FILE,
+
+ # Ability to create and view submission previews.
+ READ_PREVIEW,
+ CREATE_PREVIEW,
+]
+"""
+The default scopes afforded to an authenticated user.
+
+This static list will be deprecated by role-based access control (RBAC) at a
+later milestone of arXiv.
+"""
+
+_ADMIN_USER = GENERAL_USER + [
+ CREATE_UPLOAD_CHECKPOINT,
+ DELETE_UPLOAD_CHECKPOINT,
+ READ_UPLOAD_CHECKPOINT,
+ RESTORE_UPLOAD_CHECKPOINT,
+ READ_FULLTEXT,
+ CREATE_FULLTEXT,
+ DELETE_UPLOAD_WORKSPACE,
+ READ_UPLOAD_LOGS,
+ READ_UPLOAD_SERVICE_LOGS
+]
+ADMIN_USER = [str(Scope.from_str(scope).as_global()) for scope in _ADMIN_USER]
+"""
+Scopes afforded to an administrator.
+
+This static list will be deprecated by role-based access control (RBAC) at a
+later milestone of arXiv.
+"""
+
+_HUMAN_LABELS: Dict[Scope, str] = {
+ EDIT_PROFILE: "Grants authorization to change the contents of your user"
+ " profile. This includes your affiliation, preferred name,"
+ " default submission category, etc.",
+ VIEW_PROFILE: "Grants authorization to view the contents of your user"
+ " profile. This includes your affiliation, preferred name,"
+ " default submission category, etc.",
+ CREATE_SUBMISSION: "Grants authorization to submit papers on your behalf.",
+ EDIT_SUBMISSION: "Grants authorization to make changes to your submissions"
+ " that have not yet been announced. For example, to"
+ " update the DOI or journal reference field on your"
+ " behalf. Note that this affects only the metadata of"
+ " your submission, and not the content.",
+ VIEW_SUBMISSION: "Grants authorization to view your submissions, including"
+ " those that have not yet been announced. Note that this"
+ " only applies to the submission metadata, and not to the"
+ " uploaded submission content.",
+ READ_UPLOAD: "Grants authorization to view the contents of your uploads.",
+ WRITE_UPLOAD: "Grants authorization to add and delete files on your"
+ " behalf.",
+ READ_UPLOAD_LOGS: "Grants authorization to read logs of upload activity"
+ " related to your submissions.",
+ READ_COMPILE: "Grants authorization to read a compilation task, product,"
+ " and any log output, related to your submissions.",
+ CREATE_COMPILE: "Grants authorization to compile your submission source"
+ " files, e.g. to produce a PDF.",
+ READ_PREVIEW: "Grants authorization to retrieve the preview (e.g. PDF) of"
+ " your submission.",
+ CREATE_PREVIEW: "Grants authorization to update the preview (e.g. PDF) of"
+ " your submission.",
+}
+
+
+def get_human_label(scope: str) -> Optional[str]:
+ """Get a human-readable label for a scope, for display to end users."""
+ label: Optional[str] = _HUMAN_LABELS.get(scope, None)
+ return label
diff --git a/arxiv/auth/auth/sessions/__init__.py b/arxiv/auth/auth/sessions/__init__.py
new file mode 100644
index 00000000..0d2b82ab
--- /dev/null
+++ b/arxiv/auth/auth/sessions/__init__.py
@@ -0,0 +1,12 @@
+"""
+Integration with the distributed session store.
+
+In this implementation, we use a key-value store to hold session data
+in JSON format. When a session is created, a JWT cookie value is
+created that contains information sufficient to retrieve the session.
+
+See :mod:`.store`.
+
+"""
+
+from .store import SessionStore
diff --git a/arxiv/auth/auth/sessions/store.py b/arxiv/auth/auth/sessions/store.py
new file mode 100644
index 00000000..cd70eb36
--- /dev/null
+++ b/arxiv/auth/auth/sessions/store.py
@@ -0,0 +1,271 @@
+"""
+Internal service API for the distributed session store.
+
+Used to create, delete, and verify user and client session.
+"""
+
+import uuid
+import random
+from datetime import datetime, timedelta
+import dateutil.parser
+from pytz import timezone, UTC
+import logging
+
+from typing import Optional, Union
+
+import redis
+import rediscluster
+
+import jwt
+
+from .. import domain
+from ..exceptions import SessionCreationFailed, InvalidToken, \
+ SessionDeletionFailed, UnknownSession, ExpiredToken
+
+from arxiv.base.globals import get_application_config, get_application_global
+
+logger = logging.getLogger(__name__)
+EASTERN = timezone('US/Eastern')
+
+
+def _generate_nonce(length: int = 8) -> str:
+ return ''.join([str(random.randint(0, 9)) for i in range(length)])
+
+
+class SessionStore(object):
+ """
+ Manages a connection to Redis.
+
+ In fact, the StrictRedis instance is thread safe and connections are
+ attached at the time a command is executed. This class simply provides a
+ container for configuration.
+
+ Pass fake=True to use FakeRedis for testing of development.
+ """
+
+ def __init__(self, host: str, port: int, db: int, secret: str,
+ duration: int = 7200, token: Optional[str] = None,
+ cluster: bool = True, fake: bool = False) -> None:
+ """Open the connection to Redis."""
+ self._secret = secret
+ self._duration = duration
+ if fake:
+ logger.warning('Using FakeRedis')
+ import fakeredis # this is a dev dependency needed during testing
+ self.r = fakeredis.FakeStrictRedis()
+ else:
+ logger.debug('New Redis connection at %s, port %s', host, port)
+ if cluster:
+ self.r = rediscluster.StrictRedisCluster(
+ startup_nodes=[{'host': host, 'port': str(port)}],
+ skip_full_coverage_check=True
+ )
+ else:
+ self.r = redis.StrictRedis(host=host, port=port)
+
+ def create(self, authorizations: domain.Authorizations,
+ ip_address: str, remote_host: str, tracking_cookie: str = '',
+ user: Optional[domain.User] = None,
+ client: Optional[domain.Client] = None,
+ session_id: Optional[str] = None) -> domain.Session:
+ """
+ Create a new session.
+
+ Parameters
+ ----------
+ authorizations : :class:`domain.Authorizations`
+ ip_address : str
+ remote_host : str
+ tracking_cookie : str
+ user : :class:`domain.User`
+ client : :class:`domain.Client`
+
+ Returns
+ -------
+ :class:`.Session`
+
+ """
+ if session_id is None:
+ session_id = str(uuid.uuid4())
+ start_time = datetime.now(tz=UTC)
+ end_time = start_time + timedelta(seconds=self._duration)
+ session = domain.Session(
+ session_id=session_id,
+ user=user,
+ client=client,
+ start_time=start_time,
+ end_time=end_time,
+ authorizations=authorizations,
+ nonce=_generate_nonce()
+ )
+ logger.debug('storing session %s', session)
+ try:
+ self.r.set(session_id,
+ jwt.encode(session.json_safe_dict(), self._secret),
+ ex=self._duration)
+ except redis.exceptions.ConnectionError as e:
+ raise SessionCreationFailed(f'Connection failed: {e}') from e
+ except Exception as e:
+ raise SessionCreationFailed(f'Failed to create: {e}') from e
+
+ return session
+
+ def generate_cookie(self, session: domain.Session) -> str:
+ """Generate a cookie from a :class:`domain.Session`."""
+ if session.end_time is None:
+ raise RuntimeError('Session has no expiry')
+ if session.user is None:
+ raise RuntimeError('Session user is not set')
+ return self._pack_cookie({
+ 'user_id': session.user.user_id,
+ 'session_id': session.session_id,
+ 'nonce': session.nonce,
+ 'expires': session.end_time.isoformat()
+ })
+
+ def delete(self, cookie: str) -> None:
+ """
+ Delete a session.
+
+ Parameters
+ ----------
+ cookie : str
+
+ """
+ cookie_data = self._unpack_cookie(cookie)
+ self.delete_by_id(cookie_data['session_id'])
+
+ def delete_by_id(self, session_id: str) -> None:
+ """
+ Delete a session in the key-value store by ID.
+
+ Parameters
+ ----------
+ session_id : str
+
+ """
+ try:
+ self.r.delete(session_id)
+ except redis.exceptions.ConnectionError as e:
+ raise SessionDeletionFailed(f'Connection failed: {e}') from e
+ except Exception as e:
+ raise SessionDeletionFailed(f'Failed to delete: {e}') from e
+
+ def validate_session_against_cookie(self, session: domain.Session,
+ cookie: str) -> None:
+ """
+ Validate session data against a cookie.
+
+ Parameters
+ ----------
+ session : :class:`Session`
+ cookie : str
+
+ Raises
+ ------
+ :class:`InvalidToken`
+ Raised if the data in the cookie does not match the session data.
+
+ """
+ cookie_data = self._unpack_cookie(cookie)
+ if cookie_data['nonce'] != session.nonce \
+ or session.user is None \
+ or session.user.user_id != cookie_data['user_id']:
+ raise InvalidToken('Invalid token; likely a forgery')
+
+ def load(self, cookie: str, decode: bool = True) \
+ -> Union[domain.Session, str, bytes]:
+ """Load a session using a session cookie."""
+ try:
+ cookie_data = self._unpack_cookie(cookie)
+ expires = dateutil.parser.parse(cookie_data['expires'])
+ except (KeyError, jwt.exceptions.DecodeError) as e:
+ raise InvalidToken('Token payload malformed') from e
+
+ if expires <= datetime.now(tz=UTC):
+ raise InvalidToken('Session has expired')
+
+ session = self.load_by_id(cookie_data['session_id'], decode=decode)
+
+ if not decode:
+ assert isinstance(session, str) or isinstance(session, bytes)
+ return session
+ assert isinstance(session, domain.Session)
+ if session.expired:
+ raise ExpiredToken('Session has expired')
+ if session.user is None and session.client is None:
+ raise InvalidToken('Neither user nor client data are present')
+
+ self.validate_session_against_cookie(session, cookie)
+ return session
+
+ def load_by_id(self, session_id: str, decode: bool = True) \
+ -> Union[domain.Session, str, bytes]:
+ """Get session data by session ID."""
+ session_jwt: str = self.r.get(session_id)
+ if not session_jwt:
+ logger.debug(f'No such session: {session_id}')
+ raise UnknownSession(f'Failed to find session {session_id}')
+ if decode:
+ return self._decode(session_jwt)
+ return session_jwt
+
+ def _encode(self, session_data: dict) -> bytes:
+ return jwt.encode(session_data, self._secret)
+
+ def _decode(self, session_jwt: str) -> domain.Session:
+ try:
+ return domain.Session.parse_obj(
+ jwt.decode(session_jwt, self._secret, algorithms=['HS256']))
+ except jwt.exceptions.InvalidSignatureError:
+ raise InvalidToken('Invalid or corrupted session token')
+
+ def _unpack_cookie(self, cookie: str) -> dict:
+ secret = self._secret
+ try:
+ data = dict(jwt.decode(cookie, secret, algorithms=['HS256']))
+ except jwt.exceptions.DecodeError as e:
+ raise InvalidToken('Session cookie is malformed') from e
+ return data
+
+ def _pack_cookie(self, cookie_data: dict) -> str:
+ secret = self._secret
+ return jwt.encode(cookie_data, secret)
+
+ @classmethod
+ def init_app(cls, app: object = None) -> None:
+ """Set default configuration parameters for an application instance."""
+ config = get_application_config(app)
+ config.setdefault('REDIS_HOST', 'localhost')
+ config.setdefault('REDIS_PORT', '7000')
+ config.setdefault('REDIS_DATABASE', '0')
+ config.setdefault('REDIS_TOKEN', None)
+ config.setdefault('REDIS_CLUSTER', '1')
+ config.setdefault('JWT_SECRET', 'foosecret')
+ config.setdefault('SESSION_DURATION', '7200')
+ config.setdefault('REDIS_FAKE', False)
+
+ @classmethod
+ def get_session(cls, app: object = None) -> 'SessionStore':
+ """Get a new session with the search index."""
+ config = get_application_config(app)
+ host = config.get('REDIS_HOST', 'localhost')
+ port = int(config.get('REDIS_PORT', '7000'))
+ db = int(config.get('REDIS_DATABASE', '0'))
+ token = config.get('REDIS_TOKEN', None)
+ cluster = config.get('REDIS_CLUSTER', '1') == '1'
+ secret = config['JWT_SECRET']
+ duration = int(config.get('SESSION_DURATION', '7200'))
+ fake = config.get('REDIS_FAKE', False)
+ return cls(host, port, db, secret, duration, token=token,
+ cluster=cluster, fake=fake)
+
+ @classmethod
+ def current_session(cls) -> 'SessionStore':
+ """Get/create :class:`.SearchSession` for this context."""
+ g = get_application_global()
+ if not g:
+ return cls.get_session()
+ if 'redis' not in g:
+ g.redis = cls.get_session()
+ return g.redis # type: ignore
diff --git a/arxiv/auth/auth/sessions/tests/__init__.py b/arxiv/auth/auth/sessions/tests/__init__.py
new file mode 100644
index 00000000..d4540d44
--- /dev/null
+++ b/arxiv/auth/auth/sessions/tests/__init__.py
@@ -0,0 +1 @@
+"""Tests for :mod:`accounts.services.session_store`."""
diff --git a/arxiv/auth/auth/sessions/tests/test_integration.py b/arxiv/auth/auth/sessions/tests/test_integration.py
new file mode 100644
index 00000000..2d001e4c
--- /dev/null
+++ b/arxiv/auth/auth/sessions/tests/test_integration.py
@@ -0,0 +1,82 @@
+"""Integration tests for the session_store session store with Redis."""
+
+from unittest import TestCase, mock
+import jwt
+
+from .... import domain
+from .. import store
+
+
+class TestDistributedSessionServiceIntegration(TestCase):
+ """Test integration with Redis."""
+
+ @classmethod
+ def setUpClass(self):
+ self.secret = 'bazsecret'
+
+ @mock.patch(f'{store.__name__}.get_application_config')
+ def test_store_create(self, mock_get_config):
+ """An entry should be created in Redis."""
+ mock_get_config.return_value = {
+ 'JWT_SECRET': self.secret,
+ 'REDIS_FAKE': True
+ }
+ ip = '127.0.0.1'
+ remote_host = 'foo-host.foo.com'
+ user = domain.User(
+ user_id='1',
+ username='theuser',
+ email='the@user.com',
+ )
+ authorizations = domain.Authorizations(
+ classic=2,
+ scopes=['foo:write'],
+ endorsements=[]
+ )
+ s = store.SessionStore.current_session()
+ session = s.create(authorizations, ip, remote_host, user=user)
+ cookie = s.generate_cookie(session)
+
+ # API still works as expected.
+ self.assertIsInstance(session, domain.Session)
+ self.assertTrue(bool(session.session_id))
+ self.assertIsNotNone(cookie)
+
+ r = s.r
+ raw = r.get(session.session_id)
+ stored_data = jwt.decode(raw, self.secret, algorithms=['HS256'])
+ cookie_data = jwt.decode(cookie, self.secret, algorithms=['HS256'])
+ self.assertEqual(stored_data['nonce'], cookie_data['nonce'])
+
+ # def test_invalidate_session(self):
+ # """Invalidate a session from the datastore."""
+ # r = rediscluster.StrictRedisCluster(startup_nodes=[dict(host='localhost', port='7000')])
+ # data_in = {'end_time': time.time() + 30 * 60, 'user_id': 1,
+ # 'nonce': '123'}
+ # r.set('fookey', json.dumps(data_in))
+ # data0 = json.loads(r.get('fookey'))
+ # now = time.time()
+ # self.assertGreaterEqual(data0['end_time'], now)
+ # store.invalidate(
+ # store.current_session()._pack_cookie({
+ # 'session_id': 'fookey',
+ # 'nonce': '123',
+ # 'user_id': 1
+ # })
+ # )
+ # data1 = json.loads(r.get('fookey'))
+ # now = time.time()
+ # self.assertGreaterEqual(now, data1['end_time'])
+
+ @mock.patch(f'{store.__name__}.get_application_config')
+ def test_delete_session(self, mock_get_config):
+ """Delete a session from the datastore."""
+ mock_get_config.return_value = {
+ 'JWT_SECRET': self.secret,
+ 'REDIS_FAKE': True
+ }
+ s = store.SessionStore.current_session()
+ r = s.r
+ r.set('fookey', b'foovalue')
+ s.delete_by_id('fookey')
+ self.assertIsNone(r.get('fookey'))
diff --git a/arxiv/auth/auth/sessions/tests/test_unit.py b/arxiv/auth/auth/sessions/tests/test_unit.py
new file mode 100644
index 00000000..d5b82490
--- /dev/null
+++ b/arxiv/auth/auth/sessions/tests/test_unit.py
@@ -0,0 +1,319 @@
+"""Tests for :mod:`arxiv.users.auth.sessions.store`."""
+
+from unittest import TestCase, mock
+import time
+import jwt
+import json
+from datetime import datetime, timedelta
+from pytz import timezone, UTC
+from redis.exceptions import ConnectionError
+
+from .... import domain
+from .. import store
+
+EASTERN = timezone('US/Eastern')
+
+
+class TestDistributedSessionService(TestCase):
+ """The store session service puts sessions in a key-value store."""
+
+ @mock.patch(f'{store.__name__}.get_application_config')
+ @mock.patch(f'{store.__name__}.rediscluster')
+ def test_create(self, mock_redis, mock_get_config):
+ """Accept a :class:`.User` and returns a :class:`.Session`."""
+ mock_get_config.return_value = {'JWT_SECRET': 'foosecret'}
+ mock_redis.exceptions.ConnectionError = ConnectionError
+ mock_redis_connection = mock.MagicMock()
+ mock_redis.StrictRedisCluster.return_value = mock_redis_connection
+ ip = '127.0.0.1'
+ remote_host = 'foo-host.foo.com'
+ user = domain.User(
+ user_id='1',
+ username='theuser',
+ email='the@user.com'
+ )
+ auths = domain.Authorizations(
+ classic=2,
+ scopes=['foo:write'],
+ endorsements=[]
+ )
+ r = store.SessionStore('localhost', 7000, 0, 'foosecret')
+ session = r.create(auths, ip, remote_host, user=user)
+ cookie = r.generate_cookie(session)
+ self.assertIsInstance(session, domain.Session)
+ self.assertTrue(bool(session.session_id))
+ self.assertIsNotNone(cookie)
+ self.assertEqual(mock_redis_connection.set.call_count, 1)
+
+ @mock.patch(f'{store.__name__}.get_application_config')
+ @mock.patch(f'{store.__name__}.rediscluster')
+ def test_delete(self, mock_redis, mock_get_config):
+ """Delete a session from the datastore."""
+ mock_get_config.return_value = {'JWT_SECRET': 'foosecret'}
+ mock_redis.exceptions.ConnectionError = ConnectionError
+ mock_redis_connection = mock.MagicMock()
+ mock_redis.StrictRedisCluster.return_value = mock_redis_connection
+ r = store.SessionStore('localhost', 7000, 0, 'foosecret')
+ r.delete_by_id('fookey')
+ self.assertEqual(mock_redis_connection.delete.call_count, 1)
+
+ @mock.patch(f'{store.__name__}.get_application_config')
+ @mock.patch(f'{store.__name__}.rediscluster')
+ def test_connection_failed(self, mock_redis, mock_get_config):
+ """:class:`.SessionCreationFailed` is raised when creation fails."""
+ mock_get_config.return_value = {'JWT_SECRET': 'foosecret'}
+ mock_redis.exceptions.ConnectionError = ConnectionError
+ mock_redis_connection = mock.MagicMock()
+ mock_redis_connection.set.side_effect = ConnectionError
+ mock_redis.StrictRedisCluster.return_value = mock_redis_connection
+ ip = '127.0.0.1'
+ remote_host = 'foo-host.foo.com'
+ user = domain.User(
+ user_id='1',
+ username='theuser',
+ email='the@user.com'
+ )
+ auths = domain.Authorizations(
+ classic=2,
+ scopes=['foo:write'],
+ endorsements=[]
+ )
+ r = store.SessionStore('localhost', 7000, 0, 'foosecret')
+ with self.assertRaises(store.SessionCreationFailed):
+ r.create(auths, ip, remote_host, user=user)
+
+
+class TestGetSession(TestCase):
+ """Tests for :func:`store.SessionStore.current_session().load`."""
+
+ @mock.patch(f'{store.__name__}.get_application_config')
+ @mock.patch(f'{store.__name__}.rediscluster.StrictRedisCluster')
+ def test_not_a_token(self, mock_get_redis, mock_get_config):
+ """Something other than a JWT is passed."""
+ mock_get_config.return_value = {
+ 'JWT_SECRET': 'barsecret',
+ 'REDIS_HOST': 'redis',
+ 'REDIS_PORT': '1234',
+ 'REDIS_DATABASE': 4
+ }
+ mock_redis = mock.MagicMock()
+ mock_get_redis.return_value = mock_redis
+ with self.assertRaises(store.InvalidToken):
+ store.SessionStore.current_session().load('notatoken')
+
+ @mock.patch(f'{store.__name__}.get_application_config')
+ @mock.patch(f'{store.__name__}.rediscluster.StrictRedisCluster')
+ def test_malformed_token(self, mock_get_redis, mock_get_config):
+ """A JWT with missing claims is passed."""
+ secret = 'barsecret'
+ mock_get_config.return_value = {
+ 'JWT_SECRET': secret,
+ 'REDIS_HOST': 'redis',
+ 'REDIS_PORT': '1234',
+ 'REDIS_DATABASE': 4
+ }
+ mock_redis = mock.MagicMock()
+ mock_get_redis.return_value = mock_redis
+ required_claims = ['session_id', 'nonce']
+ for exc in required_claims:
+ claims = {claim: '' for claim in required_claims if claim != exc}
+ malformed_token = jwt.encode(claims, secret)
+ with self.assertRaises(store.InvalidToken):
+ store.SessionStore.current_session().load(malformed_token)
+
+ @mock.patch(f'{store.__name__}.get_application_config')
+ @mock.patch(f'{store.__name__}.rediscluster.StrictRedisCluster')
+ def test_token_with_bad_encryption(self, mock_get_redis, mock_get_config):
+ """A JWT produced with a different secret is passed."""
+ secret = 'barsecret'
+ mock_get_config.return_value = {
+ 'JWT_SECRET': secret,
+ 'REDIS_HOST': 'redis',
+ 'REDIS_PORT': '1234',
+ 'REDIS_DATABASE': 4
+ }
+ mock_redis = mock.MagicMock()
+ mock_get_redis.return_value = mock_redis
+ start_time = datetime.now(tz=UTC)
+ end_time = start_time + timedelta(seconds=7200)
+ claims = {
+ 'user_id': '1234',
+ 'session_id': 'ajx9043jjx00s',
+ 'nonce': '0039299290099',
+ 'expires': end_time.isoformat()
+ }
+ bad_token = jwt.encode(claims, 'nottherightsecret')
+ with self.assertRaises(store.InvalidToken):
+ store.SessionStore.current_session().load(bad_token)
+
+ @mock.patch(f'{store.__name__}.get_application_config')
+ @mock.patch(f'{store.__name__}.rediscluster.StrictRedisCluster')
+ def test_expired_token(self, mock_get_redis, mock_get_config):
+ """A JWT produced with a different secret is passed."""
+ secret = 'barsecret'
+ mock_get_config.return_value = {
+ 'JWT_SECRET': secret,
+ 'REDIS_HOST': 'redis',
+ 'REDIS_PORT': '1234',
+ 'REDIS_DATABASE': 4
+ }
+ mock_redis = mock.MagicMock()
+ start_time = datetime.now(tz=UTC)
+ mock_redis.get.return_value = json.dumps({
+ 'user_id': '1234',
+ 'session_id': 'ajx9043jjx00s',
+ 'nonce': '0039299290099',
+ 'expires': start_time.isoformat(),
+ })
+ mock_get_redis.return_value = mock_redis
+
+ claims = {
+ 'user_id': '1234',
+ 'session_id': 'ajx9043jjx00s',
+ 'nonce': '0039299290099',
+ 'expires': start_time.isoformat(),
+ }
+ expired_token = jwt.encode(claims, secret)
+ with self.assertRaises(store.InvalidToken):
+ store.SessionStore.current_session().load(expired_token)
+
+ @mock.patch(f'{store.__name__}.get_application_config')
+ @mock.patch(f'{store.__name__}.rediscluster.StrictRedisCluster')
+ def test_forged_token(self, mock_get_redis, mock_get_config):
+ """A JWT with the wrong nonce is passed."""
+ start_time = datetime.now(tz=UTC)
+ end_time = start_time + timedelta(seconds=7200)
+
+ secret = 'barsecret'
+ mock_get_config.return_value = {
+ 'JWT_SECRET': secret,
+ 'REDIS_HOST': 'redis',
+ 'REDIS_PORT': '1234',
+ 'REDIS_DATABASE': 4
+ }
+ mock_redis = mock.MagicMock()
+ mock_redis.get.return_value = jwt.encode({
+ 'session_id': 'ajx9043jjx00s',
+ 'nonce': '0039299290098',
+ 'start_time': start_time.isoformat(),
+ 'end_time': end_time.isoformat(),
+ 'user': {
+ 'user_id': '1235',
+ 'username': 'foouser',
+ 'email': 'foo@foo.com'
+ }
+ }, secret)
+ mock_get_redis.return_value = mock_redis
+
+ claims = {
+ 'user_id': '1234',
+ 'session_id': 'ajx9043jjx00s',
+ 'nonce': '0039299290099', # <- Doesn't match!
+ 'expires': end_time.isoformat(),
+ }
+ expired_token = jwt.encode(claims, secret)
+ with self.assertRaises(store.InvalidToken):
+ # loaded token is getting non Datetime end_time
+ store.SessionStore.current_session().load(expired_token)
+
+ @mock.patch(f'{store.__name__}.get_application_config')
+ @mock.patch(f'{store.__name__}.rediscluster.StrictRedisCluster')
+ def test_other_forged_token(self, mock_get_redis, mock_get_config):
+ """A JWT with the wrong user_id is passed."""
+ start_time = datetime.now(tz=UTC)
+ end_time = start_time + timedelta(seconds=7200)
+
+ secret = 'barsecret'
+ mock_get_config.return_value = {
+ 'JWT_SECRET': secret,
+ 'REDIS_HOST': 'redis',
+ 'REDIS_PORT': '1234',
+ 'REDIS_DATABASE': 4
+ }
+ mock_redis = mock.MagicMock()
+ mock_redis.get.return_value = jwt.encode({
+ 'session_id': 'ajx9043jjx00s',
+ 'nonce': '0039299290099',
+ 'start_time': start_time.isoformat(),
+ 'user': {
+ 'user_id': '1235',
+ 'username': 'foouser',
+ 'email': 'foo@foo.com'
+ }
+ }, secret)
+ mock_get_redis.return_value = mock_redis
+ claims = {
+ 'user_id': '1234', # <- Doesn't match!
+ 'session_id': 'ajx9043jjx00s',
+ 'nonce': '0039299290099',
+ 'expires': end_time.isoformat(),
+ }
+ expired_token = jwt.encode(claims, secret)
+ with self.assertRaises(store.InvalidToken):
+ store.SessionStore.current_session().load(expired_token)
+
+ @mock.patch(f'{store.__name__}.get_application_config')
+ @mock.patch(f'{store.__name__}.rediscluster.StrictRedisCluster')
+ def test_empty_session(self, mock_get_redis, mock_get_config):
+ """Session has been removed, or may never have existed."""
+ start_time = datetime.now(tz=UTC)
+ end_time = start_time + timedelta(seconds=7200)
+
+ secret = 'barsecret'
+ mock_get_config.return_value = {
+ 'JWT_SECRET': secret,
+ 'REDIS_HOST': 'redis',
+ 'REDIS_PORT': '1234',
+ 'REDIS_DATABASE': 4
+ }
+ mock_redis = mock.MagicMock()
+ mock_redis.get.return_value = '' # <- Empty record!
+ mock_get_redis.return_value = mock_redis
+
+ claims = {
+ 'user_id': '1234',
+ 'session_id': 'ajx9043jjx00s',
+ 'nonce': '0039299290099',
+ 'expires': end_time.isoformat(),
+ }
+ expired_token = jwt.encode(claims, secret)
+ with self.assertRaises(store.UnknownSession):
+ store.SessionStore.current_session().load(expired_token)
+
+ @mock.patch(f'{store.__name__}.get_application_config')
+ @mock.patch(f'{store.__name__}.rediscluster.StrictRedisCluster')
+ def test_valid_token(self, mock_get_redis, mock_get_config):
+ """A valid token is passed."""
+ start_time = datetime.now(tz=UTC)
+ end_time = start_time + timedelta(seconds=7200)
+
+ secret = 'barsecret'
+ mock_get_config.return_value = {
+ 'JWT_SECRET': secret,
+ 'REDIS_HOST': 'redis',
+ 'REDIS_PORT': '1234',
+ 'REDIS_DATABASE': 4
+ }
+ mock_redis = mock.MagicMock()
+ mock_redis.get.return_value = jwt.encode({
+ 'session_id': 'ajx9043jjx00s',
+ 'start_time': datetime.now(tz=UTC).isoformat(),
+ 'nonce': '0039299290098',
+ 'user': {
+ 'user_id': '1234',
+ 'username': 'foouser',
+ 'email': 'foo@foo.com'
+ }
+ }, secret)
+ mock_get_redis.return_value = mock_redis
+
+ claims = {
+ 'user_id': '1234',
+ 'session_id': 'ajx9043jjx00s',
+ 'nonce': '0039299290098',
+ 'expires': end_time.isoformat(),
+ }
+ valid_token = jwt.encode(claims, secret)
+
+ session = store.SessionStore.current_session().load(valid_token)
+ self.assertIsInstance(session, domain.Session, "Returns a session")
diff --git a/arxiv/auth/auth/tests/__init__.py b/arxiv/auth/auth/tests/__init__.py
new file mode 100644
index 00000000..ed72440d
--- /dev/null
+++ b/arxiv/auth/auth/tests/__init__.py
@@ -0,0 +1 @@
+"""Tests for :mod:`arxiv.users.auth`."""
diff --git a/arxiv/auth/auth/tests/test_decorators.py b/arxiv/auth/auth/tests/test_decorators.py
new file mode 100644
index 00000000..584860e5
--- /dev/null
+++ b/arxiv/auth/auth/tests/test_decorators.py
@@ -0,0 +1,240 @@
+"""Tests for :mod:`arxiv.users.auth.decorators`."""
+
+import pytest
+from datetime import datetime
+from pytz import timezone, UTC
+
+from flask import request, current_app
+from werkzeug.exceptions import Unauthorized, Forbidden
+
+from .. import scopes, decorators
+from ... import domain
+
+EASTERN = timezone('US/Eastern')
+""" @mock.patch(f'{decorators.__name__}.request')"""
+
+def test_no_session(mocker, request_context):
+ """No session is present on the request."""
+ with request_context:
+ mock_req = mocker.patch(f'{decorators.__name__}.request')
+ mock_req.auth = None
+
+ assert not hasattr(request, 'called')
+
+ @decorators.scoped(scopes.CREATE_SUBMISSION)
+ def _protected():
+ request.called = True
+
+ with pytest.raises(Unauthorized):
+ _protected()
+
+ assert not hasattr(request, 'called'), "The protected function should not have its body called"
+
+
+def test_scope_is_missing(mocker, request_context):
+ """Session does not have required scope."""
+ with request_context:
+ mock_req = mocker.patch(f'{decorators.__name__}.request')
+ mock_req.auth = domain.Session(
+ session_id='fooid',
+ start_time=datetime.now(tz=UTC),
+ user=domain.User(
+ user_id='235678',
+ email='foo@foo.com',
+ username='foouser'
+ ),
+ authorizations=domain.Authorizations(
+ scopes=[scopes.VIEW_SUBMISSION]
+ )
+ )
+
+ assert not hasattr(request, 'called')
+
+ @decorators.scoped(scopes.CREATE_SUBMISSION)
+ def protected():
+ """A protected function."""
+ request.called = True
+
+ with pytest.raises(Forbidden):
+ protected()
+
+ assert not hasattr(request, 'called'), "The protected function should not have its body called"
+
+def test_scope_is_present(mocker, request_context):
+ """Session has required scope."""
+ with request_context:
+ request.auth = domain.Session(
+ session_id='fooid',
+ start_time=datetime.now(tz=UTC),
+ user=domain.User(
+ user_id='235678',
+ email='foo@foo.com',
+ username='foouser'
+ ),
+ authorizations=domain.Authorizations(
+ scopes=[scopes.VIEW_SUBMISSION, scopes.CREATE_SUBMISSION]
+ )
+ )
+
+ assert not hasattr(request, 'called')
+
+ @decorators.scoped(scopes.CREATE_SUBMISSION)
+ def protected():
+ """A protected function."""
+ print("HERE IN PROTECTED")
+ request.called = True
+
+ protected()
+ assert request.called
+
+
+def test_user_and_client_are_missing(mocker, request_context):
+ """Session does not user nor client information."""
+ with request_context:
+ mock_req = mocker.patch(f'{decorators.__name__}.request')
+ mock_req.auth = domain.Session(
+ session_id='fooid',
+ start_time=datetime.now(tz=UTC),
+ authorizations=domain.Authorizations(
+ scopes=[scopes.CREATE_SUBMISSION]
+ )
+ )
+ assert not hasattr(request, 'called')
+ @decorators.scoped(scopes.CREATE_SUBMISSION)
+ def protected():
+ """A protected function."""
+ request.called = True
+
+ with pytest.raises(Unauthorized):
+ protected()
+
+ assert not hasattr(request, 'called'), "The protected function should not have its body called"
+
+def test_authorizer_returns_false(mocker, request_context):
+ """Session has required scope, but authorizer func returns false."""
+ with request_context:
+ mock_req = mocker.patch(f'{decorators.__name__}.request')
+ mock_req.auth = domain.Session(
+ session_id='fooid',
+ start_time=datetime.now(tz=UTC),
+ user=domain.User(
+ user_id='235678',
+ email='foo@foo.com',
+ username='foouser'
+ ),
+ authorizations=domain.Authorizations(
+ scopes=[scopes.CREATE_SUBMISSION]
+ )
+ )
+ assert not hasattr(request, 'called')
+ assert not hasattr(request, 'authorizer_called')
+
+ def return_false(session: domain.Session) -> bool:
+ request.authorizer_called = True
+ return False
+
+ @decorators.scoped(scopes.CREATE_SUBMISSION, authorizer=return_false)
+ def protected():
+ """A protected function."""
+ request.called = True
+
+ with pytest.raises(Forbidden):
+ protected()
+
+ assert not hasattr(request, 'called')
+ assert request.authorizer_called
+
+
+def test_authorizer_returns_true(mocker, request_context):
+ """Session has required scope, authorizer func returns true."""
+ with request_context:
+ mock_req = mocker.patch(f'{decorators.__name__}.request')
+ mock_req.auth = domain.Session(
+ session_id='fooid',
+ start_time=datetime.now(tz=UTC),
+ user=domain.User(
+ user_id='235678',
+ email='foo@foo.com',
+ username='foouser'
+ ),
+ authorizations=domain.Authorizations(
+ scopes=[scopes.CREATE_SUBMISSION]
+ )
+ )
+ assert not hasattr(request, 'called')
+ assert not hasattr(request, 'authorizer_called')
+
+ def return_true(session: domain.Session) -> bool:
+ request.authorizer_called = True
+ return True
+
+ @decorators.scoped(scopes.CREATE_SUBMISSION, authorizer=return_true)
+ def protected():
+ """A protected function."""
+ request.called = True
+
+ protected()
+
+ assert request.called
+ assert request.authorizer_called
+
+def test_session_has_global(mocker, request_context):
+ """Session has global scope, and authorizer func returns false."""
+ with request_context:
+ mock_req = mocker.patch(f'{decorators.__name__}.request')
+ mock_req.auth = domain.Session(
+ session_id='fooid',
+ start_time=datetime.now(tz=UTC),
+ user=domain.User(
+ user_id='235678',
+ email='foo@foo.com',
+ username='foouser'
+ ),
+ authorizations=domain.Authorizations(
+ scopes=[domain.Scope(scopes.CREATE_SUBMISSION).as_global()]
+ )
+ )
+
+ def return_false(session: domain.Session) -> bool:
+ return False
+
+ @decorators.scoped(scopes.CREATE_SUBMISSION, authorizer=return_false)
+ def protected():
+ """A protected function."""
+ request.called = True
+
+ protected()
+ assert request.called
+
+
+def test_session_has_resource_scope(mocker, request_context):
+ """Session has resource scope, and authorizer func returns false."""
+ with request_context:
+ mock_req = mocker.patch(f'{decorators.__name__}.request')
+ mock_req.auth = domain.Session(
+ session_id='fooid',
+ start_time=datetime.now(tz=UTC),
+ user=domain.User(
+ user_id='235678',
+ email='foo@foo.com',
+ username='foouser'
+ ),
+ authorizations=domain.Authorizations(
+ scopes=[domain.Scope(scopes.EDIT_SUBMISSION).for_resource('1')]
+ )
+ )
+
+ def return_false(session: domain.Session) -> bool:
+ return False
+
+ def get_resource(*args, **kwargs) -> bool:
+ return '1'
+
+ @decorators.scoped(scopes.EDIT_SUBMISSION, resource=get_resource,
+ authorizer=return_false)
+ def protected():
+ """A protected function."""
+ request.called = True
+
+ protected()
+ assert request.called
diff --git a/arxiv/auth/auth/tests/test_extension.py b/arxiv/auth/auth/tests/test_extension.py
new file mode 100644
index 00000000..32668382
--- /dev/null
+++ b/arxiv/auth/auth/tests/test_extension.py
@@ -0,0 +1,106 @@
+"""Tests for :class:`arxiv.users.auth.Auth`."""
+from logging import DEBUG
+import pytest
+
+from datetime import datetime
+from pytz import timezone, UTC
+from ... import auth, domain
+
+EASTERN = timezone('US/Eastern')
+
+@pytest.fixture
+def app_with_cookie(app):
+ app.config['CLASSIC_COOKIE_NAME'] = 'foo_cookie'
+ return app
+
+def test_no_session_legacy_available(mocker, app_with_cookie):
+ """No session is present on the request, but database is present."""
+ inst = app_with_cookie.config['arxiv_auth.Auth']
+ auth.logger.setLevel(DEBUG)
+ with app_with_cookie.test_request_context():
+ mock_legacy = mocker.patch(f'{auth.__name__}.legacy')
+ mock_request = mocker.patch(f'{auth.__name__}.request')
+ mock_request.environ = {'auth': None,
+ 'HTTP_COOKIE': 'foo_cookie=sessioncookie123'}
+
+ mock_legacy.is_configured.return_value = True
+ mock_legacy.sessions.load.return_value = None
+
+ inst.load_session()
+ assert mock_request.auth is None
+
+ assert mock_legacy.sessions.load.call_count == 1, "An attempt is made to load a legacy session"
+
+def test_legacy_is_valid(mocker, app_with_cookie):
+ """A valid legacy session is available."""
+ inst = app_with_cookie.config['arxiv_auth.Auth']
+ with app_with_cookie.test_request_context():
+ mock_legacy = mocker.patch(f'{auth.__name__}.legacy')
+ mock_request = mocker.patch(f'{auth.__name__}.request')
+ mock_request.environ = {'auth': None,
+ 'HTTP_COOKIE': 'foo_cookie=sessioncookie123'}
+
+ mock_request.auth = None
+ mock_legacy.is_configured.return_value = True
+ session = domain.Session(
+ session_id='fooid',
+ start_time=datetime.now(tz=UTC),
+ user=domain.User(
+ user_id='235678',
+ email='foo@foo.com',
+ username='foouser'
+ ),
+ authorizations=domain.Authorizations(
+ scopes=[auth.scopes.VIEW_SUBMISSION]
+ )
+ )
+ mock_legacy.sessions.load.return_value = session
+
+ inst.load_session()
+ assert mock_request.auth == session, "Session is attached to the request at auth"
+
+def test_auth_session_rename(mocker, app_with_cookie):
+ """
+ The auth session is accessed via ``request.auth``.
+
+ Per ARXIVNG-1920 using ``request.auth`` is deprecated.
+ """
+ inst = app_with_cookie.config['arxiv_auth.Auth']
+ with app_with_cookie.test_request_context():
+ mock_legacy = mocker.patch(f'{auth.__name__}.legacy')
+ mock_request = mocker.patch(f'{auth.__name__}.request')
+ mock_request.environ = {'auth': None,
+ 'HTTP_COOKIE': 'foo_cookie=sessioncookie123'}
+
+ mock_request.auth = None
+ mock_legacy.is_configured.return_value = True
+
+ authex = domain.Session(
+ session_id='fooid',
+ start_time=datetime.now(tz=UTC),
+ user=domain.User(
+ user_id='235678',
+ email='foo@foo.com',
+ username='foouser'
+ ),
+ authorizations=domain.Authorizations(
+ scopes=[auth.scopes.VIEW_SUBMISSION]
+ )
+ )
+ mock_legacy.sessions.load.return_value = authex
+
+ inst.load_session()
+ assert mock_request.auth == authex, "Auth is attached to the request"
+
+
+
+def test_middleware_exception(mocker, app_with_cookie):
+ """Middleware has passed an exception."""
+ inst = app_with_cookie.config['arxiv_auth.Auth']
+ with app_with_cookie.test_request_context():
+ mock_request = mocker.patch(f'{auth.__name__}.request')
+ mock_request.environ = {'auth': RuntimeError('Nope!')}
+
+
+ with pytest.raises(RuntimeError):
+ inst.load_session()
diff --git a/arxiv/auth/auth/tests/test_middleware.py b/arxiv/auth/auth/tests/test_middleware.py
new file mode 100644
index 00000000..0da509fb
--- /dev/null
+++ b/arxiv/auth/auth/tests/test_middleware.py
@@ -0,0 +1,103 @@
+"""Test :mod:`arxiv.users.auth.middleware`."""
+
+import os
+from unittest import TestCase, mock
+from datetime import datetime
+from pytz import timezone, UTC
+import json
+
+from flask import Flask, Blueprint
+from flask import request, current_app
+
+from arxiv.base.middleware import wrap
+from arxiv import status
+
+from ..middleware import AuthMiddleware
+from .. import tokens, scopes
+from ... import domain
+
+EASTERN = timezone('US/Eastern')
+
+
+blueprint = Blueprint('fooprint', __name__, url_prefix='')
+
+
+@blueprint.route('/public', methods=['GET'])
+def public(): # type: ignore
+ """Return the request auth as a JSON document, or raise exceptions."""
+ data = request.environ.get('auth')
+ if data:
+ if isinstance(data, Exception):
+ raise data
+ return domain.Session.parse_obj(data).json_safe_dict()
+ return json.dumps({})
+
+
+class TestAuthMiddleware(TestCase):
+ """Test :class:`.AuthMiddleware` on a Flask app."""
+
+ def setUp(self):
+ """Instantiate an app and attach middlware."""
+ self.secret = 'foosecret'
+ self.app = Flask('foo')
+ os.environ['JWT_SECRET'] = self.secret
+ self.app.register_blueprint(blueprint)
+ wrap(self.app, [AuthMiddleware])
+ self.client = self.app.test_client()
+
+ def test_no_token(self):
+ """No token is passed in the request."""
+ response = self.client.get('/public')
+ data = json.loads(response.data)
+ self.assertEqual(data, {}, "No session data is set")
+
+ def test_user_token(self):
+ """A valid user token is passed in the request."""
+ session = domain.Session(
+ session_id='foo1234',
+ start_time=datetime.now(tz=UTC),
+ user=domain.User(
+ user_id='43219',
+ email='foo@foo.com',
+ username='foouser',
+ name=domain.UserFullName(forename='Foo', surname='User')
+ ),
+ # client: Optional[Client] = None
+ # end_time: Optional[datetime] = None
+ authorizations=domain.Authorizations(
+ scopes=scopes.GENERAL_USER
+ ),
+ ip_address='10.10.10.10',
+ remote_host='foo-host.something.com',
+ nonce='asdfkjl3kmalkml;xml;mla'
+ )
+ token = tokens.encode(session, self.secret)
+ response = self.client.get('/public', headers={'Authorization': token})
+ data = json.loads(response.data)
+ self.assertEqual(data, session.json_safe_dict(),
+ "Session data are added to the request")
+
+ def test_forged_user_token(self):
+ """A forged user token is passed in the request."""
+ session = domain.Session(
+ session_id='foo1234',
+ start_time=datetime.now(tz=UTC),
+ user=domain.User(
+ user_id='43219',
+ email='foo@foo.com',
+ username='foouser',
+ name=domain.UserFullName(forename='Foo', surname='User')
+ ),
+ # client: Optional[Client] = None
+ # end_time: Optional[datetime] = None
+ authorizations=domain.Authorizations(
+ scopes=scopes.GENERAL_USER
+ ),
+ ip_address='10.10.10.10',
+ remote_host='foo-host.something.com',
+ nonce='asdfkjl3kmalkml;xml;mla'
+ )
+ token = tokens.encode(session, 'notthesecret')
+ response = self.client.get('/public', headers={'Authorization': token})
+ self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED,
+ "A 401 exception is passed by the middleware")
diff --git a/arxiv/auth/auth/tests/test_tokens.py b/arxiv/auth/auth/tests/test_tokens.py
new file mode 100644
index 00000000..2524a946
--- /dev/null
+++ b/arxiv/auth/auth/tests/test_tokens.py
@@ -0,0 +1,65 @@
+"""Tests for :mod:`arxiv.users.auth.tokens`."""
+
+from unittest import TestCase
+from datetime import datetime
+
+from arxiv.taxonomy.definitions import CATEGORIES
+from .. import tokens
+from ... import domain
+from ...auth import scopes
+
+
+class TestEncodeDecode(TestCase):
+ """Tests for :func:`tokens.encode` and :func:`tokens.decode`."""
+
+ def test_encode_session(self):
+ """Encode a typical user session."""
+ session = domain.Session(
+ session_id='asdf1234',
+ start_time=datetime.now(), end_time=datetime.now(),
+ user=domain.User(
+ user_id='12345',
+ email='foo@bar.com',
+ username='emanresu',
+ name=domain.UserFullName(forename='First', surname='Last', suffix='Lastest'),
+ profile=domain.UserProfile(
+ affiliation='FSU',
+ rank=3,
+ country='us',
+ default_category=CATEGORIES['astro-ph.CO'],
+ submission_groups=['grp_physics']
+ )
+ ),
+ authorizations=domain.Authorizations(scopes=[scopes.VIEW_SUBMISSION, scopes.CREATE_SUBMISSION])
+ )
+ secret = 'foosecret'
+ token = tokens.encode(session, secret)
+
+ data = tokens.decode(token, secret)
+ self.assertEqual(session, data)
+
+ def test_mismatched_secrets(self):
+ """Secret used to encode is not the same as the one used to decode."""
+ session = domain.Session(
+ session_id='asdf1234',
+ start_time=datetime.now(), end_time=datetime.now(),
+ user=domain.User(
+ user_id='12345',
+ email='foo@bar.com',
+ username='emanresu',
+ name=domain.UserFullName(forename='First', surname='Last', suffix='Lastest'),
+ profile=domain.UserProfile(
+ affiliation='FSU',
+ rank=3,
+ country='us',
+ default_category=CATEGORIES['astro-ph.CO'],
+ submission_groups=['grp_physics']
+ )
+ ),
+ authorizations=domain.Authorizations(scopes=[scopes.VIEW_SUBMISSION, scopes.CREATE_SUBMISSION])
+ )
+ secret = 'foosecret'
+ token = tokens.encode(session, secret)
+
+ with self.assertRaises(tokens.exceptions.InvalidToken):
+ tokens.decode(token, 'not the secret')
diff --git a/arxiv/auth/auth/tokens.py b/arxiv/auth/auth/tokens.py
new file mode 100644
index 00000000..ac2ccc5a
--- /dev/null
+++ b/arxiv/auth/auth/tokens.py
@@ -0,0 +1,49 @@
+"""
+Functions for working with authn/z tokens on user/client requests.
+
+Encrypted JSON Web Tokens can be used inside the arXiv system to securely
+convey authn/z information about each request. These tokens will usually be
+generated by the :mod:`authorizer` in response to an
+authorization subrequest from the web server, and contain information about
+the identity of the user and/or client as well as authorization information
+(e.g. :mod:`arxiv.users.auth.scopes`).
+
+It is essential that these JWTs are encrypted and decrypted precisely the same
+way in all arXiv services, so we include these routines here for convenience.
+
+"""
+
+import jwt
+from . import exceptions
+from .. import domain
+
+
+def encode(session: domain.Session, secret: str) -> str:
+ """
+ Encode session information as an encrypted JWT.
+
+ Parameters
+ ----------
+ session : :class:`.domain.Session`
+ User or client session data, including authorization information.
+ secret : str
+ A secret key used to encrypt the token. This secret is required to
+ decode the token later on (e.g. in the application handling the
+ request).
+
+ Returns
+ -------
+ str
+ An encrypted JWT.
+
+ """
+ return jwt.encode(session.json_safe_dict(), secret)
+
+
+def decode(token: str, secret: str) -> domain.Session:
+ """Decode an auth token to access session information."""
+ try:
+ data = dict(jwt.decode(token, secret, algorithms=['HS256']))
+ except jwt.exceptions.DecodeError as e:
+ raise exceptions.InvalidToken('Not a valid token') from e
+ return domain.session_from_dict(data)
diff --git a/arxiv/auth/conftest.py b/arxiv/auth/conftest.py
new file mode 100644
index 00000000..9d7b2483
--- /dev/null
+++ b/arxiv/auth/conftest.py
@@ -0,0 +1,44 @@
+import pytest
+import os
+# os.environ['CLASSIC_DB_URI'] = 'sqlite:///:memory:'
+
+from flask import Flask
+
+from ..base import Base
+from ..config import Settings
+from ..base.middleware import wrap
+from ..db.models import configure_db
+
+from ..auth.auth import Auth
+from ..auth.auth.middleware import AuthMiddleware
+
+
+@pytest.fixture()
+def app():
+
+ app = Flask('test_auth_app')
+ app.config['CLASSIC_DATABASE_URI'] = 'sqlite:///test.db'
+ app.config['CLASSIC_SESSION_HASH'] = f'fake set in {__file__}'
+ app.config['SESSION_DURATION'] = f'fake set in {__file__}'
+ app.config['CLASSIC_COOKIE_NAME'] = f'fake set in {__file__}'
+ app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///test.db'
+ app.config['AUTH_UPDATED_SESSION_REF'] = True
+ settings = Settings (
+ CLASSIC_DB_URI = 'sqlite:///test.db',
+ LATEXML_DB_URI = None
+ )
+ engine, _ = configure_db(settings)
+ app.config['DB_ENGINE'] = engine
+
+ Base(app)
+
+ Auth(app)
+ wrap(app, [AuthMiddleware])
+
+
+ return app
+
+
+@pytest.fixture()
+def request_context(app):
+ yield app.test_request_context()
diff --git a/arxiv/auth/domain.py b/arxiv/auth/domain.py
new file mode 100644
index 00000000..a655a7d2
--- /dev/null
+++ b/arxiv/auth/domain.py
@@ -0,0 +1,357 @@
+"""Defines user concepts for use in arXiv services."""
+
+
+from typing import Any, Optional, List, NamedTuple
+from collections.abc import Iterable
+
+from datetime import datetime
+from pytz import timezone, UTC
+
+from pydantic import BaseModel, ConfigDict, ValidationError, validator
+from arxiv.taxonomy.category import Category
+from arxiv.taxonomy import definitions
+from arxiv.db.models import Demographic
+
+EASTERN = timezone('US/Eastern')
+
+STAFF = ('1', 'Staff')
+PROFESSOR = ('2', 'Professor')
+POST_DOC = ('3', 'Post doc')
+GRAD_STUDENT = ('4', 'Grad student')
+OTHER = ('5', 'Other')
+RANKS = [STAFF, PROFESSOR, POST_DOC, GRAD_STUDENT, OTHER]
+
+
+def _check_category(data: Any) -> Category:
+ if isinstance(data, Category):
+ return data
+ if not isinstance(data, str):
+ raise ValidationError(f"object of type {type(data)} cannnot be used as a Category", Category)
+ cat = Category(data)
+ cat.name # possible rasie value error on non-existance
+ return cat
+
+
+class UserProfile(BaseModel):
+ """User profile data."""
+
+ model_config = ConfigDict(arbitrary_types_allowed=True)
+
+ affiliation: str
+ """Institutional affiliation."""
+
+ country: str
+ """Should be an ISO 3166-1 alpha-2 country code."""
+
+ rank: int
+ """Academic rank. Must be one of :const:`.RANKS`."""
+
+ submission_groups: List[str]
+ """
+ Groups to which the user prefers to submit.
+
+ Items should be one of :ref:`arxiv.taxonomy.definitions.GROUPS`.
+ """
+
+ default_category: Optional[Category]
+ """
+ Default submission category.
+
+ Should be one of :ref:`arxiv.taxonomy.CATEGORIES`.
+ """
+
+ # @validator('default_category')
+ # @classmethod
+ # def check_category(cls, data: Any) -> Category:
+ # """Checks if `data` is a category."""
+ # return _check_category(data)
+
+ homepage_url: str = ''
+ """User's homepage or external profile URL."""
+
+ remember_me: bool = True
+ """Indicates whether the user prefers permanent session cookies."""
+
+ @property
+ def rank_display(self) -> str:
+ """The display name of the user's rank."""
+ _rank: str = dict(RANKS)[str(self.rank)]
+ return _rank
+
+ @property
+ def default_archive(self) -> Optional[str]:
+ """The archive of the default category."""
+ return self.default_category.in_archive if self.default_category else None
+
+ @property
+ def default_subject(self) -> Optional[str]:
+ """The subject of the default category."""
+ if self.default_category is not None:
+ subject: str
+ if '.' in self.default_category.id:
+ subject = self.default_category.id.split('.', 1)[1]
+ else:
+ subject = self.default_category.id
+ return subject
+ return None
+
+ @property
+ def groups_display(self) -> str:
+ """Display-ready representation of active groups for this profile."""
+ return ", ".join([
+ definitions.GROUPS[group]['name']
+ for group in self.submission_groups
+ ])
+
+ @staticmethod
+ def from_orm (model: Demographic) -> 'UserProfile':
+ if model.subject_class:
+ category = definitions.CATEGORIES[f'{model.archive}.{model.subject_class}']
+ elif model.archive:
+ category = definitions.CATEGORIES[f'{model.archive}']
+ else:
+ category = None
+
+ return UserProfile(
+ affiliation=model.affiliation,
+ country=model.country,
+ rank=model.type,
+ submission_groups=model.groups,
+ default_category=category,
+ homepage_url=model.url,
+ )
+
+
+class Scope(str):
+ """Represents an authorization policy."""
+
+ def __new__(cls, domain, action=None, resource=None):
+ """Handle __new__."""
+ return str.__new__(cls, cls.from_parts(domain, action, resource))
+
+ @property
+ def domain(self) -> str:
+ """
+ The domain to which the scope applies.
+
+ This will generally refer to a specific service.
+ """
+ return self.parts[0]
+
+ @property
+ def action(self) -> str:
+ """An action within ``domain`."""
+ return self.parts[1]
+
+ @property
+ def resource(self) -> Optional[str]:
+ """The specific resource to which this policy applies."""
+ return self.parts[2]
+
+ @property
+ def parts(self) -> str:
+ """Get parts of the Scope."""
+ parts = self.split(':')
+ parts = parts + [None] * (3 - len(parts))
+ return parts
+
+ def for_resource(self, resource_id: str) -> 'Scope':
+ """Create a copy of this scope with a specific resource."""
+ return Scope(domain=self.domain, action=self.action, resource=resource_id)
+
+ def as_global(self) -> 'Scope':
+ """Create a copy of this scope with a global resource."""
+ return self.for_resource('*')
+
+ @classmethod
+ def from_parts(cls, domain, action=None, resource=None):
+ """Create a scope string from parts."""
+ return ":".join([o for o in [domain,action,resource] if o is not None])
+
+ @classmethod
+ def to_parts(cls, scopestr):
+ """Split a scop string to parts."""
+ parts = scopestr.split(':')
+ return parts + [None] * (3 - len(parts))
+
+ @classmethod
+ def from_str(cls, scopestr:str) -> "Scope":
+ """Make a Scope from a string."""
+ parts = cls.to_parts(scopestr)
+ return cls(domain=parts[0], action=parts[1], resource=parts[2])
+
+
+class Authorizations(BaseModel):
+ """Authorization information, e.g. associated with a :class:`.Session`."""
+
+ model_config = ConfigDict(arbitrary_types_allowed=True)
+
+ classic: int = 0
+ """Capability code associated with a user's session."""
+
+ scopes: List[str] = []
+ """Authorized :class:`.scope`s. See also :mod:`arxiv.users.auth.scopes`."""
+
+
+ @classmethod
+ def before_init(cls, data: dict) -> None:
+ if 'scopes' in data:
+ if type(data['scopes']) is str:
+ data['scopes'] = [
+ Scope(*scope.split(':')) for scope
+ in data['scopes'].split()
+ ]
+ elif type(data['scopes']) is list:
+ data['scopes'] = [
+ Scope(**scope) if type(scope) is dict
+ else Scope(*scope.split(':'))
+ for scope in data['scopes']
+ ]
+
+
+class UserFullName(BaseModel):
+ """Represents a user's full name."""
+
+ forename: str
+ """First name or given name."""
+
+ surname: str
+ """Last name or family name."""
+
+ suffix: Optional[str] = ''
+ """Any title or qualifier used as a suffix/postfix."""
+
+
+class User(BaseModel):
+ """Represents an arXiv user and their authorizations."""
+
+ username: str
+ """Slug-like username."""
+
+ email: str
+ """The user's primary e-mail address."""
+
+ user_id: Optional[str] = None
+ """Unique identifier for the user. If ``None``, the user does not exist."""
+
+ name: Optional[UserFullName] = None
+ """The user's full name (if available)."""
+
+ profile: Optional[UserProfile] = None
+ """The user's account profile (if available)."""
+
+ verified: bool = False
+ """Whether or not the users' e-mail address has been verified."""
+
+ # def asdict(self) -> dict:
+ # """Generate a dict representation of this :class:`.User`."""
+ # data = super(User, self)._asdict()
+ # if self.name is not None:
+ # data['name'] = self.name._asdict()
+ # if self.profile is not None:
+ # data['profile'] = self.profile._asdict()
+ # return data
+
+ # TODO: consider whether this information is relevant beyond the
+ # ``arxiv.users.legacy.authenticate`` module.
+ #
+ # approved: bool = True
+ # """Whether or not the users' account is approved."""
+ #
+ # banned: bool = False
+ # """Whether or not the user has been banned."""
+ #
+ # deleted: bool = False
+ # """Whether or not the user has been deleted."""
+
+
+class Client(BaseModel):
+ """API client."""
+
+ owner_id: str
+ """The arXiv user responsible for the client."""
+
+ client_id: Optional[str] = None
+ """Unique identifier for a :class:`.Client`."""
+
+ name: Optional[str] = None
+ """Human-friendly name of the API client."""
+
+ url: Optional[str] = None
+ """Homepage or other resource describing the API client."""
+
+ description: Optional[str] = None
+ """Brief description of the API client."""
+
+ redirect_uri: Optional[str] = None
+ """The authorized redirect URI for the client."""
+
+
+class Session(BaseModel):
+ """Represents an authenticated session in the arXiv system."""
+
+ model_config = ConfigDict(arbitrary_types_allowed=True)
+
+ session_id: str
+ """Unique identifier for the session."""
+
+ start_time: datetime
+ """The ISO-8601 datetime when the session was created."""
+
+ user: Optional[User] = None
+ """The user for which the session was created."""
+
+ client: Optional[Client] = None
+ """The client for which the session was created."""
+
+ end_time: Optional[datetime] = None
+ """The ISO-8601 datetime when the session ended."""
+
+ authorizations: Optional[Authorizations] = None
+ """Authorizations for the current session."""
+
+ ip_address: Optional[str] = None
+ """The IP address of the client for which the session was created."""
+
+ remote_host: Optional[str] = None
+ """The hostname of the client for which the session was created."""
+
+ nonce: Optional[str] = None
+ """A pseudo-random nonce generated when the session was created."""
+
+ def is_authorized(self, scope: Scope, resource: str) -> bool:
+ """Check whether this session is authorized for a specific resource."""
+ return (self.authorizations is not None and (
+ scope.as_global() in self.authorizations.scopes
+ or scope.for_resource(resource) in self.authorizations.scopes))
+
+ @property
+ def expired(self) -> bool:
+ """Expired if the current time is later than :attr:`.end_time`."""
+ return bool(self.end_time is not None
+ and datetime.now(tz=UTC) >= self.end_time)
+
+ @property
+ def expires(self) -> Optional[int]:
+ """
+ Number of seconds until the session expires.
+
+ If the session is already expired, returns 0.
+ """
+ if self.end_time is None:
+ return None
+ duration = (self.end_time - datetime.now(tz=UTC)).total_seconds()
+ return int(max(duration, 0))
+
+ def json_safe_dict(self) -> dict:
+ """Creates a json dict with the datetimes converted to ISO datetime strs."""
+ out = self.dict()
+ if self.start_time:
+ out['start_time'] = self.start_time.isoformat()
+ if self.end_time:
+ out['end_time'] = self.end_time.isoformat()
+ return out
+
+def session_from_dict(data: dict) -> Session:
+ """Create a Session from a dict."""
+ return Session.parse_obj(data)
diff --git a/arxiv/auth/helpers.py b/arxiv/auth/helpers.py
new file mode 100644
index 00000000..47f75576
--- /dev/null
+++ b/arxiv/auth/helpers.py
@@ -0,0 +1,49 @@
+"""Helpers and utilities for :mod:`arxiv.users`."""
+
+import os
+from typing import List
+from pytz import timezone, UTC
+import uuid
+from datetime import timedelta, datetime
+from . import auth, domain
+from arxiv.base.globals import get_application_config
+from arxiv.taxonomy.definitions import CATEGORIES
+
+
+def generate_token(user_id: str, email: str, username: str,
+ first_name: str = 'Jane', last_name: str = 'Doe',
+ suffix_name: str = 'IV',
+ affiliation: str = 'Cornell University',
+ rank: int = 3,
+ country: str = 'us',
+ default_category: domain.Category = CATEGORIES['astro-ph.GA'],
+ submission_groups: str = 'grp_physics',
+ scope: List[domain.Scope] = [],
+ verified: bool = False) -> str:
+ """Generate an auth token for dev/testing purposes."""
+ # Specify the validity period for the session.
+ start = datetime.now(tz=timezone('US/Eastern'))
+ end = start + timedelta(seconds=36000) # Make this as long as you want.
+
+ # Create a user
+ session = domain.Session(
+ session_id=str(uuid.uuid4()),
+ start_time=start, end_time=end,
+ user=domain.User(
+ user_id=user_id,
+ email=email,
+ username=username,
+ name=domain.UserFullName(forename=first_name, surname=last_name, suffix=suffix_name),
+ profile=domain.UserProfile(
+ affiliation=affiliation,
+ rank=int(rank),
+ country=country,
+ default_category=default_category,
+ submission_groups=submission_groups.split(',')
+ ),
+ verified=verified
+ ),
+ authorizations=domain.Authorizations(scopes=scope)
+ )
+ token = auth.tokens.encode(session, get_application_config()['JWT_SECRET'])
+ return token
diff --git a/arxiv/auth/legacy/__init__.py b/arxiv/auth/legacy/__init__.py
new file mode 100644
index 00000000..820be948
--- /dev/null
+++ b/arxiv/auth/legacy/__init__.py
@@ -0,0 +1,8 @@
+"""
+Integrations with the legacy arXiv database for users and sessions.
+
+This package provides integrations with legacy user and sessions data in the
+classic DB. These components were pulled out as a separate package because
+they are required by both the accounts service and the authn/z middlware,
+and maintaining them in both places would create too much duplication.
+"""
diff --git a/arxiv/auth/legacy/accounts.py b/arxiv/auth/legacy/accounts.py
new file mode 100644
index 00000000..0f418b31
--- /dev/null
+++ b/arxiv/auth/legacy/accounts.py
@@ -0,0 +1,304 @@
+"""Provide methods for working with user accounts."""
+
+from typing import Optional, Tuple, Any
+import logging
+
+from sqlalchemy.exc import OperationalError
+
+from .. import domain
+from . import util, exceptions
+from .passwords import hash_password
+from .exceptions import Unavailable
+from ...db import session
+from ...db.models import TapirUser, TapirUsersPassword, \
+ TapirNickname, Demographic
+
+
+logger = logging.getLogger(__name__)
+
+
+def does_username_exist(username: str) -> bool:
+ """
+ Determine whether a user with a particular username already exists.
+
+ Parameters
+ ----------
+ username : str
+
+ Returns
+ -------
+ bool
+
+ """
+ try:
+ data = session.query(TapirNickname) \
+ .filter(TapirNickname.nickname == username) \
+ .first()
+ except OperationalError as e:
+ raise Unavailable('Database is temporarily unavailable') from e
+ if data:
+ return True
+ return False
+
+
+def does_email_exist(email: str) -> bool:
+ """
+ Determine whether a user with a particular address already exists.
+
+ Parameters
+ ----------
+ email : str
+
+ Returns
+ -------
+ bool
+
+ """
+ try:
+ data = session.query(TapirUser).filter(TapirUser.email == email).first()
+ print (session.query(TapirUser).all())
+ except OperationalError as e:
+ raise Unavailable('Database is temporarily unavailable') from e
+ if data:
+ return True
+ return False
+
+
+def register(user: domain.User, password: str, ip: str,
+ remote_host: str) -> Tuple[domain.User, domain.Authorizations]:
+ """
+ Create a new user.
+
+ Parameters
+ ----------
+ user : :class:`.domain.User`
+ User data for the new account.
+ password : str
+ Password for the account.
+ ip : str
+ The IP address of the client requesting the registration.
+ remote_host : str
+ The remote hostname of the client requesting the registration.
+
+ Returns
+ -------
+ :class:`.domain.User`
+ Data about the created user.
+ :class:`.domain.Authorizations`
+ Privileges attached to the created user.
+
+ """
+ try:
+ db_user, db_nick, db_profile = _create(user, password, ip, remote_host)
+ session.commit()
+ except OperationalError as e:
+ raise Unavailable('Database is temporarily unavailable') from e
+ except Exception as e:
+ logger.debug(e)
+ raise exceptions.RegistrationFailed('Could not create user')# from e
+
+ user = domain.User(
+ user_id=str(db_user.user_id),
+ username=db_nick.nickname,
+ email=db_user.email,
+ name=domain.UserFullName(
+ forename=db_user.first_name,
+ surname=db_user.last_name,
+ suffix=db_user.suffix_name
+ ),
+ profile=domain.UserProfile.from_orm(db_profile) if db_profile is not None else None
+ )
+ auths = domain.Authorizations(
+ classic=util.compute_capabilities(db_user),
+ scopes=util.get_scopes(db_user),
+ )
+ return user, auths
+
+
+def get_user_by_id(user_id: str) -> domain.User:
+ """Load user data from the database."""
+ try:
+ db_user, db_nick, db_profile = _get_user_data(user_id)
+ except OperationalError as e:
+ raise Unavailable('Database is temporarily unavailable') from e
+ user = domain.User(
+ user_id=str(db_user.user_id),
+ username=db_nick.nickname,
+ email=db_user.email,
+ name=domain.UserFullName(
+ forename=db_user.first_name,
+ surname=db_user.last_name,
+ suffix=db_user.suffix_name
+ ),
+ profile=domain.UserProfile.from_orm(db_profile) if db_profile is not None else None
+ )
+ return user
+
+
+def update(user: domain.User) -> Tuple[domain.User, domain.Authorizations]:
+ """Update a user in the database."""
+ if user.user_id is None:
+ raise ValueError('User ID must be set')
+
+ db_user, db_nick, db_profile = _get_user_data(user.user_id)
+ # TODO: we probably want to think a bit more about changing usernames
+ # and e-mail addresses.
+ #
+ # _update_field_if_changed(db_nick, 'nickname', user.username)
+ # _update_field_if_changed(db_user, 'email', user.email)
+ if user.name is not None:
+ _update_field_if_changed(db_user, 'first_name', user.name.forename)
+ _update_field_if_changed(db_user, 'last_name', user.name.surname)
+ _update_field_if_changed(db_user, 'suffix_name', user.name.suffix)
+ if user.profile is not None:
+ if db_profile is not None:
+ def _has_group(group: str) -> int:
+ if user.profile is None:
+ return 0
+ return int(group in user.profile.submission_groups)
+
+ _update_field_if_changed(db_profile, 'affiliation',
+ user.profile.affiliation)
+ _update_field_if_changed(db_profile, 'country',
+ user.profile.country)
+ _update_field_if_changed(db_profile, 'type', user.profile.rank)
+ _update_field_if_changed(db_profile, 'url',
+ user.profile.homepage_url)
+ _update_field_if_changed(db_profile, 'archive',
+ user.profile.default_archive)
+ _update_field_if_changed(db_profile, 'subject_class',
+ user.profile.default_subject)
+ for grp, field in Demographic.GROUP_FLAGS:
+ _update_field_if_changed(db_profile, field,
+ _has_group(grp))
+ session.add(db_profile)
+ else:
+ db_profile = _create_profile(user, db_user)
+
+ session.add(db_nick)
+ session.add(db_user)
+ session.commit()
+
+ user = domain.User(
+ user_id=str(db_user.user_id),
+ username=db_nick.nickname,
+ email=db_user.email,
+ name=domain.UserFullName(
+ forename=db_user.first_name,
+ surname=db_user.last_name,
+ suffix=db_user.suffix_name
+ ),
+ profile=domain.UserProfile.from_orm(db_profile) if db_profile is not None else None
+ )
+ auths = domain.Authorizations(
+ classic=util.compute_capabilities(db_user),
+ scopes=util.get_scopes(db_user),
+ )
+ return user, auths
+
+
+def _update_field_if_changed(obj: Any, field: Any, update_with: Any) -> None:
+ if getattr(obj, field) != update_with:
+ setattr(obj, field, update_with)
+
+
+def _get_user_data(user_id: str) -> Tuple[TapirUser, TapirNickname, Demographic]:
+
+ try:
+ db_user, db_nick = session.query(TapirUser, TapirNickname) \
+ .filter(TapirUser.user_id == user_id) \
+ .filter(TapirUser.flag_approved == 1) \
+ .filter(TapirUser.flag_deleted == 0) \
+ .filter(TapirUser.flag_banned == 0) \
+ .filter(TapirNickname.flag_primary == 1) \
+ .filter(TapirNickname.flag_valid == 1) \
+ .filter(TapirNickname.user_id == TapirUser.user_id) \
+ .first()
+ except TypeError: # first() returns a single None if no match.
+ raise exceptions.NoSuchUser('User does not exist')
+ # Profile may not exist.
+ db_profile = session.query(Demographic) \
+ .filter(Demographic.user_id == user_id) \
+ .first()
+ if not db_user:
+ raise exceptions.NoSuchUser('User does not exist')
+ return db_user, db_nick, db_profile
+
+
+def _create_profile(user: domain.User, db_user: TapirUser) -> Demographic:
+ def _has_group(group: str) -> int:
+ if user.profile is None:
+ return 0
+ return int(group in user.profile.submission_groups)
+
+ db_profile = Demographic(
+ user=db_user,
+ country=user.profile.country if user.profile else None,
+ affiliation=user.profile.affiliation if user.profile else None,
+ url=user.profile.homepage_url if user.profile else None,
+ type=user.profile.rank if user.profile else None,
+ archive=user.profile.default_archive if user.profile else None,
+ subject_class=user.profile.default_subject if user.profile else None,
+ original_subject_classes='',
+ flag_group_physics=_has_group('grp_physics'),
+ flag_group_math=_has_group('grp_math'),
+ flag_group_cs=_has_group('grp_cs'),
+ flag_group_q_bio=_has_group('grp_q-bio'),
+ flag_group_q_fin=_has_group('grp_q-fin'),
+ flag_group_stat=_has_group('grp_stat'),
+ flag_group_eess=_has_group('grp_eess'),
+ flag_group_econ=_has_group('grp_econ'),
+ )
+ session.add(db_profile)
+ return db_profile
+
+
+def _create(user: domain.User, password: str, ip: str, remote_host: str) \
+ -> Tuple[TapirUser, TapirNickname, Optional[Demographic]]:
+ if not user.name.forename:
+ raise ValueError("Must have forename to create user")
+ if not user.name.surname:
+ raise ValueError("Must have surname to create user")
+ data = dict(
+ email=user.email,
+ policy_class=2,
+ joined_ip_num=ip,
+ joined_remote_host=remote_host,
+ joined_date=util.now(),
+ tracking_cookie='' # TODO: set this.
+ )
+ if user.name is not None:
+ data.update(dict(
+ first_name=user.name.forename,
+ last_name=user.name.surname,
+ suffix_name=user.name.suffix
+ ))
+ print (data)
+
+ # Main user entry.
+ db_user = TapirUser(**data)
+ session.add(db_user)
+ # Nickname is where we keep the username.
+ db_nick = TapirNickname(
+ user=db_user,
+ nickname=user.username,
+ flag_valid=1,
+ flag_primary=1
+ )
+ session.add(db_nick)
+
+ db_profile: Optional[Demographic]
+ if user.profile is not None:
+ db_profile = _create_profile(user, db_user)
+ else:
+ db_profile = None
+
+ db_pass = TapirUsersPassword(
+ user=db_user,
+ password_storage=2,
+ password_enc=hash_password(password)
+ )
+ session.add(db_pass)
+ from sqlalchemy import select
+ print(session.execute(select(TapirUser.email)).all())
+ return db_user, db_nick, db_profile
diff --git a/arxiv/auth/legacy/authenticate.py b/arxiv/auth/legacy/authenticate.py
new file mode 100644
index 00000000..8992e81f
--- /dev/null
+++ b/arxiv/auth/legacy/authenticate.py
@@ -0,0 +1,328 @@
+"""Provide an API for user authentication using the legacy database."""
+
+from typing import Optional, Tuple
+import logging
+
+from sqlalchemy.exc import OperationalError
+
+from . import util
+from .. import domain
+
+from . passwords import check_password, is_ascii
+from ...db import session
+from ...db.models import TapirUser, TapirUsersPassword, TapirPermanentToken, \
+ TapirNickname, Demographic
+from .exceptions import NoSuchUser, AuthenticationFailed, \
+ PasswordAuthenticationFailed, Unavailable
+
+logger = logging.getLogger(__name__)
+
+PassData = Tuple[TapirUser, TapirUsersPassword, TapirNickname, Demographic]
+
+
+def authenticate(username_or_email: Optional[str] = None,
+ password: Optional[str] = None, token: Optional[str] = None) \
+ -> Tuple[domain.User, domain.Authorizations]:
+ """
+ Validate username/password. If successful, retrieve user details.
+
+ Parameters
+ ----------
+ username_or_email : str
+ Users may log in with either their username or their email address.
+ password : str
+ Password (as entered). Danger, Will Robinson!
+ token : str
+ Alternatively, the user may provide a bearer token. This is currently
+ used to support "permanent" sessions, in which the token is used to
+ "automatically" log the user in (i.e. without entering credentials).
+
+ Returns
+ -------
+ :class:`domain.User`
+ :class:`domain.Authorizations`
+
+ Raises
+ ------
+ :class:`AuthenticationFailed`
+ Failed to authenticate user with provided credentials.
+ :class:`Unavailable`
+ Unable to connect to DB.
+ """
+ try:
+ if username_or_email and password:
+ passdata = _authenticate_password(username_or_email, password)
+ # The "tapir permanent token" is effectively a bearer token. If passed,
+ # a new session will be "automatically" created (from the user's
+ # perspective).
+ elif token:
+ db_token = _authenticate_token(token)
+ passdata = _get_user_by_user_id(db_token.user_id)
+ else:
+ logger.debug('Neither username/password nor token provided')
+ raise AuthenticationFailed('Username+password or token required')
+ except OperationalError as e:
+ # Note OperationalError can be a lot of different things and not just
+ # the DB being unavailable. So this message can be deceptive.
+ raise Unavailable('Database is temporarily unavailable') from e
+ except Exception as ex:
+ raise AuthenticationFailed() from ex
+
+ db_user, _, db_nick, db_profile = passdata
+ user = domain.User(
+ user_id=str(db_user.user_id),
+ username=db_nick.nickname,
+ email=db_user.email,
+ name=domain.UserFullName(
+ forename=db_user.first_name,
+ surname=db_user.last_name,
+ suffix=db_user.suffix_name
+ ),
+ profile=domain.UserProfile.from_orm(db_profile) if db_profile else None,
+ verified=bool(db_user.flag_email_verified)
+ )
+ auths = domain.Authorizations(
+ classic=util.compute_capabilities(db_user),
+ scopes=util.get_scopes(db_user),
+ )
+ return user, auths
+
+
+def _authenticate_token(token: str) -> TapirPermanentToken:
+ """
+ Authenticate using a permanent token.
+
+ Parameters
+ ----------
+ token : str
+
+ Returns
+ -------
+ :class:`.TapirUser`
+ :class:`.TapirPermanentToken`
+ :class:`.TapirNickname`
+ :class:`.Demographic`
+
+ Raises
+ ------
+ :class:`AuthenticationFailed`
+ Raised if the token is malformed, or there is no corresponding token
+ in the database.
+
+ """
+ try:
+ user_id, secret = token.split('-')
+ except ValueError as e:
+ raise AuthenticationFailed('Token is malformed') from e
+ try:
+ return _get_token(user_id, secret)
+ except NoSuchUser as e:
+ logger.debug('Not a valid permanent token')
+ raise AuthenticationFailed('Invalid token') from e
+
+
+def _authenticate_password(username_or_email: str, password: str) -> PassData:
+ """
+ Authenticate using username/email and password.
+
+ Parameters
+ ----------
+ username_or_email : str
+ Either the email address or username of the authenticating user.
+ password : str
+
+ Returns
+ -------
+ :class:`.TapirUser`
+ :class:`.TapirUsersPassword`
+ :class:`.TapirNickname`
+
+ Raises
+ ------
+ :class:`AuthenticationFailed`
+ Raised if the user does not exist or the password is incorrect.
+ :class:`RuntimeError`
+ Raised when other problems arise.
+
+ """
+ logger.debug(f'Authenticate with password, user: {username_or_email}')
+
+ if not password:
+ raise ValueError('Passed empty password')
+ if not isinstance(password, str):
+ raise ValueError('Passed non-str password: {type(password)}')
+ if not is_ascii(password):
+ raise ValueError('Password non-ascii password')
+
+ if not username_or_email:
+ raise ValueError('Passed empty username_or_email')
+ if not isinstance(password, str):
+ raise ValueError('Passed non-str username_or_email: {type(username_or_email)}')
+ if len(username_or_email) > 255:
+ raise ValueError(f'Passed username_or_email too long: len {len(username_or_email)}')
+ if not is_ascii(username_or_email):
+ raise ValueError('Passed non-ascii username_or_email')
+
+ if '@' in username_or_email:
+ passdata = _get_user_by_email(username_or_email)
+ else:
+ passdata = _get_user_by_username(username_or_email)
+
+ db_user, db_pass, db_nick, db_profile = passdata
+ logger.debug(f'Got user with user_id: {db_user.user_id}')
+ try:
+ if check_password(password, db_pass.password_enc):
+ return passdata
+ except PasswordAuthenticationFailed as e:
+ raise AuthenticationFailed('Invalid username or password') from e
+
+
+def _get_user_by_user_id(user_id: int) -> PassData:
+ tapir_user: TapirUser = session.query(TapirUser) \
+ .filter(TapirUser.user_id == int(user_id)) \
+ .filter(TapirUser.flag_approved == 1) \
+ .filter(TapirUser.flag_deleted == 0) \
+ .filter(TapirUser.flag_banned == 0) \
+ .first()
+ return _get_passdata(tapir_user)
+
+
+def _get_user_by_email(email: str) -> PassData:
+ if not email or '@' not in email:
+ raise ValueError("must be an email address")
+ tapir_user: TapirUser = session.query(TapirUser) \
+ .filter(TapirUser.email == email) \
+ .filter(TapirUser.flag_approved == 1) \
+ .filter(TapirUser.flag_deleted == 0) \
+ .filter(TapirUser.flag_banned == 0) \
+ .first()
+ return _get_passdata(tapir_user)
+
+
+def _get_user_by_username(username: str) -> PassData:
+ """Username is the tapir nickname."""
+ if not username or '@' in username:
+ raise ValueError("username must not contain a @")
+ tapir_nick = session.query(TapirNickname) \
+ .filter(TapirNickname.nickname == username) \
+ .filter(TapirNickname.flag_valid == 1) \
+ .first()
+ if not tapir_nick:
+ raise NoSuchUser('User lacks a nickname')
+
+ tapir_user = session.query(TapirUser) \
+ .filter(TapirUser.user_id == tapir_nick.user_id) \
+ .filter(TapirUser.flag_approved == 1) \
+ .filter(TapirUser.flag_deleted == 0) \
+ .filter(TapirUser.flag_banned == 0) \
+ .first()
+ return _get_passdata(tapir_user)
+
+
+def _get_passdata(tapir_user: TapirUser) -> PassData:
+ """
+ Retrieve password, nick name and profile data.
+
+ Parameters
+ ----------
+ username_or_email : str
+
+ Returns
+ -------
+ :class:`.TapirUser`
+ :class:`.TapirUsersPassword`
+ :class:`.TapirNickname`
+ :class:`.Demographic`
+
+ Raises
+ ------
+ :class:`NoSuchUser`
+ Raised when the user cannot be found.
+ :class:`RuntimeError`
+ Raised when other problems arise.
+
+ """
+ if not tapir_user:
+ raise NoSuchUser('User does not exist')
+
+ tapir_nick = session.query(TapirNickname) \
+ .filter(TapirNickname.user_id ==tapir_user.user_id) \
+ .filter(TapirNickname.flag_valid == 1) \
+ .first()
+ if not tapir_nick:
+ raise NoSuchUser('User lacks a nickname')
+
+ tapir_password: TapirUsersPassword = session.query(TapirUsersPassword) \
+ .filter(TapirUsersPassword.user_id == tapir_user.user_id) \
+ .first()
+ if not tapir_password:
+ raise RuntimeError(f'Missing password')
+
+ tapir_profile: Demographic = session.query(Demographic) \
+ .filter(Demographic.user_id == tapir_user.user_id) \
+ .first()
+ return tapir_user, tapir_password, tapir_nick, tapir_profile
+
+
+def _invalidate_token(user_id: str, secret: str) -> None:
+ """
+ Invalidate a user's permanent login token.
+
+ Parameters
+ ----------
+ user_id : str
+ secret : str
+
+ Raises
+ ------
+ :class:`NoSuchUser`
+ Raised when the token or user cannot be found.
+
+ """
+ db_token = _get_token(user_id, secret)
+ db_token.valid = 0
+ session.add(db_token)
+ session.commit()
+
+
+def _get_token(user_id: str, secret: str) -> TapirPermanentToken:
+ """
+ Retrieve a user's permanent token.
+
+ User ID and token are used together as the primary key for the token.
+
+ Parameters
+ ----------
+ user_id : str
+ secret : str
+ valid : int
+ (default: 1)
+
+ Returns
+ -------
+ :class:`.TapirPermanentToken`
+
+ Raises
+ ------
+ :class:`NoSuchUser`
+ Raised when the token or user cannot be found.
+
+ """
+ if not user_id.isdigit():
+ raise ValueError("user_id must be digits")
+ if not user_id:
+ raise ValueError("user_id must not be empty")
+ if len(user_id) > 50:
+ raise ValueError("user_id too long")
+ if len(secret) > 40:
+ raise ValueError("secret too long")
+
+ db_token: TapirPermanentToken = session.query(TapirPermanentToken) \
+ .filter(TapirPermanentToken.user_id == user_id) \
+ .filter(TapirPermanentToken.secret == secret) \
+ .filter(TapirPermanentToken.valid == 1) \
+ .first() # The token must still be valid.
+ if not db_token:
+ raise NoSuchUser('No such token')
+ else:
+ return db_token
diff --git a/arxiv/auth/legacy/cookies.py b/arxiv/auth/legacy/cookies.py
new file mode 100644
index 00000000..88e66a21
--- /dev/null
+++ b/arxiv/auth/legacy/cookies.py
@@ -0,0 +1,137 @@
+"""Provides functions for working with legacy session cookies.
+
+The legacy cookie is 6 parts seperated with ':'.
+
+The parts are:
+1. session id
+2. tapir_users.user_id
+3. ip the session was started at
+4. time issued at as unix epoch
+5. capabilities
+6. b64 encoced sha1 hash of parts 1-5
+
+In a way it is similar to a JWT where parts 1-5 are similar to the JWT
+payload, though fixed in strucutre, and part 6 forms the signature.
+
+Parts 1-5 are not b64 encoded.
+"""
+
+from typing import Tuple, List
+from base64 import b64encode
+import hashlib
+from datetime import datetime, timedelta
+
+from werkzeug.http import parse_cookie
+from werkzeug.datastructures import MultiDict
+
+from .exceptions import InvalidCookie
+from . import util
+
+
+def unpack(cookie: str) -> Tuple[str, str, str, datetime, datetime, str]:
+ """
+ Unpack the legacy session cookie.
+
+ Parameters
+ ----------
+ cookie : str
+ The value of session cookie.
+
+ Returns
+ -------
+ str
+ The session ID associated with the cookie.
+ str
+ The user ID of the authenticated account.
+ str
+ The IP address of the client when the session was created.
+ datetime
+ The datetime when the session was created.
+ datetime
+ The datetime when the session expires.
+ str
+ Legacy user privilege level.
+
+ Raises
+ ------
+ :class:`InvalidCookie`
+ Raised if the cookie is detectably malformed or tampered with.
+
+ """
+ parts = cookie.split(':')
+ if len(parts) < 5:
+ raise InvalidCookie('Malformed cookie')
+
+ session_id = parts[0]
+ user_id = parts[1]
+ ip = parts[2]
+ issued_at = util.from_epoch(int(parts[3]))
+ expires_at = issued_at + timedelta(seconds=util.get_session_duration())
+ capabilities = parts[4]
+ try:
+ expected = pack(session_id, user_id, ip, issued_at, capabilities)
+ except Exception as e:
+ raise InvalidCookie('Invalid session cookie; problem while repacking') from e
+
+ if expected == cookie:
+ return session_id, user_id, ip, issued_at, expires_at, capabilities
+ else:
+ raise InvalidCookie('Invalid session cookie; not as expected')
+
+
+def pack(session_id: str, user_id: str, ip: str, issued_at: datetime,
+ capabilities: str) -> str:
+ """
+ Generate a value for the classic session cookie.
+
+ Parameters
+ ----------
+ session_id : str
+ The session ID associated with the cookie.
+ user_id : str
+ The user ID of the authenticated account.
+ ip : str
+ Client IP address.
+ issued_at : datetime
+ The UNIX time at which the session was initiated.
+ capabilities : str
+ This is essentially a user privilege level.
+
+ Returns
+ -------
+ str
+ Signed session cookie value.
+
+ """
+ session_hash = util.get_session_hash()
+ value = ':'.join(map(str, [session_id, user_id, ip, util.epoch(issued_at),
+ capabilities]))
+ to_sign = f'{value}-{session_hash}'.encode('utf-8')
+ cookie_hash = b64encode(hashlib.sha1(to_sign).digest())
+ return value + ':' + cookie_hash.decode('utf-8')[:-1]
+
+
+def get_cookies(request, cookie_name:str) -> List[str]:
+ """Gets list of legacy cookies.
+
+ Duplicate cookies occur due to the browser sending both the
+ cookies for both arxiv.org and sub.arxiv.org. If this is being
+ served at sub.arxiv.org, there is no response that will cause
+ the browser to alter its cookie store for arxiv.org. Duplicate
+ cookies must be handled gracefully to for the domain and
+ subdomain to coexist.
+
+ The standard way to avoid this problem is to append part of
+ the domain's name to the cookie key but this needs to work
+ even if the configuration is not ideal.
+
+ """
+ # By default, werkzeug uses a dict-based struct that supports only a
+ # single value per key. This isn't really up to speed with RFC 6265.
+ # Luckily we can just pass in an alternate struct to parse_cookie()
+ # that can cope with multiple values.
+ raw_cookie = request.environ.get('HTTP_COOKIE', None)
+ if raw_cookie is None:
+ return []
+ cookies = parse_cookie(raw_cookie, cls=MultiDict)
+ return cookies.getlist(cookie_name)
diff --git a/arxiv/auth/legacy/endorsements.py b/arxiv/auth/legacy/endorsements.py
new file mode 100644
index 00000000..05daa774
--- /dev/null
+++ b/arxiv/auth/legacy/endorsements.py
@@ -0,0 +1,420 @@
+"""
+Provide endorsement authorizations for users.
+
+Endorsements are authorization scopes tied to specific classificatory
+categories, and are used primarily to determine whether or not a user may
+submit a paper with a particular primary or secondary classification.
+
+This module preserves the behavior of the legacy system with respect to
+interpreting endorsements and evaluating potential autoendorsement. The
+relevant policies can be found on the `arXiv help pages
+`_.
+"""
+
+from typing import List, Dict, Optional, Set, Union
+from collections import Counter
+from datetime import datetime
+from functools import lru_cache as memoize
+from itertools import groupby
+
+from sqlalchemy.sql.expression import literal
+
+from . import util
+from .. import domain
+from ...taxonomy import definitions
+from ...db import session
+from ...db.models import Endorsement, PaperOwner, Document, \
+ t_arXiv_in_category, Category, EndorsementDomain, t_arXiv_white_email, \
+ t_arXiv_black_email
+
+
+GENERAL_CATEGORIES = [
+ definitions.CATEGORIES['math.GM'],
+ definitions.CATEGORIES['physics.gen-ph'],
+]
+
+WINDOW_START = util.from_epoch(157783680)
+
+Endorsements = List[Union[domain.Category,str]]
+
+
+def get_endorsements(user: domain.User) -> Endorsements:
+ """
+ Get all endorsements (explicit and implicit) for a user.
+
+ Parameters
+ ----------
+ user : :class:`.domain.User`
+
+ Returns
+ -------
+ list
+ Each item is a :class:`.domain.Category` for which the user is
+ either explicitly or implicitly endorsed.
+
+ """
+ endorsements = list(set(explicit_endorsements(user))
+ | set(implicit_endorsements(user)))
+
+ return endorsements
+
+
+@memoize()
+def _categories_in_archive(archive: str) -> Set[str]:
+ return set(category for category, definition
+ in definitions.CATEGORIES_ACTIVE.items()
+ if definition.in_archive == archive)
+
+
+@memoize()
+def _category(archive: str, subject_class: str) -> domain.Category:
+ if subject_class:
+ return definitions.CATEGORIES[f'{archive}.{subject_class}']
+ return definitions.CATEGORIES[archive]
+
+
+@memoize()
+def _get_archive(category: domain.Category) -> str:
+ return category.in_archive
+
+
+def _all_archives(endorsements: Endorsements) -> bool:
+ archives = set(_get_archive(category) for category in endorsements
+ if category.id.endswith(".*"))
+ missing = set(definitions.ARCHIVES_ACTIVE.keys()) - archives
+ return len(missing) == 0 or (len(missing) == 1 and 'test' in missing)
+
+
+def _all_subjects_in_archive(archive: str, endorsements: Endorsements) -> bool:
+ return len(_categories_in_archive(archive) - set(endorsements)) == 0
+
+
+def compress_endorsements(endorsements: Endorsements) -> Endorsements:
+ """
+ Compress endorsed categories using wildcard notation if possible.
+
+ We want to avoid simply enumerating all of the categories that exist. If
+ all subjects in an archive are present, we represent that as "{archive}.*".
+ If all subjects in all archives are present, we represent that as "*.*".
+
+ Parameters
+ ----------
+ endorsements : list
+ A list of endorsed categories.
+
+ Returns
+ -------
+ list
+ The same endorsed categories, compressed with wildcards where possible.
+
+ """
+ compressed: Endorsements = []
+ grouped = groupby(sorted(endorsements, key=_get_archive), key=_get_archive)
+ for archive, archive_endorsements in grouped:
+ archive_endorsements_list = list(archive_endorsements)
+ if _all_subjects_in_archive(archive, archive_endorsements_list):
+ compressed.append(domain.Category(id=f"{archive}.*",
+ full_name=f"all of {archive}",
+ is_active=True,
+ in_archive=archive,
+ is_general=False,
+ ))
+
+ else:
+ for endorsement in archive_endorsements_list:
+ compressed.append(endorsement)
+ if _all_archives(compressed):
+ return list(definitions.CATEGORIES.values())
+ return compressed
+
+
+def explicit_endorsements(user: domain.User) -> Endorsements:
+ """
+ Load endorsed categories for a user.
+
+ These are endorsements (including auto-endorsements) that have been
+ explicitly commemorated.
+
+ Parameters
+ ----------
+ user : :class:`.domain.User`
+
+ Returns
+ -------
+ list
+ Each item is a :class:`.domain.Category` for which the user is
+ explicitly endorsed.
+
+ """
+ data: List[Endorsement] = (
+ session.query(
+ Endorsement.archive,
+ Endorsement.subject_class,
+ Endorsement.point_value,
+ )
+ .filter(Endorsement.endorsee_id == user.user_id)
+ .filter(Endorsement.flag_valid == 1)
+ .all()
+ )
+ pooled: Counter = Counter()
+ for archive, subject, points in data:
+ pooled[_category(archive, subject)] += points
+ return [category for category, points in pooled.items() if points]
+
+
+def implicit_endorsements(user: domain.User) -> Endorsements:
+ """
+ Determine categories for which a user may be autoendorsed.
+
+ In the classic system, this was determined upon request, when the user
+ attempted to submit to a particular category. Because we are separating
+ authorization concerns (which includes endorsement) from the submission
+ system itself, we want to calculate possible autoendorsement categories
+ ahead of time.
+
+ New development of autoendorsement-related functionality should not happen
+ here. This function and related code are intended only to preserve the
+ business logic already implemented in the classic system.
+
+ Parameters
+ ----------
+ :class:`.User`
+
+ Returns
+ -------
+ list
+ Each item is a :class:`.domain.Category` for which the user may be
+ auto-endorsed.
+
+ """
+ candidates = [definitions.CATEGORIES[category]
+ for category, data in definitions.CATEGORIES_ACTIVE.items()]
+ policies = category_policies()
+ invalidated = invalidated_autoendorsements(user)
+ papers = domain_papers(user)
+ user_is_academic = is_academic(user)
+ return [
+ category for category in candidates
+ if category in policies
+ and not _disqualifying_invalidations(category, invalidated)
+ and (policies[category]['endorse_all']
+ or _endorse_by_email(category, policies, user_is_academic)
+ or _endorse_by_papers(category, policies, papers))
+ ]
+
+
+def is_academic(user: domain.User) -> bool:
+ """
+ Determine whether a user is academic, based on their email address.
+
+ Uses whitelist and blacklist patterns in the database.
+
+ Parameters
+ ----------
+ user : :class:`.domain.User`
+
+ Returns
+ -------
+ bool
+
+ """
+ in_whitelist = (
+ session.query(t_arXiv_white_email.c.pattern)
+ .filter(literal(user.email).like(t_arXiv_white_email.c.pattern))
+ .first()
+ )
+ if in_whitelist:
+ return True
+ in_blacklist = (
+ session.query(t_arXiv_black_email.c.pattern)
+ .filter(literal(user.email).like(t_arXiv_black_email.c.pattern))
+ .first()
+ )
+ if in_blacklist:
+ return False
+ return True
+
+
+def _disqualifying_invalidations(category: domain.Category,
+ invalidated: Endorsements) -> bool:
+ """
+ Evaluate whether endorsement invalidations are disqualifying.
+
+ This enforces the policy that invalidated (revoked) auto-endorsements can
+ prevent future auto-endorsement.
+
+ Parameters
+ ----------
+ category : :class:`.Category`
+ The category for which an auto-endorsement is being considered.
+ invalidated : list
+ Categories for which the user has had auto-endorsements invalidated
+ (revoked).
+
+ Returns
+ -------
+ bool
+
+ """
+ return bool((category in GENERAL_CATEGORIES and category in invalidated)
+ or (category not in GENERAL_CATEGORIES and invalidated))
+
+
+def _endorse_by_email(category: domain.Category,
+ policies: Dict[domain.Category, Dict],
+ user_is_academic: bool) -> bool:
+ """
+ Evaluate whether an auto-endorsement can be issued based on email address.
+
+ This enforces the policy that some categories allow auto-endorsement for
+ academic users.
+
+ Parameters
+ ----------
+ category : :class:`.Category`
+ The category for which an auto-endorsement is being considered.
+ policies : dict
+ Describes auto-endorsement policies for each category (inherited from
+ their endorsement domains).
+ user_is_academic : bool
+ Whether or not the user has been determined to be academic.
+
+ Returns
+ -------
+ bool
+
+ """
+ policy = policies.get(category)
+ if policy is None or 'endorse_email' not in policy:
+ return False
+ return policy['endorse_email'] and user_is_academic
+
+
+def _endorse_by_papers(category: domain.Category,
+ policies: Dict[domain.Category, Dict],
+ papers: Dict[str, int]) -> bool:
+ """
+ Evaluate whether an auto-endorsement can be issued based on prior papers.
+
+ This enforces the policy that some categories allow auto-endorsements for
+ users who have published a minimum number of papers in categories that
+ share an endoresement domain.
+
+ Parameters
+ ----------
+ category : :class:`.Category`
+ The category for which an auto-endorsement is being considered.
+ policies : dict
+ Describes auto-endorsement policies for each category (inherited from
+ their endorsement domains).
+ papers : dict
+ The number of papers that the user has published in each endorsement
+ domain. Keys are str names of endorsement domains, values are int.
+
+ Returns
+ -------
+ bool
+
+ """
+ N_papers = papers.get(policies[category]['domain'], 0)
+ min_papers = policies[category]['min_papers']
+ return bool(N_papers >= min_papers)
+
+
+def domain_papers(user: domain.User,
+ start_date: Optional[datetime] = None) -> Dict[str, int]:
+ """
+ Calculate the number of papers that a user owns in each endorsement domain.
+
+ This includes both submitted and claimed papers.
+
+ Parameters
+ ----------
+ user : :class:`.domain.User`
+ start_date : :class:`.datetime` or None
+ If provided, will only count papers published after this date.
+
+ Returns
+ -------
+ dict
+ Keys are classification domains (str), values are the number of papers
+ in each respective domain (int).
+
+ """
+ query = session.query(PaperOwner.document_id,
+ Document.document_id,
+ t_arXiv_in_category.c.document_id,
+ Category.endorsement_domain) \
+ .filter(PaperOwner.user_id == user.user_id) \
+ .filter(Document.document_id == PaperOwner.document_id) \
+ .filter(t_arXiv_in_category.c.document_id == Document.document_id) \
+ .filter(Category.archive == t_arXiv_in_category.c.archive) \
+ .filter(Category.subject_class == t_arXiv_in_category.c.subject_class)
+
+ if start_date:
+ query = query.filter(Document.dated > util.epoch(start_date))
+ data = query.all()
+ return dict(Counter(domain for _, _, _, domain in data).items())
+
+
+def category_policies() -> Dict[domain.Category, Dict]:
+ """
+ Load auto-endorsement policies for each category from the database.
+
+ Each category belongs to an endorsement domain, which defines the
+ auto-endorsement policies. We retrieve those policies from the perspective
+ of the individueal category for ease of lookup.
+
+ Returns
+ -------
+ dict
+ Keys are :class:`.domain.Category` instances. Values are dicts with
+ policiy details.
+
+ """
+ data = session.query(Category.archive,
+ Category.subject_class,
+ EndorsementDomain.endorse_all,
+ EndorsementDomain.endorse_email,
+ EndorsementDomain.papers_to_endorse,
+ EndorsementDomain.endorsement_domain) \
+ .filter(Category.definitive == 1) \
+ .filter(Category.active == 1) \
+ .filter(Category.endorsement_domain
+ == EndorsementDomain.endorsement_domain) \
+ .all()
+
+ policies = {}
+ for arch, subj, endorse_all, endorse_email, min_papers, e_domain in data:
+ policies[_category(arch, subj)] = {
+ 'domain': e_domain,
+ 'endorse_all': endorse_all == 'y',
+ 'endorse_email': endorse_email == 'y',
+ 'min_papers': min_papers
+ }
+
+ return policies
+
+
+def invalidated_autoendorsements(user: domain.User) -> Endorsements:
+ """
+ Load any invalidated (revoked) auto-endorsements for a user.
+
+ Parameters
+ ----------
+ user : :class:`.domain.User`
+
+ Returns
+ -------
+ list
+ Items are :class:`.domain.Category` for which the user has had past
+ auto-endorsements revoked.
+
+ """
+ data: List[Endorsement] = session.query(Endorsement.archive,
+ Endorsement.subject_class) \
+ .filter(Endorsement.endorsee_id == user.user_id) \
+ .filter(Endorsement.flag_valid == 0) \
+ .filter(Endorsement.type == 'auto') \
+ .all()
+ return [_category(archive, subject) for archive, subject in data]
diff --git a/arxiv/auth/legacy/exceptions.py b/arxiv/auth/legacy/exceptions.py
new file mode 100644
index 00000000..0f4ee8d6
--- /dev/null
+++ b/arxiv/auth/legacy/exceptions.py
@@ -0,0 +1,45 @@
+"""Exceptions for legacy user/session integration."""
+
+
+class AuthenticationFailed(RuntimeError):
+ """Failed to authenticate user with provided credentials."""
+
+
+class NoSuchUser(RuntimeError):
+ """A reference to a non-existant user was passed."""
+
+
+class PasswordAuthenticationFailed(RuntimeError):
+ """An invalid username/password combination were provided."""
+
+
+class SessionCreationFailed(RuntimeError):
+ """Failed to create a session in the legacy database."""
+
+
+class SessionDeletionFailed(RuntimeError):
+ """Failed to delete a session in the legacy database."""
+
+
+class UnknownSession(RuntimeError):
+ """Failed to locate a session in the legacy database."""
+
+
+class SessionExpired(RuntimeError):
+ """A reference was made to an expired session."""
+
+
+class InvalidCookie(ValueError):
+ """The value of a passed legacy cookie is not valid."""
+
+
+class RegistrationFailed(RuntimeError):
+ """Could not create a new user."""
+
+
+class UpdateUserFailed(RuntimeError):
+ """Could not update a user."""
+
+
+class Unavailable(RuntimeError):
+ """The database is temporarily unavailable."""
diff --git a/arxiv/auth/legacy/passwords.py b/arxiv/auth/legacy/passwords.py
new file mode 100644
index 00000000..915a16c8
--- /dev/null
+++ b/arxiv/auth/legacy/passwords.py
@@ -0,0 +1,47 @@
+"""Passwords from legacy."""
+
+import secrets
+from base64 import b64encode, b64decode
+import hashlib
+
+from .exceptions import PasswordAuthenticationFailed
+
+
+def _hash_salt_and_password(salt: bytes, password: str) -> bytes:
+ return hashlib.sha1(salt + b'-' + password.encode('ascii')).digest()
+
+
+def hash_password(password: str) -> str:
+ """Generate a secure hash of a password.
+
+ The password must be ascii.
+ """
+ salt = secrets.token_bytes(4)
+ hashed = _hash_salt_and_password(salt, password)
+ return b64encode(salt + hashed).decode('ascii')
+
+
+def check_password(password: str, encrypted: bytes):
+ """Check a password against an encrypted hash."""
+ try:
+ password.encode('ascii')
+ except UnicodeEncodeError:
+ raise PasswordAuthenticationFailed('Password not ascii')
+
+ decoded = b64decode(encrypted)
+ salt = decoded[:4]
+ enc_hashed = decoded[4:]
+ pass_hashed = _hash_salt_and_password(salt, password)
+ if pass_hashed != enc_hashed:
+ raise PasswordAuthenticationFailed('Incorrect password')
+ else:
+ return True
+
+
+def is_ascii(string):
+ """Returns true if the string is only ascii chars."""
+ try:
+ string.encode('ascii')
+ return True
+ except UnicodeEncodeError:
+ return False
diff --git a/arxiv/auth/legacy/sessions.py b/arxiv/auth/legacy/sessions.py
new file mode 100644
index 00000000..bb11fe4f
--- /dev/null
+++ b/arxiv/auth/legacy/sessions.py
@@ -0,0 +1,242 @@
+"""Provides API for legacy user sessions."""
+from datetime import datetime, timedelta
+from pytz import timezone, UTC
+
+import logging
+
+from typing import Optional, Tuple
+
+from sqlalchemy.exc import SQLAlchemyError, OperationalError
+from sqlalchemy.orm.exc import NoResultFound
+
+from .. import domain
+from ...db import session
+from . import cookies, util
+
+from ...db.models import TapirSession, TapirSessionsAudit, TapirUser, \
+ TapirNickname, Demographic
+from .exceptions import UnknownSession, SessionCreationFailed, \
+ SessionExpired, InvalidCookie, Unavailable
+
+logger = logging.getLogger(__name__)
+EASTERN = timezone('US/Eastern')
+
+
+def _load(session_id: str) -> TapirSession:
+ """Get TapirSession from session id."""
+ db_session: TapirSession = session.query(TapirSession) \
+ .filter(TapirSession.session_id == session_id) \
+ .first()
+ if not db_session:
+ logger.debug(f'No session found with id {session_id}')
+ raise UnknownSession('No such session')
+ return db_session
+
+
+def _load_audit(session_id: str) -> TapirSessionsAudit:
+ """Get TapirSessionsAudit from session id."""
+ db_sessions_audit: TapirSessionsAudit = session.query(TapirSessionsAudit) \
+ .filter(TapirSessionsAudit.session_id == session_id) \
+ .first()
+ if not db_sessions_audit:
+ logger.debug(f'No session audit found with id {session_id}')
+ raise UnknownSession('No such session audit')
+ return db_sessions_audit
+
+
+def load(cookie: str) -> domain.Session:
+ """
+ Given a session cookie (from request), load the logged-in user.
+
+ Parameters
+ ----------
+ cookie : str
+ Legacy cookie value passed with the request.
+
+ Returns
+ -------
+ :class:`.domain.Session`
+
+ Raises
+ ------
+ :class:`.legacy.exceptions.SessionExpired`
+ :class:`.legacy.exceptions.UnknownSession`
+
+ """
+ session_id, user_id, ip, issued_at, expires_at, _ = cookies.unpack(cookie)
+ logger.debug('Load session %s for user %s at %s',
+ session_id, user_id, ip)
+
+ if expires_at <= datetime.now(tz=UTC):
+ raise SessionExpired(f'Session {session_id} has expired in cookie')
+
+ data: Tuple[TapirUser, TapirSession, TapirNickname, Demographic]
+ try:
+ data = session.query(TapirUser, TapirSession, TapirNickname, Demographic) \
+ .join(TapirSession).join(TapirNickname).join(Demographic) \
+ .filter(TapirUser.user_id == user_id) \
+ .filter(TapirSession.session_id == session_id ) \
+ .first()
+ except OperationalError as e:
+ raise Unavailable('Database is temporarily unavailable') from e
+
+ if not data:
+ raise UnknownSession('No such user or session')
+
+ db_user, db_session, db_nick, db_profile = data
+
+ if db_session.end_time != 0 and db_session.end_time < util.now():
+ raise SessionExpired(f'Session {session_id} has expired in the DB')
+
+ user = domain.User(
+ user_id=str(user_id),
+ username=db_nick.nickname,
+ email=db_user.email,
+ name=domain.UserFullName(
+ forename=db_user.first_name,
+ surname=db_user.last_name,
+ suffix=db_user.suffix_name
+ ),
+ profile=domain.UserProfile.from_orm(db_profile) if db_profile else None,
+ verified=bool(db_user.flag_email_verified)
+ )
+
+ authorizations = domain.Authorizations(
+ classic=util.compute_capabilities(db_user),
+ scopes=util.get_scopes(db_user)
+ )
+ user_session = domain.Session(session_id=str(db_session.session_id),
+ start_time=issued_at, end_time=expires_at,
+ user=user, authorizations=authorizations)
+ logger.debug('loaded session %s', user_session.session_id)
+ return user_session
+
+
+def create(authorizations: domain.Authorizations,
+ ip: str, remote_host: str, tracking_cookie: str = '',
+ user: Optional[domain.User] = None) -> domain.Session:
+ """
+ Create a new legacy session for an authenticated user.
+
+ Parameters
+ ----------
+ user : :class:`.User`
+ ip : str
+ Client IP address.
+ remote_host : str
+ Client hostname.
+ tracking_cookie : str
+ Tracking cookie payload from client request.
+
+ Returns
+ -------
+ :class:`.Session`
+
+ """
+ if user is None:
+ raise SessionCreationFailed('Legacy sessions require a user')
+
+ logger.debug('create session for user %s', user.user_id)
+ start = datetime.now(tz=UTC)
+ end = start + timedelta(seconds=util.get_session_duration())
+ try:
+ tapir_session = TapirSession(
+ user_id=user.user_id,
+ last_reissue=util.epoch(start),
+ start_time=util.epoch(start),
+ end_time=0
+ )
+ tapir_sessions_audit = TapirSessionsAudit(
+ session=tapir_session,
+ ip_addr=ip,
+ remote_host=remote_host,
+ tracking_cookie=tracking_cookie
+ )
+ session.add(tapir_sessions_audit)
+ session.commit()
+ except Exception as e:
+ raise SessionCreationFailed(f'Failed to create: {e}') from e
+
+ session_id = str(tapir_session.session_id)
+ user_session = domain.Session(session_id=session_id,
+ start_time=start, end_time=end,
+ user=user,
+ authorizations=authorizations,
+ ip_address=ip, remote_host=remote_host)
+ logger.debug('created session %s', user_session.session_id)
+ return user_session
+
+
+def generate_cookie(session: domain.Session) -> str:
+ """
+ Generate a cookie from a :class:`domain.Session`.
+
+ Parameters
+ ----------
+ session : :class:`domain.Session`
+
+ Returns
+ -------
+ str
+
+ """
+ if session.user is None \
+ or session.user.user_id is None \
+ or session.ip_address is None \
+ or session.authorizations is None:
+ raise RuntimeError('Cannot generate cookie without an authorized user')
+
+ return cookies.pack(str(session.session_id), session.user.user_id,
+ session.ip_address, session.start_time,
+ str(session.authorizations.classic))
+
+
+def invalidate(cookie: str) -> None:
+ """
+ Invalidate a legacy user session.
+
+ Parameters
+ ----------
+ cookie : str
+ Session cookie generated when the session was created.
+
+ Raises
+ ------
+ :class:`UnknownSession`
+ The session could not be found, or the cookie was not valid.
+
+ """
+ try:
+ session_id, user_id, ip, _, _, _ = cookies.unpack(cookie)
+ except InvalidCookie as e:
+ raise UnknownSession('No such session') from e
+
+ invalidate_by_id(session_id)
+
+
+def invalidate_by_id(session_id: str) -> None:
+ """
+ Invalidate a legacy user session by ID.
+
+ Parameters
+ ----------
+ session_id : str
+ Unique identifier for the session.
+
+ Raises
+ ------
+ :class:`UnknownSession`
+ The session could not be found, or the cookie was not valid.
+
+ """
+ delta = datetime.now(tz=UTC) - datetime.fromtimestamp(0, tz=EASTERN)
+ end = (delta).total_seconds()
+ try:
+ tapir_session = _load(session_id)
+ tapir_session.end_time = end - 1
+ session.merge(tapir_session)
+ session.commit()
+ except NoResultFound as e:
+ raise UnknownSession(f'No such session {session_id}') from e
+ except SQLAlchemyError as e:
+ raise IOError(f'Database error') from e
diff --git a/arxiv/auth/legacy/tests/__init__.py b/arxiv/auth/legacy/tests/__init__.py
new file mode 100644
index 00000000..32781b97
--- /dev/null
+++ b/arxiv/auth/legacy/tests/__init__.py
@@ -0,0 +1,4 @@
+"""Tests for :mod:`arxiv.users.legacy`."""
+import os
+
+os.environ['CLASSIC_DB_URI'] = 'sqlite:///:memory:'
\ No newline at end of file
diff --git a/arxiv/auth/legacy/tests/test_accounts.py b/arxiv/auth/legacy/tests/test_accounts.py
new file mode 100644
index 00000000..7f4192a1
--- /dev/null
+++ b/arxiv/auth/legacy/tests/test_accounts.py
@@ -0,0 +1,386 @@
+"""Tests for :mod:`arxiv.users.legacy.accounts`."""
+
+import tempfile
+from datetime import datetime
+import shutil
+import hashlib
+from pytz import UTC
+from unittest import TestCase
+from sqlalchemy import select
+from flask import Flask
+
+from arxiv.config import Settings
+from arxiv.taxonomy import definitions
+from arxiv.db import transaction
+from arxiv.db import models
+
+from .util import temporary_db
+from .. import util, authenticate, exceptions
+from .. import accounts
+from ... import domain
+
+
+def get_user(session, user_id):
+ """Helper to get user database objects by user id."""
+ db_user, db_nick = (
+ session.query(models.TapirUser, models.TapirNickname)
+ .filter(models.TapirUser.user_id == user_id)
+ .filter(models.TapirNickname.flag_primary == 1)
+ .filter(models.TapirNickname.user_id == models.TapirUser.user_id)
+ .first()
+ )
+
+ db_profile = session.query(models.Demographic) \
+ .filter(models.Demographic.user_id == user_id) \
+ .first()
+
+ return db_user, db_nick, db_profile
+
+
+class SetUpUserMixin(TestCase):
+ """Mixin for creating a test user and other database goodies."""
+
+ def setUp(self):
+ """Set up the database."""
+ self.db_path = tempfile.mkdtemp()
+ self.db_uri = f'sqlite:///{self.db_path}/test.db'
+ self.user_id = '15830'
+ self.app = Flask('test')
+ self.app.config['CLASSIC_SESSION_HASH'] = 'foohash'
+ self.app.config['CLASSIC_COOKIE_NAME'] = 'tapir_session_cookie'
+ self.app.config['SESSION_DURATION'] = '36000'
+ settings = Settings(
+ CLASSIC_DB_URI=self.db_uri,
+ LATEXML_DB_URI=None)
+
+ self.engine, _ = models.configure_db(settings) # Insert tapir policy classes
+
+ with temporary_db("sqlite:///:memory:", drop=False) as session:
+ self.user_class = session.scalar(
+ select(models.TapirPolicyClass).where(models.TapirPolicyClass.class_id==2))
+ self.email = 'first@last.iv'
+ self.db_user = models.TapirUser(
+ user_id=self.user_id,
+ first_name='first',
+ last_name='last',
+ suffix_name='iv',
+ email=self.email,
+ policy_class=self.user_class.class_id,
+ flag_edit_users=1,
+ flag_email_verified=1,
+ flag_edit_system=0,
+ flag_approved=1,
+ flag_deleted=0,
+ flag_banned=0,
+ tracking_cookie='foocookie',
+ )
+ self.username = 'foouser'
+ self.db_nick = models.TapirNickname(
+ nickname=self.username,
+ user_id=self.user_id,
+ user_seq=1,
+ flag_valid=1,
+ role=0,
+ policy=0,
+ flag_primary=1
+ )
+ self.salt = b'foo'
+ self.password = b'thepassword'
+ hashed = hashlib.sha1(self.salt + b'-' + self.password).digest()
+ self.db_password = models.TapirUsersPassword(
+ user_id=self.user_id,
+ password_storage=2,
+ password_enc=hashed
+ )
+ n = util.epoch(datetime.now(tz=UTC))
+ self.secret = 'foosecret'
+ self.db_token = models.TapirPermanentToken(
+ user_id=self.user_id,
+ secret=self.secret,
+ valid=1,
+ issued_when=n,
+ issued_to='127.0.0.1',
+ remote_host='foohost.foo.com',
+ session_id=0
+ )
+ session.add(self.user_class)
+ session.add(self.db_user)
+ session.add(self.db_password)
+ session.add(self.db_nick)
+ session.add(self.db_token)
+ session.commit()
+
+ def tearDown(self):
+ """Drop tables from the in-memory db file."""
+ util.drop_all(self.engine)
+
+
+class TestUsernameExists(SetUpUserMixin):
+ """Tests for :mod:`accounts.does_username_exist`."""
+
+ def test_with_nonexistant_user(self):
+ """There is no user with the passed username."""
+ with self.app.app_context():
+ self.assertFalse(accounts.does_username_exist('baruser'))
+
+ def test_with_existant_user(self):
+ """There is a user with the passed username."""
+ # with temporary_db(self.db_uri, create=False, drop=False):
+ # with transaction() as session:
+ # print (f'NICKS: {session.query(models.TapirNickname).all()}')
+ # self.setUp()
+ with self.app.app_context():
+ self.assertTrue(accounts.does_username_exist('foouser'))
+
+
+class TestEmailExists(SetUpUserMixin):
+ """Tests for :mod:`accounts.does_email_exist`."""
+
+ def test_with_nonexistant_email(self):
+ """There is no user with the passed email."""
+ with self.app.app_context():
+ self.assertFalse(accounts.does_email_exist('foo@bar.com'))
+
+ def test_with_existant_email(self):
+ """There is a user with the passed email."""
+ self.setUp()
+ with self.app.app_context():
+ self.assertTrue(accounts.does_email_exist('first@last.iv'))
+
+
+class TestRegister(SetUpUserMixin, TestCase):
+ """Tests for :mod:`accounts.register`."""
+
+ def test_register_with_duplicate_username(self):
+ """The username is already in the system."""
+ user = domain.User(username='foouser', email='foo@bar.com')
+ ip = '1.2.3.4'
+ with self.app.app_context():
+ with self.assertRaises(exceptions.RegistrationFailed):
+ accounts.register(user, 'apassword1', ip=ip, remote_host=ip)
+
+ def test_register_with_duplicate_email(self):
+ """The email address is already in the system."""
+ user = domain.User(username='bazuser', email='first@last.iv')
+ ip = '1.2.3.4'
+ with self.app.app_context():
+ with self.assertRaises(exceptions.RegistrationFailed):
+ accounts.register(user, 'apassword1', ip=ip, remote_host=ip)
+
+ def test_register_with_name_details(self):
+ """Registration includes the user's name."""
+ name = domain.UserFullName(forename='foo', surname='user', suffix='iv')
+ user = domain.User(username='bazuser', email='new@account.edu',
+ name=name)
+ ip = '1.2.3.4'
+
+ with self.app.app_context():
+ with transaction() as session:
+ u, _ = accounts.register(user, 'apassword1', ip=ip, remote_host=ip)
+ db_user, db_nick, db_profile = get_user(session, u.user_id)
+
+ self.assertEqual(db_user.first_name, name.forename)
+ self.assertEqual(db_user.last_name, name.surname)
+ self.assertEqual(db_user.suffix_name, name.suffix)
+
+ def test_register_with_bare_minimum(self):
+ """Registration includes only a username, name, email address, password."""
+ user = domain.User(username='bazuser', email='new@account.edu',
+ name = domain.UserFullName(forename='foo', surname='user', suffix='iv'))
+ ip = '1.2.3.4'
+
+ with self.app.app_context():
+ with transaction() as session:
+ u, _ = accounts.register(user, 'apassword1', ip=ip, remote_host=ip)
+ db_user, db_nick, db_profile = get_user(session, u.user_id)
+
+ self.assertEqual(db_user.flag_email_verified, 0)
+ self.assertEqual(db_nick.nickname, user.username)
+ self.assertEqual(db_user.email, user.email)
+
+ def test_register_with_profile(self):
+ """Registration includes profile information."""
+ profile = domain.UserProfile(
+ affiliation='School of Hard Knocks',
+ country='de',
+ rank=1,
+ submission_groups=['grp_cs', 'grp_q-bio'],
+ default_category=definitions.CATEGORIES['cs.DL'],
+ homepage_url='https://google.com'
+ )
+ name = domain.UserFullName(forename='foo', surname='user', suffix='iv')
+ user = domain.User(username='bazuser', email='new@account.edu',
+ name=name, profile=profile)
+ ip = '1.2.3.4'
+
+ with self.app.app_context():
+ with transaction() as session:
+ u, _ = accounts.register(user, 'apassword1', ip=ip, remote_host=ip)
+ db_user, db_nick, db_profile = get_user(session, u.user_id)
+
+ self.assertEqual(db_profile.affiliation, profile.affiliation)
+ self.assertEqual(db_profile.country, profile.country),
+ self.assertEqual(db_profile.type, profile.rank),
+ self.assertEqual(db_profile.flag_group_cs, 1)
+ self.assertEqual(db_profile.flag_group_q_bio, 1)
+ self.assertEqual(db_profile.flag_group_physics, 0)
+ self.assertEqual(db_profile.archive, 'cs')
+ self.assertEqual(db_profile.subject_class, 'DL')
+
+ def test_can_authenticate_after_registration(self):
+ """A may authenticate a bare-minimum user after registration."""
+ user = domain.User(username='bazuser', email='new@account.edu',
+ name=domain.UserFullName(forename='foo', surname='user'))
+ ip = '1.2.3.4'
+
+ with self.app.app_context():
+ with transaction() as session:
+ u, _ = accounts.register(user, 'apassword1', ip=ip, remote_host=ip)
+ db_user, db_nick, db_profile = get_user(session, u.user_id)
+ auth_user, auths = authenticate.authenticate(
+ username_or_email=user.username,
+ password='apassword1'
+ )
+ self.assertEqual(str(db_user.user_id), auth_user.user_id)
+
+
+class TestGetUserById(SetUpUserMixin):
+ """Tests for :func:`accounts.get_user_by_id`."""
+
+ def test_user_exists(self):
+ """A well-rounded user exists with the requested user id."""
+ profile = domain.UserProfile(
+ affiliation='School of Hard Knocks',
+ country='de',
+ rank=1,
+ submission_groups=['grp_cs', 'grp_q-bio'],
+ default_category=definitions.CATEGORIES['cs.DL'],
+ homepage_url='https://google.com'
+ )
+ name = domain.UserFullName(forename='foo', surname='user', suffix='iv')
+ user = domain.User(username='bazuser', email='new@account.edu',
+ name=name, profile=profile)
+ ip = '1.2.3.4'
+
+ with self.app.app_context():
+ u, _ = accounts.register(user, 'apassword1', ip=ip, remote_host=ip)
+ loaded_user = accounts.get_user_by_id(u.user_id)
+
+ self.assertEqual(loaded_user.username, user.username)
+ self.assertEqual(loaded_user.email, user.email)
+ self.assertEqual(loaded_user.profile.affiliation, profile.affiliation)
+
+ def test_user_does_not_exist(self):
+ """No user with the specified username."""
+ with self.app.app_context():
+ with self.assertRaises(exceptions.NoSuchUser):
+ accounts.get_user_by_id('1234')
+
+ def test_with_no_profile(self):
+ """The user exists, but there is no profile."""
+ name = domain.UserFullName(forename='foo', surname='user', suffix='iv')
+ user = domain.User(username='bazuser', email='new@account.edu',
+ name=name)
+ ip = '1.2.3.4'
+
+ with self.app.app_context():
+ u, _ = accounts.register(user, 'apassword1', ip=ip, remote_host=ip)
+ loaded_user = accounts.get_user_by_id(u.user_id)
+
+ self.assertEqual(loaded_user.username, user.username)
+ self.assertEqual(loaded_user.email, user.email)
+ self.assertIsNone(loaded_user.profile)
+
+
+class TestUpdate(SetUpUserMixin, TestCase):
+ """Tests for :func:`accounts.update`."""
+
+ def test_user_without_id(self):
+ """A :class:`domain.User` is passed without an ID."""
+ user = domain.User(username='bazuser', email='new@account.edu')
+ with self.app.app_context():
+ with self.assertRaises(ValueError):
+ accounts.update(user)
+
+ def test_update_nonexistant_user(self):
+ """A :class:`domain.User` is passed that is not in the database."""
+ user = domain.User(username='bazuser', email='new@account.edu',
+ user_id='12345')
+ with self.app.app_context():
+ with self.assertRaises(exceptions.NoSuchUser):
+ accounts.update(user)
+
+ def test_update_name(self):
+ """The user's name is changed."""
+ name = domain.UserFullName(forename='foo', surname='user', suffix='iv')
+ user = domain.User(username='bazuser', email='new@account.edu',
+ name=name)
+ ip = '1.2.3.4'
+
+ with self.app.app_context():
+ user, _ = accounts.register(user, 'apassword1', ip=ip,
+ remote_host=ip)
+
+ with self.app.app_context():
+ with transaction() as session:
+ updated_name = domain.UserFullName(forename='Foo',
+ surname=name.surname,
+ suffix=name.suffix)
+ updated_user = domain.User(user_id=user.user_id,
+ username=user.username,
+ email=user.email,
+ name=updated_name)
+
+ updated_user, _ = accounts.update(updated_user)
+ self.assertEqual(user.user_id, updated_user.user_id)
+ self.assertEqual(updated_user.name.forename, 'Foo')
+ db_user, db_nick, db_profile = get_user(session, user.user_id)
+ self.assertEqual(db_user.first_name, 'Foo')
+
+ def test_update_profile(self):
+ """Changes are made to profile information."""
+ profile = domain.UserProfile(
+ affiliation='School of Hard Knocks',
+ country='de',
+ rank=1,
+ submission_groups=['grp_cs', 'grp_q-bio'],
+ default_category=definitions.CATEGORIES['cs.DL'],
+ homepage_url='https://google.com'
+ )
+ name = domain.UserFullName(forename='foo', surname='user', suffix='iv')
+ user = domain.User(username='bazuser', email='new@account.edu',
+ name=name, profile=profile)
+ ip = '1.2.3.4'
+
+ with self.app.app_context():
+ user, _ = accounts.register(user, 'apassword1', ip=ip,
+ remote_host=ip)
+
+ updated_profile = domain.UserProfile(
+ affiliation='School of Hard Knocks',
+ country='us',
+ rank=2,
+ submission_groups=['grp_cs', 'grp_physics'],
+ default_category=definitions.CATEGORIES['cs.IR'],
+ homepage_url='https://google.com'
+ )
+ updated_user = domain.User(user_id=user.user_id,
+ username=user.username,
+ email=user.email,
+ name=name,
+ profile=updated_profile)
+
+ with self.app.app_context():
+ with transaction() as session:
+ u, _ = accounts.update(updated_user)
+ db_user, db_nick, db_profile = get_user(session, u.user_id)
+
+ self.assertEqual(db_profile.affiliation,
+ updated_profile.affiliation)
+ self.assertEqual(db_profile.country, updated_profile.country),
+ self.assertEqual(db_profile.type, updated_profile.rank),
+ self.assertEqual(db_profile.flag_group_cs, 1)
+ self.assertEqual(db_profile.flag_group_q_bio, 0)
+ self.assertEqual(db_profile.flag_group_physics, 1)
+ self.assertEqual(db_profile.archive, 'cs')
+ self.assertEqual(db_profile.subject_class, 'IR')
diff --git a/arxiv/auth/legacy/tests/test_authenticate.py b/arxiv/auth/legacy/tests/test_authenticate.py
new file mode 100644
index 00000000..e6329135
--- /dev/null
+++ b/arxiv/auth/legacy/tests/test_authenticate.py
@@ -0,0 +1,299 @@
+"""Tests for :mod:`accounts.services.user_data`."""
+
+from unittest import TestCase, mock
+from datetime import datetime
+from pytz import timezone, UTC
+import tempfile
+import shutil
+import hashlib
+
+from flask import Flask
+
+from sqlalchemy import select
+from sqlalchemy.exc import SQLAlchemyError, OperationalError
+
+from arxiv.config import Settings
+from arxiv.db import models
+
+from .. import authenticate, exceptions, util
+
+from .util import temporary_db
+
+from ..passwords import hash_password
+
+EASTERN = timezone('US/Eastern')
+
+
+class TestAuthenticateWithPermanentToken(TestCase):
+ """User has a permanent token."""
+
+ def setUp(self):
+ """Instantiate a user and its credentials in the DB."""
+ self.db = f'sqlite:///:memory:'
+ self.user_id = '1'
+
+ self.app = Flask('test')
+ settings = Settings(
+ CLASSIC_DB_URI=self.db,
+ LATEXML_DB_URI=None)
+
+ self.engine, _ = models.configure_db(settings)
+
+ with temporary_db(self.db, drop=False) as session:
+ self.user_class = session.scalar(
+ select(models.TapirPolicyClass).where(models.TapirPolicyClass.class_id==2))
+ self.email = 'first@last.iv'
+ self.db_user = models.TapirUser(
+ user_id=self.user_id,
+ first_name='first',
+ last_name='last',
+ suffix_name='iv',
+ email=self.email,
+ policy_class=self.user_class.class_id,
+ flag_edit_users=1,
+ flag_email_verified=1,
+ flag_edit_system=0,
+ flag_approved=1,
+ flag_deleted=0,
+ flag_banned=0,
+ tracking_cookie='foocookie',
+ )
+ self.username = 'foouser'
+ self.db_nick = models.TapirNickname(
+ nickname=self.username,
+ user_id=self.user_id,
+ user_seq=1,
+ flag_valid=1,
+ role=0,
+ policy=0,
+ flag_primary=1
+ )
+ self.salt = b'foo'
+ self.password = b'thepassword'
+ hashed = hashlib.sha1(self.salt + b'-' + self.password).digest()
+ self.db_password = models.TapirUsersPassword(
+ user_id=self.user_id,
+ password_storage=2,
+ password_enc=hashed
+ )
+ n = util.epoch(datetime.now(tz=UTC))
+ self.secret = 'foosecret'
+ self.db_token = models.TapirPermanentToken(
+ user_id=self.user_id,
+ secret=self.secret,
+ valid=1,
+ issued_when=n,
+ issued_to='127.0.0.1',
+ remote_host='foohost.foo.com',
+ session_id=0
+ )
+ session.add(self.user_class)
+ session.add(self.db_user)
+ session.add(self.db_password)
+ session.add(self.db_nick)
+ session.add(self.db_token)
+ session.commit()
+
+ def tearDown(self):
+ """Drop tables from the in-memory db file."""
+ util.drop_all(self.engine)
+
+ def test_token_is_malformed(self):
+ """Token is present, but it has the wrong format."""
+ bad_token = 'footokenhasnohyphen'
+ with self.app.app_context():
+ with self.assertRaises(exceptions.AuthenticationFailed):
+ authenticate.authenticate(token=bad_token)
+
+ def test_token_is_incorrect(self):
+ """Token is present, but there is no such token in the database."""
+ bad_token = '1234-nosuchtoken'
+ # with temporary_db(self.db, create=False):
+ with self.app.app_context():
+ with self.assertRaises(exceptions.AuthenticationFailed):
+ authenticate.authenticate(token=bad_token)
+
+ def test_token_is_invalid(self):
+ """The token is present, but it is not valid."""
+ with self.app.app_context():
+ authenticate._invalidate_token(self.user_id, self.secret)
+
+ with self.assertRaises(exceptions.AuthenticationFailed):
+ token = f'{self.user_id}-{self.secret}'
+ authenticate.authenticate(token=token)
+
+ def test_token_is_valid(self):
+ """The token is valid!."""
+ with self.app.app_context():
+ token = f'{self.user_id}-{self.secret}'
+ user, auths = authenticate.authenticate(token=token)
+ self.assertIsInstance(user, authenticate.domain.User,
+ "Returns data about the user")
+ self.assertIsInstance(auths, authenticate.domain.Authorizations,
+ "Returns authorization data")
+ self.assertEqual(user.user_id, self.user_id,
+ "User ID is set correctly")
+ self.assertEqual(user.username, self.username,
+ "Username is set correctly")
+ self.assertEqual(user.email, self.email,
+ "User email is set correctly")
+ self.assertEqual(auths.classic, 6,
+ "authorizations are set")
+
+
+class TestAuthenticateWithPassword(TestCase):
+ """User is attempting login with username+password."""
+
+ def setUp(self):
+ """Instantiate a user."""
+ self.db = f'sqlite:///:memory:'
+
+ self.app = Flask('test')
+ settings = Settings(
+ CLASSIC_DB_URI=self.db,
+ LATEXML_DB_URI=None)
+
+ self.engine, _ = models.configure_db(settings)
+ self.user_id = '5'
+ with temporary_db(self.db, drop=False) as session:
+ # We have a good old-fashioned user.
+ self.user_class = session.scalar(
+ select(models.TapirPolicyClass).where(models.TapirPolicyClass.class_id==2))
+ self.email = 'first@last.iv'
+ self.db_user = models.TapirUser(
+ user_id=self.user_id,
+ first_name='first',
+ last_name='last',
+ suffix_name='iv',
+ email=self.email,
+ policy_class=self.user_class.class_id,
+ flag_edit_users=1,
+ flag_email_verified=1,
+ flag_edit_system=0,
+ flag_approved=1,
+ flag_deleted=0,
+ flag_banned=0,
+ tracking_cookie='foocookie',
+ )
+ self.username = 'foouser'
+ self.db_nick = models.TapirNickname(
+ nickname=self.username,
+ user_id=self.user_id,
+ user_seq=1,
+ flag_valid=1,
+ role=0,
+ policy=0,
+ flag_primary=1
+ )
+ self.salt = b'foo'
+ self.password = 'thepassword'
+ # hashed = hashlib.sha1(self.salt + b'-' + self.password).digest()
+
+ hashed = hash_password(self.password)
+ self.db_password = models.TapirUsersPassword(
+ user_id=self.user_id,
+ password_storage=2,
+ password_enc=hashed
+ )
+ session.add(self.user_class)
+ session.add(self.db_user)
+ session.add(self.db_password)
+ session.add(self.db_nick)
+ session.commit()
+
+ def tearDown(self):
+ """Drop tables from the in-memory db file."""
+ util.drop_all(self.engine)
+
+ def test_no_username(self):
+ """Username is not entered."""
+ username = ''
+ password = 'foopass'
+ with self.assertRaises(exceptions.AuthenticationFailed):
+ with self.app.app_context():
+ authenticate.authenticate(username, password)
+
+ def test_no_password(self):
+ """Password is not entered."""
+ username = 'foouser'
+ password = ''
+ with self.assertRaises(exceptions.AuthenticationFailed):
+ with self.app.app_context():
+ authenticate.authenticate(username, password)
+
+ def test_password_is_incorrect(self):
+ """Password is incorrect."""
+ with self.app.app_context():
+ with self.assertRaises(exceptions.AuthenticationFailed):
+ authenticate.authenticate('foouser', 'notthepassword')
+
+ def test_password_is_correct(self):
+ """Password is correct."""
+ with self.app.app_context():
+
+ user, auths = authenticate.authenticate('foouser', 'thepassword')
+ self.assertIsInstance(user, authenticate.domain.User,
+ "Returns data about the user")
+ self.assertIsInstance(auths, authenticate.domain.Authorizations,
+ "Returns authorization data")
+ self.assertEqual(user.user_id, self.user_id,
+ "User ID is set correctly")
+ self.assertEqual(user.username, self.username,
+ "Username is set correctly")
+ self.assertEqual(user.email, self.email,
+ "User email is set correctly")
+ self.assertEqual(auths.classic, 6,
+ "Authorizations are set")
+
+ def test_login_with_email_and_correct_password(self):
+ """User attempts to log in with e-mail address."""
+ with self.app.app_context():
+ user, auths = authenticate.authenticate('first@last.iv',
+ 'thepassword')
+ self.assertIsInstance(user, authenticate.domain.User,
+ "Returns data about the user")
+ self.assertIsInstance(auths, authenticate.domain.Authorizations,
+ "Returns authorization data")
+ self.assertEqual(user.user_id, self.user_id,
+ "User ID is set correctly")
+ self.assertEqual(user.username, self.username,
+ "Username is set correctly")
+ self.assertEqual(user.email, self.email,
+ "User email is set correctly")
+ self.assertEqual(auths.classic, 6,
+ "authorizations are set")
+
+ def test_no_such_user(self):
+ """Username does not exist."""
+ with self.app.app_context():
+ with self.assertRaises(exceptions.AuthenticationFailed):
+ authenticate.authenticate('nobody', 'thepassword')
+
+
+ def test_bad_data(self):
+ """Test with bad data."""
+ with self.app.app_context():
+ with self.assertRaises(exceptions.AuthenticationFailed):
+ authenticate.authenticate('abc', '')
+ with self.assertRaises(exceptions.AuthenticationFailed):
+ authenticate.authenticate('abc', 234)
+ with self.assertRaises(exceptions.AuthenticationFailed):
+ authenticate.authenticate('abc', 'ฮฒ')
+ with self.assertRaises(exceptions.AuthenticationFailed):
+ authenticate.authenticate('', 'password')
+ with self.assertRaises(exceptions.AuthenticationFailed):
+ authenticate.authenticate('ฮฒ', 'password')
+ with self.assertRaises(exceptions.AuthenticationFailed):
+ authenticate.authenticate('long'*100, 'password')
+ with self.assertRaises(exceptions.AuthenticationFailed):
+ authenticate.authenticate(1234, 'password')
+
+
+ with self.assertRaises(exceptions.AuthenticationFailed):
+ authenticate.authenticate(None, None, 'abcc-something')
+ with self.assertRaises(exceptions.AuthenticationFailed):
+ authenticate.authenticate(None, None, '-something')
+ with self.assertRaises(exceptions.AuthenticationFailed):
+ authenticate.authenticate(None, None, ('long'*20) + '-something')
+ with self.assertRaises(exceptions.AuthenticationFailed):
+ authenticate.authenticate(None, None, '1234-' + 40 * 'long')
diff --git a/arxiv/auth/legacy/tests/test_bootstrap.py b/arxiv/auth/legacy/tests/test_bootstrap.py
new file mode 100644
index 00000000..e7528d99
--- /dev/null
+++ b/arxiv/auth/legacy/tests/test_bootstrap.py
@@ -0,0 +1,301 @@
+"""Test the legacy integration with synthetic data."""
+
+import os
+import sys
+from typing import Tuple
+from unittest import TestCase
+from flask import Flask
+import locale
+
+from typing import List
+import random
+from datetime import datetime
+from pytz import timezone, UTC
+from mimesis import Person, Internet, Datetime, locales
+
+from sqlalchemy import select, func
+from arxiv.db import transaction
+from arxiv.db import models
+from arxiv.config import Settings
+from arxiv.taxonomy import definitions
+from .. import util, sessions, authenticate, exceptions
+from ..passwords import hash_password
+from ... import domain
+
+LOCALES = locales.Locale.values()
+EASTERN = timezone('US/Eastern')
+
+
+def _random_category() -> Tuple[str, str]:
+ category = random.choice(list(definitions.CATEGORIES_ACTIVE.items()))
+ archive = category[1].in_archive
+ subject_class = category[0].split('.')[-1] if '.' in category[0] else ''
+ return archive, subject_class
+
+
+def _get_locale() -> str:
+ return LOCALES[random.randint(0, len(LOCALES) - 1)]
+
+
+def _prob(P: int) -> bool:
+ return random.randint(0, 100) < P
+
+
+class TestBootstrap(TestCase):
+ """Tests against legacy user integrations with fake data."""
+
+ @classmethod
+ def setUpClass(cls):
+ """Generate some fake data."""
+ cls.app = Flask('test')
+ cls.app.config['CLASSIC_SESSION_HASH'] = 'foohash'
+ cls.app.config['CLASSIC_COOKIE_NAME'] = 'tapir_session_cookie'
+ cls.app.config['SESSION_DURATION'] = '36000'
+ settings = Settings(
+ CLASSIC_DB_URI='sqlite:///:memory:',
+ LATEXML_DB_URI=None)
+
+ engine, _ = models.configure_db(settings)
+
+ with cls.app.app_context():
+ util.create_all(engine)
+ with transaction() as session:
+ edc = session.execute(select(models.Endorsement)).all()
+ for row in edc:
+ print(row)
+ assert len(edc) == 0, "Expect the table to be empty at the start"
+
+ session.add(models.EndorsementDomain(
+ endorsement_domain='test_domain_bootstrap',
+ endorse_all='n',
+ mods_endorse_all='n',
+ endorse_email='y',
+ papers_to_endorse=3
+ ))
+
+ for category in definitions.CATEGORIES_ACTIVE.keys():
+ if '.' in category:
+ archive, subject_class = category.split('.', 1)
+ else:
+ archive, subject_class = category, ''
+
+ with transaction() as session:
+ #print(f"arch: {archive} sc: {subject_class}")
+ session.add(models.Category(
+ archive=archive,
+ subject_class=subject_class,
+ definitive=1,
+ active=1,
+ endorsement_domain='test_domain_bootstrap'
+ ))
+
+ COUNT = 100
+
+ cls.users = []
+
+ _users = []
+ _domain_users = []
+ for i in range(COUNT):
+ with transaction() as session:
+ locale = _get_locale()
+ person = Person(locale)
+ net = Internet()
+ ip_addr = net.ip_v4()
+ email = person.email()
+ approved = 1 if _prob(90) else 0
+ deleted = 1 if _prob(2) else 0
+ banned = 1 if random.randint(0, 100) <= 1 else 0
+ first_name = person.name()
+ last_name = person.surname()
+ suffix_name = person.title()
+ name = (first_name, last_name, suffix_name)
+ joined_date = util.epoch(
+ Datetime(locale).datetime().replace(tzinfo=EASTERN)
+ )
+ db_user = models.TapirUser(
+ first_name=first_name,
+ last_name=last_name,
+ suffix_name=suffix_name,
+ share_first_name=1,
+ share_last_name=1,
+ email=email,
+ flag_approved=approved,
+ flag_deleted=deleted,
+ flag_banned=banned,
+ flag_edit_users=0,
+ flag_edit_system=0,
+ flag_email_verified=1,
+ share_email=8,
+ email_bouncing=0,
+ policy_class=2, # Public user. TODO: consider admin.
+ joined_date=joined_date,
+ joined_ip_num=ip_addr,
+ joined_remote_host=ip_addr
+ )
+ session.add(db_user)
+
+ # Create a username.
+ username_is_valid = 1 if _prob(90) else 0
+ username = person.username()
+ db_nick = models.TapirNickname(
+ user=db_user,
+ nickname=username,
+ flag_valid=username_is_valid,
+ flag_primary=1
+ )
+
+ # Create the user's profile.
+ archive, subject_class = _random_category()
+ db_profile = models.Demographic(
+ user=db_user,
+ country=locale,
+ affiliation=person.university(),
+ url=net.url(),
+ type=random.randint(1, 5),
+ archive=archive,
+ subject_class=subject_class,
+ original_subject_classes='',
+ flag_group_math=1 if _prob(5) else 0,
+ flag_group_cs=1 if _prob(5) else 0,
+ flag_group_nlin=1 if _prob(5) else 0,
+ flag_group_q_bio=1 if _prob(5) else 0,
+ flag_group_q_fin=1 if _prob(5) else 0,
+ flag_group_stat=1 if _prob(5) else 0
+ )
+
+ # Set the user's password.
+ password = person.password()
+ db_password = models.TapirUsersPassword(
+ user=db_user,
+ password_storage=2,
+ password_enc=hash_password(password)
+ )
+
+ # Create some endorsements.
+ archive, subject_class = _random_category()
+ net_points = 0
+ for _ in range(0, random.randint(1, 4)):
+ etype = random.choice(['auto', 'user', 'admin'])
+ point_value = random.randint(-10, 10)
+ net_points += point_value
+ if len(_users) > 0 and etype == 'auto':
+ endorser_id = random.choice(_users).user_id
+ else:
+ endorser_id = None
+ issued_when = util.epoch(
+ Datetime(locale).datetime().replace(tzinfo=EASTERN)
+ )
+ session.add(models.Endorsement(
+ endorsee=db_user,
+ endorser_id=endorser_id,
+ archive=archive,
+ subject_class=subject_class,
+ flag_valid=1,
+ type=etype,
+ point_value=point_value,
+ issued_when=issued_when
+ ))
+
+ session.add(db_password)
+ session.add(db_nick)
+ session.add(db_profile)
+ _users.append(db_user)
+ _domain_users.append((
+ domain.User(
+ user_id=str(db_user.user_id),
+ username=db_nick.nickname,
+ email=db_user.email,
+ name=domain.UserFullName(
+ forename=db_user.first_name,
+ surname=db_user.last_name,
+ suffix=db_user.suffix_name
+ ),
+ profile=domain.UserProfile.from_orm(db_profile),
+ verified=bool(db_user.flag_email_verified)
+ ),
+ domain.Authorizations(
+ classic=util.compute_capabilities(db_user),
+ )
+ ))
+ session.commit()
+ # We'll use these data to run tests.
+ cls.users.append((
+ email, username, password, name,
+ (archive, subject_class, net_points),
+ (approved, deleted, banned, username_is_valid),
+ ))
+
+
+ def test_authenticate_and_use_session(self):
+ """Attempt to authenticate users and create/load auth sessions."""
+ with self.app.app_context():
+ for datum in self.users:
+ email, username, password, name, endorsement, status = datum
+ approved, deleted, banned, username_is_valid = status
+
+
+ # Banned or deleted users may not log in.
+ if deleted or banned:
+ with self.assertRaises(exceptions.AuthenticationFailed):
+ authenticate.authenticate(email, password)
+ continue
+
+ # Users who are not approved may not log in.
+ elif not approved:
+ with self.assertRaises(exceptions.AuthenticationFailed):
+ authenticate.authenticate(email, password)
+ continue
+
+ # username not valid may not log in
+ elif not username_is_valid:
+ print( f"USERNAME_IS_VALID: {datum}")
+ with self.assertRaises(exceptions.AuthenticationFailed):
+ authenticate.authenticate(email, password)
+ continue
+
+ # Approved users may log in.
+ assert approved and not deleted and not banned and username_is_valid
+ user, auths = authenticate.authenticate(email, password)
+ self.assertIsInstance(user, domain.User,
+ "User data is returned")
+ self.assertEqual(user.email, email,
+ "Email is set correctly")
+ self.assertEqual(user.username, username,
+ "Username is set correctly")
+
+ first_name, last_name, suffix_name = name
+ self.assertEqual(user.name.forename, first_name,
+ "Forename is set correctly")
+ self.assertEqual(user.name.surname, last_name,
+ "Surname is set correctly")
+ self.assertEqual(user.name.suffix, suffix_name,
+ "Suffix is set correctly")
+ self.assertIsInstance(auths, domain.Authorizations,
+ "Authorizations data are returned")
+ # if endorsement[2] > 0:
+ # self.assertTrue(auths.endorsed_for(
+ # domain.Category(
+ # id=f'{endorsement[0]}.{endorsement[1]}',
+ # full_name="fake",
+ # is_active=False,
+ # in_archive=endorsement[0],
+ # is_general=False
+ # )
+ # ), "Endorsements are included in authorizations")
+
+ net = Internet()
+ ip = net.ip_v4()
+ session = sessions.create(auths, ip, ip, user=user)
+ cookie = sessions.generate_cookie(session)
+
+ session_loaded = sessions.load(cookie)
+ self.assertEqual(session.user, session_loaded.user,
+ "Loaded the correct user")
+ self.assertEqual(session.session_id, session_loaded.session_id,
+ "Loaded the correct session")
+
+ # Invalidate 10% of the sessions, and try again.
+ if _prob(10):
+ sessions.invalidate(cookie)
+ with self.assertRaises(exceptions.SessionExpired):
+ sessions.load(cookie)
diff --git a/arxiv/auth/legacy/tests/test_endorsement_auto.py b/arxiv/auth/legacy/tests/test_endorsement_auto.py
new file mode 100644
index 00000000..8aa57d69
--- /dev/null
+++ b/arxiv/auth/legacy/tests/test_endorsement_auto.py
@@ -0,0 +1,325 @@
+"""Tests for :mod:`arxiv.users.legacy.endorsements` using a live test DB."""
+
+import os
+from unittest import TestCase, mock
+from datetime import datetime
+from pytz import timezone, UTC
+
+from flask import Flask
+from sqlalchemy import insert
+from mimesis import Person, Internet, Datetime
+
+from arxiv.db import transaction
+from arxiv.db import models
+from arxiv.config import Settings
+from arxiv.taxonomy import definitions
+from .. import endorsements, util
+from ... import domain
+
+EASTERN = timezone('US/Eastern')
+
+
+class TestAutoEndorsement(TestCase):
+ """Tests for :func:`get_autoendorsements`."""
+
+ def setUp(self):
+ """Generate some fake data."""
+
+ self.app = Flask('test')
+ self.app.config['CLASSIC_SESSION_HASH'] = 'foohash'
+ self.app.config['CLASSIC_COOKIE_NAME'] = 'tapir_session_cookie'
+ self.app.config['SESSION_DURATION'] = '36000'
+ self.app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite://' #in memory
+ settings = Settings(
+ CLASSIC_DB_URI='sqlite:///:memory:',
+ LATEXML_DB_URI=None)
+
+ engine, _ = models.configure_db(settings)
+ self.default_tracking_data = {
+ 'remote_addr': '0.0.0.0',
+ 'remote_host': 'foo-host.foo.com',
+ 'tracking_cookie': '0'
+ }
+
+ with self.app.app_context():
+ util.create_all(engine)
+ with transaction() as session:
+ person = Person('en')
+ net = Internet()
+ ip_addr = net.ip_v4()
+ email = person.email()
+ approved = 1
+ deleted = 0
+ banned = 0
+ first_name = person.name()
+ last_name = person.surname()
+ suffix_name = person.title()
+ joined_date = util.epoch(
+ Datetime('en').datetime().replace(tzinfo=EASTERN)
+ )
+ db_user = models.TapirUser(
+ first_name=first_name,
+ last_name=last_name,
+ suffix_name=suffix_name,
+ share_first_name=1,
+ share_last_name=1,
+ email=email,
+ flag_approved=approved,
+ flag_deleted=deleted,
+ flag_banned=banned,
+ flag_edit_users=0,
+ flag_edit_system=0,
+ flag_email_verified=1,
+ share_email=8,
+ email_bouncing=0,
+ policy_class=2, # Public user. TODO: consider admin.
+ joined_date=joined_date,
+ joined_ip_num=ip_addr,
+ joined_remote_host=ip_addr
+ )
+ session.add(db_user)
+
+ self.user = domain.User(
+ user_id=str(db_user.user_id),
+ username='foouser',
+ email=db_user.email,
+ name=domain.UserFullName(
+ forename=db_user.first_name,
+ surname=db_user.last_name,
+ suffix=db_user.suffix_name
+ )
+ )
+
+
+ def test_invalidated_autoendorsements(self):
+ """The user has two autoendorsements that have been invalidated."""
+ with self.app.app_context():
+ with transaction() as session:
+ issued_when = util.epoch(
+ Datetime('en').datetime().replace(tzinfo=EASTERN)
+ )
+ session.add(models.Endorsement(
+ endorsee_id=self.user.user_id,
+ archive='astro-ph',
+ subject_class='CO',
+ flag_valid=0,
+ type='auto',
+ point_value=10,
+ issued_when=issued_when
+ ))
+ session.add(models.Endorsement(
+ endorsee_id=self.user.user_id,
+ archive='astro-ph',
+ subject_class='CO',
+ flag_valid=0,
+ type='auto',
+ point_value=10,
+ issued_when=issued_when
+ ))
+ session.add(models.Endorsement(
+ endorsee_id=self.user.user_id,
+ archive='astro-ph',
+ subject_class='CO',
+ flag_valid=1,
+ type='auto',
+ point_value=10,
+ issued_when=issued_when
+ ))
+ session.add(models.Endorsement(
+ endorsee_id=self.user.user_id,
+ archive='astro-ph',
+ subject_class='CO',
+ flag_valid=1,
+ type='user',
+ point_value=10,
+ issued_when=issued_when
+ ))
+
+ result = endorsements.invalidated_autoendorsements(self.user)
+ self.assertEqual(len(result), 2, "Two revoked endorsements are loaded")
+
+ def test_category_policies(self):
+ """Load category endorsement policies from the database."""
+ with self.app.app_context():
+ with transaction() as session:
+ session.add(models.Category(
+ archive='astro-ph',
+ subject_class='CO',
+ definitive=1,
+ active=1,
+ endorsement_domain='astro-ph'
+ ))
+ session.add(models.EndorsementDomain(
+ endorsement_domain='astro-ph',
+ endorse_all='n',
+ mods_endorse_all='n',
+ endorse_email='y',
+ papers_to_endorse=3
+ ))
+
+ policies = endorsements.category_policies()
+ category = definitions.CATEGORIES['astro-ph.CO']
+ self.assertIn(category, policies, "Data are loaded for categories")
+ self.assertEqual(policies[category]['domain'], 'astro-ph')
+ self.assertFalse(policies[category]['endorse_all'])
+ self.assertTrue(policies[category]['endorse_email'])
+ self.assertEqual(policies[category]['min_papers'], 3)
+
+ def test_domain_papers(self):
+ """Get the number of papers published in each domain."""
+ with self.app.app_context():
+ with transaction() as session:
+ # User owns three papers.
+ document1 = models.Document(
+ document_id=1,
+ title='Foo Title',
+ submitter_email='foo@bar.baz',
+ paper_id='2101.00123',
+ dated=util.epoch(datetime.now(tz=UTC))
+ )
+ session.add(document1)
+ session.add(models.PaperOwner(
+ document=document1,
+ user_id=self.user.user_id,
+ flag_author=0, # <- User is _not_ an author.
+ valid=1,
+ **self.default_tracking_data
+ ))
+ session.execute(
+ insert(models.t_arXiv_in_category)
+ .values(
+ document_id=1,
+ archive='cs',
+ subject_class='DL',
+ is_primary=1
+ )
+ )
+ session.add(models.Category(
+ archive='cs',
+ subject_class='DL',
+ definitive=1,
+ active=1,
+ endorsement_domain='firstdomain'
+ ))
+ # Here's another paper.
+ document2 = models.Document(
+ document_id=2,
+ title='Foo Title',
+ submitter_email='foo@bar.baz',
+ paper_id='2101.00124',
+ dated=util.epoch(datetime.now(tz=UTC))
+ )
+ session.add(document2)
+ session.add (models.PaperOwner(
+ document=document2,
+ user_id=self.user.user_id,
+ flag_author=1, # <- User is an author.
+ valid=1,
+ **self.default_tracking_data
+ ))
+ session.execute(
+ insert(models.t_arXiv_in_category)
+ .values(
+ document_id=2,
+ archive='cs',
+ subject_class='IR',
+ is_primary=1
+ )
+ )
+ session.add(models.Category(
+ archive='cs',
+ subject_class='IR',
+ definitive=1,
+ active=1,
+ endorsement_domain='firstdomain'
+ ))
+ # Here's a paper for which the user is an author.
+ document3 = models.Document(
+ document_id=3,
+ title='Foo Title',
+ submitter_email='foo@bar.baz',
+ paper_id='2101.00125',
+ dated=util.epoch(datetime.now(tz=UTC))
+ )
+ session.add(document3)
+ session.add (models.PaperOwner(
+ document=document3,
+ user_id=self.user.user_id,
+ flag_author=1,
+ valid=1,
+ **self.default_tracking_data
+ ))
+ # It has both a primary and a secondary classification.
+ session.execute(
+ insert(models.t_arXiv_in_category)
+ .values(
+ document_id=3,
+ archive='astro-ph',
+ subject_class='EP',
+ is_primary=1
+ )
+ )
+ session.execute(
+ insert(models.t_arXiv_in_category)
+ .values(
+ document_id=3,
+ archive='astro-ph',
+ subject_class='CO',
+ is_primary=0 # <- secondary!
+ )
+ )
+ session.add(models.Category(
+ archive='astro-ph',
+ subject_class='EP',
+ definitive=1,
+ active=1,
+ endorsement_domain='seconddomain'
+ ))
+ session.add(models.Category(
+ archive='astro-ph',
+ subject_class='CO',
+ definitive=1,
+ active=1,
+ endorsement_domain='seconddomain'
+ ))
+ papers = endorsements.domain_papers(self.user)
+ self.assertEqual(papers['firstdomain'], 2)
+ self.assertEqual(papers['seconddomain'], 2)
+
+ def test_is_academic(self):
+ """Determine whether a user is academic based on email."""
+ ok_patterns = ['%w3.org', '%aaas.org', '%agu.org', '%ams.org']
+ bad_patterns = ['%.com', '%.net', '%.biz.%']
+ with self.app.app_context():
+ with transaction() as session:
+ for pattern in ok_patterns:
+ session.execute(insert(
+ models.t_arXiv_white_email)
+ .values(pattern=str(pattern))
+ )
+ for pattern in bad_patterns:
+ session.execute(insert(
+ models.t_arXiv_black_email)
+ .values(pattern=str(pattern))
+ )
+
+ self.assertTrue(endorsements.is_academic(domain.User(
+ user_id='2',
+ email='someone@fsu.edu',
+ username='someone'
+ )))
+ self.assertFalse(endorsements.is_academic(domain.User(
+ user_id='2',
+ email='someone@fsu.biz.edu',
+ username='someone'
+ )))
+ self.assertTrue(endorsements.is_academic(domain.User(
+ user_id='2',
+ email='someone@aaas.org',
+ username='someone'
+ )))
+ self.assertFalse(endorsements.is_academic(domain.User(
+ user_id='2',
+ email='someone@foo.com',
+ username='someone'
+ )))
diff --git a/arxiv/auth/legacy/tests/test_endorsements.py b/arxiv/auth/legacy/tests/test_endorsements.py
new file mode 100644
index 00000000..96abe3e3
--- /dev/null
+++ b/arxiv/auth/legacy/tests/test_endorsements.py
@@ -0,0 +1,463 @@
+"""Tests for :mod:`arxiv.users.legacy.endorsements` using a live test DB."""
+
+import os
+from unittest import TestCase, mock
+from datetime import datetime
+from pytz import timezone, UTC
+
+from flask import Flask
+from mimesis import Person, Internet, Datetime
+from sqlalchemy import insert
+
+from arxiv.taxonomy import definitions
+from arxiv.config import Settings
+from arxiv.db import models, transaction
+from .. import endorsements, util
+from ... import domain
+
+EASTERN = timezone('US/Eastern')
+
+
+class TestEndorsement(TestCase):
+ """Tests for :func:`get_endorsements`."""
+
+ def setUp(self):
+ """Generate some fake data."""
+ self.app = Flask('test')
+ self.app.config['CLASSIC_DATABASE_URI'] = 'sqlite:///test.db'
+ self.app.config['CLASSIC_SESSION_HASH'] = 'foohash'
+ settings = Settings(
+ CLASSIC_DB_URI='sqlite:///test.db',
+ LATEXML_DB_URI=None)
+
+ engine, _ = models.configure_db(settings)
+
+ self.default_tracking_data = {
+ 'remote_addr': '0.0.0.0',
+ 'remote_host': 'foo-host.foo.com',
+ 'tracking_cookie': '0'
+ }
+
+ with self.app.app_context():
+ util.create_all(engine)
+ with transaction() as session:
+ person = Person('en')
+ net = Internet()
+ ip_addr = net.ip_v4()
+ email = "foouser@agu.org"
+ approved = 1
+ deleted = 0
+ banned = 0
+ first_name = person.name()
+ last_name = person.surname()
+ suffix_name = person.title()
+ joined_date = util.epoch(
+ Datetime('en').datetime().replace(tzinfo=EASTERN)
+ )
+ db_user = models.TapirUser(
+ first_name=first_name,
+ last_name=last_name,
+ suffix_name=suffix_name,
+ share_first_name=1,
+ share_last_name=1,
+ email=email,
+ flag_approved=approved,
+ flag_deleted=deleted,
+ flag_banned=banned,
+ flag_edit_users=0,
+ flag_edit_system=0,
+ flag_email_verified=1,
+ share_email=8,
+ email_bouncing=0,
+ policy_class=2, # Public user. TODO: consider admin.
+ joined_date=joined_date,
+ joined_ip_num=ip_addr,
+ joined_remote_host=ip_addr
+ )
+ session.add(db_user)
+
+ self.user = domain.User(
+ user_id=str(db_user.user_id),
+ username='foouser',
+ email=db_user.email,
+ name=domain.UserFullName(
+ forename=db_user.first_name,
+ surname=db_user.last_name,
+ suffix=db_user.suffix_name
+ )
+ )
+
+ ok_patterns = ['%w3.org', '%aaas.org', '%agu.org', '%ams.org']
+ bad_patterns = ['%.com', '%.net', '%.biz.%']
+
+ with transaction() as session:
+ for pattern in ok_patterns:
+ session.execute(
+ insert(models.t_arXiv_white_email)
+ .values(pattern=str(pattern))
+ )
+ for pattern in bad_patterns:
+ session.execute(
+ insert(models.t_arXiv_black_email)
+ .values(pattern=str(pattern))
+ )
+
+ session.add(models.EndorsementDomain(
+ endorsement_domain='test_domain',
+ endorse_all='n',
+ mods_endorse_all='n',
+ endorse_email='y',
+ papers_to_endorse=3
+ ))
+
+ # for category, definition in definitions.CATEGORIES_ACTIVE.items():
+ # if '.' in category:
+ # archive, subject_class = category.split('.', 1)
+ # else:
+ # archive, subject_class = category, ''
+ # session.add(models.Category(
+ # archive=archive,
+ # subject_class=subject_class,
+ # definitive=1,
+ # active=1,
+ # endorsement_domain='test_domain'
+ # ))
+
+ def test_get_endorsements(self):
+ """Test :func:`endoresement.get_endorsements`."""
+ with self.app.app_context():
+ with transaction() as session:
+ for category, definition in definitions.CATEGORIES_ACTIVE.items():
+ if '.' in category:
+ archive, subject_class = category.split('.', 1)
+ else:
+ archive, subject_class = category, ''
+ session.add(models.Category(
+ archive=archive,
+ subject_class=subject_class,
+ definitive=1,
+ active=1,
+ endorsement_domain='test_domain'
+ ))
+ all_endorsements = set(
+ endorsements.get_endorsements(self.user)
+ )
+ all_possible = set(definitions.CATEGORIES_ACTIVE.values())
+ self.assertEqual(all_endorsements, all_possible)
+
+ def tearDown(self):
+ """Remove the test DB."""
+ try:
+ os.remove('./test.db')
+ except FileNotFoundError:
+ pass
+
+
+class TestAutoEndorsement(TestCase):
+ """Tests for :func:`get_autoendorsements`."""
+
+ def setUp(self):
+ """Generate some fake data."""
+ self.app = Flask('test')
+ self.app.config['CLASSIC_DATABASE_URI'] = 'sqlite:///test.db'
+ self.app.config['CLASSIC_SESSION_HASH'] = 'foohash'
+ settings = Settings(
+ CLASSIC_DB_URI='sqlite:///test.db',
+ LATEXML_DB_URI=None)
+
+ engine, _ = models.configure_db(settings)
+
+ self.default_tracking_data = {
+ 'remote_addr': '0.0.0.0',
+ 'remote_host': 'foo-host.foo.com',
+ 'tracking_cookie': '0'
+ }
+
+ with self.app.app_context():
+ util.create_all(engine)
+ with transaction() as session:
+ person = Person('en')
+ net = Internet()
+ ip_addr = net.ip_v4()
+ email = person.email()
+ approved = 1
+ deleted = 0
+ banned = 0
+ first_name = person.name()
+ last_name = person.surname()
+ suffix_name = person.title()
+ joined_date = util.epoch(
+ Datetime('en').datetime().replace(tzinfo=EASTERN)
+ )
+ db_user = models.TapirUser(
+ first_name=first_name,
+ last_name=last_name,
+ suffix_name=suffix_name,
+ share_first_name=1,
+ share_last_name=1,
+ email=email,
+ flag_approved=approved,
+ flag_deleted=deleted,
+ flag_banned=banned,
+ flag_edit_users=0,
+ flag_edit_system=0,
+ flag_email_verified=1,
+ share_email=8,
+ email_bouncing=0,
+ policy_class=2, # Public user. TODO: consider admin.
+ joined_date=joined_date,
+ joined_ip_num=ip_addr,
+ joined_remote_host=ip_addr
+ )
+ session.add(db_user)
+
+ self.user = domain.User(
+ user_id=str(db_user.user_id),
+ username='foouser',
+ email=db_user.email,
+ name=domain.UserFullName(
+ forename=db_user.first_name,
+ surname=db_user.last_name,
+ suffix=db_user.suffix_name
+ )
+ )
+
+ def tearDown(self):
+ """Remove the test DB."""
+ try:
+ os.remove('./test.db')
+ except FileNotFoundError:
+ pass
+
+ def test_invalidated_autoendorsements(self):
+ """The user has two autoendorsements that have been invalidated."""
+ with self.app.app_context():
+ with transaction() as session:
+ issued_when = util.epoch(
+ Datetime('en').datetime().replace(tzinfo=EASTERN)
+ )
+ session.add(models.Endorsement(
+ endorsee_id=self.user.user_id,
+ archive='astro-ph',
+ subject_class='CO',
+ flag_valid=0,
+ type='auto',
+ point_value=10,
+ issued_when=issued_when
+ ))
+ session.add(models.Endorsement(
+ endorsee_id=self.user.user_id,
+ archive='astro-ph',
+ subject_class='CO',
+ flag_valid=0,
+ type='auto',
+ point_value=10,
+ issued_when=issued_when
+ ))
+ session.add(models.Endorsement(
+ endorsee_id=self.user.user_id,
+ archive='astro-ph',
+ subject_class='CO',
+ flag_valid=1,
+ type='auto',
+ point_value=10,
+ issued_when=issued_when
+ ))
+ session.add(models.Endorsement(
+ endorsee_id=self.user.user_id,
+ archive='astro-ph',
+ subject_class='CO',
+ flag_valid=1,
+ type='user',
+ point_value=10,
+ issued_when=issued_when
+ ))
+
+ result = endorsements.invalidated_autoendorsements(self.user)
+ self.assertEqual(len(result), 2, "Two revoked endorsements are loaded")
+
+ def test_category_policies(self):
+ """Load category endorsement policies from the database."""
+ with self.app.app_context():
+ with transaction() as session:
+ session.add(models.Category(
+ archive='astro-ph',
+ subject_class='CO',
+ definitive=1,
+ active=1,
+ endorsement_domain='astro-ph'
+ ))
+ session.add(models.EndorsementDomain(
+ endorsement_domain='astro-ph',
+ endorse_all='n',
+ mods_endorse_all='n',
+ endorse_email='y',
+ papers_to_endorse=3
+ ))
+
+ policies = endorsements.category_policies()
+ category = definitions.CATEGORIES['astro-ph.CO']
+ self.assertIn(category, policies, "Data are loaded for categories")
+ self.assertEqual(policies[category]['domain'], 'astro-ph')
+ self.assertFalse(policies[category]['endorse_all'])
+ self.assertTrue(policies[category]['endorse_email'])
+ self.assertEqual(policies[category]['min_papers'], 3)
+
+ def test_domain_papers(self):
+ """Get the number of papers published in each domain."""
+ with self.app.app_context():
+ with transaction() as session:
+ # User owns three papers.
+ document1 = models.Document(
+ document_id=1,
+ title='Foo Title',
+ submitter_email='foo@bar.baz',
+ paper_id='2101.00123',
+ dated=util.epoch(datetime.now(tz=UTC))
+ )
+ session.add(document1)
+ session.add(models.PaperOwner(
+ document=document1,
+ user_id=self.user.user_id,
+ flag_author=0,
+ valid=1,
+ **self.default_tracking_data
+ ))
+ session.execute(
+ insert(models.t_arXiv_in_category)
+ .values(
+ document_id=1,
+ archive='cs',
+ subject_class='DL',
+ is_primary=1
+ )
+ )
+ session.add(models.Category(
+ archive='cs',
+ subject_class='DL',
+ definitive=1,
+ active=1,
+ endorsement_domain='firstdomain'
+ ))
+ # Here's another paper.
+ document2 = models.Document(
+ document_id=2,
+ title='Foo Title',
+ submitter_email='foo@bar.baz',
+ paper_id='2101.00124',
+ dated=util.epoch(datetime.now(tz=UTC))
+ )
+ session.add(document2)
+ session.add(models.PaperOwner(
+ document=document2,
+ user_id=self.user.user_id,
+ flag_author=1,
+ valid=1,
+ **self.default_tracking_data
+ ))
+ session.execute(
+ insert(models.t_arXiv_in_category)
+ .values(
+ document_id=2,
+ archive='cs',
+ subject_class='IR',
+ is_primary=1
+ )
+ )
+ session.add(models.Category(
+ archive='cs',
+ subject_class='IR',
+ definitive=1,
+ active=1,
+ endorsement_domain='firstdomain'
+ ))
+ # Here's a paper for which the user is an author.
+ document3 = models.Document(
+ document_id=3,
+ title='Foo Title',
+ submitter_email='foo@bar.baz',
+ paper_id='2101.00125',
+ dated=util.epoch(datetime.now(tz=UTC))
+ )
+ session.add(document3)
+ session.add(models.PaperOwner(
+ document=document3,
+ user_id=self.user.user_id,
+ flag_author=1,
+ valid=1,
+ **self.default_tracking_data
+ ))
+ # It has both a primary and a secondary classification.
+ session.execute(
+ insert(models.t_arXiv_in_category)
+ .values(
+ document_id=3,
+ archive='astro-ph',
+ subject_class='EP',
+ is_primary=1
+ )
+ )
+ session.execute(
+ insert(models.t_arXiv_in_category)
+ .values(
+ document_id=3,
+ archive='astro-ph',
+ subject_class='CO',
+ is_primary=0 # <- secondary!
+ )
+ )
+ session.add(models.Category(
+ archive='astro-ph',
+ subject_class='EP',
+ definitive=1,
+ active=1,
+ endorsement_domain='seconddomain'
+ ))
+ session.add(models.Category(
+ archive='astro-ph',
+ subject_class='CO',
+ definitive=1,
+ active=1,
+ endorsement_domain='seconddomain'
+ ))
+ papers = endorsements.domain_papers(self.user)
+ self.assertEqual(papers['firstdomain'], 2)
+ self.assertEqual(papers['seconddomain'], 2)
+
+ def test_is_academic(self):
+ """Determine whether a user is academic based on email."""
+ ok_patterns = ['%w3.org', '%aaas.org', '%agu.org', '%ams.org']
+ bad_patterns = ['%.com', '%.net', '%.biz.%']
+ with self.app.app_context():
+ with transaction() as session:
+ for pattern in ok_patterns:
+ session.execute(
+ insert(models.t_arXiv_white_email)
+ .values(pattern=str(pattern))
+ )
+ for pattern in bad_patterns:
+ session.execute(
+ insert(models.t_arXiv_black_email)
+ .values(pattern=str(pattern))
+ )
+
+ self.assertTrue(endorsements.is_academic(domain.User(
+ user_id='2',
+ email='someone@fsu.edu',
+ username='someone'
+ )))
+ self.assertFalse(endorsements.is_academic(domain.User(
+ user_id='2',
+ email='someone@fsu.biz.edu',
+ username='someone'
+ )))
+ self.assertTrue(endorsements.is_academic(domain.User(
+ user_id='2',
+ email='someone@aaas.org',
+ username='someone'
+ )))
+ self.assertFalse(endorsements.is_academic(domain.User(
+ user_id='2',
+ email='someone@foo.com',
+ username='someone'
+ )))
diff --git a/arxiv/auth/legacy/tests/test_passwords.py b/arxiv/auth/legacy/tests/test_passwords.py
new file mode 100644
index 00000000..9c8a97fe
--- /dev/null
+++ b/arxiv/auth/legacy/tests/test_passwords.py
@@ -0,0 +1,30 @@
+"""Tests for :mod:`legacy_users.passwords`."""
+
+from unittest import TestCase
+from ..exceptions import PasswordAuthenticationFailed
+from ..passwords import hash_password, check_password
+
+from hypothesis import given, settings
+from hypothesis import strategies as st
+import string
+
+
+class TestCheckPassword(TestCase):
+ """Tests passwords."""
+ @given(st.text(alphabet=string.printable))
+ @settings(max_examples=500)
+ def test_check_passwords_successful(self, passw):
+ encrypted = hash_password(passw)
+ self.assertTrue( check_password(passw, encrypted.encode('ascii')),
+ f"should work for password '{passw}'")
+
+ @given(st.text(alphabet=string.printable), st.text(alphabet=st.characters()))
+ @settings(max_examples=5000)
+ def test_check_passwords_fuzz(self, passw, fuzzpw):
+ if passw == fuzzpw:
+ self.assertTrue(check_password(fuzzpw,
+ hash_password(passw).encode('ascii')))
+ else:
+ with self.assertRaises(PasswordAuthenticationFailed):
+ check_password(fuzzpw,
+ hash_password(passw).encode('ascii'))
diff --git a/arxiv/auth/legacy/tests/test_sessions.py b/arxiv/auth/legacy/tests/test_sessions.py
new file mode 100644
index 00000000..90111145
--- /dev/null
+++ b/arxiv/auth/legacy/tests/test_sessions.py
@@ -0,0 +1,104 @@
+"""Tests for legacy_users service."""
+import time
+from unittest import mock, TestCase
+from datetime import datetime
+from pytz import timezone, UTC
+
+from arxiv.db import models, transaction
+from .. import exceptions, sessions, util, cookies
+
+from .util import temporary_db
+
+EASTERN = timezone('US/Eastern')
+
+
+class TestCreateSession(TestCase):
+ """Tests for public function :func:`.`."""
+
+ @mock.patch(f'{sessions.__name__}.util.get_session_duration')
+ def test_create(self, mock_get_session_duration):
+ """Accept a :class:`.User` and returns a :class:`.Session`."""
+ mock_get_session_duration.return_value = 36000
+ user = sessions.domain.User(
+ user_id="1",
+ username='theuser',
+ email='the@user.com',
+ )
+ auths = sessions.domain.Authorizations(classic=6)
+ ip_address = '127.0.0.1'
+ remote_host = 'foo-host.foo.com'
+ tracking = "1.foo"
+ with temporary_db('sqlite:///:memory:', create=True):
+ user_session = sessions.create(auths, ip_address, remote_host,
+ tracking, user=user)
+ self.assertIsInstance(user_session, sessions.domain.Session)
+ tapir_session = sessions._load(user_session.session_id)
+ self.assertIsNotNone(user_session, 'verifying we have a session')
+ if tapir_session is not None:
+ self.assertEqual(
+ tapir_session.session_id,
+ int(user_session.session_id),
+ "Returned session has correct session id."
+ )
+ self.assertEqual(tapir_session.user_id, int(user.user_id),
+ "Returned session has correct user id.")
+ self.assertEqual(tapir_session.end_time, 0,
+ "End time is 0 (no end time)")
+
+ tapir_session_audit = sessions._load_audit(user_session.session_id)
+ self.assertIsNotNone(tapir_session_audit)
+ if tapir_session_audit is not None:
+ self.assertEqual(
+ tapir_session_audit.session_id,
+ int(user_session.session_id),
+ "Returned session audit has correct session id."
+ )
+ self.assertEqual(
+ tapir_session_audit.ip_addr,
+ user_session.ip_address,
+ "Returned session audit has correct ip address"
+ )
+ self.assertEqual(
+ tapir_session_audit.remote_host,
+ user_session.remote_host,
+ "Returned session audit has correct remote host"
+ )
+
+
+class TestInvalidateSession(TestCase):
+ """Tests for public function :func:`.invalidate`."""
+
+ @mock.patch(f'{cookies.__name__}.util.get_session_duration')
+ def test_invalidate(self, mock_get_duration):
+ """The session is invalidated by setting `end_time`."""
+ mock_get_duration.return_value = 36000
+ session_id = "424242424"
+ user_id = "12345"
+ ip = "127.0.0.1"
+ capabilities = 6
+ start = datetime.now(tz=UTC)
+
+ with temporary_db('sqlite:///:memory:') as db_session:
+ cookie = cookies.pack(session_id, user_id, ip, start, capabilities)
+ with transaction() as db_session:
+ tapir_session = models.TapirSession(
+ session_id=session_id,
+ user_id=12345,
+ last_reissue=util.epoch(start),
+ start_time=util.epoch(start),
+ end_time=0
+ )
+ db_session.add(tapir_session)
+
+ sessions.invalidate(cookie)
+ tapir_session = sessions._load(session_id)
+ time.sleep(1)
+ self.assertGreaterEqual(util.now(), tapir_session.end_time)
+
+ @mock.patch(f'{cookies.__name__}.util.get_session_duration')
+ def test_invalidate_nonexistant_session(self, mock_get_duration):
+ """An exception is raised if the session doesn't exist."""
+ mock_get_duration.return_value = 36000
+ with temporary_db('sqlite:///:memory:'):
+ with self.assertRaises(exceptions.UnknownSession):
+ sessions.invalidate('1:1:10.10.10.10:1531145500:4')
diff --git a/arxiv/auth/legacy/tests/test_util.py b/arxiv/auth/legacy/tests/test_util.py
new file mode 100644
index 00000000..1ff69488
--- /dev/null
+++ b/arxiv/auth/legacy/tests/test_util.py
@@ -0,0 +1,32 @@
+"""Tests for :mod:`legacy_users.util`."""
+
+from unittest import TestCase
+from arxiv.db import models
+from .util import temporary_db
+from .. import util, sessions
+
+class TestGetSession(TestCase):
+ """
+ Tests for private function :func:`._load`.
+
+ Gets a :class:`.TapirSession` given a session ID.
+ """
+
+ def test_load_returns_a_session(self) -> None:
+ """If ID matches a known session, returns a :class:`.TapirSession`."""
+ session_id = "424242424"
+ with temporary_db('sqlite:///:memory:') as db_session:
+ start = util.now()
+ db_session.add(models.TapirSession(
+ session_id=session_id,
+ user_id=12345,
+ last_reissue=start,
+ start_time=start,
+ end_time=0
+ ))
+ db_session.commit()
+ tapir_session = sessions._load(session_id)
+ self.assertIsNotNone(tapir_session, 'verifying we have a session')
+ self.assertEqual(tapir_session.session_id, int(session_id),
+ "Returned session has correct session id.")
+
diff --git a/arxiv/auth/legacy/tests/util.py b/arxiv/auth/legacy/tests/util.py
new file mode 100644
index 00000000..eab847fd
--- /dev/null
+++ b/arxiv/auth/legacy/tests/util.py
@@ -0,0 +1,33 @@
+"""Testing helpers."""
+from contextlib import contextmanager
+
+from flask import Flask
+
+from arxiv.config import Settings
+from arxiv.db import transaction
+from arxiv.db.models import configure_db
+from .. import util
+
+
+@contextmanager
+def temporary_db(db_uri: str, create: bool = True, drop: bool = True):
+ """Provide an in-memory sqlite database for testing purposes."""
+ app = Flask('foo')
+ app.config['CLASSIC_SESSION_HASH'] = 'foohash'
+ app.config['SESSION_DURATION'] = 3600
+ app.config['CLASSIC_COOKIE_NAME'] = 'tapir_session'
+ settings = Settings(
+ CLASSIC_DB_URI=db_uri,
+ LATEXML_DB_URI=None)
+
+ engine, _ = configure_db (settings)
+
+ with app.app_context():
+ if create:
+ util.create_all(engine)
+ try:
+ with transaction() as session:
+ yield session
+ finally:
+ if drop:
+ util.drop_all(engine)
diff --git a/arxiv/auth/legacy/util.py b/arxiv/auth/legacy/util.py
new file mode 100644
index 00000000..c718927f
--- /dev/null
+++ b/arxiv/auth/legacy/util.py
@@ -0,0 +1,103 @@
+"""Helpers and Flask application integration."""
+
+from typing import List, Any
+from datetime import datetime
+from pytz import timezone, UTC
+import logging
+
+from sqlalchemy import text, Engine
+
+from ...base.globals import get_application_config
+
+from ..auth import scopes
+from .. import domain
+from ...db import session, Base, SessionLocal
+from ...db.models import TapirUser, TapirPolicyClass
+
+EASTERN = timezone('US/Eastern')
+logger = logging.getLogger(__name__)
+logger.propagate = False
+
+
+def now() -> int:
+ """Get the current epoch/unix time."""
+ return epoch(datetime.now(tz=UTC))
+
+
+def epoch(t: datetime) -> int:
+ """Convert a :class:`.datetime` to UNIX time."""
+ delta = t - datetime.fromtimestamp(0, tz=EASTERN)
+ return int(round((delta).total_seconds()))
+
+
+def from_epoch(t: int) -> datetime:
+ """Get a :class:`datetime` from an UNIX timestamp."""
+ return datetime.fromtimestamp(t, tz=EASTERN)
+
+
+def create_all(engine: Engine) -> None:
+ """Create all tables in the database."""
+ Base.metadata.create_all(engine)
+ with SessionLocal() as session:
+ data = session.query(TapirPolicyClass).all()
+ if data:
+ return
+
+ for datum in TapirPolicyClass.POLICY_CLASSES:
+ session.add(TapirPolicyClass(**datum))
+ session.commit()
+
+def drop_all(engine: Engine) -> None:
+ """Drop all tables in the database."""
+ Base.metadata.drop_all(engine)
+
+
+def compute_capabilities(tapir_user: TapirUser) -> int:
+ """Calculate the privilege level code for a user."""
+ return int(sum([2 * tapir_user.flag_edit_users,
+ 4 * tapir_user.flag_email_verified,
+ 8 * tapir_user.flag_edit_system]))
+
+
+def get_scopes(db_user: TapirUser) -> List[domain.Scope]:
+ """Generate a list of authz scopes for a legacy user based on class."""
+ if db_user.policy_class == TapirPolicyClass.PUBLIC_USER:
+ return scopes.GENERAL_USER
+ if db_user.policy_class == TapirPolicyClass.ADMIN:
+ return scopes.ADMIN_USER
+ return []
+
+
+def is_configured() -> bool:
+ """Determine whether or not the legacy auth is configured of the `Flask` app."""
+ config = get_application_config()
+ return not bool(missing_configs(config))
+
+def missing_configs(config) -> List[str]:
+ """Returns missing keys for configs needed in `Flask.config` for legacy auth to work."""
+ missing = [key for key in ['CLASSIC_SESSION_HASH', 'SESSION_DURATION', 'CLASSIC_COOKIE_NAME']
+ if key not in config]
+ return missing
+
+def get_session_hash() -> str:
+ """Get the legacy hash secret."""
+ config = get_application_config()
+ session_hash: str = config['CLASSIC_SESSION_HASH']
+ return session_hash
+
+
+def get_session_duration() -> int:
+ """Get the session duration from the config."""
+ config = get_application_config()
+ timeout: str = config['SESSION_DURATION']
+ return int(timeout)
+
+
+def is_available(**kwargs: Any) -> bool:
+ """Check our connection to the database."""
+ try:
+ session.query("1").from_statement(text("SELECT 1")).all()
+ except Exception as e:
+ logger.error('Encountered an error talking to database: %s', e)
+ return False
+ return True
diff --git a/arxiv/auth/tests/__init__.py b/arxiv/auth/tests/__init__.py
new file mode 100644
index 00000000..1f835616
--- /dev/null
+++ b/arxiv/auth/tests/__init__.py
@@ -0,0 +1 @@
+"""Package-level tests for :mod:`arxiv.users`."""
diff --git a/arxiv/auth/tests/test_domain.py b/arxiv/auth/tests/test_domain.py
new file mode 100644
index 00000000..f34dd0a1
--- /dev/null
+++ b/arxiv/auth/tests/test_domain.py
@@ -0,0 +1,49 @@
+"""Tests for :mod:`arxiv.users.domain`."""
+
+from unittest import TestCase
+from datetime import datetime
+from arxiv.taxonomy import definitions
+from pytz import timezone
+from ..auth import scopes
+from ..auth import domain
+
+EASTERN = timezone('US/Eastern')
+
+
+class TestSession(TestCase):
+ def test_with_session(self):
+ session = domain.Session(
+ session_id='asdf1234',
+ start_time=datetime.now(), end_time=datetime.now(),
+ user=domain.User(
+ user_id='12345',
+ email='foo@bar.com',
+ username='emanresu',
+ name=domain.UserFullName(forename='First', surname='Last', suffix='Lastest'),
+ profile=domain.UserProfile(
+ affiliation='FSU',
+ rank=3,
+ country='us',
+ default_category=definitions.CATEGORIES['astro-ph.CO'],
+ submission_groups=['grp_physics']
+ )
+ ),
+ authorizations=domain.Authorizations(
+ scopes=[scopes.VIEW_SUBMISSION, scopes.CREATE_SUBMISSION],
+ )
+ )
+ session_data = session.dict()
+ self.assertEqual(session_data['authorizations']['scopes'],
+ ['submission:read','submission:create'])
+
+ self.assertEqual(session_data['user']['profile']['affiliation'], 'FSU')
+ self.assertEqual(session_data['user']['profile']['country'], 'us')
+ self.assertEqual(session_data['user']['profile']['submission_groups'], ['grp_physics'])
+ self.assertEqual(session_data['user']['profile']['default_category']['id'], 'astro-ph.CO')
+ self.assertEqual(
+ session_data['user']['name'],
+ {'forename': 'First', 'surname': 'Last', 'suffix': 'Lastest'}
+ )
+
+ as_session = domain.session_from_dict(session_data)
+ self.assertEqual(session, as_session)
diff --git a/arxiv/auth/tests/test_helpers.py b/arxiv/auth/tests/test_helpers.py
new file mode 100644
index 00000000..8e58ccaa
--- /dev/null
+++ b/arxiv/auth/tests/test_helpers.py
@@ -0,0 +1,63 @@
+"""Tests for :mod:`.helpers`."""
+
+from unittest import TestCase, mock
+import os
+import logging
+
+from flask import Flask
+
+from arxiv import status
+from arxiv.config import Settings
+from arxiv.db.models import configure_db
+from arxiv.base import Base
+from arxiv.base.middleware import wrap
+
+from .. import auth, helpers, legacy
+from ..auth import decorators
+from ..auth.middleware import AuthMiddleware
+from ..auth.scopes import VIEW_SUBMISSION, CREATE_SUBMISSION, EDIT_SUBMISSION
+
+
+class TestGenerateToken(TestCase):
+ """Tests for :func:`.helpers.generate_token`."""
+
+ def test_token_is_usable(self):
+ """Verify that :func:`.helpers.generate_token` makes usable tokens."""
+ os.environ['JWT_SECRET'] = 'thesecret'
+ scope = [VIEW_SUBMISSION, EDIT_SUBMISSION, CREATE_SUBMISSION]
+ token = helpers.generate_token("1234", "user@foo.com", "theuser",
+ scope=scope)
+
+ app = Flask('test')
+ app.config['CLASSIC_SESSION_HASH'] = 'foohash'
+ app.config['CLASSIC_COOKIE_NAME'] = 'tapir_session_cookie'
+ app.config['SESSION_DURATION'] = '36000'
+ app.config['AUTH_UPDATED_SESSION_REF'] = True
+
+ app.config.update({
+ 'JWT_SECRET': 'thesecret',
+ })
+
+ settings = Settings(
+ CLASSIC_DB_URI='sqlite:///:memory:',
+ LATEXML_DB_URI=None)
+
+ configure_db (settings)
+
+ Base(app)
+ auth.Auth(app)
+ wrap(app, [AuthMiddleware])
+
+ @app.route('/')
+ @decorators.scoped(EDIT_SUBMISSION)
+ def protected():
+ return "this is protected"
+
+ client = app.test_client()
+ with app.app_context():
+ response = client.get('/')
+ self.assertEqual(response.status_code,
+ status.HTTP_401_UNAUTHORIZED)
+
+ response = client.get('/', headers={'Authorization': token})
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
diff --git a/arxiv/base/__init__.py b/arxiv/base/__init__.py
index 64da60c4..0e99b73e 100644
--- a/arxiv/base/__init__.py
+++ b/arxiv/base/__init__.py
@@ -129,4 +129,6 @@ def register_blueprint(self: Flask, blueprint: Blueprint,
# See: https://github.com/pallets-eco/flask-sqlalchemy/blob/42a36a3cb604fd39d81d00b54ab3988bbd0ad184/src/flask_sqlalchemy/session.py#L109
@app.teardown_appcontext
def remove_scoped_session (response_or_exc: BaseException | None) -> None:
+ if response_or_exc:
+ session.rollback()
session.remove()
diff --git a/arxiv/config/__init__.py b/arxiv/config/__init__.py
index e5a0fc01..ee038ea0 100644
--- a/arxiv/config/__init__.py
+++ b/arxiv/config/__init__.py
@@ -170,7 +170,7 @@ class Settings(BaseSettings):
TRACKBACK_SECRET: SecretStr = SecretStr(token_hex(10))
CLASSIC_DB_URI: str = DEFAULT_DB
- LATEXML_DB_URI: str = DEFAULT_LATEXML_DB
+ LATEXML_DB_URI: Optional[str] = DEFAULT_LATEXML_DB
ECHO_SQL: bool = False
CLASSIC_DB_TRANSACTION_ISOLATION_LEVEL: Optional[IsolationLevel] = None
LATEXML_DB_TRANSACTION_ISOLATION_LEVEL: Optional[IsolationLevel] = None
diff --git a/arxiv/db/__init__.py b/arxiv/db/__init__.py
index 05763471..7bbc0072 100644
--- a/arxiv/db/__init__.py
+++ b/arxiv/db/__init__.py
@@ -40,7 +40,7 @@
from flask.globals import app_ctx
from flask import has_app_context
-from sqlalchemy import create_engine, MetaData
+from sqlalchemy import Engine, MetaData
from sqlalchemy.event import listens_for
from sqlalchemy.orm import sessionmaker, scoped_session, DeclarativeBase
@@ -56,18 +56,7 @@ class LaTeXMLBase(DeclarativeBase):
logger = logging.getLogger(__name__)
-engine = create_engine(settings.CLASSIC_DB_URI,
- echo=settings.ECHO_SQL,
- isolation_level=settings.CLASSIC_DB_TRANSACTION_ISOLATION_LEVEL,
- pool_recycle=600,
- max_overflow=(settings.REQUEST_CONCURRENCY - 5), # max overflow is how many + base pool size, which is 5 by default
- pool_pre_ping=settings.POOL_PRE_PING)
-latexml_engine = create_engine(settings.LATEXML_DB_URI,
- echo=settings.ECHO_SQL,
- isolation_level=settings.LATEXML_DB_TRANSACTION_ISOLATION_LEVEL,
- pool_recycle=600,
- max_overflow=(settings.REQUEST_CONCURRENCY - 5),
- pool_pre_ping=settings.POOL_PRE_PING)
+
SessionLocal = sessionmaker(autocommit=False, autoflush=False)
def _app_ctx_id () -> int:
@@ -75,6 +64,9 @@ def _app_ctx_id () -> int:
session = scoped_session(SessionLocal, scopefunc=_app_ctx_id)
+def get_engine () -> Engine:
+ return SessionLocal().get_bind(Base)
+
@contextmanager
def get_db ():
db = SessionLocal()
@@ -93,19 +85,20 @@ def transaction ():
if db.new or db.dirty or db.deleted:
db.commit()
except Exception as e:
- logger.warn(f'Commit failed, rolling back', exc_info=1)
+ logger.warning(f'Commit failed, rolling back', exc_info=1)
db.rollback()
+ raise
finally:
if not in_flask:
db.close()
-def config_query_timing( slightly_long_sec: float, long_sec: float):
- @listens_for(engine, "before_cursor_execute")
+def config_query_timing(slightly_long_sec: float, long_sec: float):
+ @listens_for(get_engine(), "before_cursor_execute")
def _record_query_start (conn, cursor, statement, parameters, context, executemany):
conn.info['query_start'] = datetime.now()
- @listens_for(engine, "after_cursor_execute")
+ @listens_for(get_engine(), "after_cursor_execute")
def _calculate_query_run_time (conn, cursor, statement, parameters, context, executemany):
if conn.info.get('query_start'):
delta: timedelta = (datetime.now() - conn.info['query_start'])
diff --git a/arxiv/db/models.py b/arxiv/db/models.py
index 8259daa6..a8b8ac09 100644
--- a/arxiv/db/models.py
+++ b/arxiv/db/models.py
@@ -3,7 +3,7 @@
This file was generated by using sqlacodgen
"""
-from typing import Optional, Literal, Any
+from typing import Optional, Literal, Any, Tuple, List
import re
import os
import hashlib
@@ -30,7 +30,9 @@
Text,
Table,
Enum,
- text
+ text,
+ create_engine,
+ Engine
)
from sqlalchemy.schema import FetchedValue
from sqlalchemy.orm import (
@@ -39,15 +41,67 @@
relationship
)
-from ..config import settings
+from ..config import Settings, settings
from . import Base, LaTeXMLBase, metadata, \
- SessionLocal, engine, latexml_engine
+ SessionLocal
from .types import intpk
tb_secret = settings.TRACKBACK_SECRET
tz = gettz(settings.ARXIV_BUSINESS_TZ)
+def configure_db (base_settings: Settings) -> Tuple[Engine, Optional[Engine]]:
+ if 'sqlite' in base_settings.CLASSIC_DB_URI:
+ engine = create_engine(base_settings.CLASSIC_DB_URI)
+ if base_settings.LATEXML_DB_URI:
+ latexml_engine = create_engine(base_settings.LATEXML_DB_URI)
+ else:
+ latexml_engine = None
+ else:
+ engine = create_engine(base_settings.CLASSIC_DB_URI,
+ echo=base_settings.ECHO_SQL,
+ isolation_level=base_settings.CLASSIC_DB_TRANSACTION_ISOLATION_LEVEL,
+ pool_recycle=600,
+ max_overflow=(base_settings.REQUEST_CONCURRENCY - 5), # max overflow is how many + base pool size, which is 5 by default
+ pool_pre_ping=base_settings.POOL_PRE_PING)
+ if base_settings.LATEXML_DB_URI:
+ latexml_engine = create_engine(base_settings.LATEXML_DB_URI,
+ echo=base_settings.ECHO_SQL,
+ isolation_level=base_settings.LATEXML_DB_TRANSACTION_ISOLATION_LEVEL,
+ pool_recycle=600,
+ max_overflow=(base_settings.REQUEST_CONCURRENCY - 5),
+ pool_pre_ping=base_settings.POOL_PRE_PING)
+ else:
+ latexml_engine = None
+ SessionLocal.configure(binds={
+ Base: engine,
+ LaTeXMLBase: (latexml_engine if latexml_engine else engine),
+ t_arXiv_stats_hourly: engine,
+ t_arXiv_admin_state: engine,
+ t_arXiv_bad_pw: engine,
+ t_arXiv_black_email: engine,
+ t_arXiv_block_email: engine,
+ t_arXiv_bogus_subject_class: engine,
+ t_arXiv_duplicates: engine,
+ t_arXiv_in_category: engine,
+ t_arXiv_moderators: engine,
+ t_arXiv_ownership_requests_papers: engine,
+ t_arXiv_refresh_list: engine,
+ t_arXiv_updates_tmp: engine,
+ t_arXiv_white_email: engine,
+ t_arXiv_xml_notifications: engine,
+ t_demographics_backup: engine,
+ t_tapir_email_change_tokens_used: engine,
+ t_tapir_email_tokens_used: engine,
+ t_tapir_error_log: engine,
+ t_tapir_no_cookies: engine,
+ t_tapir_periodic_tasks_log: engine,
+ t_tapir_periodic_tasks_log: engine,
+ t_tapir_permanent_tokens_used: engine,
+ t_tapir_save_post_variables: engine
+ })
+ return engine, latexml_engine
+
class MemberInstitution(Base):
__tablename__ = 'Subscription_UniversalInstitution'
@@ -97,7 +151,7 @@ class AdminLog(Base):
id: Mapped[intpk]
logtime: Mapped[Optional[str]] = mapped_column(String(24))
- created: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=FetchedValue())
+ created: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=text('CURRENT_TIMESTAMP'), server_onupdate=text('CURRENT_TIMESTAMP'))
paper_id: Mapped[Optional[str]] = mapped_column(String(20), index=True)
username: Mapped[Optional[str]] = mapped_column(String(20), index=True)
host: Mapped[Optional[str]] = mapped_column(String(64))
@@ -294,17 +348,18 @@ class Category(Base):
archive: Mapped[str] = mapped_column(ForeignKey('arXiv_archives.archive_id'), primary_key=True, nullable=False, server_default=FetchedValue())
subject_class: Mapped[str] = mapped_column(String(16), primary_key=True, nullable=False, server_default=FetchedValue())
- definitive: Mapped[int] = mapped_column(Integer, nullable=False, server_default=FetchedValue())
- active: Mapped[int] = mapped_column(Integer, nullable=False, server_default=FetchedValue())
+ definitive: Mapped[int] = mapped_column(Integer, nullable=False, server_default=text("'0'"))
+ active: Mapped[int] = mapped_column(Integer, nullable=False, server_default=text("'0'"))
category_name: Mapped[Optional[str]]
- endorse_all: Mapped[Literal['y', 'n', 'd']] = mapped_column(Enum('y', 'n', 'd'), nullable=False, server_default=FetchedValue())
- endorse_email: Mapped[Literal['y', 'n', 'd']] = mapped_column(Enum('y', 'n', 'd'), nullable=False, server_default=FetchedValue())
- papers_to_endorse: Mapped[int] = mapped_column(SmallInteger, nullable=False, server_default=FetchedValue())
+ endorse_all: Mapped[Literal['y', 'n', 'd']] = mapped_column(Enum('y', 'n', 'd'), nullable=False, server_default=text("'d'"))
+ endorse_email: Mapped[Literal['y', 'n', 'd']] = mapped_column(Enum('y', 'n', 'd'), nullable=False, server_default=text("'d'"))
+ papers_to_endorse: Mapped[int] = mapped_column(SmallInteger, nullable=False, server_default=text("'0'"))
endorsement_domain: Mapped[Optional[str]] = mapped_column(ForeignKey('arXiv_endorsement_domains.endorsement_domain'), index=True)
arXiv_archive = relationship('Archive', primaryjoin='Category.archive == Archive.archive_id', backref='arXiv_categories')
+ arXiv_endorsements = relationship('Endorsement', back_populates='arXiv_categories')
arXiv_endorsement_domain = relationship('EndorsementDomain', primaryjoin='Category.endorsement_domain == EndorsementDomain.endorsement_domain', backref='arXiv_categories')
-
+ arXiv_endorsement_requests = relationship('EndorsementRequest', back_populates='arXiv_categories')
class QuestionableCategory(Category):
__tablename__ = 'arXiv_questionable_categories'
@@ -342,8 +397,8 @@ class ControlHold(Base):
placed_by: Mapped[Optional[int]] = mapped_column(ForeignKey('tapir_users.user_id'), index=True)
last_changed_by: Mapped[Optional[int]] = mapped_column(ForeignKey('tapir_users.user_id'), index=True)
- tapir_user = relationship('TapirUser', primaryjoin='ControlHold.last_changed_by == TapirUser.user_id', backref='tapiruser_arXiv_control_holds')
- tapir_user1 = relationship('TapirUser', primaryjoin='ControlHold.placed_by == TapirUser.user_id', backref='tapiruser_arXiv_control_holds_0')
+ tapir_users = relationship('TapirUser', foreign_keys=[last_changed_by], back_populates='arXiv_control_holds')
+ tapir_users_ = relationship('TapirUser', foreign_keys=[placed_by], back_populates='arXiv_control_holds_')
@@ -369,8 +424,8 @@ class CrossControl(Base):
publish_date: Mapped[int] = mapped_column(Integer, nullable=False, server_default=FetchedValue())
arXiv_category = relationship('Category', primaryjoin='and_(CrossControl.archive == Category.archive, CrossControl.subject_class == Category.subject_class)', backref='arXiv_cross_controls')
- document = relationship('Document', primaryjoin='CrossControl.document_id == Document.document_id', backref='arXiv_cross_controls')
- user = relationship('TapirUser', primaryjoin='CrossControl.user_id == TapirUser.user_id', backref='arXiv_cross_controls')
+ document = relationship('Document', primaryjoin='CrossControl.document_id == Document.document_id', back_populates='arXiv_cross_controls')
+ user = relationship('TapirUser', primaryjoin='CrossControl.user_id == TapirUser.user_id', back_populates='arXiv_cross_controls')
@@ -432,12 +487,13 @@ class Document(Base):
authors: Mapped[Optional[str]] = mapped_column(Text)
submitter_email: Mapped[str] = mapped_column(String(64), nullable=False, index=True, server_default=FetchedValue())
submitter_id: Mapped[Optional[int]] = mapped_column(ForeignKey('tapir_users.user_id'), index=True)
- dated: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=FetchedValue())
+ dated: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=text("'0'"))
primary_subject_class: Mapped[Optional[str]] = mapped_column(String(16))
created: Mapped[Optional[datetime]]
- submitter = relationship('TapirUser', primaryjoin='Document.submitter_id == TapirUser.user_id', backref='arXiv_documents')
-
+ owners = relationship('PaperOwner', back_populates='document')
+ submitter = relationship('TapirUser', primaryjoin='Document.submitter_id == TapirUser.user_id', back_populates='arXiv_documents')
+ arXiv_cross_controls = relationship('CrossControl', back_populates='document')
class DBLP(Document):
__tablename__ = 'arXiv_dblp'
@@ -490,8 +546,10 @@ class EndorsementRequest(Base):
issued_when: Mapped[int] = mapped_column(Integer, nullable=False, server_default=FetchedValue())
point_value: Mapped[int] = mapped_column(Integer, nullable=False, server_default=FetchedValue())
- arXiv_category = relationship('Category', primaryjoin='and_(EndorsementRequest.archive == Category.archive, EndorsementRequest.subject_class == Category.subject_class)', backref='arXiv_endorsement_requests')
- endorsee = relationship('TapirUser', primaryjoin='EndorsementRequest.endorsee_id == TapirUser.user_id', backref='arXiv_endorsement_requests')
+ arXiv_categories = relationship('Category', primaryjoin='and_(EndorsementRequest.archive == Category.archive, EndorsementRequest.subject_class == Category.subject_class)', back_populates='arXiv_endorsement_requests')
+ endorsee = relationship('TapirUser', primaryjoin='EndorsementRequest.endorsee_id == TapirUser.user_id', back_populates='arXiv_endorsement_requests', uselist=False)
+ endorsement = relationship('Endorsement', back_populates='request', uselist=False)
+ audit = relationship('EndorsementRequestsAudit', uselist=False)
class EndorsementRequestsAudit(EndorsementRequest):
@@ -524,10 +582,10 @@ class Endorsement(Base):
issued_when: Mapped[int] = mapped_column(Integer, nullable=False, server_default=FetchedValue())
request_id: Mapped[Optional[int]] = mapped_column(ForeignKey('arXiv_endorsement_requests.request_id'), index=True)
- arXiv_category = relationship('Category', primaryjoin='and_(Endorsement.archive == Category.archive, Endorsement.subject_class == Category.subject_class)', backref='arXiv_endorsements')
- endorsee = relationship('TapirUser', primaryjoin='Endorsement.endorsee_id == TapirUser.user_id', backref='tapiruser_arXiv_endorsements')
- endorser = relationship('TapirUser', primaryjoin='Endorsement.endorser_id == TapirUser.user_id', backref='tapiruser_arXiv_endorsements_0')
- request = relationship('EndorsementRequest', primaryjoin='Endorsement.request_id == EndorsementRequest.request_id', backref='arXiv_endorsements')
+ arXiv_categories = relationship('Category', primaryjoin='and_(Endorsement.archive == Category.archive, Endorsement.subject_class == Category.subject_class)', back_populates='arXiv_endorsements')
+ endorsee = relationship('TapirUser', primaryjoin='Endorsement.endorsee_id == TapirUser.user_id', back_populates='endorsee_of')
+ endorser = relationship('TapirUser', primaryjoin='Endorsement.endorser_id == TapirUser.user_id', back_populates='endorses')
+ request = relationship('EndorsementRequest', primaryjoin='Endorsement.request_id == EndorsementRequest.request_id', back_populates='endorsement')
class EndorsementsAudit(Endorsement):
@@ -598,7 +656,7 @@ class JrefControl(Base):
publish_date: Mapped[int] = mapped_column(Integer, nullable=False, server_default=FetchedValue())
document = relationship('Document', primaryjoin='JrefControl.document_id == Document.document_id', backref='arXiv_jref_controls')
- user = relationship('TapirUser', primaryjoin='JrefControl.user_id == TapirUser.user_id', backref='arXiv_jref_controls')
+ user = relationship('TapirUser', primaryjoin='JrefControl.user_id == TapirUser.user_id', back_populates='arXiv_jref_controls')
@@ -658,7 +716,7 @@ class Metadata(Base):
document = relationship('Document', primaryjoin='Metadata.document_id == Document.document_id', backref='arXiv_metadata')
arXiv_license = relationship('License', primaryjoin='Metadata.license == License.name', backref='arXiv_metadata')
- submitter = relationship('TapirUser', primaryjoin='Metadata.submitter_id == TapirUser.user_id', backref='arXiv_metadata')
+ submitter = relationship('TapirUser', primaryjoin='Metadata.submitter_id == TapirUser.user_id', back_populates='arXiv_metadata')
@@ -688,7 +746,7 @@ class ModeratorApiKey(Base):
issued_to: Mapped[str] = mapped_column(String(16), nullable=False, server_default=FetchedValue())
remote_host: Mapped[str] = mapped_column(String(255), nullable=False, server_default=FetchedValue())
- user = relationship('TapirUser', primaryjoin='ModeratorApiKey.user_id == TapirUser.user_id', backref='arXiv_moderator_api_keys')
+ user = relationship('TapirUser', primaryjoin='ModeratorApiKey.user_id == TapirUser.user_id', back_populates='arXiv_moderator_api_keys')
@@ -762,6 +820,12 @@ class OrcidConfig(Base):
value: Mapped[Optional[str]] = mapped_column(String(150))
+t_arXiv_ownership_requests_papers = Table(
+ 'arXiv_ownership_requests_papers', metadata,
+ Column('request_id', ForeignKey('arXiv_ownership_requests.request_id'), nullable=False, server_default=FetchedValue()),
+ Column('document_id', ForeignKey('arXiv_documents.document_id'), nullable=False, index=True, server_default=FetchedValue()),
+ Index('request_id', 'request_id', 'document_id', unique=True)
+)
class OwnershipRequest(Base):
__tablename__ = 'arXiv_ownership_requests'
@@ -771,11 +835,12 @@ class OwnershipRequest(Base):
endorsement_request_id: Mapped[Optional[int]] = mapped_column(ForeignKey('arXiv_endorsement_requests.request_id'), index=True)
workflow_status: Mapped[Literal['pending', 'accepted', 'rejected']] = mapped_column(Enum('pending', 'accepted', 'rejected'), nullable=False, server_default=FetchedValue())
+ request_audit = relationship('OwnershipRequestsAudit', back_populates='ownership_request', uselist=False)
endorsement_request = relationship('EndorsementRequest', primaryjoin='OwnershipRequest.endorsement_request_id == EndorsementRequest.request_id', backref='arXiv_ownership_requests')
- user = relationship('TapirUser', primaryjoin='OwnershipRequest.user_id == TapirUser.user_id', backref='arXiv_ownership_requests')
+ user = relationship('TapirUser', primaryjoin='OwnershipRequest.user_id == TapirUser.user_id', back_populates='arXiv_ownership_requests')
+ documents = relationship("Document", secondary=t_arXiv_ownership_requests_papers)
-
-class OwnershipRequestsAudit(OwnershipRequest):
+class OwnershipRequestsAudit(Base):
__tablename__ = 'arXiv_ownership_requests_audit'
request_id: Mapped[int] = mapped_column(ForeignKey('arXiv_ownership_requests.request_id'), primary_key=True, server_default=FetchedValue())
@@ -785,31 +850,33 @@ class OwnershipRequestsAudit(OwnershipRequest):
tracking_cookie: Mapped[str] = mapped_column(String(255), nullable=False, server_default=FetchedValue())
date: Mapped[int] = mapped_column(Integer, nullable=False, server_default=FetchedValue())
+ ownership_request = relationship('OwnershipRequest', primaryjoin='OwnershipRequestsAudit.request_id == OwnershipRequest.request_id', back_populates='request_audit', uselist=False)
-t_arXiv_ownership_requests_papers = Table(
- 'arXiv_ownership_requests_papers', metadata,
- Column('request_id', Integer, nullable=False, server_default=FetchedValue()),
- Column('document_id', Integer, nullable=False, index=True, server_default=FetchedValue()),
- Index('request_id', 'request_id', 'document_id')
-)
-
+class PaperOwner(Base):
+ __tablename__ = 'arXiv_paper_owners'
+ __table_args__ = (
+ ForeignKeyConstraint(['added_by'], ['tapir_users.user_id'], name='0_595'),
+ ForeignKeyConstraint(['document_id'], ['arXiv_documents.document_id'], name='0_593'),
+ ForeignKeyConstraint(['user_id'], ['tapir_users.user_id'], name='0_594'),
+ Index('added_by', 'added_by'),
+ # Index('document_id', 'document_id', 'user_id', unique=True),
+ # Index('user_id', 'user_id'),
+ )
+ document_id: Mapped[int] = mapped_column(ForeignKey('arXiv_documents.document_id'), primary_key=True)
+ user_id: Mapped[int] = mapped_column(ForeignKey('tapir_users.user_id'), primary_key=True)
+ date: Mapped[int] = mapped_column(Integer, nullable=False, server_default=text("'0'"))
+ added_by: Mapped[int] = mapped_column(Integer, nullable=False, server_default=text("'0'"))
+ remote_addr: Mapped[str] = mapped_column(String(16), nullable=False, server_default=text("''"))
+ remote_host: Mapped[str] = mapped_column(String(255), nullable=False, server_default=text("''"))
+ tracking_cookie: Mapped[str] = mapped_column(String(32), nullable=False, server_default=text("''"))
+ valid: Mapped[int] = mapped_column(SmallInteger, nullable=False, server_default=text("'0'"))
+ flag_author: Mapped[int] = mapped_column(SmallInteger, nullable=False, server_default=text("'0'"))
+ flag_auto: Mapped[int] = mapped_column(SmallInteger, nullable=False, server_default=text("'1'"))
-t_arXiv_paper_owners = Table(
- 'arXiv_paper_owners', metadata,
- Column('document_id', ForeignKey('arXiv_documents.document_id'), nullable=False, server_default=FetchedValue()),
- Column('user_id', ForeignKey('tapir_users.user_id'), nullable=False, index=True, server_default=FetchedValue()),
- Column('date', Integer, nullable=False, server_default=FetchedValue()),
- Column('added_by', ForeignKey('tapir_users.user_id'), nullable=False, index=True, server_default=FetchedValue()),
- Column('remote_addr', String(16), nullable=False, server_default=FetchedValue()),
- Column('remote_host', String(255), nullable=False, server_default=FetchedValue()),
- Column('tracking_cookie', String(32), nullable=False, server_default=FetchedValue()),
- Column('valid', Integer, nullable=False, server_default=FetchedValue()),
- Column('flag_author', Integer, nullable=False, server_default=FetchedValue()),
- Column('flag_auto', Integer, nullable=False, server_default=FetchedValue()),
- Index('owners_document_id', 'document_id', 'user_id')
-)
+ document = relationship('Document', back_populates='owners')
+ owner = relationship('TapirUser', foreign_keys="[PaperOwner.user_id]", back_populates='owned_papers')
class PaperSession(Base):
@@ -883,7 +950,7 @@ class ShowEmailRequest(Base):
request_id: Mapped[intpk]
document = relationship('Document', primaryjoin='ShowEmailRequest.document_id == Document.document_id', backref='arXiv_show_email_requests')
- user = relationship('TapirUser', primaryjoin='ShowEmailRequest.user_id == TapirUser.user_id', backref='arXiv_show_email_requests')
+ user = relationship('TapirUser', primaryjoin='ShowEmailRequest.user_id == TapirUser.user_id', back_populates='arXiv_show_email_requests')
@@ -963,7 +1030,7 @@ class SubmissionCategoryProposal(Base):
proposal_comment = relationship('AdminLog', primaryjoin='SubmissionCategoryProposal.proposal_comment_id == AdminLog.id', backref='arxivadminlog_arXiv_submission_category_proposals')
response_comment = relationship('AdminLog', primaryjoin='SubmissionCategoryProposal.response_comment_id == AdminLog.id', backref='arxivadminlog_arXiv_submission_category_proposals_0')
submission = relationship('Submission', primaryjoin='SubmissionCategoryProposal.submission_id == Submission.submission_id', backref='arXiv_submission_category_proposals')
- user = relationship('TapirUser', primaryjoin='SubmissionCategoryProposal.user_id == TapirUser.user_id', backref='arXiv_submission_category_proposals')
+ user = relationship('TapirUser', primaryjoin='SubmissionCategoryProposal.user_id == TapirUser.user_id', back_populates='arXiv_submission_category_proposal')
@@ -985,7 +1052,7 @@ class SubmissionControl(Base):
publish_date: Mapped[int] = mapped_column(Integer, nullable=False, server_default=FetchedValue())
document = relationship('Document', primaryjoin='SubmissionControl.document_id == Document.document_id', backref='arXiv_submission_controls')
- user = relationship('TapirUser', primaryjoin='SubmissionControl.user_id == TapirUser.user_id', backref='arXiv_submission_controls')
+ user = relationship('TapirUser', primaryjoin='SubmissionControl.user_id == TapirUser.user_id', back_populates='arXiv_submission_control')
@@ -1002,7 +1069,7 @@ class SubmissionFlag(Base):
updated: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=FetchedValue())
submission = relationship('Submission', primaryjoin='SubmissionFlag.submission_id == Submission.submission_id', backref='arXiv_submission_flags')
- user = relationship('TapirUser', primaryjoin='SubmissionFlag.user_id == TapirUser.user_id', backref='arXiv_submission_flags')
+ user = relationship('TapirUser', primaryjoin='SubmissionFlag.user_id == TapirUser.user_id', back_populates='arXiv_submission_flag')
@@ -1018,7 +1085,7 @@ class SubmissionHoldReason(Base):
comment = relationship('AdminLog', primaryjoin='SubmissionHoldReason.comment_id == AdminLog.id', backref='arXiv_submission_hold_reasons')
submission = relationship('Submission', primaryjoin='SubmissionHoldReason.submission_id == Submission.submission_id', backref='arXiv_submission_hold_reasons')
- user = relationship('TapirUser', primaryjoin='SubmissionHoldReason.user_id == TapirUser.user_id', backref='arXiv_submission_hold_reasons')
+ user = relationship('TapirUser', primaryjoin='SubmissionHoldReason.user_id == TapirUser.user_id', back_populates='arXiv_submission_hold_reason')
@@ -1061,7 +1128,7 @@ class SubmissionViewFlag(Base):
updated: Mapped[Optional[datetime]]
submission = relationship('Submission', primaryjoin='SubmissionViewFlag.submission_id == Submission.submission_id', backref='arXiv_submission_view_flags')
- user = relationship('TapirUser', primaryjoin='SubmissionViewFlag.user_id == TapirUser.user_id', backref='arXiv_submission_view_flags')
+ user = relationship('TapirUser', primaryjoin='SubmissionViewFlag.user_id == TapirUser.user_id', back_populates='arXiv_submission_view_flag')
@@ -1073,7 +1140,7 @@ class Submission(Base):
doc_paper_id: Mapped[Optional[str]] = mapped_column(String(20), index=True)
sword_id: Mapped[Optional[int]] = mapped_column(ForeignKey('arXiv_tracking.sword_id'), index=True)
userinfo: Mapped[Optional[int]] = mapped_column(Integer, server_default=FetchedValue())
- is_author: Mapped[int] = mapped_column(Integer, nullable=False, server_default=FetchedValue())
+ is_author: Mapped[int] = mapped_column(Integer, nullable=False, server_default=text("'0'"))
agree_policy: Mapped[Optional[int]] = mapped_column(Integer, server_default=FetchedValue())
viewed: Mapped[Optional[int]] = mapped_column(Integer, server_default=FetchedValue())
stage: Mapped[Optional[int]] = mapped_column(Integer, server_default=FetchedValue())
@@ -1082,7 +1149,7 @@ class Submission(Base):
submitter_email: Mapped[Optional[str]] = mapped_column(String(64))
created: Mapped[Optional[datetime]]
updated: Mapped[Optional[datetime]]
- status: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=FetchedValue())
+ status: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=text("'0'"))
sticky_status: Mapped[Optional[int]] = mapped_column(Integer)
must_process: Mapped[Optional[int]] = mapped_column(Integer, server_default=FetchedValue())
submit_time: Mapped[Optional[datetime]]
@@ -1091,7 +1158,7 @@ class Submission(Base):
source_format: Mapped[Optional[str]] = mapped_column(String(12))
source_flags: Mapped[Optional[str]] = mapped_column(String(12))
has_pilot_data: Mapped[Optional[int]] = mapped_column(Integer)
- is_withdrawn: Mapped[int] = mapped_column(Integer, nullable=False, server_default=FetchedValue())
+ is_withdrawn: Mapped[int] = mapped_column(Integer, nullable=False, server_default=text("'0'"))
title: Mapped[Optional[str]] = mapped_column(Text)
authors: Mapped[Optional[str]] = mapped_column(Text)
comments: Mapped[Optional[str]] = mapped_column(Text)
@@ -1103,7 +1170,7 @@ class Submission(Base):
doi: Mapped[Optional[str]]
abstract: Mapped[Optional[str]] = mapped_column(Text)
license: Mapped[Optional[str]] = mapped_column(ForeignKey('arXiv_licenses.name', onupdate='CASCADE'), index=True)
- version: Mapped[int] = mapped_column(Integer, nullable=False, server_default=FetchedValue())
+ version: Mapped[int] = mapped_column(Integer, nullable=False, server_default=text("'1'"))
type: Mapped[Optional[str]] = mapped_column(String(8), index=True)
is_ok: Mapped[Optional[int]] = mapped_column(Integer, index=True)
admin_ok: Mapped[Optional[int]] = mapped_column(Integer)
@@ -1114,13 +1181,13 @@ class Submission(Base):
package: Mapped[str] = mapped_column(String(255), nullable=False, server_default=FetchedValue())
rt_ticket_id: Mapped[Optional[int]] = mapped_column(Integer, index=True)
auto_hold: Mapped[Optional[int]] = mapped_column(Integer, server_default=FetchedValue())
- is_locked: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=FetchedValue())
+ is_locked: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=text("'0'"))
agreement_id = mapped_column(ForeignKey('arXiv_submission_agreements.agreement_id'), index=True)
agreement = relationship('SubmissionAgreement', primaryjoin='Submission.agreement_id == SubmissionAgreement.agreement_id', backref='arXiv_submissions')
document = relationship('Document', primaryjoin='Submission.document_id == Document.document_id', backref='arXiv_submissions')
arXiv_license = relationship('License', primaryjoin='Submission.license == License.name', backref='arXiv_submissions')
- submitter = relationship('TapirUser', primaryjoin='Submission.submitter_id == TapirUser.user_id', backref='arXiv_submissions')
+ submitter = relationship('TapirUser', primaryjoin='Submission.submitter_id == TapirUser.user_id', back_populates='arXiv_submissions')
sword = relationship('Tracking', primaryjoin='Submission.sword_id == Tracking.sword_id', backref='arXiv_submissions')
@@ -1137,17 +1204,15 @@ class PilotDataset(Submission):
last_checked: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=FetchedValue())
-class SubmissionAbsClassifierDatum(Submission):
+class SubmissionAbsClassifierDatum(Base):
__tablename__ = 'arXiv_submission_abs_classifier_data'
submission_id: Mapped[int] = mapped_column(ForeignKey('arXiv_submissions.submission_id', ondelete='CASCADE'), primary_key=True, server_default=FetchedValue())
json: Mapped[Optional[str]] = mapped_column(Text)
last_update: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=FetchedValue())
- # status: Mapped[Optional[Literal['processing', 'success', 'failed', 'no connection']]] = mapped_column(Enum('processing', 'success', 'failed', 'no connection'))
- # ^This column is inherited
+ status: Mapped[Optional[Literal['processing', 'success', 'failed', 'no connection']]] = mapped_column(Enum('processing', 'success', 'failed', 'no connection'))
message: Mapped[Optional[str]] = mapped_column(Text)
- # is_oversize: Mapped[Optional[int]] = mapped_column(Integer)
- # ^This column is inherited
+ is_oversize: Mapped[Optional[int]] = mapped_column(Integer)
suggested_primary: Mapped[Optional[str]] = mapped_column(Text)
suggested_reason: Mapped[Optional[str]] = mapped_column(Text)
autoproposal_primary: Mapped[Optional[str]] = mapped_column(Text)
@@ -1156,17 +1221,15 @@ class SubmissionAbsClassifierDatum(Submission):
classifier_model_version: Mapped[Optional[str]] = mapped_column(Text)
-class SubmissionClassifierDatum(Submission):
+class SubmissionClassifierDatum(Base):
__tablename__ = 'arXiv_submission_classifier_data'
submission_id: Mapped[int] = mapped_column(ForeignKey('arXiv_submissions.submission_id', ondelete='CASCADE'), primary_key=True, server_default=FetchedValue())
json: Mapped[Optional[str]] = mapped_column(Text)
last_update: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=FetchedValue())
- # status: Mapped[Optional[Literal['processing', 'success', 'failed', 'no connection']]] = mapped_column(Enum('processing', 'success', 'failed', 'no connection'))
- # ^This column is inherited
+ status: Mapped[Optional[Literal['processing', 'success', 'failed', 'no connection']]] = mapped_column(Enum('processing', 'success', 'failed', 'no connection'))
message: Mapped[Optional[str]] = mapped_column(Text)
- # is_oversize: Mapped[Optional[int]] = mapped_column(Integer)
- # ^This column is inherited
+ is_oversize: Mapped[Optional[int]] = mapped_column(Integer)
class SubmitterFlag(Base):
@@ -1412,7 +1475,7 @@ class TapirAddress(Base):
share_addr: Mapped[int] = mapped_column(Integer, nullable=False, server_default=FetchedValue())
tapir_country = relationship('TapirCountry', primaryjoin='TapirAddress.country == TapirCountry.digraph', backref='tapir_address')
- user = relationship('TapirUser', primaryjoin='TapirAddress.user_id == TapirUser.user_id', backref='tapir_address')
+ user = relationship('TapirUser', primaryjoin='TapirAddress.user_id == TapirUser.user_id', back_populates='tapir_address')
@@ -1431,8 +1494,8 @@ class TapirAdminAudit(Base):
comment: Mapped[str] = mapped_column(Text, nullable=False)
entry_id: Mapped[intpk]
- tapir_user = relationship('TapirUser', primaryjoin='TapirAdminAudit.admin_user == TapirUser.user_id', backref='tapiruser_tapir_admin_audits')
- tapir_user1 = relationship('TapirUser', primaryjoin='TapirAdminAudit.affected_user == TapirUser.user_id', backref='tapiruser_tapir_admin_audits_0')
+ tapir_users = relationship('TapirUser', foreign_keys=[admin_user], back_populates='tapir_admin_audit')
+ tapir_users_ = relationship('TapirUser', foreign_keys=[affected_user], back_populates='tapir_admin_audit_')
session = relationship('TapirSession', primaryjoin='TapirAdminAudit.session_id == TapirSession.session_id', backref='tapir_admin_audits')
@@ -1462,7 +1525,7 @@ class TapirEmailChangeToken(Base):
consumed_when: Mapped[Optional[int]] = mapped_column(Integer)
consumed_from: Mapped[Optional[str]] = mapped_column(String(16))
- user = relationship('TapirUser', primaryjoin='TapirEmailChangeToken.user_id == TapirUser.user_id', backref='tapir_email_change_tokens')
+ user = relationship('TapirUser', primaryjoin='TapirEmailChangeToken.user_id == TapirUser.user_id', back_populates='tapir_email_change_tokens')
t_tapir_email_change_tokens_used = Table(
@@ -1514,8 +1577,8 @@ class TapirEmailMailing(Base):
mailing_name: Mapped[Optional[str]]
comment: Mapped[Optional[str]] = mapped_column(Text)
- tapir_user = relationship('TapirUser', primaryjoin='TapirEmailMailing.created_by == TapirUser.user_id', backref='tapiruser_tapir_email_mailings')
- tapir_user1 = relationship('TapirUser', primaryjoin='TapirEmailMailing.sent_by == TapirUser.user_id', backref='tapiruser_tapir_email_mailings_0')
+ tapir_users = relationship('TapirUser', foreign_keys=[created_by], back_populates='tapir_email_mailings')
+ tapir_users_ = relationship('TapirUser', foreign_keys=[sent_by], back_populates='tapir_email_mailings_')
template = relationship('TapirEmailTemplate', primaryjoin='TapirEmailMailing.template_id == TapirEmailTemplate.template_id', backref='tapir_email_mailings')
@@ -1538,8 +1601,8 @@ class TapirEmailTemplate(Base):
workflow_status: Mapped[int] = mapped_column(Integer, nullable=False, server_default=FetchedValue())
flag_system: Mapped[int] = mapped_column(Integer, nullable=False, server_default=FetchedValue())
- tapir_user = relationship('TapirUser', primaryjoin='TapirEmailTemplate.created_by == TapirUser.user_id', backref='tapiruser_tapir_email_templates')
- tapir_user1 = relationship('TapirUser', primaryjoin='TapirEmailTemplate.updated_by == TapirUser.user_id', backref='tapiruser_tapir_email_templates_0')
+ tapir_users = relationship('TapirUser', foreign_keys=[created_by], back_populates='tapir_email_templates')
+ tapir_users_ = relationship('TapirUser', foreign_keys=[updated_by], back_populates='tapir_email_templates_')
@@ -1555,7 +1618,7 @@ class TapirEmailToken(Base):
tracking_cookie: Mapped[str] = mapped_column(String(255), nullable=False, server_default=FetchedValue())
wants_perm_token: Mapped[int] = mapped_column(Integer, nullable=False, server_default=FetchedValue())
- user = relationship('TapirUser', primaryjoin='TapirEmailToken.user_id == TapirUser.user_id', backref='tapir_email_tokens')
+ user = relationship('TapirUser', primaryjoin='TapirEmailToken.user_id == TapirUser.user_id', back_populates='tapir_email_tokens')
@@ -1603,13 +1666,13 @@ class TapirNickname(Base):
nick_id: Mapped[intpk]
nickname: Mapped[str] = mapped_column(String(20), nullable=False, unique=True, server_default=FetchedValue())
user_id: Mapped[int] = mapped_column(ForeignKey('tapir_users.user_id'), nullable=False, server_default=FetchedValue())
- user_seq: Mapped[int] = mapped_column(Integer, nullable=False, server_default=FetchedValue())
- flag_valid: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=FetchedValue())
- role: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=FetchedValue())
- policy: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=FetchedValue())
- flag_primary: Mapped[int] = mapped_column(Integer, nullable=False, server_default=FetchedValue())
+ user_seq: Mapped[int] = mapped_column(Integer, nullable=False, server_default=text("'0'"))
+ flag_valid: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=text("'0'"))
+ role: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=text("'0'"))
+ policy: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=text("'0'"))
+ flag_primary: Mapped[int] = mapped_column(Integer, nullable=False, server_default=text("'0'"))
- user = relationship('TapirUser', primaryjoin='TapirNickname.user_id == TapirUser.user_id', backref='tapir_nicknames')
+ user = relationship('TapirUser', primaryjoin='TapirNickname.user_id == TapirUser.user_id', back_populates='tapir_nicknames')
@@ -1656,7 +1719,7 @@ class TapirPermanentToken(Base):
session_id: Mapped[int] = mapped_column(ForeignKey('tapir_sessions.session_id'), nullable=False, index=True, server_default=FetchedValue())
session = relationship('TapirSession', primaryjoin='TapirPermanentToken.session_id == TapirSession.session_id', backref='tapir_permanent_tokens')
- user = relationship('TapirUser', primaryjoin='TapirPermanentToken.user_id == TapirUser.user_id', backref='tapir_permanent_tokens')
+ user = relationship('TapirUser', primaryjoin='TapirPermanentToken.user_id == TapirUser.user_id', back_populates='tapir_permanent_tokens')
@@ -1680,13 +1743,22 @@ class TapirPhone(Base):
phone_number: Mapped[Optional[str]] = mapped_column(String(32), index=True)
share_phone: Mapped[int] = mapped_column(Integer, nullable=False, server_default=FetchedValue())
- user = relationship('TapirUser', primaryjoin='TapirPhone.user_id == TapirUser.user_id', backref='tapir_phones')
+ user = relationship('TapirUser', primaryjoin='TapirPhone.user_id == TapirUser.user_id', back_populates='tapir_phone')
class TapirPolicyClass(Base):
__tablename__ = 'tapir_policy_classes'
+ ADMIN = 1
+ PUBLIC_USER = 2
+ LEGACY_USER = 3
+ POLICY_CLASSES = [
+ {"name": "Administrator", "class_id": ADMIN, "description": "", "password_storage": 2, "recovery_policy": 3, "permanent_login": 1},
+ {"name": "Public user", "class_id": PUBLIC_USER, "description": "", "password_storage": 2, "recovery_policy": 3, "permanent_login": 1},
+ {"name": "Legacy user", "class_id": LEGACY_USER, "description": "", "password_storage": 2, "recovery_policy": 3, "permanent_login": 1},
+ ]
+
class_id: Mapped[int] = mapped_column(SmallInteger, primary_key=True)
name: Mapped[str] = mapped_column(String(64), nullable=False, server_default=FetchedValue())
description: Mapped[str] = mapped_column(Text, nullable=False)
@@ -1694,7 +1766,7 @@ class TapirPolicyClass(Base):
recovery_policy: Mapped[int] = mapped_column(Integer, nullable=False, server_default=FetchedValue())
permanent_login: Mapped[int] = mapped_column(Integer, nullable=False, server_default=FetchedValue())
-
+ tapir_users = relationship('TapirUser', back_populates='tapir_policy_classes')
class TapirPresession(Base):
__tablename__ = 'tapir_presessions'
@@ -1719,7 +1791,7 @@ class TapirRecoveryToken(Base):
remote_host: Mapped[str] = mapped_column(String(255), nullable=False, server_default=FetchedValue())
tracking_cookie: Mapped[str] = mapped_column(String(255), nullable=False, server_default=FetchedValue())
- user = relationship('TapirUser', primaryjoin='TapirRecoveryToken.user_id == TapirUser.user_id', backref='tapir_recovery_tokens')
+ user = relationship('TapirUser', primaryjoin='TapirRecoveryToken.user_id == TapirUser.user_id', back_populates='tapir_recovery_tokens')
@@ -1734,7 +1806,7 @@ class TapirRecoveryTokensUsed(Base):
session_id: Mapped[Optional[int]] = mapped_column(ForeignKey('tapir_sessions.session_id'), index=True)
session = relationship('TapirSession', primaryjoin='TapirRecoveryTokensUsed.session_id == TapirSession.session_id', backref='tapir_recovery_tokens_useds')
- user = relationship('TapirUser', primaryjoin='TapirRecoveryTokensUsed.user_id == TapirUser.user_id', backref='tapir_recovery_tokens_useds')
+ user = relationship('TapirUser', primaryjoin='TapirRecoveryTokensUsed.user_id == TapirUser.user_id', back_populates='tapir_recovery_tokens_used')
@@ -1750,23 +1822,25 @@ class TapirRecoveryTokensUsed(Base):
class TapirSession(Base):
__tablename__ = 'tapir_sessions'
- session_id: Mapped[intpk]
- user_id: Mapped[int] = mapped_column(ForeignKey('tapir_users.user_id'), nullable=False, index=True, server_default=FetchedValue())
- last_reissue: Mapped[int] = mapped_column(Integer, nullable=False, server_default=FetchedValue())
- start_time: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=FetchedValue())
- end_time: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=FetchedValue())
+ session_id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
+ user_id: Mapped[int] = mapped_column(ForeignKey('tapir_users.user_id'), nullable=False, index=True, server_default=text("'0'"))
+ last_reissue: Mapped[int] = mapped_column(Integer, nullable=False, server_default=text("'0'"))
+ start_time: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=text("'0'"))
+ end_time: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=text("'0'"))
- user = relationship('TapirUser', primaryjoin='TapirSession.user_id == TapirUser.user_id', backref='tapir_sessions')
+ user = relationship('TapirUser', primaryjoin='TapirSession.user_id == TapirUser.user_id', back_populates='tapir_sessions')
-class TapirSessionsAudit(TapirSession):
+class TapirSessionsAudit(Base):
__tablename__ = 'tapir_sessions_audit'
- session_id: Mapped[int] = mapped_column(ForeignKey('tapir_sessions.session_id'), primary_key=True, server_default=FetchedValue())
+ session_id: Mapped[int] = mapped_column(ForeignKey('tapir_sessions.session_id'), primary_key=True, server_default=text("'0'"), autoincrement="false")
ip_addr: Mapped[str] = mapped_column(String(16), nullable=False, index=True, server_default=FetchedValue())
remote_host: Mapped[str] = mapped_column(String(255), nullable=False, server_default=FetchedValue())
tracking_cookie: Mapped[str] = mapped_column(String(255), nullable=False, index=True, server_default=FetchedValue())
+ session = relationship('TapirSession')
+
class TapirStringVariable(Base):
@@ -1789,37 +1863,92 @@ class TapirString(Base):
class TapirUser(Base):
__tablename__ = 'tapir_users'
+ __table_args__ = (
+ ForeignKeyConstraint(['policy_class'], ['tapir_policy_classes.class_id'], name='0_510'),
+ Index('email', 'email', unique=True),
+ Index('first_name', 'first_name'),
+ Index('flag_approved', 'flag_approved'),
+ Index('flag_banned', 'flag_banned'),
+ Index('flag_can_lock', 'flag_can_lock'),
+ Index('flag_deleted', 'flag_deleted'),
+ Index('flag_edit_users', 'flag_edit_users'),
+ Index('flag_internal', 'flag_internal'),
+ Index('joined_date', 'joined_date'),
+ Index('joined_ip_num', 'joined_ip_num'),
+ Index('last_name', 'last_name'),
+ Index('policy_class', 'policy_class'),
+ Index('tracking_cookie', 'tracking_cookie')
+ )
user_id: Mapped[intpk]
first_name: Mapped[Optional[str]] = mapped_column(String(50), index=True)
last_name: Mapped[Optional[str]] = mapped_column(String(50), index=True)
suffix_name: Mapped[Optional[str]] = mapped_column(String(50))
- share_first_name: Mapped[int] = mapped_column(Integer, nullable=False, server_default=FetchedValue())
- share_last_name: Mapped[int] = mapped_column(Integer, nullable=False, server_default=FetchedValue())
- email: Mapped[str] = mapped_column(String(255), nullable=False, unique=True, server_default=FetchedValue())
- share_email: Mapped[int] = mapped_column(Integer, nullable=False, server_default=FetchedValue())
- email_bouncing: Mapped[int] = mapped_column(Integer, nullable=False, server_default=FetchedValue())
- policy_class: Mapped[int] = mapped_column(ForeignKey('tapir_policy_classes.class_id'), nullable=False, index=True, server_default=FetchedValue())
- joined_date: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=FetchedValue())
+ share_first_name: Mapped[int] = mapped_column(Integer, nullable=False, server_default=text("'1'"))
+ share_last_name: Mapped[int] = mapped_column(Integer, nullable=False, server_default=text("'1'"))
+ email: Mapped[str] = mapped_column(String(255), nullable=False, unique=True, server_default=text("''"))
+ share_email: Mapped[int] = mapped_column(Integer, nullable=False, server_default=text("'8'"))
+ email_bouncing: Mapped[int] = mapped_column(Integer, nullable=False, server_default=text("'0'"))
+ policy_class: Mapped[int] = mapped_column(ForeignKey('tapir_policy_classes.class_id'), nullable=False, index=True, server_default=text("'0'"))
+ joined_date: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=text("'0'"))
joined_ip_num: Mapped[Optional[str]] = mapped_column(String(16), index=True)
- joined_remote_host: Mapped[str] = mapped_column(String(255), nullable=False, server_default=FetchedValue())
- flag_internal: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=FetchedValue())
- flag_edit_users: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=FetchedValue())
- flag_edit_system: Mapped[int] = mapped_column(Integer, nullable=False, server_default=FetchedValue())
- flag_email_verified: Mapped[int] = mapped_column(Integer, nullable=False, server_default=FetchedValue())
- flag_approved: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=FetchedValue())
- flag_deleted: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=FetchedValue())
- flag_banned: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=FetchedValue())
- flag_wants_email: Mapped[int] = mapped_column(Integer, nullable=False, server_default=FetchedValue())
- flag_html_email: Mapped[int] = mapped_column(Integer, nullable=False, server_default=FetchedValue())
- tracking_cookie: Mapped[str] = mapped_column(String(255), nullable=False, index=True, server_default=FetchedValue())
- flag_allow_tex_produced: Mapped[int] = mapped_column(Integer, nullable=False, server_default=FetchedValue())
- flag_can_lock: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=FetchedValue())
-
- tapir_policy_class = relationship('TapirPolicyClass', primaryjoin='TapirUser.policy_class == TapirPolicyClass.class_id', backref='tapir_users')
-
-
-class AuthorIds(TapirUser):
+ joined_remote_host: Mapped[str] = mapped_column(String(255), nullable=False, server_default=text("''"))
+ flag_internal: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=text("'0'"))
+ flag_edit_users: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=text("'0'"))
+ flag_edit_system: Mapped[int] = mapped_column(Integer, nullable=False, server_default=text("'0'"))
+ flag_email_verified: Mapped[int] = mapped_column(Integer, nullable=False, server_default=text("'0'"))
+ flag_approved: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=text("'1'"))
+ flag_deleted: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=text("'0'"))
+ flag_banned: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=text("'0'"))
+ flag_wants_email: Mapped[int] = mapped_column(Integer, nullable=False, server_default=text("'0'"))
+ flag_html_email: Mapped[int] = mapped_column(Integer, nullable=False, server_default=text("'0'"))
+ tracking_cookie: Mapped[str] = mapped_column(String(255), nullable=False, index=True, server_default=text("''"))
+ flag_allow_tex_produced: Mapped[int] = mapped_column(Integer, nullable=False, server_default=text("'0'"))
+ flag_can_lock: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=text("'0'"))
+
+ tapir_policy_classes = relationship('TapirPolicyClass', back_populates='tapir_users')
+ arXiv_control_holds = relationship('ControlHold', foreign_keys='[ControlHold.last_changed_by]', back_populates='tapir_users')
+ arXiv_control_holds_ = relationship('ControlHold', foreign_keys='[ControlHold.placed_by]', back_populates='tapir_users_')
+ arXiv_documents = relationship('Document', back_populates='submitter')
+ arXiv_moderator_api_keys = relationship('ModeratorApiKey', back_populates='user')
+ tapir_address = relationship('TapirAddress', back_populates='user')
+ tapir_email_change_tokens = relationship('TapirEmailChangeToken', back_populates='user')
+ tapir_email_templates = relationship('TapirEmailTemplate', foreign_keys='[TapirEmailTemplate.created_by]', back_populates='tapir_users')
+ tapir_email_templates_ = relationship('TapirEmailTemplate', foreign_keys='[TapirEmailTemplate.updated_by]', back_populates='tapir_users_')
+ tapir_email_tokens = relationship('TapirEmailToken', back_populates='user')
+ tapir_nicknames = relationship('TapirNickname', back_populates='user', uselist=False)
+ tapir_phone = relationship('TapirPhone', back_populates='user')
+ tapir_recovery_tokens = relationship('TapirRecoveryToken', back_populates='user')
+ tapir_sessions = relationship('TapirSession', back_populates='user')
+ arXiv_cross_controls = relationship('CrossControl', back_populates='user')
+ arXiv_endorsement_requests = relationship('EndorsementRequest', back_populates='endorsee')
+ arXiv_jref_controls = relationship('JrefControl', back_populates='user')
+ arXiv_metadata = relationship('Metadata', back_populates='submitter')
+ arXiv_show_email_requests = relationship('ShowEmailRequest', back_populates='user')
+ arXiv_submission_control = relationship('SubmissionControl', back_populates='user')
+ arXiv_submissions = relationship('Submission', back_populates='submitter')
+ tapir_admin_audit = relationship('TapirAdminAudit', foreign_keys='[TapirAdminAudit.admin_user]', back_populates='tapir_users')
+ tapir_admin_audit_ = relationship('TapirAdminAudit', foreign_keys='[TapirAdminAudit.affected_user]', back_populates='tapir_users_')
+ tapir_email_mailings = relationship('TapirEmailMailing', foreign_keys='[TapirEmailMailing.created_by]', back_populates='tapir_users')
+ tapir_email_mailings_ = relationship('TapirEmailMailing', foreign_keys='[TapirEmailMailing.sent_by]', back_populates='tapir_users_')
+ tapir_permanent_tokens = relationship('TapirPermanentToken', back_populates='user')
+ tapir_recovery_tokens_used = relationship('TapirRecoveryTokensUsed', back_populates='user')
+
+ endorsee_of = relationship('Endorsement', foreign_keys='[Endorsement.endorsee_id]', back_populates='endorsee')
+ endorses = relationship('Endorsement', foreign_keys='[Endorsement.endorser_id]', back_populates='endorser')
+
+ arXiv_ownership_requests = relationship('OwnershipRequest', back_populates='user')
+ arXiv_submission_category_proposal = relationship('SubmissionCategoryProposal', back_populates='user')
+ arXiv_submission_flag = relationship('SubmissionFlag', back_populates='user')
+ arXiv_submission_hold_reason = relationship('SubmissionHoldReason', back_populates='user')
+ arXiv_submission_view_flag = relationship('SubmissionViewFlag', back_populates='user')
+
+ owned_papers = relationship("PaperOwner", foreign_keys="[PaperOwner.user_id]", back_populates="owner")
+
+ demographics = relationship('Demographic', foreign_keys="[Demographic.user_id]", uselist=False, back_populates='user')
+
+
+class AuthorIds(Base):
__tablename__ = 'arXiv_author_ids'
user_id: Mapped[int] = mapped_column(ForeignKey('tapir_users.user_id'), primary_key=True)
@@ -1827,7 +1956,7 @@ class AuthorIds(TapirUser):
updated: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=FetchedValue())
-class Demographic(TapirUser):
+class Demographic(Base):
__tablename__ = 'arXiv_demographics'
__table_args__ = (
ForeignKeyConstraint(['archive', 'subject_class'], ['arXiv_categories.archive', 'arXiv_categories.subject_class']),
@@ -1843,26 +1972,44 @@ class Demographic(TapirUser):
subject_class: Mapped[Optional[str]] = mapped_column(String(16))
original_subject_classes: Mapped[str] = mapped_column(String(255), nullable=False, server_default=FetchedValue())
flag_group_physics: Mapped[Optional[int]] = mapped_column(Integer, index=True)
- flag_group_math: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=FetchedValue())
- flag_group_cs: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=FetchedValue())
- flag_group_nlin: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=FetchedValue())
- flag_proxy: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=FetchedValue())
- flag_journal: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=FetchedValue())
- flag_xml: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=FetchedValue())
- dirty: Mapped[int] = mapped_column(Integer, nullable=False, server_default=FetchedValue())
- flag_group_test: Mapped[int] = mapped_column(Integer, nullable=False, server_default=FetchedValue())
- flag_suspect: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=FetchedValue())
- flag_group_q_bio: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=FetchedValue())
- flag_group_q_fin: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=FetchedValue())
- flag_group_stat: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=FetchedValue())
- flag_group_eess: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=FetchedValue())
- flag_group_econ: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=FetchedValue())
- veto_status: Mapped[Literal['ok', 'no-endorse', 'no-upload', 'no-replace']] = mapped_column(Enum('ok', 'no-endorse', 'no-upload', 'no-replace'), nullable=False, server_default=FetchedValue())
-
+ flag_group_math: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=text("'0'"))
+ flag_group_cs: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=text("'0'"))
+ flag_group_nlin: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=text("'0'"))
+ flag_proxy: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=text("'0'"))
+ flag_journal: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=text("'0'"))
+ flag_xml: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=text("'0'"))
+ dirty: Mapped[int] = mapped_column(Integer, nullable=False, server_default=text("'0'"))
+ flag_group_test: Mapped[int] = mapped_column(Integer, nullable=False, server_default=text("'0'"))
+ flag_suspect: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=text("'0'"))
+ flag_group_q_bio: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=text("'0'"))
+ flag_group_q_fin: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=text("'0'"))
+ flag_group_stat: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=text("'0'"))
+ flag_group_eess: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=text("'0'"))
+ flag_group_econ: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=text("'0'"))
+ veto_status: Mapped[Literal['ok', 'no-endorse', 'no-upload', 'no-replace']] = mapped_column(Enum('ok', 'no-endorse', 'no-upload', 'no-replace'), nullable=False, server_default=text("'ok'"))
+
+ user = relationship('TapirUser', back_populates='demographics')
arXiv_category = relationship('Category', primaryjoin='and_(Demographic.archive == Category.archive, Demographic.subject_class == Category.subject_class)', backref='arXiv_demographics')
+ GROUP_FLAGS = [
+ ('grp_physics', 'flag_group_physics'),
+ ('grp_math', 'flag_group_math'),
+ ('grp_cs', 'flag_group_cs'),
+ ('grp_q-bio', 'flag_group_q_bio'),
+ ('grp_q-fin', 'flag_group_q_fin'),
+ ('grp_q-stat', 'flag_group_stat'),
+ ('grp_q-econ', 'flag_group_econ'),
+ ('grp_eess', 'flag_group_eess'),
+ ]
-class OrcidIds(TapirUser):
+ @property
+ def groups(self) -> List[str]:
+ """Active groups for this user profile."""
+ return [group for group, column in self.GROUP_FLAGS
+ if getattr(self, column) == 1]
+
+
+class OrcidIds(Base):
__tablename__ = 'arXiv_orcid_ids'
user_id: Mapped[int] = mapped_column(ForeignKey('tapir_users.user_id'), primary_key=True)
@@ -1871,7 +2018,7 @@ class OrcidIds(TapirUser):
updated: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=FetchedValue())
-class QueueView(TapirUser):
+class QueueView(Base):
__tablename__ = 'arXiv_queue_view'
user_id: Mapped[int] = mapped_column(ForeignKey('tapir_users.user_id', ondelete='CASCADE'), primary_key=True, server_default=FetchedValue())
@@ -1880,14 +2027,14 @@ class QueueView(TapirUser):
total_views: Mapped[int] = mapped_column(Integer, nullable=False, server_default=FetchedValue())
-class SuspiciousName(TapirUser):
+class SuspiciousName(Base):
__tablename__ = 'arXiv_suspicious_names'
user_id: Mapped[int] = mapped_column(ForeignKey('tapir_users.user_id'), primary_key=True, server_default=FetchedValue())
full_name: Mapped[str] = mapped_column(String(255), nullable=False, server_default=FetchedValue())
-class SwordLicense(TapirUser):
+class SwordLicense(Base):
__tablename__ = 'arXiv_sword_licenses'
user_id: Mapped[int] = mapped_column(ForeignKey('tapir_users.user_id'), primary_key=True)
@@ -1895,7 +2042,7 @@ class SwordLicense(TapirUser):
updated: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=FetchedValue())
-class TapirDemographic(TapirUser):
+class TapirDemographic(Base):
__tablename__ = 'tapir_demographics'
user_id: Mapped[int] = mapped_column(ForeignKey('tapir_users.user_id'), primary_key=True, server_default=FetchedValue())
@@ -1910,7 +2057,7 @@ class TapirDemographic(TapirUser):
tapir_country = relationship('TapirCountry', primaryjoin='TapirDemographic.country == TapirCountry.digraph', backref='tapir_demographics')
-class TapirUsersHot(TapirUser):
+class TapirUsersHot(Base):
__tablename__ = 'tapir_users_hot'
user_id: Mapped[int] = mapped_column(ForeignKey('tapir_users.user_id'), primary_key=True, server_default=FetchedValue())
@@ -1919,13 +2066,16 @@ class TapirUsersHot(TapirUser):
number_sessions: Mapped[int] = mapped_column(Integer, nullable=False, index=True, server_default=FetchedValue())
-class TapirUsersPassword(TapirUser):
+class TapirUsersPassword(Base):
__tablename__ = 'tapir_users_password'
user_id: Mapped[int] = mapped_column(ForeignKey('tapir_users.user_id'), primary_key=True, server_default=FetchedValue())
password_storage: Mapped[int] = mapped_column(Integer, nullable=False, server_default=FetchedValue())
password_enc: Mapped[str] = mapped_column(String(50), nullable=False, server_default=FetchedValue())
+ user = relationship('TapirUser')
+
+
class DBLaTeXMLDocuments(LaTeXMLBase):
__tablename__ = 'arXiv_latexml_doc'
@@ -1970,31 +2120,3 @@ class DBLaTeXMLFeedback (LaTeXMLBase):
selected_html: Mapped[Optional[str]]
initiation_mode: Mapped[Optional[str]]
-SessionLocal.configure(binds={
- Base: engine,
- LaTeXMLBase: latexml_engine,
- t_arXiv_stats_hourly: engine,
- t_arXiv_admin_state: engine,
- t_arXiv_bad_pw: engine,
- t_arXiv_black_email: engine,
- t_arXiv_block_email: engine,
- t_arXiv_bogus_subject_class: engine,
- t_arXiv_duplicates: engine,
- t_arXiv_in_category: engine,
- t_arXiv_moderators: engine,
- t_arXiv_ownership_requests_papers: engine,
- t_arXiv_refresh_list: engine,
- t_arXiv_paper_owners: engine,
- t_arXiv_updates_tmp: engine,
- t_arXiv_white_email: engine,
- t_arXiv_xml_notifications: engine,
- t_demographics_backup: engine,
- t_tapir_email_change_tokens_used: engine,
- t_tapir_email_tokens_used: engine,
- t_tapir_error_log: engine,
- t_tapir_no_cookies: engine,
- t_tapir_periodic_tasks_log: engine,
- t_tapir_periodic_tasks_log: engine,
- t_tapir_permanent_tokens_used: engine,
- t_tapir_save_post_variables: engine
-})
\ No newline at end of file
diff --git a/arxiv/dev_environment/build.sh b/arxiv/dev_environment/build.sh
deleted file mode 100755
index 6b0f3397..00000000
--- a/arxiv/dev_environment/build.sh
+++ /dev/null
@@ -1,3 +0,0 @@
-docker build -t classic_db_img -f Dockerfile.classic_db
-terraform init
-terraform plan
\ No newline at end of file
diff --git a/arxiv/dev_environment/classic_db/Dockerfile b/arxiv/dev_environment/classic_db/Dockerfile
deleted file mode 100644
index 63b5d17a..00000000
--- a/arxiv/dev_environment/classic_db/Dockerfile
+++ /dev/null
@@ -1,17 +0,0 @@
-FROM mysql:8.0.36-debian
-
-######## Install Python 3.11 ########
-
-RUN apt install software-properties-common -y
-RUN add-apt-repository ppa:deadsnakes/ppa
-RUN apt-get update -y
-RUN apt install python3.11 python3-pip -y
-
-###### Install Cloud SQL Proxy ######
-
-WORKDIR /cloudsql
-
-RUN apt-get install -y curl
-
-RUN curl -o cloud-sql-proxy https://storage.googleapis.com/cloud-sql-connectors/cloud-sql-proxy/v2.10.1/cloud-sql-proxy.linux.amd64
-RUN chmod +x cloud-sql-proxy
diff --git a/arxiv/dev_environment/classic_db/wrapper.sh b/arxiv/dev_environment/classic_db/wrapper.sh
deleted file mode 100644
index d6660106..00000000
--- a/arxiv/dev_environment/classic_db/wrapper.sh
+++ /dev/null
@@ -1,5 +0,0 @@
-#!/bin/sh
-/cloudsql/cloud-sql-proxy arxiv-development:us-central1:latexml-db -u /cloudsql &
-/cloudsql/cloud-sql-proxy arxiv-development:us-east4:arxiv-db-dev -u /cloudsql &
-
-gunicorn --bind 0.0.0.0:8000 -t 600 -w 12 --threads 2 entry_point:app
diff --git a/arxiv/dev_environment/main.tf b/arxiv/dev_environment/main.tf
deleted file mode 100644
index bf98a41e..00000000
--- a/arxiv/dev_environment/main.tf
+++ /dev/null
@@ -1,29 +0,0 @@
-terraform {
- required_providers {
- google = {
- source = "hashicorp/google"
- version = "4.51.0"
- }
- }
-}
-
-provider "google" {
- project = "arxiv-development"
-}
-
-provider "docker" {
-
-}
-
-data "google_secret_manager_secret_version" "CLASSIC_DB_URI" {
- secret = "browse-sqlalchemy-db-uri"
- version = "latest"
-}
-
-resource "docker_container" "classic_db" {
- image = "classic_db_img"
- name = "classic_db"
- env = [
- "CLASSIC_DB_URI=${data.google_secret_manager_secret_version.CLASSIC_DB_URI.secret_data}"
- ]
-}
\ No newline at end of file
diff --git a/arxiv/taxonomy/__init__.py b/arxiv/taxonomy/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/poetry.lock b/poetry.lock
index e827c138..5cafd073 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -1,4 +1,4 @@
-# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
+# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
[[package]]
name = "alabaster"
@@ -321,6 +321,17 @@ files = [
[package.extras]
toml = ["tomli"]
+[[package]]
+name = "decorator"
+version = "5.1.1"
+description = "Decorators for Humans"
+optional = false
+python-versions = ">=3.5"
+files = [
+ {file = "decorator-5.1.1-py3-none-any.whl", hash = "sha256:b8c3f85900b9dc423225913c5aace94729fe1fa9763b38939a95226f02d37186"},
+ {file = "decorator-5.1.1.tar.gz", hash = "sha256:637996211036b6385ef91435e4fae22989472f9d571faba8927ba8253acbc330"},
+]
+
[[package]]
name = "docutils"
version = "0.21.2"
@@ -332,6 +343,25 @@ files = [
{file = "docutils-0.21.2.tar.gz", hash = "sha256:3a6b18732edf182daa3cd12775bbb338cf5691468f91eeeb109deff6ebfa986f"},
]
+[[package]]
+name = "fakeredis"
+version = "2.7.1"
+description = "Fake implementation of redis API for testing purposes."
+optional = false
+python-versions = ">=3.8.1,<4.0"
+files = [
+ {file = "fakeredis-2.7.1-py3-none-any.whl", hash = "sha256:e2f7a88dad23be1191ad6212008e170d75c2d63dde979c2694be8cbfd917428e"},
+ {file = "fakeredis-2.7.1.tar.gz", hash = "sha256:854d758794dab9953be16b9a0f7fbd4bbd6b6964db7a9684e163291c1342ece6"},
+]
+
+[package.dependencies]
+redis = "<4.5"
+sortedcontainers = ">=2.4.0,<3.0.0"
+
+[package.extras]
+json = ["jsonpath-ng (>=1.5,<2.0)"]
+lua = ["lupa (>=1.14,<2.0)"]
+
[[package]]
name = "fastly"
version = "5.2.0"
@@ -1029,6 +1059,21 @@ files = [
{file = "MarkupSafe-2.1.5.tar.gz", hash = "sha256:d283d37a890ba4c1ae73ffadf8046435c76e7bc2247bbb63c00bd1a709c6544b"},
]
+[[package]]
+name = "mimesis"
+version = "17.0.0"
+description = "Mimesis: Fake Data Generator."
+optional = false
+python-versions = "<4.0,>=3.10"
+files = [
+ {file = "mimesis-17.0.0-py3-none-any.whl", hash = "sha256:a088a14075c6d0356fea15e7687afb8f900a412daa80f2f3fe20f1771532402f"},
+ {file = "mimesis-17.0.0.tar.gz", hash = "sha256:57fd7c2762c668054f2ebb85f71d484fa2fd55647d2a02bac4f2e97b81f22d8d"},
+]
+
+[package.extras]
+factory = ["factory-boy (>=3.3.0,<4.0.0)"]
+pytest = ["pytest (>=7.2,<8.0)"]
+
[[package]]
name = "mypy"
version = "1.10.0"
@@ -1167,6 +1212,17 @@ files = [
{file = "protobuf-4.25.3.tar.gz", hash = "sha256:25b5d0b42fd000320bd7830b349e3b696435f3b329810427a6bcce6a5492cc5c"},
]
+[[package]]
+name = "py"
+version = "1.11.0"
+description = "library with cross-python path, ini-parsing, io, code, log facilities"
+optional = false
+python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
+files = [
+ {file = "py-1.11.0-py2.py3-none-any.whl", hash = "sha256:607c53218732647dff4acdfcd50cb62615cedf612e72d1724fb1a0cc6405b378"},
+ {file = "py-1.11.0.tar.gz", hash = "sha256:51c75c4126074b472f746a24399ad32f6053d1b34b68d2fa41e558e6f4a98719"},
+]
+
[[package]]
name = "pyasn1"
version = "0.6.0"
@@ -1275,6 +1331,23 @@ files = [
[package.extras]
windows-terminal = ["colorama (>=0.4.6)"]
+[[package]]
+name = "pyjwt"
+version = "2.8.0"
+description = "JSON Web Token implementation in Python"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "PyJWT-2.8.0-py3-none-any.whl", hash = "sha256:59127c392cc44c2da5bb3192169a91f429924e17aff6534d70fdc02ab3e04320"},
+ {file = "PyJWT-2.8.0.tar.gz", hash = "sha256:57e28d156e3d5c10088e0c68abb90bfac3df82b40a71bd0daa20c65ccd5c23de"},
+]
+
+[package.extras]
+crypto = ["cryptography (>=3.4.0)"]
+dev = ["coverage[toml] (==5.0.4)", "cryptography (>=3.4.0)", "pre-commit", "pytest (>=6.0.0,<7.0.0)", "sphinx (>=4.5.0,<5.0.0)", "sphinx-rtd-theme", "zope.interface"]
+docs = ["sphinx (>=4.5.0,<5.0.0)", "sphinx-rtd-theme", "zope.interface"]
+tests = ["coverage[toml] (==5.0.4)", "pytest (>=6.0.0,<7.0.0)"]
+
[[package]]
name = "pytest"
version = "8.2.1"
@@ -1313,6 +1386,23 @@ pytest = ">=4.6"
[package.extras]
testing = ["fields", "hunter", "process-tests", "pytest-xdist", "virtualenv"]
+[[package]]
+name = "pytest-mock"
+version = "3.14.0"
+description = "Thin-wrapper around the mock package for easier use with pytest"
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "pytest-mock-3.14.0.tar.gz", hash = "sha256:2719255a1efeceadbc056d6bf3df3d1c5015530fb40cf347c0f9afac88410bd0"},
+ {file = "pytest_mock-3.14.0-py3-none-any.whl", hash = "sha256:0b72c38033392a5f4621342fe11e9219ac11ec9d375f8e2a0c164539e0d70f6f"},
+]
+
+[package.dependencies]
+pytest = ">=6.2.5"
+
+[package.extras]
+dev = ["pre-commit", "pytest-asyncio", "tox"]
+
[[package]]
name = "python-dateutil"
version = "2.9.0.post0"
@@ -1338,6 +1428,30 @@ files = [
{file = "pytz-2024.1.tar.gz", hash = "sha256:2a29735ea9c18baf14b448846bde5a48030ed267578472d8955cd0e7443a9812"},
]
+[[package]]
+name = "redis"
+version = "2.10.6"
+description = "Python client for Redis key-value store"
+optional = false
+python-versions = "*"
+files = [
+ {file = "redis-2.10.6-py2.py3-none-any.whl", hash = "sha256:8a1900a9f2a0a44ecf6e8b5eb3e967a9909dfed219ad66df094f27f7d6f330fb"},
+ {file = "redis-2.10.6.tar.gz", hash = "sha256:a22ca993cea2962dbb588f9f30d0015ac4afcc45bee27d3978c0dbe9e97c6c0f"},
+]
+
+[[package]]
+name = "redis-py-cluster"
+version = "1.3.6"
+description = "Library for communicating with Redis Clusters. Built on top of redis-py lib"
+optional = false
+python-versions = "*"
+files = [
+ {file = "redis-py-cluster-1.3.6.tar.gz", hash = "sha256:7db54b1de60bd34da3806676b112f07fc9afae556d8260ac02c3335d574ee42c"},
+]
+
+[package.dependencies]
+redis = "2.10.6"
+
[[package]]
name = "requests"
version = "2.32.0"
@@ -1359,6 +1473,21 @@ urllib3 = ">=1.21.1,<3"
socks = ["PySocks (>=1.5.6,!=1.5.7)"]
use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"]
+[[package]]
+name = "retry"
+version = "0.9.2"
+description = "Easy to use retry decorator."
+optional = false
+python-versions = "*"
+files = [
+ {file = "retry-0.9.2-py2.py3-none-any.whl", hash = "sha256:ccddf89761fa2c726ab29391837d4327f819ea14d244c232a1d24c67a2f98606"},
+ {file = "retry-0.9.2.tar.gz", hash = "sha256:f8bfa8b99b69c4506d6f5bd3b0aabf77f98cdb17f3c9fc3f5ca820033336fba4"},
+]
+
+[package.dependencies]
+decorator = ">=3.4.2"
+py = ">=1.4.26,<2.0.0"
+
[[package]]
name = "rsa"
version = "4.9"
@@ -1390,6 +1519,21 @@ botocore = ">=1.33.2,<2.0a.0"
[package.extras]
crt = ["botocore[crt] (>=1.33.2,<2.0a.0)"]
+[[package]]
+name = "setuptools"
+version = "70.1.0"
+description = "Easily download, build, install, upgrade, and uninstall Python packages"
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "setuptools-70.1.0-py3-none-any.whl", hash = "sha256:d9b8b771455a97c8a9f3ab3448ebe0b29b5e105f1228bba41028be116985a267"},
+ {file = "setuptools-70.1.0.tar.gz", hash = "sha256:01a1e793faa5bd89abc851fa15d0a0db26f160890c7102cd8dce643e886b47f5"},
+]
+
+[package.extras]
+docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "pyproject-hooks (!=1.1)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier"]
+testing = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "importlib-metadata", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "jaraco.test", "mypy (==1.10.0)", "packaging (>=23.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.1)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-home (>=0.5)", "pytest-mypy", "pytest-perf", "pytest-ruff (>=0.3.2)", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"]
+
[[package]]
name = "six"
version = "1.16.0"
@@ -1782,4 +1926,4 @@ sphinx = ["sphinx", "sphinx-autodoc-typehints", "sphinxcontrib-websupport"]
[metadata]
lock-version = "2.0"
python-versions = "^3.11"
-content-hash = "7c3918d253eb10e738c8cf47af4235054785932dbeaa4c4447bc374ad10b5962"
+content-hash = "c355759252de731ba6f40ffceb152166a7873941532338ee6ba6aa4eb66594cc"
diff --git a/pyproject.toml b/pyproject.toml
index 5fa8f35f..7fc0b850 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -37,13 +37,21 @@ validators = "*"
sphinx = { version = "*", optional = true }
sphinxcontrib-websupport = { version = "*", optional = true }
sphinx-autodoc-typehints = { version = "*", optional = true }
+mimesis = "*"
+retry = "^0.9.2"
+pyjwt = "*"
+redis = "==2.10.6"
+redis-py-cluster = "==1.3.6"
+setuptools = "^70.0.0"
[tool.poetry.dev-dependencies]
pydocstyle = "*"
mypy = "*"
pytest = "*"
+pytest-mock = "^3.8.2"
pytest-cov = "*"
hypothesis = "*"
+fakeredis = "*"
click = "*"
diff --git a/test.db-journal b/test.db-journal
new file mode 100644
index 00000000..526187cd
Binary files /dev/null and b/test.db-journal differ