Skip to content
Snippets Groups Projects
Commit d4e77a8c authored by BARBIER Jean-Matthieu's avatar BARBIER Jean-Matthieu
Browse files

fix(dynamic): add on_request_default_fields for DynamicAPIMixin

parent 26724b4d
Branches
No related tags found
No related merge requests found
"""
Mixin to dynamically select only a subset of fields per DRF resource.
"""
import logging
import warnings
from typing import Pattern, Dict, TypedDict, List
from django.db.models import QuerySet
from rest_framework.fields import empty
logger = logging.getLogger(__name__)
class OnRequestParams(TypedDict, total=False):
related: List[str]
......@@ -15,6 +18,7 @@ class OnRequestParams(TypedDict, total=False):
class DynamicFieldsAPIViewMixin(object):
on_request_fields: Dict[str, OnRequestParams] = {}
on_request_default_fields: List[str] = []
def _process_related(self, related, field):
if callable(related):
......@@ -42,19 +46,26 @@ class DynamicFieldsAPIViewMixin(object):
:return: list of related fields
"""
related_fields = getattr(self, "on_request_fields", dict())
default_fields = ",".join(getattr(self, "on_request_default_fields",
list()))
regex_fields = set(filter(lambda f: isinstance(f, Pattern),
related_fields.keys()))
out = list()
fields = self.request.query_params.get("fields", "").split(",")
fields = self.request.query_params \
.get("fields", default_fields) \
.split(",")
for f in fields:
out += self._process_related(
related_fields.get(f, dict()).get("related", list()), f)
related_fields.get(f, dict())
.get("related", list()), f)
for rg in regex_fields:
if rg.match(f):
out += self._process_related(
related_fields.get(rg, dict()).get("related", list()),
f)
return list(set(out))
related_fields.get(rg, dict())
.get("related", list()), f)
related = list(set(out))
logger.debug("Related fields : %s", related)
return related
def get_prefetch(self):
"""
......@@ -69,34 +80,46 @@ class DynamicFieldsAPIViewMixin(object):
"""
related_fields = getattr(self, "on_request_fields", dict())
default_fields = ",".join(getattr(self, "on_request_default_fields",
list()))
regex_fields = set(filter(lambda f: isinstance(f, Pattern),
related_fields.keys()))
out = list()
fields = self.request.query_params.get("fields", "").split(",")
fields = self.request.query_params \
.get("fields", default_fields) \
.split(",")
for f in fields:
out += self._process_related(
related_fields.get(f, dict()).get("prefetch", list()), f)
related_fields.get(f, dict())
.get("prefetch", list()), f)
for rg in regex_fields:
if rg.match(f):
out += self._process_related(
related_fields.get(rg, dict()).get("prefetch", list()),
f)
return list(set(out))
related_fields.get(rg, dict())
.get("prefetch", list()), f)
prefetch = list(set(out))
logger.debug("Prefetched fields : %s", prefetch)
return prefetch
def get_annotations(self, qs: QuerySet):
related_fields = getattr(self, "on_request_fields", dict())
default_fields = ",".join(getattr(self, "on_request_default_fields",
list()))
regex_fields = set(filter(lambda f: isinstance(f, Pattern),
related_fields.keys()))
out = dict()
fields = self.request.query_params.get("fields", "").split(",")
fields = self.request.query_params \
.get("fields", default_fields) \
.split(",")
for f in fields:
out |= self._process_annotation(
related_fields.get(f, dict()).get("annotate", dict()), f)
related_fields.get(f, dict())
.get("annotate", dict()), f)
for rg in regex_fields:
if rg.match(f):
out |= self._process_annotation(
related_fields.get(rg, dict()).get("annotate", dict()),
f)
related_fields.get(rg, dict())
.get("annotate", dict()), f)
return qs.annotate(**out)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment