diff --git a/doc/specs/stdlib_stats.md b/doc/specs/stdlib_stats.md index b8a4f7f78..3f758a19d 100644 --- a/doc/specs/stdlib_stats.md +++ b/doc/specs/stdlib_stats.md @@ -204,9 +204,9 @@ and if `n` is an odd number, the median is: median(array) = mean( array_sorted( floor( (n + 1) / 2.):floor( (n + 1) / 2.) + 1 ) ) ``` -The current implementation is a quite naive implementation that relies on sorting -the whole array, using the subroutine `[[stdlib_sorting(module):ord_sort(interface)]]` -provided by the `[[stdlib_sorting(module)]]` module. +The current implementation relies on a selection algorithm applied on a copy of +the whole array, using the subroutine `[[stdlib_selection(module):select(interface)]]` +provided by the `[[stdlib_selection(module)]]` module. ### Syntax @@ -220,11 +220,11 @@ Generic subroutine ### Arguments -`array`: Shall be an array of type `integer` or `real`. +`array`: Shall be an array of type `integer` or `real`. It is an `intent(in)` argument. -`dim`: Shall be a scalar of type `integer` with a value in the range from 1 to `n`, where `n` is the rank of `array`. +`dim`: Shall be a scalar of type `integer` with a value in the range from 1 to `n`, where `n` is the rank of `array`. It is an `intent(in)` argument. -`mask` (optional): Shall be of type `logical` and either a scalar or an array of the same shape as `array`. +`mask` (optional): Shall be of type `logical` and either a scalar or an array of the same shape as `array`. It is an `intent(in)` argument. ### Return value diff --git a/src/Makefile.manual b/src/Makefile.manual index 2bfcb7091..3d02c4ec2 100644 --- a/src/Makefile.manual +++ b/src/Makefile.manual @@ -134,7 +134,7 @@ stdlib_stats_mean.o: \ stdlib_stats_median.o: \ stdlib_optval.o \ stdlib_kinds.o \ - stdlib_sorting.o \ + stdlib_selection.o \ stdlib_stats.o stdlib_stats_moment.o: \ stdlib_optval.o \ diff --git a/src/stdlib_stats_median.fypp b/src/stdlib_stats_median.fypp index 5adb7b256..ef1e797ca 100644 --- a/src/stdlib_stats_median.fypp +++ b/src/stdlib_stats_median.fypp @@ -7,10 +7,7 @@ submodule (stdlib_stats) stdlib_stats_median use, intrinsic:: ieee_arithmetic, only: ieee_value, ieee_quiet_nan, ieee_is_nan use stdlib_error, only: error_stop use stdlib_optval, only: optval - ! Use "ord_sort" rather than "sort" because the former can be much faster for arrays - ! that are already partly sorted. While it is slightly slower for random arrays, - ! ord_sort seems a better overall choice. - use stdlib_sorting, only: sort => ord_sort + use stdlib_selection, only: select implicit none contains @@ -24,6 +21,7 @@ contains real(${o1}$) :: res integer(kind = int64) :: c, n + ${t1}$ :: val, val1 ${t1}$, allocatable :: x_tmp(:) if (.not.optval(mask, .true.) .or. size(x) == 0) then @@ -43,16 +41,18 @@ contains x_tmp = reshape(x, [n]) - call sort(x_tmp) + call select(x_tmp, c, val) if (mod(n, 2_int64) == 0) then + val1 = minval(x_tmp(c+1:n)) !instead of call select(x_tmp, c+1, val1, left = c) #:if t1[0] == 'r' - res = sum(x_tmp(c:c+1)) / 2._${o1}$ + res = (val + val1) / 2._${o1}$ #:else - res = sum( real(x_tmp(c:c+1), kind=${o1}$) ) / 2._${o1}$ + res = (real(val, kind=${o1}$) + & + real(val1, kind=${o1}$)) / 2._${o1}$ #:endif else - res = x_tmp(c) + res = val end if end function ${name}$ @@ -74,6 +74,7 @@ contains integer :: j${fj}$ #:endfor #:endif + ${t1}$ :: val, val1 ${t1}$, allocatable :: x_tmp(:) if (.not.optval(mask, .true.) .or. size(x) == 0) then @@ -107,17 +108,18 @@ contains end if #:endif - call sort(x_tmp) + call select(x_tmp, c, val) if (mod(n, 2) == 0) then + val1 = minval(x_tmp(c+1:n)) res${reduce_subvector('j', rank, fi)}$ = & #:if t1[0] == 'r' - sum(x_tmp(c:c+1)) / 2._${o1}$ + (val + val1) / 2._${o1}$ #:else - sum(real(x_tmp(c:c+1), kind=${o1}$) ) / 2._${o1}$ + (real(val, kind=${o1}$) + real(val1, kind=${o1}$)) / 2._${o1}$ #:endif else - res${reduce_subvector('j', rank, fi)}$ = x_tmp(c) + res${reduce_subvector('j', rank, fi)}$ = val end if #:for fj in range(1, rank) end do @@ -141,6 +143,7 @@ contains real(${o1}$) :: res integer(kind = int64) :: c, n + ${t1}$ :: val, val1 ${t1}$, allocatable :: x_tmp(:) if (any(shape(x) .ne. shape(mask))) then @@ -156,21 +159,26 @@ contains x_tmp = pack(x, mask) - call sort(x_tmp) - n = size(x_tmp, kind=int64) - c = floor( (n + 1) / 2._${o1}$, kind=int64) if (n == 0) then res = ieee_value(1._${o1}$, ieee_quiet_nan) - else if (mod(n, 2_int64) == 0) then + return + end if + + c = floor( (n + 1) / 2._${o1}$, kind=int64) + + call select(x_tmp, c, val) + + if (mod(n, 2_int64) == 0) then + val1 = minval(x_tmp(c+1:n)) #:if t1[0] == 'r' - res = sum(x_tmp(c:c+1)) / 2._${o1}$ + res = (val + val1) / 2._${o1}$ #:else - res = sum(real(x_tmp(c:c+1), kind=${o1}$)) / 2._${o1}$ + res = (real(val, kind=${o1}$) + real(val1, kind=${o1}$)) / 2._${o1}$ #:endif else if (mod(n, 2_int64) == 1) then - res = x_tmp(c) + res = val end if end function ${name}$ @@ -192,6 +200,7 @@ contains integer :: j${fj}$ #:endfor #:endif + ${t1}$ :: val, val1 ${t1}$, allocatable :: x_tmp(:) if (any(shape(x) .ne. shape(mask))) then @@ -220,23 +229,28 @@ contains end if #:endif - call sort(x_tmp) - n = size(x_tmp, kind=int64) - c = floor( (n + 1) / 2._${o1}$, kind=int64 ) if (n == 0) then res${reduce_subvector('j', rank, fi)}$ = & ieee_value(1._${o1}$, ieee_quiet_nan) - else if (mod(n, 2_int64) == 0) then + return + end if + + c = floor( (n + 1) / 2._${o1}$, kind=int64 ) + + call select(x_tmp, c, val) + + if (mod(n, 2_int64) == 0) then + val1 = minval(x_tmp(c+1:n)) res${reduce_subvector('j', rank, fi)}$ = & #:if t1[0] == 'r' - sum(x_tmp(c:c+1)) / 2._${o1}$ + (val + val1) / 2._${o1}$ #:else - sum(real(x_tmp(c:c+1), kind=${o1}$)) / 2._${o1}$ + (real(val, kind=${o1}$) + real(val1, kind=${o1}$)) / 2._${o1}$ #:endif else if (mod(n, 2_int64) == 1) then - res${reduce_subvector('j', rank, fi)}$ = x_tmp(c) + res${reduce_subvector('j', rank, fi)}$ = val end if deallocate(x_tmp) diff --git a/src/tests/stats/test_median.fypp b/src/tests/stats/test_median.fypp index d257c7485..22d19785d 100644 --- a/src/tests/stats/test_median.fypp +++ b/src/tests/stats/test_median.fypp @@ -117,12 +117,6 @@ contains call check(error, median(d2odd_${k1}$), 1._dp& , 'median(d2odd_${k1}$): uncorrect answer'& , thr = dptol) - if (allocated(error)) return - - call check(error, median(d2odd_${k1}$), 1._dp& - , 'median(d2odd_${k1}$): uncorrect answer'& - , thr = dptol) - if (allocated(error)) return end subroutine