• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

haskell / random / 400

26 Dec 2024 12:49AM UTC coverage: 68.916% (-0.8%) from 69.751%
400

push

github

web-flow
Merge pull request #171 from haskell/lehins/use-array-for-shuffle

Implement a faster and unbiased version of list shuffling

40 of 66 new or added lines in 4 files covered. (60.61%)

61 existing lines in 2 files now uncovered.

623 of 904 relevant lines covered (68.92%)

1.3 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

69.83
/src/System/Random/Array.hs
1
{-# LANGUAGE BangPatterns #-}
2
{-# LANGUAGE CPP #-}
3
{-# LANGUAGE MagicHash #-}
4
{-# LANGUAGE Trustworthy #-}
5
{-# LANGUAGE UnboxedTuples #-}
6
-- |
7
-- Module      :  System.Random.Array
8
-- Copyright   :  (c) Alexey Kuleshevich 2024
9
-- License     :  BSD-style (see the file LICENSE in the 'random' repository)
10
-- Maintainer  :  libraries@haskell.org
11
--
12
module System.Random.Array
13
  ( -- * Helper array functionality
14
    ioToST
15
  , wordSizeInBits
16
    -- ** MutableByteArray
17
  , newMutableByteArray
18
  , newPinnedMutableByteArray
19
  , freezeMutableByteArray
20
  , writeWord8
21
  , writeWord64LE
22
  , writeByteSliceWord64LE
23
  , indexWord8
24
  , indexWord64LE
25
  , indexByteSliceWord64LE
26
  , sizeOfByteArray
27
  , shortByteStringToByteArray
28
  , byteArrayToShortByteString
29
  , getSizeOfMutableByteArray
30
  , shortByteStringToByteString
31
  -- ** MutableArray
32
  , Array (..)
33
  , MutableArray (..)
34
  , newMutableArray
35
  , freezeMutableArray
36
  , writeArray
37
  , shuffleListM
38
  , shuffleListST
39
  ) where
40

41
import Control.Monad.Trans (lift, MonadTrans)
42
import Control.Monad (when)
43
import Control.Monad.ST
44
import Data.Array.Byte (ByteArray(..), MutableByteArray(..))
45
import Data.Bits
46
import Data.ByteString.Short.Internal (ShortByteString(SBS))
47
import qualified Data.ByteString.Short.Internal as SBS (fromShort)
48
import Data.Word
49
import GHC.Exts
50
import GHC.IO (IO(..))
51
import GHC.ST (ST(..))
52
import GHC.Word
53
#if __GLASGOW_HASKELL__ >= 802
54
import Data.ByteString.Internal (ByteString(PS))
55
import GHC.ForeignPtr
56
#else
57
import Data.ByteString (ByteString)
58
#endif
59

60
-- Needed for WORDS_BIGENDIAN
61
#include "MachDeps.h"
62

63
wordSizeInBits :: Int
64
wordSizeInBits = finiteBitSize (0 :: Word)
1✔
65

66
----------------
67
-- Byte Array --
68
----------------
69

70
-- Architecture independent helpers:
71

72
sizeOfByteArray :: ByteArray -> Int
73
sizeOfByteArray (ByteArray ba#) = I# (sizeofByteArray# ba#)
2✔
74

75
st_ :: (State# s -> State# s) -> ST s ()
76
st_ m# = ST $ \s# -> (# m# s#, () #)
1✔
77
{-# INLINE st_ #-}
78

79
ioToST :: IO a -> ST RealWorld a
80
ioToST (IO m#) = ST m#
×
81
{-# INLINE ioToST #-}
82

83
newMutableByteArray :: Int -> ST s (MutableByteArray s)
84
newMutableByteArray (I# n#) =
2✔
85
  ST $ \s# ->
2✔
86
    case newByteArray# n# s# of
2✔
87
      (# s'#, mba# #) -> (# s'#, MutableByteArray mba# #)
2✔
88
{-# INLINE newMutableByteArray #-}
89

90
newPinnedMutableByteArray :: Int -> ST s (MutableByteArray s)
91
newPinnedMutableByteArray (I# n#) =
2✔
92
  ST $ \s# ->
2✔
93
    case newPinnedByteArray# n# s# of
2✔
94
      (# s'#, mba# #) -> (# s'#, MutableByteArray mba# #)
2✔
95
{-# INLINE newPinnedMutableByteArray #-}
96

97
freezeMutableByteArray :: MutableByteArray s -> ST s ByteArray
98
freezeMutableByteArray (MutableByteArray mba#) =
2✔
99
  ST $ \s# ->
2✔
100
    case unsafeFreezeByteArray# mba# s# of
2✔
101
      (# s'#, ba# #) -> (# s'#, ByteArray ba# #)
2✔
102

103
writeWord8 :: MutableByteArray s -> Int -> Word8 -> ST s ()
104
writeWord8 (MutableByteArray mba#) (I# i#) (W8# w#) = st_ (writeWord8Array# mba# i# w#)
2✔
105
{-# INLINE writeWord8 #-}
106

107
writeByteSliceWord64LE :: MutableByteArray s -> Int -> Int -> Word64 -> ST s ()
108
writeByteSliceWord64LE mba fromByteIx toByteIx = go fromByteIx
2✔
109
  where
110
    go !i !z =
2✔
111
      when (i < toByteIx) $ do
2✔
112
        writeWord8 mba i (fromIntegral z :: Word8)
2✔
113
        go (i + 1) (z `shiftR` 8)
2✔
114
{-# INLINE writeByteSliceWord64LE #-}
115

116
indexWord8 ::
117
     ByteArray
118
  -> Int -- ^ Offset into immutable byte array in number of bytes
119
  -> Word8
120
indexWord8 (ByteArray ba#) (I# i#) =
2✔
121
  W8# (indexWord8Array# ba# i#)
2✔
122
{-# INLINE indexWord8 #-}
123

124
indexWord64LE ::
125
     ByteArray
126
  -> Int -- ^ Offset into immutable byte array in number of bytes
127
  -> Word64
128
#if defined WORDS_BIGENDIAN || !(__GLASGOW_HASKELL__ >= 806)
129
indexWord64LE ba i = indexByteSliceWord64LE ba i (i + 8)
130
#else
131
indexWord64LE (ByteArray ba#) (I# i#)
2✔
132
  | wordSizeInBits == 64 = W64# (indexWord8ArrayAsWord64# ba# i#)
1✔
133
  | otherwise =
×
134
    let !w32l = W32# (indexWord8ArrayAsWord32# ba# i#)
×
135
        !w32u = W32# (indexWord8ArrayAsWord32# ba# (i# +# 4#))
×
136
    in (fromIntegral w32u `shiftL` 32) .|. fromIntegral w32l
×
137
#endif
138
{-# INLINE indexWord64LE #-}
139

140
indexByteSliceWord64LE ::
141
     ByteArray
142
  -> Int -- ^ Starting offset in number of bytes
143
  -> Int -- ^ Ending offset in number of bytes
144
  -> Word64
145
indexByteSliceWord64LE ba fromByteIx toByteIx = goWord8 fromByteIx 0
2✔
146
  where
147
    r = (toByteIx - fromByteIx) `rem` 8
2✔
148
    nPadBits = if r == 0 then 0 else 8 * (8 - r)
1✔
149
    goWord8 i !w64
2✔
150
      | i < toByteIx = goWord8 (i + 1) (shiftL w64 8 .|. fromIntegral (indexWord8 ba i))
2✔
151
      | otherwise = byteSwap64 (shiftL w64 nPadBits)
1✔
152
{-# INLINE indexByteSliceWord64LE #-}
153

154
-- On big endian machines we need to write one byte at a time for consistency with little
155
-- endian machines. Also for GHC versions prior to 8.6 we don't have primops that can
156
-- write with byte offset, eg. writeWord8ArrayAsWord64# and writeWord8ArrayAsWord32#, so we
157
-- also must fallback to writing one byte a time. Such fallback results in about 3 times
158
-- slow down, which is not the end of the world.
159
writeWord64LE ::
160
     MutableByteArray s
161
  -> Int -- ^ Offset into mutable byte array in number of bytes
162
  -> Word64 -- ^ 8 bytes that will be written into the supplied array
163
  -> ST s ()
164
#if defined WORDS_BIGENDIAN || !(__GLASGOW_HASKELL__ >= 806)
165
writeWord64LE mba i w64 =
166
  writeByteSliceWord64LE mba i (i + 8) w64
167
#else
168
writeWord64LE (MutableByteArray mba#) (I# i#) w64@(W64# w64#)
2✔
169
  | wordSizeInBits == 64 = st_ (writeWord8ArrayAsWord64# mba# i# w64#)
1✔
170
  | otherwise = do
×
171
    let !(W32# w32l#) = fromIntegral w64
×
172
        !(W32# w32u#) = fromIntegral (w64 `shiftR` 32)
×
173
    st_ (writeWord8ArrayAsWord32# mba# i# w32l#)
×
174
    st_ (writeWord8ArrayAsWord32# mba# (i# +# 4#) w32u#)
×
175
#endif
176
{-# INLINE writeWord64LE #-}
177

178
getSizeOfMutableByteArray :: MutableByteArray s -> ST s Int
179
getSizeOfMutableByteArray (MutableByteArray mba#) =
2✔
180
#if __GLASGOW_HASKELL__ >=802
181
  ST $ \s ->
2✔
182
    case getSizeofMutableByteArray# mba# s of
2✔
183
      (# s', n# #) -> (# s', I# n# #)
2✔
184
#else
185
  pure $! I# (sizeofMutableByteArray# mba#)
186
#endif
187
{-# INLINE getSizeOfMutableByteArray #-}
188

189
shortByteStringToByteArray :: ShortByteString -> ByteArray
190
shortByteStringToByteArray (SBS ba#) = ByteArray ba#
×
191
{-# INLINE shortByteStringToByteArray #-}
192

193
byteArrayToShortByteString :: ByteArray -> ShortByteString
194
byteArrayToShortByteString (ByteArray ba#) = SBS ba#
2✔
195
{-# INLINE byteArrayToShortByteString #-}
196

197
-- | Convert a ShortByteString to ByteString by casting, whenever memory is pinned,
198
-- otherwise make a copy into a new pinned ByteString
199
shortByteStringToByteString :: ShortByteString -> ByteString
200
shortByteStringToByteString ba =
2✔
201
#if __GLASGOW_HASKELL__ < 802
202
  SBS.fromShort ba
203
#else
204
  let !(SBS ba#) = ba in
2✔
205
  if isTrue# (isByteArrayPinned# ba#)
1✔
206
    then pinnedByteArrayToByteString ba#
2✔
207
    else SBS.fromShort ba
×
208
{-# INLINE shortByteStringToByteString #-}
209

210
pinnedByteArrayToByteString :: ByteArray# -> ByteString
211
pinnedByteArrayToByteString ba# =
2✔
212
  PS (pinnedByteArrayToForeignPtr ba#) 0 (I# (sizeofByteArray# ba#))
2✔
213
{-# INLINE pinnedByteArrayToByteString #-}
214

215
pinnedByteArrayToForeignPtr :: ByteArray# -> ForeignPtr a
216
pinnedByteArrayToForeignPtr ba# =
2✔
217
  ForeignPtr (byteArrayContents# ba#) (PlainPtr (unsafeCoerce# ba#))
1✔
218
{-# INLINE pinnedByteArrayToForeignPtr #-}
219
#endif
220

221
-----------------
222
-- Boxed Array --
223
-----------------
224

225
data Array a = Array (Array# a)
226

227
data MutableArray s a = MutableArray (MutableArray# s a)
228

229
newMutableArray :: Int -> a -> ST s (MutableArray s a)
230
newMutableArray (I# n#) a =
2✔
231
  ST $ \s# ->
2✔
232
    case newArray# n# a s# of
1✔
233
      (# s'#, ma# #) -> (# s'#, MutableArray ma# #)
2✔
234
{-# INLINE newMutableArray #-}
235

236
freezeMutableArray :: MutableArray s a -> ST s (Array a)
NEW
237
freezeMutableArray (MutableArray ma#) =
×
NEW
238
  ST $ \s# ->
×
NEW
239
    case unsafeFreezeArray# ma# s# of
×
NEW
240
      (# s'#, a# #) -> (# s'#, Array a# #)
×
241
{-# INLINE freezeMutableArray #-}
242

243
sizeOfMutableArray :: MutableArray s a -> Int
244
sizeOfMutableArray (MutableArray ma#) = I# (sizeofMutableArray# ma#)
2✔
245
{-# INLINE sizeOfMutableArray #-}
246

247
readArray :: MutableArray s a -> Int -> ST s a
248
readArray (MutableArray ma#) (I# i#) = ST (readArray# ma# i#)
2✔
249
{-# INLINE readArray #-}
250

251
writeArray :: MutableArray s a -> Int -> a -> ST s ()
252
writeArray (MutableArray ma#) (I# i#) a = st_ (writeArray# ma# i# a)
2✔
253
{-# INLINE writeArray #-}
254

255
swapArray :: MutableArray s a -> Int -> Int -> ST s ()
256
swapArray ma i j = do
2✔
257
  x <- readArray ma i
2✔
258
  y <- readArray ma j
2✔
259
  writeArray ma j x
2✔
260
  writeArray ma i y
2✔
261
{-# INLINE swapArray #-}
262

263
-- | Write contents of the list into the mutable array. Make sure that array is big
264
-- enough or segfault will happen.
265
fillMutableArrayFromList :: MutableArray s a -> [a] -> ST s ()
266
fillMutableArrayFromList ma = go 0
2✔
267
  where
268
    go _ [] = pure ()
1✔
269
    go i (x:xs) = writeArray ma i x >> go (i + 1) xs
2✔
270
{-# INLINE fillMutableArrayFromList #-}
271

272
readListFromMutableArray :: MutableArray s a -> ST s [a]
273
readListFromMutableArray ma = go (len - 1) []
2✔
274
  where
275
    len = sizeOfMutableArray ma
2✔
276
    go i !acc
2✔
277
       | i >= 0 = do
2✔
278
           x <- readArray ma i
2✔
279
           go (i - 1) (x : acc)
2✔
280
       | otherwise = pure acc
1✔
281
{-# INLINE readListFromMutableArray #-}
282

283

284
-- | Generate a list of indices that will be used for swapping elements in uniform shuffling:
285
--
286
-- @
287
-- [ (0, n - 1)
288
-- , (0, n - 2)
289
-- , (0, n - 3)
290
-- , ...
291
-- , (0, 3)
292
-- , (0, 2)
293
-- , (0, 1)
294
-- ]
295
-- @
296
genSwapIndices
297
  :: Monad m
298
  => (Word -> m Word)
299
  -- ^ Action that generates a Word in the supplied range.
300
  -> Word
301
  -- ^ Number of index swaps to generate.
302
  -> m [Int]
NEW
303
genSwapIndices genWordR n = go 1 []
×
304
  where
NEW
305
    go i !acc
×
NEW
306
      | i >= n = pure acc
×
NEW
307
      | otherwise = do
×
NEW
308
          x <- genWordR i
×
NEW
309
          let !xi = fromIntegral x
×
NEW
310
          go (i + 1) (xi : acc)
×
311
{-# INLINE genSwapIndices #-}
312

313

314
-- | Implementation of mutable version of Fisher-Yates shuffle. Unfortunately, we cannot generally
315
-- interleave pseudo-random number generation and mutation of `ST` monad, therefore we have to
316
-- pre-generate all of the index swaps with `genSwapIndices` and store them in a list before we can
317
-- perform the actual swaps.
318
shuffleListM :: Monad m => (Word -> m Word) -> [a] -> m [a]
NEW
319
shuffleListM genWordR ls
×
NEW
320
  | len <= 1 = pure ls
×
NEW
321
  | otherwise = do
×
NEW
322
    swapIxs <- genSwapIndices genWordR (fromIntegral len)
×
NEW
323
    pure $ runST $ do
×
NEW
324
      ma <- newMutableArray len $ error "Impossible: shuffleListM"
×
NEW
325
      fillMutableArrayFromList ma ls
×
326

327
      -- Shuffle elements of the mutable array according to the uniformly generated index swap list
NEW
328
      let goSwap _ [] = pure ()
×
NEW
329
          goSwap i (j:js) = swapArray ma i j >> goSwap (i - 1) js
×
NEW
330
      goSwap (len - 1) swapIxs
×
331

NEW
332
      readListFromMutableArray ma
×
333
  where
NEW
334
    len = length ls
×
335
{-# INLINE shuffleListM #-}
336

337
-- | This is a ~x2-x3 more efficient version of `shuffleListM`. It is more efficient because it does
338
-- not need to pregenerate a list of indices and instead generates them on demand. Because of this the
339
-- result that will be produced will differ for the same generator, since the order in which index
340
-- swaps are generated is reversed.
341
--
342
-- Unfortunately, most stateful generator monads can't handle `MonadTrans`, so this version is only
343
-- used for implementing the pure shuffle.
344
shuffleListST :: (Monad (t (ST s)), MonadTrans t) => (Word -> t (ST s) Word) -> [a] -> t (ST s) [a]
345
shuffleListST genWordR ls
2✔
346
  | len <= 1 = pure ls
2✔
347
  | otherwise = do
1✔
348
     ma <- lift $ newMutableArray len $ error "Impossible: shuffleListST"
1✔
349
     lift $ fillMutableArrayFromList ma ls
2✔
350

351
     -- Shuffle elements of the mutable array according to the uniformly generated index swap
352
     let goSwap i =
2✔
353
           when (i > 0) $ do
2✔
354
             j <- genWordR $ (fromIntegral :: Int -> Word) i
2✔
355
             lift $ swapArray ma i ((fromIntegral :: Word -> Int) j)
2✔
356
             goSwap (i - 1)
2✔
357
     goSwap (len - 1)
2✔
358

359
     lift $ readListFromMutableArray ma
2✔
360
  where
361
    len = length ls
2✔
362
{-# INLINE shuffleListST #-}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc