diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml deleted file mode 100644 index 4ce7a2c..0000000 --- a/.github/workflows/pytest.yml +++ /dev/null @@ -1,32 +0,0 @@ -# This workflow will install Python dependencies, run tests and lint with a single version of Python -# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python - -name: VATcomply tests - -on: - push: - branches: [ "master" ] - pull_request: - branches: [ "master" ] - -permissions: - contents: read - -jobs: - build: - - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v3 - - name: Set up Python 3.11 - uses: actions/setup-python@v3 - with: - python-version: "3.11" - - name: Install dependencies - run: | - python -m pip install --upgrade pip - if [ -f requirements.txt ]; then pip install -r requirements.txt; fi - - name: Test with pytest - run: | - pytest diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml new file mode 100644 index 0000000..93e665d --- /dev/null +++ b/.github/workflows/tests.yml @@ -0,0 +1,29 @@ +name: VATcomply tests + +on: + push: + branches: ["master"] + pull_request: + branches: ["master"] + +permissions: + contents: read + +jobs: + build: + + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v3 + - name: Set up Python 3.11 + uses: actions/setup-python@v3 + with: + python-version: "3.11" + - name: Install Dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + - name: Run Tests + run: | + python manage.py test diff --git a/LICENSE b/LICENSE index 8a1effd..2764648 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ MIT License -Copyright (c) 2021 Madis Väin +Copyright (c) 2023 Madis Väin Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/Makefile b/Makefile index 132489d..74a93d1 100644 --- a/Makefile +++ b/Makefile @@ -1,11 +1,14 @@ run: - uvicorn app:app --reload + export DEBUG=True; uvicorn vatcomply.asgi:application --reload pip: pip install -r requirements.in --upgrade migrate: - PYTHONPATH=.:$PYTHONPATH alembic upgrade head + python manage.py migrate + +migrations: + python manage.py makemigrations test: - export TESTING=True; pytest -s --disable-warnings + python manage.py test --keepdb \ No newline at end of file diff --git a/Procfile b/Procfile deleted file mode 100644 index 3064a7d..0000000 --- a/Procfile +++ /dev/null @@ -1,2 +0,0 @@ -web: gunicorn app:app --worker-class uvicorn.workers.UvicornWorker --workers ${WEB_CONCURRENCY} --max-requests 1000 -release: alembic upgrade head \ No newline at end of file diff --git a/README.md b/README.md index 5a45544..4b1ba40 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # VATcomply -VATcomply is a free API service for vat number validation, user ip geolocation and foreign exchange rates [published by the European Central Bank](https://www.ecb.europa.eu/stats/policy_and_exchange_rates/euro_reference_exchange_rates/html/index.en.html). +[VATcomply](https://www.vatcomply.com) is a free API service for vat number validation, user ip geolocation and foreign exchange rates [published by the European Central Bank](https://www.ecb.europa.eu/stats/policy_and_exchange_rates/euro_reference_exchange_rates/html/index.en.html). ## Usage @@ -12,12 +12,6 @@ Get the latest foreign exchange rates. GET /rates ``` -Get historical rates for any day since 1999. - -```http -GET /rates/2018-03-26 -``` - Rates are quoted against the Euro by default. Quote against a different currency by setting the base parameter in your request. ```http @@ -30,7 +24,7 @@ Request specific exchange rates by setting the symbols parameter. GET /rates?symbols=USD,GBP ``` -#### Rates history +#### Rates history and query parameters combinations Get historical rates for a date @@ -60,7 +54,7 @@ let demo = () => { alert("£1 = $" + rate.toFixed(4)); }; -fetch("https://api.vatcomply.com/latest") +fetch("https://api.vatcomply.com/rates") .then((resp) => resp.json()) .then((data) => (fx.rates = data.rates)) .then(demo); @@ -68,14 +62,13 @@ fetch("https://api.vatcomply.com/latest") ## Stack -VATcomply API is built upon [Starlette](https://github.com/encode/starlette) the asyncronous Python framework to achieve high throughput. The current setup can asyncronously handle thousands of requests per second. +VATcomply API is built upon [Django](https://www.djangoproject.com/) with asyncronous views, [Pydantic](https://docs.pydantic.dev/latest/) and asyncronous ORM queries to achieve high throughput. The current setup can asyncronously handle thousands of requests per second. #### Libraries used -- [Starlette](https://github.com/encode/starlette) -- [SQLAlchemy](https://www.sqlalchemy.org/) +- [Django](https://www.djangoproject.com/) +- [Pydantic](https://docs.pydantic.dev/latest/) - [APScheduler](https://github.com/agronholm/apscheduler) -- [uvloop](https://github.com/MagicStack/uvloop) - [ultraJSON](https://github.com/esnme/ultrajson) ## Deployment @@ -105,10 +98,10 @@ On initialization it will check the database. If it's empty all the historic rat ## Development ```shell -uvicorn app:app --reload +export DEBUG=True; uvicorn vatcomply.asgi:application --reload ``` -or +or for simplicity a Makefile is provided with all the commands for development. ```shell make run @@ -116,15 +109,29 @@ make run ## Migrations +Make migrations + +```shell +make migrations +``` + +Run migrations + +```shell +make migrate +``` + +## Tests + ```shell -PYTHONPATH=.:\$PYTHONPATH alembic revision --autogenerate -m "create users table" +make test ``` ## Contributing Thanks for your interest in the project! All pull requests are welcome from developers of all skill levels. To get started, simply fork the master branch on GitHub to your personal account and then clone the fork into your development environment. -Madis Väin (madisvain on Github, Twitter) is the original creator of the VATcomply API. +Madis Väin ([madisvain](https://github.com/madisvain) on Github) is the original creator of the VATcomply API. ## License diff --git a/alembic.ini b/alembic.ini deleted file mode 100644 index 411ea77..0000000 --- a/alembic.ini +++ /dev/null @@ -1,72 +0,0 @@ -# A generic, single database configuration. - -[alembic] -# path to migration scripts -script_location = migrations - -# template used to generate migration files -# file_template = %%(rev)s_%%(slug)s - -# timezone to use when rendering the date -# within the migration file as well as the filename. -# string value is passed to dateutil.tz.gettz() -# leave blank for localtime -# timezone = - -# max length of characters to apply to the -# "slug" field -# truncate_slug_length = 40 - -# set to 'true' to run the environment during -# the 'revision' command, regardless of autogenerate -# revision_environment = false - -# set to 'true' to allow .pyc and .pyo files without -# a source .py file to be detected as revisions in the -# versions/ directory -# sourceless = false - -# version location specification; this defaults -# to migrations/versions. When using multiple version -# directories, initial revisions must be specified with --version-path -# version_locations = %(here)s/bar %(here)s/bat migrations/versions - -# the output encoding used when revision files -# are written from script.py.mako -# output_encoding = utf-8 - - -# Logging configuration -[loggers] -keys = root,sqlalchemy,alembic - -[handlers] -keys = console - -[formatters] -keys = generic - -[logger_root] -level = WARN -handlers = console -qualname = - -[logger_sqlalchemy] -level = WARN -handlers = -qualname = sqlalchemy.engine - -[logger_alembic] -level = INFO -handlers = -qualname = alembic - -[handler_console] -class = StreamHandler -args = (sys.stderr,) -level = NOTSET -formatter = generic - -[formatter_generic] -format = %(levelname)-5.5s [%(name)s] %(message)s -datefmt = %H:%M:%S diff --git a/app.json b/app.json deleted file mode 100644 index c2d6a8e..0000000 --- a/app.json +++ /dev/null @@ -1,23 +0,0 @@ -{ - "name": "VATcomply", - "description": "A painless VAT compliance & Exchange rates API", - "website": "https://www.vatcomply.com", - "repository": "https://github.com/madisvain/vatcomply", - "keywords": [ - "python", - "starlette", - "api", - "exchange rates", - "vat compliance" - ], - "env": { - "WEB_CONCURRENCY": { - "description": "The number of gunicorn workers to run.", - "value": "4" - }, - "SENTRY_DSN": { - "description": "To enable Sentry integration provide your Sentry DSN.", - "required": false - } - } -} \ No newline at end of file diff --git a/app.py b/app.py deleted file mode 100644 index 3653453..0000000 --- a/app.py +++ /dev/null @@ -1,327 +0,0 @@ -import fcntl -import json -import pendulum -import sentry_sdk -import ujson -import uvicorn -import zeep - -from apscheduler.schedulers.asyncio import AsyncIOScheduler -from babel.numbers import get_currency_name, get_currency_symbol -from decimal import Decimal -from passlib.hash import pbkdf2_sha256 -from pydantic import ValidationError -from pydantic.error_wrappers import ErrorWrapper -from sentry_sdk.integrations.asgi import SentryAsgiMiddleware -from starlette.applications import Starlette -from starlette.middleware.cors import CORSMiddleware -from starlette.middleware.httpsredirect import HTTPSRedirectMiddleware -from starlette.middleware.trustedhost import TrustedHostMiddleware -from starlette.responses import JSONResponse -from typing import Any - -from db import database, Countries, Rates, Users -from errors import AlreadyExistsError -from models import ( - LoginValidationModel, - RatesQueryValidationModel, - RegistrationValidationModel, - VATValidationModel, - VatRatesModel, -) -from settings import ( - ALLOWED_HOSTS, - DEBUG, - FORCE_HTTPS, - SENTRY_DSN, - SYMBOLS, - TESTING, - VIES_URL, -) -from utils import load_countries, load_rates - - -CORS_HEADERS = { - "Access-Control-Allow-Credentials": "true", - "Access-Control-Allow-Methods": "GET", - "Access-Control-Allow-Origin": "*", -} - - -class UJSONResponse(JSONResponse): - media_type = "application/json" - - def render(self, content: Any) -> bytes: - assert ujson is not None, "ujson must be installed to use UJSONResponse" - return ujson.dumps(content, ensure_ascii=False).encode("utf-8") - - -app = Starlette(debug=DEBUG) - -""" Allowed hosts """ -if ALLOWED_HOSTS: - app.add_middleware(TrustedHostMiddleware, allowed_hosts=list(ALLOWED_HOSTS)) - -""" Force HTTPS """ -if FORCE_HTTPS: - app.add_middleware(HTTPSRedirectMiddleware) - -""" Sentry """ -if SENTRY_DSN: - sentry_sdk.init(dsn=SENTRY_DSN) - app.add_middleware(SentryAsgiMiddleware) - -""" CORS """ -# if CORS: -app.add_middleware( - CORSMiddleware, - allow_origins=["*"], - allow_credentials=True, -) - -""" Authentication """ -# app.add_middleware(AuthenticationMiddleware, backend=TokenAuthenticationBackend()) - - -""" Startup & Shutdown """ - - -@app.on_event("startup") -async def startup(): - await database.connect() - - # Schedule exchangerate updates - try: - _ = open("scheduler.lock", "w") - fcntl.lockf(_.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB) - - scheduler = AsyncIOScheduler() - scheduler.start() - - if not TESTING: - # Countries - scheduler.add_job(load_countries) - - # Updates lates 90 days data - scheduler.add_job(load_rates, "interval", hours=1, minutes=10) - - # Fill up database with rates - scheduler.add_job(load_rates, kwargs={"last_90_days": False}) - except BlockingIOError: - pass - - -@app.on_event("shutdown") -async def shutdown(): - await database.disconnect() - - -""" API """ - - -@app.route("/login", methods=["POST"]) -async def login(request): - try: - data = await request.json() - login = LoginValidationModel(**data) - - return UJSONResponse(login.dict()) - except ValidationError as e: - return UJSONResponse(e.errors(), status_code=400) - - -@app.route("/register", methods=["POST"]) -async def register(request): - try: - data = await request.json() - registration = RegistrationValidationModel(**data) - - # Check if the email is unique - user = await database.fetch_one( - query=Users.select().where(Users.c.email == registration.email) - ) - if user: - raise ValidationError( - [ - ErrorWrapper( - AlreadyExistsError(email=registration.email), loc="email" - ) - ], - model=RegistrationValidationModel, - ) - - await database.execute( - query=Users.insert(), - values={ - "email": registration.email, - "password": pbkdf2_sha256.hash( - registration.password.get_secret_value() - ), - }, - ) - response = registration.dict() - del response["password"] - return UJSONResponse(response, status_code=201) - except ValidationError as e: - return UJSONResponse(e.errors(), status_code=400) - - -@app.route("/vat") -# @requires("authenticated") -async def vat(request): - try: - query = VATValidationModel(**request.query_params) - client = zeep.Client(wsdl=str(VIES_URL)) - try: - response = zeep.helpers.serialize_object( - client.service.checkVat( - countryCode=query.vat_number[:2], vatNumber=query.vat_number[2:] - ) - ) - except zeep.exceptions.Fault as e: - return UJSONResponse({"error": e.message}) - - return UJSONResponse( - { - "valid": response["valid"], - "vat_number": response["vatNumber"], - "name": response["name"], - "address": response["address"].strip() if response["address"] else "", - "country_code": response["countryCode"], - }, - headers=CORS_HEADERS, - ) - except ValidationError as e: - return UJSONResponse(e.errors(), headers=CORS_HEADERS, status_code=400) - - -@app.route("/geolocate", methods=["GET", "HEAD"]) -async def geolocate(request): - country_code = request.headers.get("CF-IPCountry") - ip = request.headers.get("CF-Connecting-IP") - - if not country_code: - return UJSONResponse({"ip": ip}, headers=CORS_HEADERS, status_code=404) - - # Get the rates data - record = await database.fetch_one( - query=Countries.select().where(Countries.c.iso2 == country_code.upper()) - ) - - return UJSONResponse( - { - "iso2": record.iso2 if record else None, - "iso3": record.iso3 if record else None, - "country_code": country_code.upper() if country_code else None, - "name": record.name if record else None, - "numeric_code": record.numeric_code if record else None, - "phone_code": record.phone_code if record else None, - "capital": record.capital if record else None, - "currency": record.currency if record else None, - "tld": record.tld if record else None, - "region": record.region if record else None, - "subregion": record.subregion if record else None, - "latitude": Decimal(record.latitude) if record else None, - "longitude": Decimal(record.longitude) if record else None, - "emoji": record.emoji if record else None, - "ip": ip, - }, - headers=CORS_HEADERS, - ) - - -@app.route("/countries") -# @requires("authenticated") -async def countries(request): - records = await database.fetch_all( - query=Countries.select().order_by(Countries.c.iso2.asc()) - ) - - countries = [] - for country in records: - countries.append( - { - "iso2": country.iso2, - "iso3": country.iso3, - "name": country.name, - "numeric_code": country.numeric_code, - "phone_code": country.phone_code, - "capital": country.capital, - "currency": country.currency, - "tld": country.tld, - "region": country.region, - "subregion": country.subregion, - "latitude": Decimal(country.latitude), - "longitude": Decimal(country.longitude), - "emoji": country.emoji, - } - ) - return UJSONResponse( - countries, - headers=CORS_HEADERS, - ) - - -@app.route("/rates") -@app.route("/rates/latest") -@app.route("/rates/{date}") -async def rates(request): - query_params = dict(request.query_params) - if "date" in request.path_params: - query_params["date"] = request.path_params["date"] - - try: - query = RatesQueryValidationModel(**query_params) - - # Find the date - date = query.date if query.date else pendulum.now().date() - - # Get the rates data - record = await database.fetch_one( - query=Rates.select() - .where(Rates.c.date <= date) - .order_by(Rates.c.date.desc()) - .limit(1) - ) - - # Base re-calculation - rates = {"EUR": 1} - rates.update(record.rates) - if query.base and query.base != "EUR": - base_rate = Decimal(record.rates[query.base]) - rates = { - currency: Decimal(rate) / base_rate for currency, rate in rates.items() - } - rates.update({"EUR": Decimal(1) / base_rate}) - - # Symbols - if query.symbols: - for rate in list(rates): - if rate not in query.symbols: - del rates[rate] - - return UJSONResponse( - {"date": record.date.isoformat(), "base": query.base, "rates": rates}, - headers=CORS_HEADERS, - ) - except ValidationError as e: - return UJSONResponse(e.errors(), headers=CORS_HEADERS, status_code=400) - - -@app.route("/currencies") -# @requires("authenticated") -async def currencies(request): - currencies = {} - for symbol in list(SYMBOLS): - currencies[symbol] = { - "name": get_currency_name(symbol, locale="en"), - "symbol": get_currency_symbol(symbol, locale="en"), - } - - headers = {"Cache-Control": "max-age=86400"} - headers.update(CORS_HEADERS) - return UJSONResponse(currencies, headers=headers) - - -if __name__ == "__main__": - uvicorn.run(app, host="127.0.0.1", port=8000) diff --git a/auth.py b/auth.py deleted file mode 100644 index d3b0131..0000000 --- a/auth.py +++ /dev/null @@ -1,26 +0,0 @@ -from starlette.authentication import AuthenticationBackend, AuthCredentials, AuthenticationError, SimpleUser - -from settings import TESTING - - -class TokenAuthenticationBackend(AuthenticationBackend): - """ https://github.com/encode/django-rest-framework/blob/master/rest_framework/authentication.py#L144 - """ - - keyword = "Token" - - async def authenticate(self, request): - if "authorization" not in request.headers: - return - - auth = request.headers["authorization"].split() - if not auth or auth[0].lower() != self.keyword.lower(): - return - - if len(auth) == 1: - raise AuthenticationError("Invalid token header. No credentials provided.") - elif len(auth) > 2: - raise AuthenticationError("Invalid token header. Token string should not contain spaces.") - - # TODO: Implement validation - return AuthCredentials(["authenticated"]), SimpleUser("username") diff --git a/bin/post_compile b/bin/post_compile deleted file mode 100644 index 6a5da16..0000000 --- a/bin/post_compile +++ /dev/null @@ -1 +0,0 @@ -PYTHONPATH=.:$PYTHONPATH alembic upgrade head \ No newline at end of file diff --git a/conftest.py b/conftest.py deleted file mode 100644 index c074937..0000000 --- a/conftest.py +++ /dev/null @@ -1,52 +0,0 @@ -import asyncio -import pytest - -from starlette.config import environ -from starlette.testclient import TestClient -from sqlalchemy import create_engine -from sqlalchemy_utils import database_exists, create_database, drop_database - -# This sets `os.environ`, but provides some additional protection. -# If we placed it below the application import, it would raise an error -# informing us that 'TESTING' had already been read from the environment. -environ["TESTING"] = "True" - -import app - -from db import metadata -from settings import TEST_DATABASE_URL -from utils import load_rates - - -@pytest.fixture(scope="session", autouse=True) -def create_test_database(): - """ - Create a clean database on every test case. - For safety, we should abort if a database already exists. - - We use the `sqlalchemy_utils` package here for a few helpers in consistently - creating and dropping the database. - """ - url = str(TEST_DATABASE_URL) - engine = create_engine(url) - assert not database_exists(url), "Test database already exists. Aborting tests." - create_database(url) # Create the test database. - metadata.create_all(engine) # Create the tables. - asyncio.get_event_loop().run_until_complete(load_rates(last_90_days=False)) - yield # Run the tests. - drop_database(url) # Drop the test database. - - -@pytest.fixture() -def client(): - """ - When using the 'client' fixture in test cases, we'll get full database - rollbacks between test cases: - - def test_homepage(client): - url = app.url_path_for('homepage') - response = client.get(url) - assert response.status_code == 200 - """ - with TestClient(app) as client: - yield client diff --git a/db.py b/db.py deleted file mode 100644 index 9861250..0000000 --- a/db.py +++ /dev/null @@ -1,48 +0,0 @@ -import databases -import sqlalchemy - -from settings import DATABASE_URL, TEST_DATABASE_URL, TESTING - -metadata = sqlalchemy.MetaData() - -Countries = sqlalchemy.Table( - "countries", - metadata, - sqlalchemy.Column("name", sqlalchemy.String, nullable=False), - sqlalchemy.Column("iso2", sqlalchemy.String, nullable=False, unique=True), - sqlalchemy.Column("iso3", sqlalchemy.String, nullable=False), - sqlalchemy.Column("numeric_code", sqlalchemy.String), - sqlalchemy.Column("phone_code", sqlalchemy.String), - sqlalchemy.Column("capital", sqlalchemy.String, nullable=False), - sqlalchemy.Column("currency", sqlalchemy.String, nullable=False), - sqlalchemy.Column("tld", sqlalchemy.String, nullable=False), - sqlalchemy.Column("region", sqlalchemy.String, nullable=False), - sqlalchemy.Column("subregion", sqlalchemy.String, nullable=False), - sqlalchemy.Column("latitude", sqlalchemy.String, nullable=False), - sqlalchemy.Column("longitude", sqlalchemy.String, nullable=False), - sqlalchemy.Column("emoji", sqlalchemy.String, nullable=False), -) - -Rates = sqlalchemy.Table( - "rates", - metadata, - sqlalchemy.Column("date", sqlalchemy.Date, primary_key=True), - sqlalchemy.Column("rates", sqlalchemy.JSON), -) - -Users = sqlalchemy.Table( - "users", - metadata, - sqlalchemy.Column("pk", sqlalchemy.Integer, primary_key=True), - sqlalchemy.Column("email", sqlalchemy.String, nullable=False, unique=True), - sqlalchemy.Column("password", sqlalchemy.String, nullable=False), - sqlalchemy.Column( - "created_at", sqlalchemy.DateTime, default=sqlalchemy.func.now(), nullable=False - ), - sqlalchemy.Column("last_login", sqlalchemy.DateTime), -) - -if TESTING: - database = databases.Database(TEST_DATABASE_URL, force_rollback=True) -else: - database = databases.Database(DATABASE_URL) diff --git a/errors.py b/errors.py deleted file mode 100644 index 003181d..0000000 --- a/errors.py +++ /dev/null @@ -1,6 +0,0 @@ -from pydantic import PydanticValueError - - -class AlreadyExistsError(PydanticValueError): - code = "already_exists" - msg_template = "User with an email '{email}' already exists" diff --git a/manage.py b/manage.py new file mode 100755 index 0000000..06eba20 --- /dev/null +++ b/manage.py @@ -0,0 +1,22 @@ +#!/usr/bin/env python +"""Django's command-line utility for administrative tasks.""" +import os +import sys + + +def main(): + """Run administrative tasks.""" + os.environ.setdefault("DJANGO_SETTINGS_MODULE", "vatcomply.settings") + try: + from django.core.management import execute_from_command_line + except ImportError as exc: + raise ImportError( + "Couldn't import Django. Are you sure it's installed and " + "available on your PYTHONPATH environment variable? Did you " + "forget to activate a virtual environment?" + ) from exc + execute_from_command_line(sys.argv) + + +if __name__ == "__main__": + main() diff --git a/migrations/README b/migrations/README deleted file mode 100644 index 98e4f9c..0000000 --- a/migrations/README +++ /dev/null @@ -1 +0,0 @@ -Generic single-database configuration. \ No newline at end of file diff --git a/migrations/env.py b/migrations/env.py deleted file mode 100644 index 8bd0fcb..0000000 --- a/migrations/env.py +++ /dev/null @@ -1,78 +0,0 @@ -from alembic import context -from logging.config import fileConfig -from sqlalchemy import engine_from_config -from sqlalchemy import pool - -import db -from settings import DATABASE_URL - -# this is the Alembic Config object, which provides -# access to the values within the .ini file in use. -config = context.config - -# Database URL configuration -config.set_main_option("sqlalchemy.url", str(DATABASE_URL)) - -target_metadata = db.metadata - -# Interpret the config file for Python logging. -# This line sets up loggers basically. -fileConfig(config.config_file_name) - -# add your model's MetaData object here -# for 'autogenerate' support -# from myapp import mymodel -# target_metadata = mymodel.Base.metadata -# target_metadata = None - -# other values from the config, defined by the needs of env.py, -# can be acquired: -# my_important_option = config.get_main_option("my_important_option") -# ... etc. - - -def run_migrations_offline(): - """Run migrations in 'offline' mode. - - This configures the context with just a URL - and not an Engine, though an Engine is acceptable - here as well. By skipping the Engine creation - we don't even need a DBAPI to be available. - - Calls to context.execute() here emit the given string to the - script output. - - """ - url = config.get_main_option("sqlalchemy.url") - context.configure(url=url, target_metadata=target_metadata, literal_binds=True) - - with context.begin_transaction(): - context.run_migrations() - - -def run_migrations_online(): - """Run migrations in 'online' mode. - - In this scenario we need to create an Engine - and associate a connection with the context. - - """ - connectable = engine_from_config( - config.get_section(config.config_ini_section), - prefix="sqlalchemy.", - poolclass=pool.NullPool, - ) - - with connectable.connect() as connection: - context.configure( - connection=connection, target_metadata=target_metadata, compare_type=True - ) - - with context.begin_transaction(): - context.run_migrations() - - -if context.is_offline_mode(): - run_migrations_offline() -else: - run_migrations_online() diff --git a/migrations/script.py.mako b/migrations/script.py.mako deleted file mode 100644 index 2c01563..0000000 --- a/migrations/script.py.mako +++ /dev/null @@ -1,24 +0,0 @@ -"""${message} - -Revision ID: ${up_revision} -Revises: ${down_revision | comma,n} -Create Date: ${create_date} - -""" -from alembic import op -import sqlalchemy as sa -${imports if imports else ""} - -# revision identifiers, used by Alembic. -revision = ${repr(up_revision)} -down_revision = ${repr(down_revision)} -branch_labels = ${repr(branch_labels)} -depends_on = ${repr(depends_on)} - - -def upgrade(): - ${upgrades if upgrades else "pass"} - - -def downgrade(): - ${downgrades if downgrades else "pass"} diff --git a/migrations/versions/17ee4cab92e6_create_countries_table.py b/migrations/versions/17ee4cab92e6_create_countries_table.py deleted file mode 100644 index e89f10e..0000000 --- a/migrations/versions/17ee4cab92e6_create_countries_table.py +++ /dev/null @@ -1,44 +0,0 @@ -"""create countries table - -Revision ID: 17ee4cab92e6 -Revises: c8d6ab7251df -Create Date: 2021-10-11 22:44:08.477311 - -""" -from alembic import op -import sqlalchemy as sa - - -# revision identifiers, used by Alembic. -revision = "17ee4cab92e6" -down_revision = "c8d6ab7251df" -branch_labels = None -depends_on = None - - -def upgrade(): - # ### commands auto generated by Alembic - please adjust! ### - op.create_table( - "countries", - sa.Column("name", sa.String(), nullable=False), - sa.Column("iso2", sa.String(), nullable=False), - sa.Column("iso3", sa.String(), nullable=False), - sa.Column("numeric_code", sa.String(), nullable=True), - sa.Column("phone_code", sa.String(), nullable=True), - sa.Column("capital", sa.String(), nullable=False), - sa.Column("currency", sa.String(), nullable=False), - sa.Column("tld", sa.String(), nullable=False), - sa.Column("region", sa.String(), nullable=False), - sa.Column("subregion", sa.String(), nullable=False), - sa.Column("latitude", sa.String(), nullable=False), - sa.Column("longitude", sa.String(), nullable=False), - sa.Column("emoji", sa.String(), nullable=False), - sa.UniqueConstraint("iso2"), - ) - # ### end Alembic commands ### - - -def downgrade(): - # ### commands auto generated by Alembic - please adjust! ### - op.drop_table("countries") - # ### end Alembic commands ### diff --git a/migrations/versions/ae9f5ea7bd9a_create_rates_table.py b/migrations/versions/ae9f5ea7bd9a_create_rates_table.py deleted file mode 100644 index 1202f50..0000000 --- a/migrations/versions/ae9f5ea7bd9a_create_rates_table.py +++ /dev/null @@ -1,28 +0,0 @@ -"""create rates table - -Revision ID: ae9f5ea7bd9a -Revises: -Create Date: 2019-07-11 11:07:50.885209 - -""" -from alembic import op -import sqlalchemy as sa - - -# revision identifiers, used by Alembic. -revision = "ae9f5ea7bd9a" -down_revision = None -branch_labels = None -depends_on = None - - -def upgrade(): - op.create_table( - "rates", - sa.Column("date", sa.Date, primary_key=True), - sa.Column("rates", sa.JSON), - ) - - -def downgrade(): - op.drop_table("rates") diff --git a/migrations/versions/c8d6ab7251df_create_users_table.py b/migrations/versions/c8d6ab7251df_create_users_table.py deleted file mode 100644 index 2a0a3c3..0000000 --- a/migrations/versions/c8d6ab7251df_create_users_table.py +++ /dev/null @@ -1,37 +0,0 @@ -"""create users table - -Revision ID: c8d6ab7251df -Revises: ae9f5ea7bd9a -Create Date: 2020-01-05 22:57:21.139411 - -""" -from alembic import op -import sqlalchemy as sa - - -# revision identifiers, used by Alembic. -revision = "c8d6ab7251df" -down_revision = "ae9f5ea7bd9a" -branch_labels = None -depends_on = None - - -def upgrade(): - # ### commands auto generated by Alembic - please adjust! ### - op.create_table( - "users", - sa.Column("pk", sa.Integer(), nullable=False), - sa.Column("email", sa.String(), nullable=False), - sa.Column("password", sa.String(), nullable=False), - sa.Column("created_at", sa.DateTime(), nullable=False), - sa.Column("last_login", sa.DateTime(), nullable=True), - sa.PrimaryKeyConstraint("pk"), - sa.UniqueConstraint("email"), - ) - # ### end Alembic commands ### - - -def downgrade(): - # ### commands auto generated by Alembic - please adjust! ### - op.drop_table("users") - # ### end Alembic commands ### diff --git a/models.py b/models.py deleted file mode 100644 index c88325d..0000000 --- a/models.py +++ /dev/null @@ -1,67 +0,0 @@ -from datetime import date -from typing import Optional - -from passlib.hash import pbkdf2_sha256 -from pydantic import BaseModel, validator, root_validator, EmailStr, SecretStr - -from db import database, Users -from settings import SYMBOLS - - -class LoginValidationModel(BaseModel): - email: EmailStr - password: SecretStr - - @root_validator - def password_length(cls, values): - email, password = values.get("email"), values.get("password") - user = database.fetch_one(query=Users.select().where(Users.c.email == email)) - if not user or pbkdf2_sha256.verify(password.get_secret_value(), user.password): - raise ValueError( - "Please enter a correct username and password. Note that both fields may be case-sensitive." - ) - return values - - -class RatesQueryValidationModel(BaseModel): - base: str = "EUR" - date: Optional[date] - symbols: Optional[list] - - @validator("base") - def base_validation(cls, base): - if base not in list(SYMBOLS): - raise ValueError(f"Base currency {base} is not supported.") - return base - - @validator("symbols", pre=True, whole=True) - def symbols_validation(cls, symbols): - symbols = symbols.split(",") - diff = list(set(symbols) - set(list(SYMBOLS))) - if diff: - raise ValueError(f"Symbols {', '.join(diff)} are not supported.") - return symbols - - -class RegistrationValidationModel(BaseModel): - email: EmailStr - password: SecretStr - - @validator("password") - def password_length(cls, v): - if len(v.get_secret_value()) < 6: - raise ValueError("Password must be at least 6 characters long.") - return v - - -class VATValidationModel(BaseModel): - vat_number: str - - @validator("vat_number") - def format_validation(cls, v): - # TODO: implement Regex validators - return v - - -class VatRatesModel(BaseModel): - vat_nu: str diff --git a/requirements.in b/requirements.in index b73cfc8..d67ce79 100644 --- a/requirements.in +++ b/requirements.in @@ -1,18 +1,12 @@ -alembic apscheduler -babel -databases[asyncpg] +django +django-cors-headers gunicorn httpx -passlib +lxml pendulum -psycopg2-binary -pydantic[email] -pytest -sentry-sdk -starlette -sqlalchemy-utils -requests +pydantic +sentry-sdk[django] ujson uvicorn zeep \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 37856e3..5d24ce3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,49 +1,40 @@ -aiosqlite==0.19.0 -alembic==1.11.1 -anyio==3.7.0 -APScheduler==3.10.1 -asyncpg==0.27.0 +annotated-types==0.6.0 +anyio==4.2.0 +APScheduler==3.10.4 +asgiref==3.7.2 attrs==23.1.0 -Babel==2.12.1 -certifi==2023.5.7 -charset-normalizer==3.1.0 -click==8.1.3 -databases==0.7.0 -dnspython==2.3.0 -email-validator==2.0.0.post2 -gunicorn==20.1.0 +certifi==2023.11.17 +charset-normalizer==3.3.2 +click==8.1.7 +Django==4.2.8 +django-cors-headers==4.3.1 +exceptiongroup==1.2.0 +gunicorn==21.2.0 h11==0.14.0 -httpcore==0.17.2 -httpx==0.24.1 -idna==3.4 -iniconfig==2.0.0 +httpcore==1.0.2 +httpx==0.26.0 +idna==3.6 isodate==0.6.1 -lxml==4.9.2 -Mako==1.2.4 -MarkupSafe==2.1.3 -packaging==23.1 -passlib==1.7.4 -pendulum==2.1.2 -platformdirs==3.5.3 -pluggy==1.0.0 -psycopg2-binary==2.9.6 -pydantic==1.10.9 -pytest==7.3.2 +lxml==4.9.4 +packaging==23.2 +pendulum==3.0.0 +platformdirs==4.1.0 +pydantic==2.5.3 +pydantic_core==2.14.6 python-dateutil==2.8.2 -pytz==2023.3 -pytzdata==2020.1 +pytz==2023.3.post1 requests==2.31.0 requests-file==1.5.1 requests-toolbelt==1.0.0 -sentry-sdk==1.25.1 +sentry-sdk==1.39.1 six==1.16.0 sniffio==1.3.0 -SQLAlchemy==1.4.48 -SQLAlchemy-Utils==0.41.1 -starlette==0.28.0 -typing_extensions==4.6.3 -tzlocal==5.0.1 -ujson==5.8.0 -urllib3==2.0.3 -uvicorn==0.22.0 +sqlparse==0.4.4 +time-machine==2.13.0 +typing_extensions==4.9.0 +tzdata==2023.3 +tzlocal==5.2 +ujson==5.9.0 +urllib3==2.1.0 +uvicorn==0.25.0 zeep==4.2.1 diff --git a/runtime.txt b/runtime.txt index 2d4e051..d5831c5 100644 --- a/runtime.txt +++ b/runtime.txt @@ -1 +1 @@ -python-3.11.2 \ No newline at end of file +python-3.11.5 \ No newline at end of file diff --git a/settings.py b/settings.py deleted file mode 100644 index 984ea75..0000000 --- a/settings.py +++ /dev/null @@ -1,83 +0,0 @@ -from databases import DatabaseURL -from starlette.config import Config -from starlette.datastructures import CommaSeparatedStrings, URL - -config = Config(".env") - -DEBUG = config("DEBUG", cast=bool, default=False) -TESTING = config("TESTING", cast=bool, default=False) - -DATABASE_URL = config("DATABASE_URL", cast=DatabaseURL, default="sqlite:///db.sqlite3") -TEST_DATABASE_URL = DATABASE_URL.replace(database="test_" + DATABASE_URL.database) - -CORS = config("CORS", cast=bool, default=True) -ALLOWED_HOSTS = config("ALLOWED_HOSTS", cast=CommaSeparatedStrings, default=[]) -FORCE_HTTPS = config("FORCE_HTTPS", cast=bool, default=False) - -COUNTRIES_URL = config( - "COUNTRIES_URL", - cast=URL, - default="https://raw.githubusercontent.com/dr5hn/countries-states-cities-database/master/countries.json", -) -RATES_URL = config( - "RATES_URL", - cast=URL, - default="https://www.ecb.europa.eu/stats/eurofxref/eurofxref-hist.xml", -) -RATES_LAST_90_DAYS_URL = config( - "RATES_LAST_90_DAYS_URL", - cast=URL, - default="https://www.ecb.europa.eu/stats/eurofxref/eurofxref-hist-90d.xml", -) -VIES_URL = config( - "VIES_URL", - cast=URL, - default="http://ec.europa.eu/taxation_customs/vies/checkVatService.wsdl", -) - -SYMBOLS = config( - "SYMBOLS", - cast=CommaSeparatedStrings, - default=[ - "EUR", - "USD", - "JPY", - "BGN", - "CZK", - "DKK", - "GBP", - "HUF", - "PLN", - "RON", - "SEK", - "CHF", - "ISK", - "NOK", - "HRK", - "RUB", - "TRY", - "AUD", - "BRL", - "CAD", - "CNY", - "HKD", - "IDR", - "ILS", - "INR", - "KRW", - "MXN", - "MYR", - "NZD", - "PHP", - "SGD", - "THB", - "ZAR", - ], -) - -SENTRY_DSN = config("SENTRY_DSN", cast=str, default="") - -# Testing -if TESTING: - pass - # DATABASE_URL = DATABASE_URL.replace(path="/test_db.sqlite3") diff --git a/tests/test_currencies.py b/tests/test_currencies.py deleted file mode 100644 index 3ce9de3..0000000 --- a/tests/test_currencies.py +++ /dev/null @@ -1,21 +0,0 @@ -from starlette.testclient import TestClient - -from app import app - - -class TestCurrenciesAPI(object): - # def test_currencies_api_unauth(self): - # with TestClient(app) as client: - # response = client.get("/currencies") - # assert response.status_code == 403 - - def test_currencies_api(self): - with TestClient(app) as client: - response = client.get("/currencies", headers={"Authorization": "Token test-token"}) - assert response.status_code == 200 - assert isinstance(response.json(), dict) - assert "EUR" in response.json() - assert "USD" in response.json() - assert response.json()["EUR"]["name"] == "Euro" - assert response.json()["EUR"]["symbol"] == "€" - assert len(response.json()) == 33 diff --git a/tests/test_rates_api.py b/tests/test_rates_api.py deleted file mode 100644 index 0119205..0000000 --- a/tests/test_rates_api.py +++ /dev/null @@ -1,90 +0,0 @@ -from starlette.testclient import TestClient - -from app import app - - -class TestRatesAPI(object): - def test_latest_api(self): - with TestClient(app) as client: - response = client.get( - "/rates", headers={"Authorization": "Token test-token"} - ) - assert response.status_code == 200 - assert isinstance(response.json(), dict) - assert "date" in response.json() - assert "base" in response.json() - assert response.json()["base"] == "EUR" - assert "rates" in response.json() - assert len(response.json()["rates"]) == 31 - - def test_date_api(self): - with TestClient(app) as client: - response = client.get( - "/rates?date=2018-10-12", headers={"Authorization": "Token test-token"} - ) - assert response.status_code == 200 - assert isinstance(response.json(), dict) - assert "date" in response.json() - assert response.json()["date"] == "2018-10-12" - assert "base" in response.json() - assert response.json()["base"] == "EUR" - assert "rates" in response.json() - assert len(response.json()["rates"]) == 33 - - def test_invalid_date_api(self): - with TestClient(app) as client: - response = client.get( - "/rates?date=abc", headers={"Authorization": "Token test-token"} - ) - assert response.status_code == 400 - assert isinstance(response.json(), list) - - def test_date_weekend_api(self): - with TestClient(app) as client: - response = client.get( - "/rates?date=2018-10-13", headers={"Authorization": "Token test-token"} - ) - assert response.status_code == 200 - assert isinstance(response.json(), dict) - assert "date" in response.json() - assert response.json()["date"] == "2018-10-12" - assert "base" in response.json() - assert response.json()["base"] == "EUR" - assert "rates" in response.json() - assert len(response.json()["rates"]) == 33 - - def test_base_api(self): - with TestClient(app) as client: - response = client.get( - "/rates?base=USD", headers={"Authorization": "Token test-token"} - ) - assert response.status_code == 200 - assert isinstance(response.json(), dict) - assert "date" in response.json() - assert "base" in response.json() - assert response.json()["base"] == "USD" - assert "rates" in response.json() - assert len(response.json()["rates"]) == 31 - assert response.json()["rates"]["USD"] == 1 - - def test_symbols_api(self): - with TestClient(app) as client: - response = client.get( - "/rates?symbols=USD,JPY,GBP", - headers={"Authorization": "Token test-token"}, - ) - assert response.status_code == 200 - assert isinstance(response.json(), dict) - assert "date" in response.json() - assert "base" in response.json() - assert response.json()["base"] == "EUR" - assert "rates" in response.json() - assert len(response.json()["rates"]) == 3 - - def test_invalid_symbols_api(self): - with TestClient(app) as client: - response = client.get( - "/rates?symbols=12345", headers={"Authorization": "Token test-token"} - ) - assert response.status_code == 400 - assert isinstance(response.json(), list) diff --git a/tests/test_registration.py b/tests/test_registration.py deleted file mode 100644 index 572382d..0000000 --- a/tests/test_registration.py +++ /dev/null @@ -1,32 +0,0 @@ -from starlette.testclient import TestClient - -from app import app - - -class TestRegistrationAPI(object): - def test_auth(self): - with TestClient(app) as client: - response = client.get("/register") - assert response.status_code == 405 - - def test_register(self): - with TestClient(app) as client: - response = client.post("/register", json={"email": "test@test.com", "password": "password"}) - assert response.status_code == 201 - - def test_register_existing(self): - with TestClient(app) as client: - response = client.post("/register", json={"email": "test@test.com", "password": "password"}) - assert response.status_code == 201 - response = client.post("/register", json={"email": "test@test.com", "password": "password"}) - assert response.status_code == 400 - - def test_register_without_password(self): - with TestClient(app) as client: - response = client.post("/register", json={"email": "test2@test.com", "password": ""}) - assert response.status_code == 400 - - def test_register_without_email(self): - with TestClient(app) as client: - response = client.post("/register", json={"email": "", "password": "password"}) - assert response.status_code == 400 diff --git a/tests/test_vat.py b/tests/test_vat.py deleted file mode 100644 index 2a7ae2a..0000000 --- a/tests/test_vat.py +++ /dev/null @@ -1,11 +0,0 @@ -from starlette.testclient import TestClient - -from app import app - - -class TestVATAPI(object): - def test_vat_api(self): - with TestClient(app) as client: - response = client.get("/vat?vat_number=EE101600930", headers={"Authorization": "Token test-token"}) - assert response.status_code == 200 - assert isinstance(response.json(), dict) diff --git a/utils.py b/utils.py deleted file mode 100644 index e4b5fde..0000000 --- a/utils.py +++ /dev/null @@ -1,75 +0,0 @@ -import json -import pendulum -import requests -import ujson - -from starlette.responses import JSONResponse -from typing import Any -from xml.etree import ElementTree - -from db import database, Countries, Rates -from settings import COUNTRIES_URL, RATES_URL, RATES_LAST_90_DAYS_URL - - -class UJSONResponse(JSONResponse): - media_type = "application/json" - - def render(self, content: Any) -> bytes: - assert ujson is not None, "ujson must be installed to use UJSONResponse" - return ujson.dumps(content, ensure_ascii=False).encode("utf-8") - - -async def load_rates(last_90_days=True): - print("Loading rates ...") - r = requests.get(RATES_LAST_90_DAYS_URL if last_90_days else RATES_URL) - envelope = ElementTree.fromstring(r.content) - - namespaces = { - "gesmes": "http://www.gesmes.org/xml/2002-08-01", - "eurofxref": "http://www.ecb.int/vocabulary/2002-08-01/eurofxref", - } - data = envelope.findall("./eurofxref:Cube/eurofxref:Cube[@time]", namespaces) - for i, d in enumerate(data): - time = pendulum.parse(d.attrib["time"], strict=False) - if not await database.fetch_one( - query=Rates.select().where(Rates.c.date == time) - ): - await database.execute( - query=Rates.insert(), - values={ - "date": time, - "rates": { - str(c.attrib["currency"]): float(c.attrib["rate"]) - for c in list(d) - }, - }, - ) - print("Loading rates finished!") - - -async def load_countries(): - print("Loading countries ...") - r = requests.get(COUNTRIES_URL) - for country in r.json(): - if not await database.fetch_one( - query=Countries.select().where(Countries.c.iso2 == country["iso2"]) - ): - await database.execute( - query=Countries.insert(), - values={ - "name": country["name"], - "iso2": country["iso2"], - "iso3": country["iso3"], - "numeric_code": country["numeric_code"], - "phone_code": country["phone_code"], - "capital": country["capital"], - "currency": country["currency"], - "tld": country["tld"], - "region": country["region"], - "subregion": country["subregion"], - "latitude": country["latitude"], - "longitude": country["longitude"], - "emoji": country["emoji"], - }, - ) - print("Loading countries finished!") diff --git a/tests/__init__.py b/vatcomply/__init__.py similarity index 100% rename from tests/__init__.py rename to vatcomply/__init__.py diff --git a/vatcomply/asgi.py b/vatcomply/asgi.py new file mode 100644 index 0000000..f8d60e1 --- /dev/null +++ b/vatcomply/asgi.py @@ -0,0 +1,19 @@ +""" +ASGI config for VATcomply project. + +It exposes the ASGI callable as a module-level variable named ``application``. + +For more information on this file, see +https://docs.djangoproject.com/en/5.0/howto/deployment/asgi/ +""" + +import os + +from django.core.asgi import get_asgi_application + +from vatcomply.middleware import BackgroundTasksMiddleware + +os.environ.setdefault("DJANGO_SETTINGS_MODULE", "vatcomply.settings") + +# application = BackgroundTasksMiddleware(get_asgi_application()) +application = BackgroundTasksMiddleware(get_asgi_application()) diff --git a/vatcomply/http.py b/vatcomply/http.py new file mode 100644 index 0000000..3ef4989 --- /dev/null +++ b/vatcomply/http.py @@ -0,0 +1,45 @@ +from typing import Tuple, Union + +import ujson +from django.http import HttpResponse +from pydantic import ValidationError + + +def loc_to_dot_sep(loc: Tuple[Union[str, int], ...]) -> str: + path = "" + for i, x in enumerate(loc): + if isinstance(x, str): + if i > 0: + path += "." + path += x + elif isinstance(x, int): + path += f"[{x}]" + else: + raise TypeError("Unexpected type") + return path + + +class UJsonResponse(HttpResponse): + def __init__( + self, + data, + safe=True, + **kwargs, + ): + # Convert pydantic ValidationError to dict + if isinstance(data, ValidationError): + errors = {} + for error in data.errors(): + errors[loc_to_dot_sep(error["loc"])] = [] + errors[loc_to_dot_sep(error["loc"])].append(error["msg"]) + data = errors + + if safe and not isinstance(data, dict) and not isinstance(data, list): + raise TypeError( + "In order to allow non-dict or non-list objects to be serialized set the " + "safe parameter to False." + ) + + kwargs.setdefault("content_type", "application/json") + data = ujson.dumps(data, ensure_ascii=False).encode("utf-8") + super().__init__(content=data, **kwargs) diff --git a/vatcomply/management/commands/load_countries.py b/vatcomply/management/commands/load_countries.py new file mode 100644 index 0000000..7598025 --- /dev/null +++ b/vatcomply/management/commands/load_countries.py @@ -0,0 +1,35 @@ +import httpx + +from django.conf import settings +from django.core.management.base import BaseCommand + +from vatcomply.models import Country + + +class Command(BaseCommand): + help = "Load countries dataset" + + def handle(self, *args, **kwargs): + self.stdout.write("Loading countries...") + + r = httpx.get(settings.COUNTRIES_URL) + for country in r.json(): + Country.objects.update_or_create( + iso2=country["iso2"], + defaults={ + "name": country["name"], + "iso3": country["iso3"], + "numeric_code": country["numeric_code"], + "phone_code": country["phone_code"], + "capital": country["capital"], + "currency": country["currency"], + "tld": country["tld"], + "region": country["region"], + "subregion": country["subregion"], + "latitude": country["latitude"], + "longitude": country["longitude"], + "emoji": country["emoji"], + }, + ) + + self.stdout.write("Loading countries dataset finished!") diff --git a/vatcomply/management/commands/load_rates.py b/vatcomply/management/commands/load_rates.py new file mode 100644 index 0000000..db6b8f4 --- /dev/null +++ b/vatcomply/management/commands/load_rates.py @@ -0,0 +1,49 @@ +import httpx +import pendulum + +from django.conf import settings +from django.core.management.base import BaseCommand +from django.db import IntegrityError +from xml.etree import ElementTree + +from vatcomply.models import Rate + + +class Command(BaseCommand): + help = "Load ECB rates" + + def add_arguments(self, parser): + parser.add_argument( + "--last-90-days", + action="store_true", + help="Load rates for last 90 days", + ) + + def handle(self, *args, **options): + self.stdout.write("Loading rates...") + + last_90_days = True if options["last_90_days"] else False + r = httpx.get( + settings.RATES_LAST_90_DAYS_URL if last_90_days else settings.RATES_URL + ) + envelope = ElementTree.fromstring(r.content) + + namespaces = { + "gesmes": "http://www.gesmes.org/xml/2002-08-01", + "eurofxref": "http://www.ecb.int/vocabulary/2002-08-01/eurofxref", + } + data = envelope.findall("./eurofxref:Cube/eurofxref:Cube[@time]", namespaces) + for i, d in enumerate(data): + time = pendulum.parse(d.attrib["time"], strict=False) + try: + Rate.objects.create( + date=time, + rates={ + str(c.attrib["currency"]): float(c.attrib["rate"]) + for c in list(d) + }, + ) + except IntegrityError: + self.stdout.write(f"Rate for {time} already exists, skipping...") + + self.stdout.write("Loading rates finished!") diff --git a/vatcomply/middleware.py b/vatcomply/middleware.py new file mode 100644 index 0000000..e0a0f2f --- /dev/null +++ b/vatcomply/middleware.py @@ -0,0 +1,42 @@ +import fcntl + +from apscheduler.schedulers.asyncio import AsyncIOScheduler +from django.conf import settings +from django.core.management import call_command + + +# ASGI middleware to run periodic tasks in the background +class BackgroundTasksMiddleware: + def __init__(self, app): + self.app = app + + async def __call__(self, scope, receive, send): + if settings.BACKGROUND_SCHEDULER and scope["type"] == "lifespan": + message = await receive() + if message["type"] == "lifespan.startup": + # Schedule rates updates and load initial data + try: + # Lock file to prevent multiple instances of scheduler + _ = open("scheduler.lock", "w") + fcntl.lockf(_.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB) + + scheduler = AsyncIOScheduler() + scheduler.start() + + # Fill up database with Countries + scheduler.add_job(lambda: call_command("load_countries")) + + # Periodically updates lates 90 days Rates data + scheduler.add_job( + lambda: call_command("load_rates", last_90_days=True), + "interval", + hours=1, + minutes=10, + ) + + # Fill up database with Rates + scheduler.add_job(lambda: call_command("load_rates")) + except BlockingIOError as e: + print(e) + + return await self.app(scope, receive, send) diff --git a/vatcomply/migrations/0001_initial.py b/vatcomply/migrations/0001_initial.py new file mode 100644 index 0000000..82017c5 --- /dev/null +++ b/vatcomply/migrations/0001_initial.py @@ -0,0 +1,62 @@ +# Generated by Django 4.2.8 on 2023-12-19 16:36 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + initial = True + + dependencies = [] + + operations = [ + migrations.CreateModel( + name="Country", + fields=[ + ( + "id", + models.BigAutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ("iso2", models.CharField(max_length=2, unique=True)), + ("iso3", models.CharField(max_length=3)), + ("name", models.CharField(max_length=200)), + ("numeric_code", models.IntegerField()), + ("phone_code", models.CharField(max_length=3)), + ("capital", models.CharField(max_length=200)), + ("currency", models.CharField(max_length=200)), + ("tld", models.CharField(max_length=200)), + ("region", models.CharField(max_length=200)), + ("subregion", models.CharField(max_length=200)), + ("latitude", models.FloatField()), + ("longitude", models.FloatField()), + ("emoji", models.CharField(max_length=1)), + ], + options={ + "ordering": ["name"], + }, + ), + migrations.CreateModel( + name="Rate", + fields=[ + ( + "date", + models.DateField( + db_index=True, + editable=False, + primary_key=True, + serialize=False, + unique=True, + ), + ), + ("rates", models.JSONField(default=dict)), + ], + options={ + "ordering": ["-date"], + }, + ), + ] diff --git a/vatcomply/migrations/__init__.py b/vatcomply/migrations/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vatcomply/models.py b/vatcomply/models.py new file mode 100644 index 0000000..fd9e932 --- /dev/null +++ b/vatcomply/models.py @@ -0,0 +1,30 @@ +from django.db import models + + +class Rate(models.Model): + date = models.DateField( + unique=True, primary_key=True, editable=False, db_index=True + ) + rates = models.JSONField(default=dict) + + class Meta: + ordering = ["-date"] + + +class Country(models.Model): + iso2 = models.CharField(max_length=2, unique=True) + iso3 = models.CharField(max_length=3) + name = models.CharField(max_length=200) + numeric_code = models.IntegerField() + phone_code = models.CharField(max_length=3) + capital = models.CharField(max_length=200) + currency = models.CharField(max_length=200) + tld = models.CharField(max_length=200) + region = models.CharField(max_length=200) + subregion = models.CharField(max_length=200) + latitude = models.FloatField() + longitude = models.FloatField() + emoji = models.CharField(max_length=1) + + class Meta: + ordering = ["name"] diff --git a/vatcomply/serializers.py b/vatcomply/serializers.py new file mode 100644 index 0000000..0773f04 --- /dev/null +++ b/vatcomply/serializers.py @@ -0,0 +1,48 @@ +import re + +from datetime import date as dtdate +from typing import Optional + +from django.conf import settings +from pydantic import BaseModel, field_validator + + +class VATValidationModel(BaseModel): + vat_number: str + + @field_validator("vat_number") + def format_validation(cls, vat_number): + pattern = r"^[A-Z]{2}[\dA-Z]{8,12}$" + if not re.match(pattern, vat_number): + raise ValueError( + "Invalid VAT number format. Expected format: Two-letter country code followed by 8-12 digits or letters." + ) + + # Brexit + if vat_number.startswith("GB"): + raise ValueError( + "As of 01/01/2021, the VoW service to validate UK (GB) VAT numbers ceased to exist while a new service to validate VAT numbers of businesses operating under the Protocol on Ireland and Northern Ireland appeared. These VAT numbers are starting with the “XI” prefix." + ) + return vat_number + + +class RatesQueryValidationModel(BaseModel): + base: str = "EUR" + date: Optional[dtdate] = dtdate.today() + symbols: Optional[list] = [] + + @field_validator("base") + @classmethod + def base_validation(cls, base: str): + if base not in list(settings.CURRENCY_SYMBOLS): + raise ValueError(f"Base currency {base} is not supported.") + return base + + @field_validator("symbols", mode="before") + @classmethod + def symbols_validation(cls, symbols: list): + symbols = symbols.split(",") + diff = list(set(symbols) - set(list(settings.CURRENCY_SYMBOLS))) + if diff: + raise ValueError(f"Symbols {', '.join(diff)} are not supported.") + return symbols diff --git a/vatcomply/settings.py b/vatcomply/settings.py new file mode 100644 index 0000000..6d5a8c3 --- /dev/null +++ b/vatcomply/settings.py @@ -0,0 +1,176 @@ +""" +Django settings for VATcomply project. + +Generated by 'django-admin startproject' using Django 5.0. +""" + +import os +from pathlib import Path + +import sentry_sdk + +# Build paths inside the project like this: BASE_DIR / 'subdir'. +BASE_DIR = Path(__file__).resolve().parent.parent + + +# Quick-start development settings - unsuitable for production +# See https://docs.djangoproject.com/en/5.0/howto/deployment/checklist/ + +# SECURITY WARNING: keep the secret key used in production secret! +SECRET_KEY = os.getenv( + "SECRET_KEY", "django-insecure-b(bve3$7_h%f15mxx1nmsb#6i(l+u2kdd^+)b(%ce5@1kucpyy" +) + +# SECURITY WARNING: don't run with debug turned on in production! +DEBUG = os.getenv("DEBUG", False) + +if os.getenv("ALLOWED_HOSTS"): + ALLOWED_HOSTS = os.getenv("ALLOWED_HOSTS").split(",") + +# CORS +CORS_ALLOW_ALL_ORIGINS = True + + +# Application definition +INSTALLED_APPS = [ + "django.contrib.admin", + "django.contrib.auth", + "django.contrib.contenttypes", + "django.contrib.sessions", + "django.contrib.messages", + "django.contrib.staticfiles", + "corsheaders", + "vatcomply", +] + +MIDDLEWARE = [ + "corsheaders.middleware.CorsMiddleware", + "django.middleware.security.SecurityMiddleware", + "django.contrib.sessions.middleware.SessionMiddleware", + "django.middleware.common.CommonMiddleware", + "django.middleware.csrf.CsrfViewMiddleware", + "django.contrib.auth.middleware.AuthenticationMiddleware", + "django.contrib.messages.middleware.MessageMiddleware", + "django.middleware.clickjacking.XFrameOptionsMiddleware", +] + +ROOT_URLCONF = "vatcomply.urls" + +TEMPLATES = [ + { + "BACKEND": "django.template.backends.django.DjangoTemplates", + "DIRS": [], + "APP_DIRS": True, + "OPTIONS": { + "context_processors": [ + "django.template.context_processors.debug", + "django.template.context_processors.request", + "django.contrib.auth.context_processors.auth", + "django.contrib.messages.context_processors.messages", + ], + }, + }, +] + +WSGI_APPLICATION = "vatcomply.wsgi.application" + + +# Database +# https://docs.djangoproject.com/en/5.0/ref/settings/#databases +DATABASES = { + "default": { + "ENGINE": "django.db.backends.sqlite3", + "NAME": BASE_DIR / "db.sqlite3", + } +} + + +# Password validation +# https://docs.djangoproject.com/en/5.0/ref/settings/#auth-password-validators +AUTH_PASSWORD_VALIDATORS = [ + { + "NAME": "django.contrib.auth.password_validation.UserAttributeSimilarityValidator", + }, + { + "NAME": "django.contrib.auth.password_validation.MinimumLengthValidator", + }, + { + "NAME": "django.contrib.auth.password_validation.CommonPasswordValidator", + }, + { + "NAME": "django.contrib.auth.password_validation.NumericPasswordValidator", + }, +] + + +# Internationalization +# https://docs.djangoproject.com/en/5.0/topics/i18n/ +LANGUAGE_CODE = "en-us" +TIME_ZONE = "UTC" +USE_I18N = True +USE_TZ = True + + +# Static files (CSS, JavaScript, Images) +# https://docs.djangoproject.com/en/5.0/howto/static-files/ +STATIC_URL = "static/" + +# Default primary key field type +# https://docs.djangoproject.com/en/5.0/ref/settings/#default-auto-field +DEFAULT_AUTO_FIELD = "django.db.models.BigAutoField" + +# Countries +COUNTRIES_URL = "https://raw.githubusercontent.com/dr5hn/countries-states-cities-database/master/countries.json" + +# ECB +RATES_LAST_90_DAYS_URL = ( + "https://www.ecb.europa.eu/stats/eurofxref/eurofxref-hist-90d.xml" +) +RATES_URL = "https://www.ecb.europa.eu/stats/eurofxref/eurofxref-hist.xml" +CURRENCY_SYMBOLS = [ + "EUR", + "USD", + "JPY", + "BGN", + "CZK", + "DKK", + "GBP", + "HUF", + "PLN", + "RON", + "SEK", + "CHF", + "ISK", + "NOK", + "HRK", + "RUB", + "TRY", + "AUD", + "BRL", + "CAD", + "CNY", + "HKD", + "IDR", + "ILS", + "INR", + "KRW", + "MXN", + "MYR", + "NZD", + "PHP", + "SGD", + "THB", + "ZAR", +] + +# VIES +VIES_WSDL = "https://ec.europa.eu/taxation_customs/vies/checkVatService.wsdl" + +# Background scheduler +BACKGROUND_SCHEDULER = os.getenv("BACKGROUND_SCHEDULER", False) + +# Sentry +sentry_sdk.init( + dsn=os.getenv("SENTRY_DSN"), + enable_tracing=False, +) diff --git a/vatcomply/tests/__init__.py b/vatcomply/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vatcomply/tests/test_countries.py b/vatcomply/tests/test_countries.py new file mode 100644 index 0000000..fbde633 --- /dev/null +++ b/vatcomply/tests/test_countries.py @@ -0,0 +1,16 @@ +from django.test import TestCase +from django.core.management import call_command + +from vatcomply.models import Country + + +class CountriesTest(TestCase): + @classmethod + def setUpTestData(cls): + call_command("load_countries") + + def test_countries_api(self): + response = self.client.get("/countries") + self.assertEqual(response.status_code, 200) + self.assertIsInstance(response.json(), list) + self.assertEqual(len(response.json()), Country.objects.count()) diff --git a/vatcomply/tests/test_rates.py b/vatcomply/tests/test_rates.py new file mode 100644 index 0000000..706fc32 --- /dev/null +++ b/vatcomply/tests/test_rates.py @@ -0,0 +1,73 @@ +from django.test import TestCase +from django.core.management import call_command + + +class VATTest(TestCase): + @classmethod + def setUpTestData(cls): + call_command("load_rates") + + def test_latest_api(self): + response = self.client.get("/rates") + self.assertEqual(response.status_code, 200) + self.assertIsInstance(response.json(), dict) + self.assertIn("date", response.json()) + self.assertIn("base", response.json()) + self.assertEqual(response.json()["base"], "EUR") + self.assertIn("rates", response.json()) + self.assertEqual(len(response.json()["rates"]), 31) + + def test_date_api(self): + response = self.client.get("/rates?date=2018-10-12") + self.assertEqual(response.status_code, 200) + self.assertIsInstance(response.json(), dict) + self.assertIn("date", response.json()) + self.assertEqual(response.json()["date"], "2018-10-12") + self.assertIn("base", response.json()) + self.assertEqual(response.json()["base"], "EUR") + self.assertIn("rates", response.json()) + self.assertEqual(len(response.json()["rates"]), 33) + + def test_invalid_date_api(self): + response = self.client.get("/rates?date=abc") + self.assertEqual(response.status_code, 400) + self.assertIsInstance(response.json(), dict) + + def test_date_weekend_api(self): + response = self.client.get("/rates?date=2018-10-13") + self.assertEqual(response.status_code, 200) + self.assertIsInstance(response.json(), dict) + self.assertIn("date", response.json()) + self.assertEqual(response.json()["date"], "2018-10-12") + self.assertIn("base", response.json()) + self.assertEqual(response.json()["base"], "EUR") + self.assertIn("rates", response.json()) + self.assertEqual(len(response.json()["rates"]), 33) + + def test_base_api(self): + response = self.client.get("/rates?base=USD") + self.assertEqual(response.status_code, 200) + self.assertIsInstance(response.json(), dict) + self.assertIn("date", response.json()) + self.assertIn("base", response.json()) + self.assertEqual(response.json()["base"], "USD") + self.assertIn("rates", response.json()) + self.assertEqual(len(response.json()["rates"]), 31) + self.assertEqual(response.json()["rates"]["USD"], 1) + + def test_symbols_api(self): + response = self.client.get("/rates?symbols=USD,JPY,GBP") + self.assertEqual(response.status_code, 200) + self.assertIsInstance(response.json(), dict) + self.assertIn("date", response.json()) + self.assertIn("base", response.json()) + self.assertEqual(response.json()["base"], "EUR") + self.assertIn("rates", response.json()) + self.assertEqual(len(response.json()["rates"]), 3) + + def test_invalid_symbols_api(self): + response = self.client.get("/rates?symbols=12345") + self.assertEqual(response.status_code, 400) + self.assertIsInstance(response.json(), dict) + self.assertIn("symbols", response.json()) + self.assertIsInstance(response.json()["symbols"], list) diff --git a/vatcomply/tests/test_vat.py b/vatcomply/tests/test_vat.py new file mode 100644 index 0000000..e8e0c6e --- /dev/null +++ b/vatcomply/tests/test_vat.py @@ -0,0 +1,31 @@ +from django.test import TestCase + + +class VATTest(TestCase): + def test_vat_api(self): + response = self.client.get("/vat?vat_number=EE101600930") + self.assertEqual(response.status_code, 200) + self.assertIsInstance(response.json(), dict) + + def test_vat_blank_api(self): + response = self.client.get("/vat?vat_number=") + self.assertEqual(response.status_code, 400) + self.assertIsInstance(response.json(), dict) + + def test_vat_invalid_format_api(self): + response = self.client.get("/vat?vat_number=123") + self.assertEqual(response.status_code, 400) + self.assertIsInstance(response.json(), dict) + self.assertEqual( + response.json()["vat_number"][0], + "Value error, Invalid VAT number format. Expected format: Two-letter country code followed by 8-12 digits or letters.", + ) + + def test_vat_brexit_api(self): + response = self.client.get("/vat?vat_number=GB123456789") + self.assertEqual(response.status_code, 400) + self.assertIsInstance(response.json(), dict) + self.assertEqual( + response.json()["vat_number"][0], + "Value error, As of 01/01/2021, the VoW service to validate UK (GB) VAT numbers ceased to exist while a new service to validate VAT numbers of businesses operating under the Protocol on Ireland and Northern Ireland appeared. These VAT numbers are starting with the “XI” prefix.", + ) diff --git a/vatcomply/urls.py b/vatcomply/urls.py new file mode 100644 index 0000000..0e8a557 --- /dev/null +++ b/vatcomply/urls.py @@ -0,0 +1,14 @@ +# from django.contrib import admin +from django.urls import path + +from vatcomply.views import CountriesView, GeolocateView, RatesView, VATView + + +urlpatterns = [ + # path("admin/", admin.site.urls), + # API + path("countries", CountriesView.as_view()), + path("geolocate", GeolocateView.as_view()), + path("vat", VATView.as_view()), + path("rates", RatesView.as_view()), +] diff --git a/vatcomply/views.py b/vatcomply/views.py new file mode 100644 index 0000000..29e0d70 --- /dev/null +++ b/vatcomply/views.py @@ -0,0 +1,156 @@ +from decimal import Decimal + +import pendulum +import zeep +from django.conf import settings +from django.utils.decorators import method_decorator +from django.views import View +from django.views.decorators.csrf import csrf_exempt +from pydantic import ValidationError + +from vatcomply.http import UJsonResponse as JsonResponse +from vatcomply.models import Country, Rate +from vatcomply.serializers import RatesQueryValidationModel, VATValidationModel + + +@method_decorator(csrf_exempt, name="dispatch") +class CountriesView(View): + async def get(self, request): + countries = [] + async for country in Country.objects.order_by("iso2").all(): + countries.append( + { + "iso2": country.iso2, + "iso3": country.iso3, + "name": country.name, + "numeric_code": country.numeric_code, + "phone_code": country.phone_code, + "capital": country.capital, + "currency": country.currency, + "tld": country.tld, + "region": country.region, + "subregion": country.subregion, + "latitude": Decimal(country.latitude), + "longitude": Decimal(country.longitude), + "emoji": country.emoji, + } + ) + + return JsonResponse(countries) + + +@method_decorator(csrf_exempt, name="dispatch") +class GeolocateView(View): + async def get(self, request): + country_code = request.headers.get("CF-IPCountry") + ip = request.headers.get("CF-Connecting-IP") + + country_code = "EE" + if not country_code: + return JsonResponse( + { + "error": "Country code not received from CloudFlare headers `CF-IPCountry`." + }, + status=404, + ) + + # Get the country data + try: + record = await Country.objects.aget(iso2=country_code.upper()) + except Country.DoesNotExist: + return JsonResponse( + {"error": f"Data for country code `{country_code.upper()}` not found."}, + status=404, + ) + + return JsonResponse( + { + "iso2": record.iso2, + "iso3": record.iso3, + "country_code": country_code.upper(), + "name": record.name, + "numeric_code": record.numeric_code, + "phone_code": record.phone_code, + "capital": record.capital, + "currency": record.currency, + "tld": record.tld, + "region": record.region, + "subregion": record.subregion, + "latitude": record.latitude, + "longitude": record.longitude, + "emoji": record.emoji, + "ip": ip, + }, + ) + + +@method_decorator(csrf_exempt, name="dispatch") +class VATView(View): + async def get(self, request): + try: + query = VATValidationModel(**request.GET.dict()) + client = zeep.AsyncClient(wsdl=str(settings.VIES_WSDL)) + try: + response = await zeep.helpers.serialize_object( + client.service.checkVat( + countryCode=query.vat_number[:2], vatNumber=query.vat_number[2:] + ) + ) + except zeep.exceptions.Fault as e: + return JsonResponse({"error": e.message}, status=400) + + return JsonResponse( + { + "valid": response["valid"], + "vat_number": response["vatNumber"], + "name": response["name"], + "address": response["address"].strip() + if response["address"] + else "", + "country_code": response["countryCode"], + } + ) + except ValidationError as e: + return JsonResponse(e, status=400) + + +@method_decorator(csrf_exempt, name="dispatch") +class RatesView(View): + async def get(self, request): + query_params = request.GET.dict() + if "date" in request.GET: + query_params["date"] = request.GET["date"] + + try: + query = RatesQueryValidationModel(**query_params) + + # Find the date + date = query.date if query.date else pendulum.now().date() + + # Get the rates data + record = ( + await Rate.objects.filter(date__lte=date).order_by("-date").afirst() + ) + + # Base re-calculation + rates = {"EUR": 1} + rates.update(record.rates) + if query.base and query.base != "EUR": + base_rate = Decimal(record.rates[query.base]) + rates = { + currency: Decimal(rate) / base_rate + for currency, rate in rates.items() + } + rates.update({"EUR": Decimal(1) / base_rate}) + + # Symbols + if query.symbols: + for rate in list(rates): + if rate not in query.symbols: + del rates[rate] + + return JsonResponse( + {"date": record.date.isoformat(), "base": query.base, "rates": rates} + ) + except ValidationError as e: + return JsonResponse(e, status=400) diff --git a/vatcomply/wsgi.py b/vatcomply/wsgi.py new file mode 100644 index 0000000..3f21b30 --- /dev/null +++ b/vatcomply/wsgi.py @@ -0,0 +1,16 @@ +""" +WSGI config for VATcomply project. + +It exposes the WSGI callable as a module-level variable named ``application``. + +For more information on this file, see +https://docs.djangoproject.com/en/5.0/howto/deployment/wsgi/ +""" + +import os + +from django.core.wsgi import get_wsgi_application + +os.environ.setdefault("DJANGO_SETTINGS_MODULE", "vatcomply.settings") + +application = get_wsgi_application()