Skip to content

Commit

Permalink
Closed database connections.
Browse files Browse the repository at this point in the history
eb16ed0 removed (among a lot of other
things) a call to close_old_connections(), mostly because I hadn't
understood what that was doing exactly.
Since then I've been kept awake at night with the sense of dread that
old database connections were piling up unclosed on Django's server.

So I wrote this wsgi middleware that should hopefully restore balance
in the universe, or at least close db connections.
  • Loading branch information
bmispelon authored Feb 14, 2024
1 parent 1e87f0f commit c845c54
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 1 deletion.
2 changes: 2 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ jobs:
uses: actions/checkout@v4
- name: Install requirements
run: python -m pip install -r requirements.txt
- name: Install backport of unittest.mock
run: python -m pip install mock
- name: Run tests
run: python -m django test tracdjangoplugin.tests
env:
Expand Down
22 changes: 22 additions & 0 deletions DjangoPlugin/tracdjangoplugin/middlewares.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from django.core.signals import request_finished, request_started


class DjangoDBManagementMiddleware:
"""
A simple WSGI middleware that manually manages opening/closing db connections.
Django normally does that as part of its own middleware chain, but we're using Trac's middleware
so we must do this by hand.
This hopefully prevents open connections from piling up.
"""

def __init__(self, application):
self.application = application

def __call__(self, environ, start_response):
request_started.send(sender=self.__class__)
try:
for data in self.application(environ, start_response):
yield data
finally:
request_finished.send(sender=self.__class__)
59 changes: 58 additions & 1 deletion DjangoPlugin/tracdjangoplugin/tests.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
from functools import partial

try:
from unittest.mock import Mock
except ImportError:
from mock import Mock

from django.core.signals import request_finished, request_started
from django.contrib.auth.forms import AuthenticationForm
from django.contrib.auth.models import User
from django.test import TestCase
from django.test import SimpleTestCase, TestCase

from trac.test import EnvironmentStub, MockRequest
from trac.web.api import RequestDone
from trac.web.main import RequestDispatcher

from tracdjangoplugin.middlewares import DjangoDBManagementMiddleware
from tracdjangoplugin.plugins import PlainLoginComponent


Expand Down Expand Up @@ -127,3 +134,53 @@ def test_login_invalid_username_uppercased(self):
def test_login_invalid_inactive_user(self):
User.objects.create_user(username="test", password="test", is_active=False)
self.assertLoginFails(username="test", password="test")


class DjangoDBManagementMiddlewareTestCase(SimpleTestCase):
@classmethod
def setUpClass(cls):
# Remove receivers from the request_started and request_finished signals,
# replacing them with a mock object so we can still check if they were called.
super(DjangoDBManagementMiddlewareTestCase, cls).setUpClass()
cls._original_signal_receivers = {}
cls.signals = {}
for signal in [request_started, request_finished]:
cls.signals[signal] = Mock()
cls._original_signal_receivers[signal] = signal.receivers
signal.receivers = []
signal.connect(cls.signals[signal])

@classmethod
def tearDownClass(cls):
# Restore the signals we modified in setUpClass() to what they were before
super(DjangoDBManagementMiddlewareTestCase, cls).tearDownClass()
for signal, original_receivers in cls._original_signal_receivers.items():
# messing about with receivers directly is not an official API, so we need to
# call some undocumented methods to make sure caches and such are taken care of.
with signal.lock:
signal.receivers = original_receivers
signal._clear_dead_receivers()
signal.sender_receivers_cache.clear()

def setUp(self):
super(DjangoDBManagementMiddlewareTestCase, self).setUp()
for mockobj in self.signals.values():
mockobj.reset_mock()

def test_request_start_fired(self):
app = DjangoDBManagementMiddleware(lambda environ, start_response: [b"test"])
output = b"".join(app(None, None))
self.assertEqual(output, b"test")
self.signals[request_started].assert_called_once()

def test_request_finished_fired(self):
app = DjangoDBManagementMiddleware(lambda environ, start_response: [b"test"])
output = b"".join(app(None, None))
self.assertEqual(output, b"test")
self.signals[request_finished].assert_called_once()

def test_request_finished_fired_even_with_error(self):
app = DjangoDBManagementMiddleware(lambda environ, start_response: [1 / 0])
with self.assertRaises(ZeroDivisionError):
list(app(None, None))
self.signals[request_finished].assert_called_once()
5 changes: 5 additions & 0 deletions DjangoPlugin/tracdjangoplugin/wsgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@
# Python 3 would perform better here, but we are still on 2.7 for Trac, so leak fds for now.
from tracopt.versioncontrol.git import PyGIT

from .middlewares import DjangoDBManagementMiddleware


application = DjangoDBManagementMiddleware(application)

PyGIT.close_fds = False

trac_dsn = os.getenv("SENTRY_DSN")
Expand Down

0 comments on commit c845c54

Please sign in to comment.