Skip to content

Commit d58f77d

Browse files
authored
Merge pull request #584 from jvdp1/median_select
Replace the call to sort by select in stdlib_stats_median
2 parents 49d269b + de04268 commit d58f77d

File tree

4 files changed

+44
-36
lines changed

4 files changed

+44
-36
lines changed

doc/specs/stdlib_stats.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -204,9 +204,9 @@ and if `n` is an odd number, the median is:
204204
median(array) = mean( array_sorted( floor( (n + 1) / 2.):floor( (n + 1) / 2.) + 1 ) )
205205
```
206206

207-
The current implementation is a quite naive implementation that relies on sorting
208-
the whole array, using the subroutine `[[stdlib_sorting(module):ord_sort(interface)]]`
209-
provided by the `[[stdlib_sorting(module)]]` module.
207+
The current implementation relies on a selection algorithm applied on a copy of
208+
the whole array, using the subroutine `[[stdlib_selection(module):select(interface)]]`
209+
provided by the `[[stdlib_selection(module)]]` module.
210210

211211
### Syntax
212212

src/Makefile.manual

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ stdlib_stats_mean.o: \
148148
stdlib_stats_median.o: \
149149
stdlib_optval.o \
150150
stdlib_kinds.o \
151-
stdlib_sorting.o \
151+
stdlib_selection.o \
152152
stdlib_stats.o
153153
stdlib_stats_moment.o: \
154154
stdlib_optval.o \

src/stdlib_stats_median.fypp

Lines changed: 40 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,7 @@ submodule (stdlib_stats) stdlib_stats_median
77
use, intrinsic:: ieee_arithmetic, only: ieee_value, ieee_quiet_nan, ieee_is_nan
88
use stdlib_error, only: error_stop
99
use stdlib_optval, only: optval
10-
! Use "ord_sort" rather than "sort" because the former can be much faster for arrays
11-
! that are already partly sorted. While it is slightly slower for random arrays,
12-
! ord_sort seems a better overall choice.
13-
use stdlib_sorting, only: sort => ord_sort
10+
use stdlib_selection, only: select
1411
implicit none
1512

1613
contains
@@ -24,6 +21,7 @@ contains
2421
real(${o1}$) :: res
2522

2623
integer(kind = int64) :: c, n
24+
${t1}$ :: val, val1
2725
${t1}$, allocatable :: x_tmp(:)
2826

2927
if (.not.optval(mask, .true.) .or. size(x) == 0) then
@@ -43,16 +41,18 @@ contains
4341

4442
x_tmp = reshape(x, [n])
4543

46-
call sort(x_tmp)
44+
call select(x_tmp, c, val)
4745

4846
if (mod(n, 2_int64) == 0) then
47+
val1 = minval(x_tmp(c+1:n)) !instead of call select(x_tmp, c+1, val1, left = c)
4948
#:if t1[0] == 'r'
50-
res = sum(x_tmp(c:c+1)) / 2._${o1}$
49+
res = (val + val1) / 2._${o1}$
5150
#:else
52-
res = sum( real(x_tmp(c:c+1), kind=${o1}$) ) / 2._${o1}$
51+
res = (real(val, kind=${o1}$) + &
52+
real(val1, kind=${o1}$)) / 2._${o1}$
5353
#:endif
5454
else
55-
res = x_tmp(c)
55+
res = val
5656
end if
5757

5858
end function ${name}$
@@ -74,6 +74,7 @@ contains
7474
integer :: j${fj}$
7575
#:endfor
7676
#:endif
77+
${t1}$ :: val, val1
7778
${t1}$, allocatable :: x_tmp(:)
7879

7980
if (.not.optval(mask, .true.) .or. size(x) == 0) then
@@ -107,17 +108,18 @@ contains
107108
end if
108109
#:endif
109110

110-
call sort(x_tmp)
111+
call select(x_tmp, c, val)
111112

112113
if (mod(n, 2) == 0) then
114+
val1 = minval(x_tmp(c+1:n))
113115
res${reduce_subvector('j', rank, fi)}$ = &
114116
#:if t1[0] == 'r'
115-
sum(x_tmp(c:c+1)) / 2._${o1}$
117+
(val + val1) / 2._${o1}$
116118
#:else
117-
sum(real(x_tmp(c:c+1), kind=${o1}$) ) / 2._${o1}$
119+
(real(val, kind=${o1}$) + real(val1, kind=${o1}$)) / 2._${o1}$
118120
#:endif
119121
else
120-
res${reduce_subvector('j', rank, fi)}$ = x_tmp(c)
122+
res${reduce_subvector('j', rank, fi)}$ = val
121123
end if
122124
#:for fj in range(1, rank)
123125
end do
@@ -141,6 +143,7 @@ contains
141143
real(${o1}$) :: res
142144

143145
integer(kind = int64) :: c, n
146+
${t1}$ :: val, val1
144147
${t1}$, allocatable :: x_tmp(:)
145148

146149
if (any(shape(x) .ne. shape(mask))) then
@@ -156,21 +159,26 @@ contains
156159

157160
x_tmp = pack(x, mask)
158161

159-
call sort(x_tmp)
160-
161162
n = size(x_tmp, kind=int64)
162-
c = floor( (n + 1) / 2._${o1}$, kind=int64)
163163

164164
if (n == 0) then
165165
res = ieee_value(1._${o1}$, ieee_quiet_nan)
166-
else if (mod(n, 2_int64) == 0) then
166+
return
167+
end if
168+
169+
c = floor( (n + 1) / 2._${o1}$, kind=int64)
170+
171+
call select(x_tmp, c, val)
172+
173+
if (mod(n, 2_int64) == 0) then
174+
val1 = minval(x_tmp(c+1:n))
167175
#:if t1[0] == 'r'
168-
res = sum(x_tmp(c:c+1)) / 2._${o1}$
176+
res = (val + val1) / 2._${o1}$
169177
#:else
170-
res = sum(real(x_tmp(c:c+1), kind=${o1}$)) / 2._${o1}$
178+
res = (real(val, kind=${o1}$) + real(val1, kind=${o1}$)) / 2._${o1}$
171179
#:endif
172180
else if (mod(n, 2_int64) == 1) then
173-
res = x_tmp(c)
181+
res = val
174182
end if
175183

176184
end function ${name}$
@@ -192,6 +200,7 @@ contains
192200
integer :: j${fj}$
193201
#:endfor
194202
#:endif
203+
${t1}$ :: val, val1
195204
${t1}$, allocatable :: x_tmp(:)
196205

197206
if (any(shape(x) .ne. shape(mask))) then
@@ -220,23 +229,28 @@ contains
220229
end if
221230
#:endif
222231

223-
call sort(x_tmp)
224-
225232
n = size(x_tmp, kind=int64)
226-
c = floor( (n + 1) / 2._${o1}$, kind=int64 )
227233

228234
if (n == 0) then
229235
res${reduce_subvector('j', rank, fi)}$ = &
230236
ieee_value(1._${o1}$, ieee_quiet_nan)
231-
else if (mod(n, 2_int64) == 0) then
237+
return
238+
end if
239+
240+
c = floor( (n + 1) / 2._${o1}$, kind=int64 )
241+
242+
call select(x_tmp, c, val)
243+
244+
if (mod(n, 2_int64) == 0) then
245+
val1 = minval(x_tmp(c+1:n))
232246
res${reduce_subvector('j', rank, fi)}$ = &
233247
#:if t1[0] == 'r'
234-
sum(x_tmp(c:c+1)) / 2._${o1}$
248+
(val + val1) / 2._${o1}$
235249
#:else
236-
sum(real(x_tmp(c:c+1), kind=${o1}$)) / 2._${o1}$
250+
(real(val, kind=${o1}$) + real(val1, kind=${o1}$)) / 2._${o1}$
237251
#:endif
238252
else if (mod(n, 2_int64) == 1) then
239-
res${reduce_subvector('j', rank, fi)}$ = x_tmp(c)
253+
res${reduce_subvector('j', rank, fi)}$ = val
240254
end if
241255

242256
deallocate(x_tmp)

src/tests/stats/test_median.fypp

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -117,12 +117,6 @@ contains
117117
call check(error, median(d2odd_${k1}$), 1._dp&
118118
, 'median(d2odd_${k1}$): uncorrect answer'&
119119
, thr = dptol)
120-
if (allocated(error)) return
121-
122-
call check(error, median(d2odd_${k1}$), 1._dp&
123-
, 'median(d2odd_${k1}$): uncorrect answer'&
124-
, thr = dptol)
125-
if (allocated(error)) return
126120

127121
end subroutine
128122

0 commit comments

Comments
 (0)