diff --git a/.gitignore b/.gitignore
index a9c50ad..22c1e96 100644
--- a/.gitignore
+++ b/.gitignore
@@ -5,7 +5,9 @@ local-dev.txt
dist/
MANIFEST
.mypy_cache/
+.idea/
.vscode/
drf_flex_fields.egg-info/
venv.sh
-.venv
\ No newline at end of file
+.venv
+venv/
diff --git a/README.md b/README.md
index 1f8b23f..e7725d5 100644
--- a/README.md
+++ b/README.md
@@ -483,12 +483,14 @@ class PersonSerializer(FlexFieldsModelSerializer):
Parameter names and wildcard values can be configured within a Django setting, named `REST_FLEX_FIELDS`.
-| Option | Description | Default |
-| --------------- | :----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | --------------- |
-| EXPAND_PARAM | The name of the parameter with the fields to be expanded | `"expand"` |
-| FIELDS_PARAM | The name of the parameter with the fields to be included (others will be omitted) | `"fields"` |
-| OMIT_PARAM | The name of the parameter with the fields to be omitted | `"omit"` |
-| WILDCARD_VALUES | List of values that stand in for all field names. Can be used with the `fields` and `expand` parameters.
When used with `expand`, a wildcard value will trigger the expansion of all `expandable_fields` at a given level.
When used with `fields`, all fields are included at a given level. For example, you could pass `fields=name,state.*` if you have a city resource with a nested state in order to expand only the city's name field and all of the state's fields.
To disable use of wildcards, set this setting to `None`. | `["*", "~all"]` |
+| Option | Description | Default |
+|-------------------------------|:------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:|-----------------|
+| EXPAND_PARAM | The name of the parameter with the fields to be expanded | `"expand"` |
+| MAXIMUM_EXPANSION_DEPTH | The number of maximum depth permitted expansion | `None` |
+| FIELDS_PARAM | The name of the parameter with the fields to be included (others will be omitted) | `"fields"` |
+| OMIT_PARAM | The name of the parameter with the fields to be omitted | `"omit"` |
+| RECURSIVE_EXPANSION_PERMITTED | If `False`, an exception is raised when a recursive pattern is found | `True` |
+| WILDCARD_VALUES | List of values that stand in for all field names. Can be used with the `fields` and `expand` parameters.
When used with `expand`, a wildcard value will trigger the expansion of all `expandable_fields` at a given level.
When used with `fields`, all fields are included at a given level. For example, you could pass `fields=name,state.*` if you have a city resource with a nested state in order to expand only the city's name field and all of the state's fields.
To disable use of wildcards, set this setting to `None`. | `["*", "~all"]` |
For example, if you want your API to work a bit more like [JSON API](https://jsonapi.org/format/#fetching-includes), you could do:
@@ -496,6 +498,15 @@ For example, if you want your API to work a bit more like [JSON API](https://jso
REST_FLEX_FIELDS = {"EXPAND_PARAM": "include"}
```
+### Defining expansion and recursive limits at serializer level
+
+`maximum_expansion_depth` property can be overridden at serializer level. It can be configured as `int` or `None`.
+
+`recursive_expansion_permitted` property can be overridden at serializer level. It must be `bool`.
+
+Both settings raise `serializers.ValidationError` when conditions are met but exceptions can be overridden in `_recursive_expansion_found` and `_expansion_depth_exceeded` methods.
+
+
## Serializer Introspection
When using an instance of `FlexFieldsModelSerializer`, you can examine the property `expanded_fields` to discover which fields, if any, have been dynamically expanded.
diff --git a/rest_flex_fields/__init__.py b/rest_flex_fields/__init__.py
index c24a953..108f477 100644
--- a/rest_flex_fields/__init__.py
+++ b/rest_flex_fields/__init__.py
@@ -5,6 +5,8 @@
EXPAND_PARAM = FLEX_FIELDS_OPTIONS.get("EXPAND_PARAM", "expand")
FIELDS_PARAM = FLEX_FIELDS_OPTIONS.get("FIELDS_PARAM", "fields")
OMIT_PARAM = FLEX_FIELDS_OPTIONS.get("OMIT_PARAM", "omit")
+MAXIMUM_EXPANSION_DEPTH = FLEX_FIELDS_OPTIONS.get("MAXIMUM_EXPANSION_DEPTH", None)
+RECURSIVE_EXPANSION_PERMITTED = FLEX_FIELDS_OPTIONS.get("RECURSIVE_EXPANSION_PERMITTED", True)
WILDCARD_ALL = "~all"
WILDCARD_ASTERISK = "*"
@@ -20,9 +22,12 @@
assert isinstance(FIELDS_PARAM, str), "'FIELDS_PARAM' should be a string"
assert isinstance(OMIT_PARAM, str), "'OMIT_PARAM' should be a string"
-if type(WILDCARD_VALUES) not in (list, None):
+if type(WILDCARD_VALUES) not in (list, type(None)):
raise ValueError("'WILDCARD_EXPAND_VALUES' should be a list of strings or None")
-
+if type(MAXIMUM_EXPANSION_DEPTH) not in (int, type(None)):
+ raise ValueError("'MAXIMUM_EXPANSION_DEPTH' should be a int or None")
+if type(RECURSIVE_EXPANSION_PERMITTED) is not bool:
+ raise ValueError("'RECURSIVE_EXPANSION_PERMITTED' should be a bool")
from .utils import *
from .serializers import FlexFieldsModelSerializer
diff --git a/rest_flex_fields/serializers.py b/rest_flex_fields/serializers.py
index bce1dcb..ea9a1a0 100644
--- a/rest_flex_fields/serializers.py
+++ b/rest_flex_fields/serializers.py
@@ -2,6 +2,7 @@
import importlib
from typing import List, Optional, Tuple
+from django.conf import settings
from rest_framework import serializers
from rest_flex_fields import (
@@ -22,6 +23,8 @@ class FlexFieldsSerializerMixin(object):
"""
expandable_fields = {}
+ maximum_expansion_depth: Optional[int] = None
+ recursive_expansion_permitted: Optional[bool] = None
def __init__(self, *args, **kwargs):
expand = list(kwargs.pop(EXPAND_PARAM, []))
@@ -58,6 +61,21 @@ def __init__(self, *args, **kwargs):
+ self._flex_options_rep_only["omit"],
}
+ def get_maximum_expansion_depth(self) -> Optional[int]:
+ """
+ Defined at serializer level or based on MAXIMUM_EXPANSION_DEPTH setting
+ """
+ return self.maximum_expansion_depth or settings.REST_FLEX_FIELDS.get("MAXIMUM_EXPANSION_DEPTH", None)
+
+ def get_recursive_expansion_permitted(self) -> bool:
+ """
+ Defined at serializer level or based on RECURSIVE_EXPANSION_PERMITTED setting
+ """
+ if self.recursive_expansion_permitted is not None:
+ return self.recursive_expansion_permitted
+ else:
+ return settings.REST_FLEX_FIELDS.get("RECURSIVE_EXPANSION_PERMITTED", True)
+
def to_representation(self, instance):
if not self._flex_fields_rep_applied:
self.apply_flex_fields(self.fields, self._flex_options_rep_only)
@@ -264,11 +282,63 @@ def _get_query_param_value(self, field: str) -> List[str]:
if not values:
values = self.context["request"].query_params.getlist("{}[]".format(field))
+ for expand_path in values:
+ self._validate_recursive_expansion(expand_path)
+ self._validate_expansion_depth(expand_path)
+
if values and len(values) == 1:
return values[0].split(",")
return values or []
+ def _split_expand_field(self, expand_path: str) -> List[str]:
+ return expand_path.split('.')
+
+ def recursive_expansion_not_permitted(self):
+ """
+ A customized exception can be raised when recursive expansion is found, default ValidationError
+ """
+ raise serializers.ValidationError(detail="Recursive expansion found")
+
+ def _validate_recursive_expansion(self, expand_path: str) -> None:
+ """
+ Given an expand_path, a dotted-separated string,
+ an Exception is raised when a recursive
+ expansion is detected.
+ Only applies when REST_FLEX_FIELDS["RECURSIVE_EXPANSION"] setting is False.
+ """
+ recursive_expansion_permitted = self.get_recursive_expansion_permitted()
+ if recursive_expansion_permitted is True:
+ return
+
+ expansion_path = self._split_expand_field(expand_path)
+ expansion_length = len(expansion_path)
+ expansion_length_unique = len(set(expansion_path))
+ if expansion_length != expansion_length_unique:
+ self.recursive_expansion_not_permitted()
+
+ def expansion_depth_exceeded(self):
+ """
+ A customized exception can be raised when expansion depth is found, default ValidationError
+ """
+ raise serializers.ValidationError(detail="Expansion depth exceeded")
+
+ def _validate_expansion_depth(self, expand_path: str) -> None:
+ """
+ Given an expand_path, a dotted-separated string,
+ an Exception is raised when expansion level is
+ greater than the `expansion_depth` property configuration.
+ Only applies when REST_FLEX_FIELDS["EXPANSION_DEPTH"] setting is set
+ or serializer has its own expansion configuration through default_expansion_depth attribute.
+ """
+ maximum_expansion_depth = self.get_maximum_expansion_depth()
+ if maximum_expansion_depth is None:
+ return
+
+ expansion_path = self._split_expand_field(expand_path)
+ if len(expansion_path) > maximum_expansion_depth:
+ self.expansion_depth_exceeded()
+
def _get_permitted_expands_from_query_param(self, expand_param: str) -> List[str]:
"""
If a list of permitted_expands has been passed to context,
diff --git a/tests/test_flex_fields_model_serializer.py b/tests/test_flex_fields_model_serializer.py
index b0a0551..27fea58 100644
--- a/tests/test_flex_fields_model_serializer.py
+++ b/tests/test_flex_fields_model_serializer.py
@@ -1,11 +1,17 @@
from unittest import TestCase
+from unittest.mock import patch, PropertyMock
+from django.test import override_settings
from django.utils.datastructures import MultiValueDict
+from rest_framework import serializers
+
from rest_flex_fields import FlexFieldsModelSerializer
class MockRequest(object):
- def __init__(self, query_params=MultiValueDict(), method="GET"):
+ def __init__(self, query_params=None, method="GET"):
+ if query_params is None:
+ query_params = MultiValueDict()
self.query_params = query_params
self.method = method
@@ -178,3 +184,73 @@ def test_import_serializer_class(self):
def test_make_expanded_field_serializer(self):
pass
+
+ @override_settings(REST_FLEX_FIELDS={"RECURSIVE_EXPANSION_PERMITTED": False})
+ def test_recursive_expansion(self):
+ with self.assertRaises(serializers.ValidationError):
+ FlexFieldsModelSerializer(
+ context={
+ "request": MockRequest(
+ method="GET", query_params=MultiValueDict({"expand": ["dog.leg.dog"]})
+ )
+ }
+ )
+
+ @patch('rest_flex_fields.FlexFieldsModelSerializer.recursive_expansion_permitted', new_callable=PropertyMock)
+ def test_recursive_expansion_serializer_level(self, mock_recursive_expansion_permitted):
+ mock_recursive_expansion_permitted.return_value = False
+
+ with self.assertRaises(serializers.ValidationError):
+ FlexFieldsModelSerializer(
+ context={
+ "request": MockRequest(
+ method="GET", query_params=MultiValueDict({"expand": ["dog.leg.dog"]})
+ )
+ }
+ )
+
+ @override_settings(REST_FLEX_FIELDS={"MAXIMUM_EXPANSION_DEPTH": 3})
+ def test_expansion_depth(self):
+ serializer = FlexFieldsModelSerializer(
+ context={
+ "request": MockRequest(
+ method="GET", query_params=MultiValueDict({"expand": ["dog.leg.paws"]})
+ )
+ }
+ )
+ self.assertEqual(serializer._flex_options_all["expand"], ["dog.leg.paws"])
+
+ @override_settings(REST_FLEX_FIELDS={"MAXIMUM_EXPANSION_DEPTH": 2})
+ def test_expansion_depth_exception(self):
+ with self.assertRaises(serializers.ValidationError):
+ FlexFieldsModelSerializer(
+ context={
+ "request": MockRequest(
+ method="GET", query_params=MultiValueDict({"expand": ["dog.leg.paws"]})
+ )
+ }
+ )
+
+ @patch('rest_flex_fields.FlexFieldsModelSerializer.maximum_expansion_depth', new_callable=PropertyMock)
+ def test_expansion_depth_serializer_level(self, mock_maximum_expansion_depth):
+ mock_maximum_expansion_depth.return_value = 3
+ serializer = FlexFieldsModelSerializer(
+ context={
+ "request": MockRequest(
+ method="GET", query_params=MultiValueDict({"expand": ["dog.leg.paws"]})
+ )
+ }
+ )
+ self.assertEqual(serializer._flex_options_all["expand"], ["dog.leg.paws"])
+
+ @patch('rest_flex_fields.FlexFieldsModelSerializer.maximum_expansion_depth', new_callable=PropertyMock)
+ def test_expansion_depth_serializer_level_exception(self, mock_maximum_expansion_depth):
+ mock_maximum_expansion_depth.return_value = 2
+ with self.assertRaises(serializers.ValidationError):
+ FlexFieldsModelSerializer(
+ context={
+ "request": MockRequest(
+ method="GET", query_params=MultiValueDict({"expand": ["dog.leg.paws"]})
+ )
+ }
+ )
diff --git a/tests/test_serializer.py b/tests/test_serializer.py
index 9f24df5..0f529c6 100644
--- a/tests/test_serializer.py
+++ b/tests/test_serializer.py
@@ -10,7 +10,9 @@
class MockRequest(object):
- def __init__(self, query_params={}, method="GET"):
+ def __init__(self, query_params=None, method="GET"):
+ if query_params is None:
+ query_params = {}
self.query_params = query_params
self.method = method