From 0f5d37fc7cf08a862f7e8615a94cd943809b630c Mon Sep 17 00:00:00 2001 From: smilerz Date: Wed, 27 Mar 2024 08:17:20 -0500 Subject: [PATCH] apply PK only update to NestedWritableSerializer --- cookbook/serializer.py | 44 +++++++++--- .../tests/other/test_nested_serializer.py | 69 +++++++++++++++++++ 2 files changed, 103 insertions(+), 10 deletions(-) create mode 100644 cookbook/tests/other/test_nested_serializer.py diff --git a/cookbook/serializer.py b/cookbook/serializer.py index f0090b4ba..e87284de9 100644 --- a/cookbook/serializer.py +++ b/cookbook/serializer.py @@ -5,6 +5,7 @@ from gettext import gettext as _ from html import escape from smtplib import SMTPException +from django.forms.models import model_to_dict from django.contrib.auth.models import AnonymousUser, Group, User from django.core.cache import caches from django.core.mail import send_mail @@ -13,7 +14,8 @@ from django.http import BadHeaderError from django.urls import reverse from django.utils import timezone from django_scopes import scopes_disabled -from drf_writable_nested import UniqueFieldsMixin, WritableNestedModelSerializer +from drf_writable_nested import UniqueFieldsMixin +from drf_writable_nested import WritableNestedModelSerializer as WNMS from oauth2_provider.models import AccessToken from PIL import Image from rest_framework import serializers @@ -38,6 +40,33 @@ from cookbook.templatetags.custom_tags import markdown from recipes.settings import AWS_ENABLED, MEDIA_URL +class WritableNestedModelSerializer(WNMS): + + # overload to_internal_value to allow using PK only on nested object + def to_internal_value(self, data): + # iterate through every field on the posted object + for f in list(data): + if f not in self.fields: + continue + elif issubclass(self.fields[f].__class__, serializers.Serializer): + # if the field is a serializer and an integer, assume its an ID of an existing object + if isinstance(data[f], int): + # only retrieve serializer required fields + required_fields = ['id'] + [field_name for field_name, field in self.fields[f].__class__().fields.items() if field.required] + data[f] = model_to_dict(self.fields[f].Meta.model.objects.get(id=data[f]), fields=required_fields) + elif issubclass(self.fields[f].__class__, serializers.ListSerializer): + # if the field is a ListSerializer get dict values of PKs provided + if any(isinstance(x, int) for x in data[f]): + # only retrieve serializer required fields + required_fields = ['id'] + [field_name for field_name, field in self.fields[f].child.__class__().fields.items() if field.required] + # filter values to integer values + pk_data = [x for x in data[f] if isinstance(x, int)] + # merge non-pk values with retrieved values + data[f] = [x for x in data[f] if not isinstance(x, int)] \ + + list(self.fields[f].child.Meta.model.objects.filter(id__in=pk_data).values(*required_fields)) + return super().to_internal_value(data) + + class ExtendedRecipeMixin(serializers.ModelSerializer): # adds image and recipe count to serializer when query param extended=1 # ORM path to this object from Recipe @@ -56,8 +85,7 @@ class ExtendedRecipeMixin(serializers.ModelSerializer): api_serializer = None # extended values are computationally expensive and not needed in normal circumstances try: - if str2bool( - self.context['request'].query_params.get('extended', False)) and self.__class__ == api_serializer: + if str2bool(self.context['request'].query_params.get('extended', False)) and self.__class__ == api_serializer: return fields except (AttributeError, KeyError): pass @@ -122,16 +150,12 @@ class CustomOnHandField(serializers.Field): if not self.context["request"].user.is_authenticated: return [] shared_users = [] - if c := caches['default'].get( - f'shopping_shared_users_{self.context["request"].space.id}_{self.context["request"].user.id}', None): + if c := caches['default'].get(f'shopping_shared_users_{self.context["request"].space.id}_{self.context["request"].user.id}', None): shared_users = c else: try: - shared_users = [x.id for x in list(self.context['request'].user.get_shopping_share())] + [ - self.context['request'].user.id] - caches['default'].set( - f'shopping_shared_users_{self.context["request"].space.id}_{self.context["request"].user.id}', - shared_users, timeout=5 * 60) + shared_users = [x.id for x in list(self.context['request'].user.get_shopping_share())] + [self.context['request'].user.id] + caches['default'].set(f'shopping_shared_users_{self.context["request"].space.id}_{self.context["request"].user.id}', shared_users, timeout=5 * 60) # TODO ugly hack that improves API performance significantly, should be done properly except AttributeError: # Anonymous users (using share links) don't have shared users pass diff --git a/cookbook/tests/other/test_nested_serializer.py b/cookbook/tests/other/test_nested_serializer.py new file mode 100644 index 000000000..9623551cc --- /dev/null +++ b/cookbook/tests/other/test_nested_serializer.py @@ -0,0 +1,69 @@ +import json + +import pytest +from django.urls import reverse +from django_scopes import scopes_disabled +from pytest_factoryboy import LazyFixture, register + +from cookbook.tests.factories import FoodFactory, KeywordFactory, UnitFactory + +RECIPE_URL = 'api:recipe-detail' +FOOD_URL = 'api:food-detail' + +register(FoodFactory, 'food_1', space=LazyFixture('space_1')) +register(FoodFactory, 'food_2', space=LazyFixture('space_1')) +register(KeywordFactory, 'keyword_1', space=LazyFixture('space_1')) +register(KeywordFactory, 'keyword_2', space=LazyFixture('space_1')) +register(UnitFactory, 'unit_1', space=LazyFixture('space_1')) + + +@pytest.mark.parametrize("arg", ['dict', 'pk']) +def test_unnested_serializer__single(arg, recipe_1_s1, food_1, u1_s1): + if arg == 'dict': + recipe = {'id': recipe_1_s1.id, 'name': recipe_1_s1.name, } + elif arg == 'pk': + recipe = recipe_1_s1.id + r = u1_s1.patch(reverse(FOOD_URL, args={food_1.id}), {'name': food_1.name, 'recipe': recipe}, content_type='application/json') + assert r.status_code == 200 + assert json.loads(r.content)['recipe']['id'] == recipe_1_s1.id + + +def test_nested_serializer_many(recipe_1_s1, food_1, food_2, keyword_1, keyword_2, unit_1, u1_s1): + with scopes_disabled(): + assert food_1 not in [i.food for i in recipe_1_s1.steps.all()[0].ingredients.all()] + assert food_2 not in [i.food for i in recipe_1_s1.steps.all()[0].ingredients.all()] + assert keyword_1 not in recipe_1_s1.keywords.all() + assert keyword_2 not in recipe_1_s1.keywords.all() + r = u1_s1.patch(reverse(RECIPE_URL, args={recipe_1_s1.id}), { + 'name': + recipe_1_s1.name, + 'steps': [{ + 'ingredients': [{ + 'amount': 1, + 'unit': { + 'id': unit_1.id, + 'name': unit_1.name + }, + 'food': { + 'id': food_1.id, + 'name': food_1.name + } + }, { + 'amount': 1, + 'unit': unit_1.id, + 'food': food_2.id + }] + }], + 'keywords': [{ + 'id': keyword_1.id, + 'name': keyword_1.name + }, keyword_2.id] + }, + content_type='application/json') + assert r.status_code == 200 + with scopes_disabled(): + # recipe_1_s1 = Recipe.objects.get(id=recipe_1_s1.id) + assert food_1 in [i.food for i in recipe_1_s1.steps.all()[0].ingredients.all()] + assert food_2 in [i.food for i in recipe_1_s1.steps.all()[0].ingredients.all()] + assert keyword_1 in recipe_1_s1.keywords.all() + assert keyword_2 in recipe_1_s1.keywords.all()