@@ -119,23 +119,29 @@ def get_queryset(self, *args, **kwargs):
119
119
included_model = None
120
120
levels = included .split ('.' )
121
121
level_model = qs .model
122
+ # Suppose we can do select_related by default
123
+ can_select_related = True
122
124
for level in levels :
123
125
if not hasattr (level_model , level ):
124
126
break
125
127
field = getattr (level_model , level )
126
128
field_class = field .__class__
127
129
128
130
is_forward_relation = (
129
- issubclass (field_class , ForwardManyToOneDescriptor ) or
130
- issubclass (field_class , ManyToManyDescriptor )
131
+ issubclass (field_class , (ForwardManyToOneDescriptor , ManyToManyDescriptor ))
131
132
)
132
133
is_reverse_relation = (
133
- issubclass (field_class , ReverseManyToOneDescriptor ) or
134
- issubclass (field_class , ReverseOneToOneDescriptor )
134
+ issubclass (field_class , (ReverseManyToOneDescriptor , ReverseOneToOneDescriptor ))
135
135
)
136
136
if not (is_forward_relation or is_reverse_relation ):
137
137
break
138
138
139
+ # Figuring out if relation should be select related rather than prefetch_related
140
+ # If at least one relation in the chain is not "selectable" then use "prefetch"
141
+ can_select_related &= (
142
+ issubclass (field_class , (ForwardManyToOneDescriptor , ReverseOneToOneDescriptor ))
143
+ )
144
+
139
145
if level == levels [- 1 ]:
140
146
included_model = field
141
147
else :
@@ -151,7 +157,10 @@ def get_queryset(self, *args, **kwargs):
151
157
level_model = model_field .model
152
158
153
159
if included_model is not None :
154
- qs = qs .prefetch_related (included .replace ('.' , '__' ))
160
+ if can_select_related :
161
+ qs = qs .select_related (included .replace ('.' , '__' ))
162
+ else :
163
+ qs = qs .prefetch_related (included .replace ('.' , '__' ))
155
164
156
165
return qs
157
166
0 commit comments