ST で破壊的なヒープソート

ST モナドの中では、配列に対して破壊的な操作ができるので、試しにヒープソートを作ってみました。ヒープソートのアルゴリズムは、「珠玉のプログラミング」を参考にしました。

> x <- randomArray 10 100
> x
array (1,10) [(1,71),(2,27),(3,85),(4,6),(5,79),(6,8),(7,58),(8,97),(9,25),(10,89)]
> heapSort x
array (1,10) [(1,6),(2,8),(3,25),(4,27),(5,58),(6,71),(7,79),(8,85),(9,89),(10,97)]

以下がそのコードです。(UArray に指定するインデックスや値の型制約はうまく書けないみたいなので、型を決めうちにしています。)

import Control.Applicative
import Control.Monad
import Control.Monad.ST
import Data.Array.ST
import Data.Array.Unboxed
import Random

----------------------------------------------------------------

type Index = Int
type Value = Int
type UA = UArray Index Value
type SUA s = STUArray s Index Value
type PRED = Value -> Value -> Bool

----------------------------------------------------------------

heapSort :: UA -> UA
heapSort ua = runSTUArray $ heapsort ua
              
heapsort :: UA -> ST s (SUA s)
heapsort ua = do
    let (beg,end) = bounds ua
    sua <- newArray_ (beg,end) -- this sets the type of 's'
    copy ua sua beg end
    forM_ [beg+1..end] $ shiftUp sua
    forM_ [end,end-1..beg+1] $ swapAndShiftDown sua beg
    return sua
  where
    copy from to beg end = forM_ [beg..end] $ \i -> writeArray to i (from ! i)
    swapAndShiftDown arr beg idx = 
        swapAndDo arr beg (\_ _ -> True) idx (shiftDown arr beg (idx - beg))

shiftUp :: SUA s -> Index -> ST s ()
shiftUp _ 1 = return ()
shiftUp sua c = swapAndDo sua p (>) c (shiftUp sua p)
  where
    p = c `div` 2

shiftDown :: SUA s -> Index -> Index -> ST s ()
shiftDown sua p n
  | c1 > n    = return ()
  | c1 == n   = swapAndDo sua p (>) c1 (return ())
  | otherwise = do
      let c2 = c1 + 1
      xc1 <- readArray sua c1
      xc2 <- readArray sua c2
      let c = if xc1 > xc2 then c1 else c2
      swapAndDo sua p (>) c (shiftDown sua c n)
  where
    c1 = 2 * p

swapAndDo :: SUA s -> Index -> PRED -> Index -> ST s () -> ST s ()
swapAndDo sua p op c cont = do
    xp <- readArray sua p
    xc <- readArray sua c
    when (xc `op` xp) $ do
        writeArray sua c xp
        writeArray sua p xc
        cont

----------------------------------------------------------------

randomArray :: Index -> Value -> IO UA
randomArray n boundary = listArray (1,n) <$> getList
  where
      getList = replicateM n randomInt
      randomInt :: IO Int
      randomInt = getStdRandom (randomR (0,boundary))