We wish to simulate exponential decay on a set of discrete elements over time.
We run our simulation at a frequency of \(\frac{1\ \text{s}}{\delta t}\ \text{Hz}\). Each simulation step happens \(\delta t\) after the previous.
Our simulation speed relates to the half-life \(t_{1/2}\) with the coefficient \(k =\frac{t_{1/2}}{\delta t}\).
On average, we expect to decay half of the remaining elements every \(t_{1/2}\), or on every \(k\) simulation steps.
On each simulation step we expect to decay \(N_t (1 - 2^{-1/k})\), where \(N_t\) is the number of remaining elements.
Proof:
$$ \begin{split} decay(N) & = N - N (1 - 2^{-1/k}) \\ & = N\ 2^{-1/k} \\ decay^{\circ k}(N) & = N\ 2^{-k/k} = \frac{N}{2}\ \square \end{split} $$
At each simulation step, each element has a probability \(p = 2^{-1/k}\) to survive the step, and a probability \(q = 1-p\) to decay.
Finally, on each simulation step, we decay \(\delta N\) of the remaining elements. \(\delta N\) is a binomially distributed random variable:
$$ \begin{split} \delta N & \sim B(N_t, q) \\ E[\delta N] & = N_t\ q = N_t (1 - 2^{-1/k}) \end{split} $$
Talk is cheap, show me the code
Here is an abridged snippet that shows the main logic:
import Control.Lens.Operators
import Control.Monad.Random.Class
import Control.Monad.State.Strict qualified as State
import Data.Generics.Labels ()
import Data.Sequence (Seq)
import Data.Sequence qualified as Seq
import System.Random
data DecayStrategy
= OldestFirst
| NewestFirst
| RandomElements
data State a = State
{ stdGen :: StdGen
, halfLife :: Double
, decayStrategy :: DecayStrategy
, elements :: Seq a
}
deriving (Generic)
decay :: Double -> State a -> State a
decay dt = State.execState do
State{..} <- State.get
let k = halfLife / dt
n = Seq.length elements
p = 2 ** (-1 / k)
q = 1 - p
dn <- sampleState $ Binomial n q
case decayStrategy of
OldestFirst -> #elements %= Seq.drop dn
NewestFirst -> #elements %= Seq.take (n - dn)
RandomElements -> do
indices <- sampleStateRVar $ shuffleNofM dn n [0 .. n - 1]
#elements %= flip (foldr Seq.deleteAt) (reverse . List.sort $ indices)
Let’s test it:
import Test.Hspec
import Test.QuickCheck
data TestCase = TestCase {state :: State (), steps :: Natural}
deriving (Show)
instance Arbitrary TestCase where
arbitrary = do
state <- arbitrary
steps <- fromIntegral <$> chooseInteger (10, 100)
pure TestCase{..}
instance (Arbitrary a) => Arbitrary (State a) where
arbitrary = do
halfLife <- choose (1e-6, 1e6)
decayStrategy <- arbitraryBoundedEnum
n <- chooseInt (0, 1000)
elements <- Seq.fromList <$> replicateM n arbitrary
pure State{stdGen = mkStdGen 0, ..}
decaysCorrectly :: TestCase -> Expectation
decaysCorrectly TestCase{..} = fromRational (abs avgDeviation) `shouldBeLT` (10 * sigma + eps)
where
dt = state.halfLife / fromIntegral steps
n0 = fromIntegral $ Seq.length state.elements
seeds = Seq.fromList [0 .. 1000]
deviations = flip fmap seeds $ \seed -> flip evalState state do
#stdGen .= mkStdGen seed
replicateM_ (fromIntegral steps) . State.modify . decay $ abs dt
actualN <- fromIntegral <$> State.gets (Seq.length . (.elements))
pure $ n0 `div` 2 - actualN
cleanedDeviations = Seq.drop 3 $ Seq.reverse $ Seq.drop 3 $ Seq.sort deviations
avgDeviation = sum cleanedDeviations % fromIntegral (Seq.length cleanedDeviations)
variance = n0 % fromIntegral (4 * length deviations)
sigma = sqrt $ fromRational variance
eps = 1e-6
spec :: Spec
spec = it "decays half of the remaining elements every half-life" . property $ decaysCorrectly
Let’s plot it:
import Codec.Picture.Gif
import Codec.Picture.Types
main :: IO ()
main =
either error id . writeGifImages "decay.gif" LoopingForever $
[(greyPalette, round $ dt * 100, frame) | frame <- take steps frames]
where
imageWidth = 50
imageHeight = 50
pixelScale = 20
state =
State
{ stdGen = mkStdGen 0
, halfLife = 2
, decayStrategy = OldestFirst
, elements = Seq.fromList [(x, y) | y <- [0 .. imageHeight - 1], x <- [0 .. imageWidth - 1]]
}
duration = 10
dt = 0.05
steps = round $ duration / dt
pixelAt s x y = if HashSet.member (x `div` pixelScale, y `div` pixelScale) s then 20 else 200
frame State{elements} =
let set = HashSet.fromList $ toList elements
in generateImage
(pixelAt set)
(imageWidth * pixelScale)
(imageHeight * pixelScale)
frames = frame <$> iterate (decay dt) state
Rendering 10 s of decay using the OldestFirst
, NewestFirst
, and RandomElements
strategies with \(N_0 = 2500, t_{1/2}=2\ \text{s}, dt = 0.05\ \text{s}\). Black pixels represent remaining elements arranged in a 2D grid.