Poll Model(투표 예제)

 

아래와 같은 투표 요구사항이 있을때 DB 모델링과 쿼리 최적화를 하려 한다.

 

 DB 모델

 

1. User의 썸네일을 나타내기 위해 Image라는 모델을 만들었다.

2. 하나의 투표(Poll)은 여러개의 질문(PollQuestion)을 가진다.

3. 하나의 질문에 여러개의 투표(PollVote)를 할 수 있다.

# models.py
class Image(models.Model):
    image_url = models.CharField(max_length=128)


class User(AbstractUser):
    objects = UserManager()
    image = models.OneToOneField('Image', related_name='image', blank=True, null=True, on_delete=models.CASCADE)


class PollQuestion(models.Model):
    poll = models.ForeignKey(
        'Poll',
        verbose_name="투표ID",
        related_name='questions',
        on_delete=models.CASCADE)
    content = models.CharField(max_length=512)
    ordering = models.PositiveIntegerField(default=1)

    def __str__(self):
        return f'[{self.id}]{self.content}'


class PollVote(models.Model):
    poll = models.ForeignKey(
        'Poll', verbose_name="투표ID", related_name='poll_votes', on_delete=models.CASCADE)
    question = models.ForeignKey(
        'PollQuestion', verbose_name="질문", related_name='question_votes', on_delete=models.CASCADE)
    owner = models.ForeignKey(
        User, on_delete=models.CASCADE, related_name='my_votes', verbose_name="투표자")

    def __str__(self):
        return f'투표자[{self.owner.first_name}]: {self.question}'


class Poll(models.Model):
    owner = models.ForeignKey(User, on_delete=models.CASCADE)
    head = models.CharField(max_length=50)
    message = models.CharField(max_length=128)
    end_time = models.DateTimeField(null=True, blank=True)
    is_anonymous_vote = models.BooleanField(verbose_name="익명투표", default=False)
    is_multiple_vote = models.BooleanField(verbose_name="복수투표", default=False)

 

투표 시리얼라이저

위 그림처럼 투표를 한 화면에 모두 보여주려면 Poll, Question, Vote를 모두 보여줘야한다.

그래서 PollDetailSerializer를 만들었다.

해당 시리얼라이저는 상당히 많은 Serializer를 nested(내포)하고 있으며, poll instance 또는 poll queryset을 생성자로 던지면 마법 처럼 알아서 직렬화를 해준다.

이 마법은 내부적으로 상당히 많은 문제가 발생 시키는데 이를 Lazy Loding으로 인한 부작용, N+1 Problem이라 부른다. 

poll: Poll = Poll.objects.first()
serializer = PollDetailSerializer(poll)
print(serializer.data)

polls: QuerySet[Poll] = Poll.objects.all()
serializer = PollDetailSerializer(polls, many=True)
print(serializer.data)
class ImageSerializer(serializers.ModelSerializer):
    class Meta:
        model = Image
        fields = '__all__'


class UserSerializer(serializers.ModelSerializer):
    image = ImageSerializer()

    class Meta:
        model = User
        fields = ('id', 'first_name', 'username', 'image',)


class PollVoteSerializer(serializers.ModelSerializer):
    owner = UserSerializer()

    class Meta:
        model = PollVote
        exclude = ('poll', 'question',)


class PollQuestionSerializer(serializers.ModelSerializer):
    question_votes = PollVoteSerializer(many=True)

    class Meta:
        model = PollQuestion
        exclude = ('poll',)


class PollDetailSerializer(serializers.ModelSerializer):
    questions = PollQuestionSerializer(many=True)
    owner = UserSerializer(read_only=True, )

    class Meta:
        model = Poll
        fields = ('id', 'owner', 'head', 'message', 'questions', 'end_time',
                  'is_anonymous_vote','is_multiple_vote',)

 

투표 테스트

간단한 투표 생성 및 확인 테스트를 해보겠다.

하나의 투표를 생성하고 직렬화 해봤다.

    def test_poll_create(self):
        end_time: datetime = datetime.datetime.now() + datetime.timedelta(hours=0, minutes=300, seconds=0)
        poll: Poll = Poll.objects.create(
            head='목요일 회식 가능??',
            owner=self.user,
            end_time=end_time)

        yes: PollQuestion = PollQuestion.objects.create(poll=poll, content="네")
        no: PollQuestion = PollQuestion.objects.create(poll=poll, content="아니오")

        PollVote.objects.create(poll=poll, question=yes, owner=self.user)
        PollVote.objects.create(poll=poll, question=no, owner=self.user2)

        self.assertEqual(poll.questions.all().count(), 2)
        self.assertEqual(poll.poll_votes.all().count(), 2)
        
        serializer = PollDetailSerializer(poll)
        self.assertEqual(len(serializer.data['questions']), 2)
        self.assertEqual(len(serializer.data['questions'][0]['question_votes']), 1)
        self.assertEqual(len(serializer.data['questions'][1]['question_votes']), 1)

 

이제 여러개의 투표를 생성해보고 Serializer를 돌려보자.

(필자는 귀찮음을 싫어해 factory boy를 사용했다.)

import datetime

import factory
from dateutil.tz import UTC
from django.db import connection
from django.db.models import QuerySet
from django.test import TestCase
from django.test.utils import CaptureQueriesContext
from factory.fuzzy import FuzzyDateTime

from poll.models import Poll, PollQuestion, PollVote, User, Image
from poll.serializers import PollDetailSerializer


class UserFactory(factory.django.DjangoModelFactory):
    class Meta:
        model = User

    first_name = factory.Sequence(lambda n: "Agent %03d" % n)
    username = factory.Sequence(lambda n: "Agent %03d" % n)
    password = factory.Sequence(lambda n: "Agent %03d" % n)


class PollFactory(factory.django.DjangoModelFactory):
    class Meta:
        model = Poll

    head = factory.Sequence(lambda n: "Agent %d" % n)
    owner = factory.SubFactory(UserFactory)
    end_time = FuzzyDateTime(datetime.datetime(2023, 1, 1, tzinfo=UTC), datetime.datetime(2024, 1, 1, tzinfo=UTC))


class PollTest(TestCase):
    def setUp(self):
        self.user: User = User.objects.create(**self.create_user())
        self.user.image = Image.objects.create(image_url='www.image.com')
        self.user.save()

        user2 = self.create_user()
        user2['username'] = 'user2'
        self.user2: User = User.objects.create(**user2)
        self.user2.image = Image.objects.create(image_url='www.image.com')
        self.user2.save()

    def create_user(self) -> dict:
        return {
            "username": "leemoney93",
            "password": "mememememe",
            "first_name": "lee",
            "last_name": "money",
        }

    def create_poll(self):
        poll: Poll = PollFactory.create()
        q1 = PollQuestion.objects.create(poll=poll, content="yes")
        q2 = PollQuestion.objects.create(poll=poll, content="no")
        PollVote.objects.create(poll=poll, question=q1, owner=self.user)
        PollVote.objects.create(poll=poll, question=q2, owner=self.user2)

    def test_poll_create2(self):
        for i in range(6):
            self.create_poll()

        with CaptureQueriesContext(connection) as num_queries:
            polls: QuerySet[Poll] = Poll.objects.all()
            serializer = PollDetailSerializer(polls, many=True)
            self.assertEqual(len(serializer.data), 6)

        print(len(num_queries.captured_queries))
        print(num_queries.captured_queries)

test_poll_create2 테스트 함수를 보면 6개의 투표들이 잘 직렬화된걸 알 수 있다.

DRF(Django Rest Framework)가 마법을 부려서 알아서 잘 해줬다.

하지만! num_queries를 이용해 얼마나 많은 쿼리가 실행 됐는지 봐라.

무려 49번이나 실행됐다. 기절할 노릇이다. 고작 6개의 투표들을 직렬화 했을때 이정돈데 10개가 되면 도대체 몇개를 호출한단 말인가..

DB가 기절할 노릇이다.

 

문제점을 파악했으니 정확한 원인을 알아보자.

필자는 위에서 이 마법의 문제를 Lazy Loding으로 인한 부작용, N+1 Problem이라 했다.

DRF 입장에서 PostDetailSerializer는 owner, questions을 추가적으로 직렬화 해야하나 현재 polls 쿼리셋으로는 해당 정보를 공급 해 줄수가 없다. 해서 뒤 늦게 자기가 추가적인 정보를 얻기위해(owner, questions) 쿼리를 실행해 Loading했다. 이를 Lazy Loading이라 한다. 여기서 owner의 UserSerializer 역시 추가적으로 image를 직렬화 한다. 여기서 또 쿼리가 발생한다. 이는 questions 또한 마찬가지다. PollVote 직렬화를 위해 추가적인 쿼리를 생산한다. 

마법은 마법인데.. 수동 마법이다.

 

 

문제 해결 즉시 로딩 (Eager Loading)

Lazy Loading의 반댓말은 Eager Loading이다.

polls queryset에서 부족한 정보를 미리 공급해 추가적인 쿼리를 발생시키지 않겠다는 말이다. 

방법은 여러가지가 있지만 여기서는 PollManager를 생성해 get_queryset 오버라이딩 하는 방식을 취했다.

(django two scoops 책에서 추천하는 방법)

get_queryset에서는 PollDetailSerializer가 필요한 정보를 공급해줘야한다.

class Poll(models.Model):
    class PollManager(models.Manager):
        def get_queryset(self) -> QuerySet:
            queryset: QuerySet['Poll'] = (
                super().get_queryset()
                .prefetch_related(
                    Prefetch('questions', queryset=PollQuestion.objects.prefetch_related(
                        Prefetch('question_votes', queryset=PollVote.objects.select_related('owner__image')))))
                .select_related('owner__image'))
            return queryset

    owner = models.ForeignKey(User, on_delete=models.CASCADE)
    head = models.CharField(max_length=50)
    message = models.CharField(max_length=128)
    end_time = models.DateTimeField(null=True, blank=True)
    is_anonymous_vote = models.BooleanField(verbose_name="익명투표", default=False)
    is_multiple_vote = models.BooleanField(verbose_name="복수투표", default=False)

    objects = PollManager()

 

 

코드 자세히 살펴보기

Poll 모델 입장에서

1. owner는 다 대 일 형식이다. 해서 select_related를 사용해야한다. 

그리고 UserSerializer는 image를 추가적으로 요구한다. 해서 아래와 같이 select_related를 사용하면 된다.

(select_related는 내부적으로 join이다.)

.select_related('owner__image')

 

Poll 모델 입장에서

2. questions는 일 대 다 형식이다. 해서 prefetch_related를 사용해야한다.

그리고 PollQuestionSerializer는 추가적으로 votes를 원한다. 해서 한번 더 prefetch를 해줘야한다.

그리고 PollVote는 한번 더 owner를 요구한다. select_related를 한번 더 추가 해준다.

(prefetch_related는 내부적으로 where in 구절을 사용한다.)

.prefetch_related(
    Prefetch('questions', queryset=PollQuestion.objects.prefetch_related(
        Prefetch('question_votes', queryset=PollVote.objects.select_related('owner__image')))))

 

 

def test_poll_n_plus1(self):
    self.create_poll()

    with CaptureQueriesContext(connection) as expected_num_queries:
        polls: QuerySet[Poll] = Poll.objects.all()
        serializer = PollDetailSerializer(polls, many=True)
        print(serializer.data)

    self.create_poll()
    self.create_poll()
    self.create_poll()
    self.create_poll()
    self.create_poll()

    with CaptureQueriesContext(connection) as checked_num_queries:
        polls: QuerySet[Poll] = Poll.objects.all()
        serializer = PollDetailSerializer(polls, many=True)
        print(serializer.data)

    self.assertEqual(len(expected_num_queries), len(checked_num_queries))

 

 

 

테스트 함수를 살펴보면 투표가 얼마나 늘던지 간에 항상 쿼리 개수가 정해져있다.

나름의 최적화가 완성 되었다!

def test_poll_n_plus1(self):
    self.create_poll()

    with CaptureQueriesContext(connection) as expected_num_queries:
        polls: QuerySet[Poll] = Poll.objects.all()
        serializer = PollDetailSerializer(polls, many=True)
        print(serializer.data)

    self.create_poll()
    self.create_poll()
    self.create_poll()
    self.create_poll()
    self.create_poll()

    with CaptureQueriesContext(connection) as checked_num_queries:
        polls: QuerySet[Poll] = Poll.objects.all()
        serializer = PollDetailSerializer(polls, many=True)
        print(serializer.data)

    self.assertEqual(len(expected_num_queries), len(checked_num_queries))

 

소스코드 : https://github.com/seunwoolee/djangopoll

+ Recent posts