Skip to content

Commit

Permalink
Fix the race condition in unsafeAcquireWriteAccess.
Browse files Browse the repository at this point in the history
This mainly involves rolling back previous `RWState` changes if an interrupt
happens. The code is extended with comments to explain how the code is now
interrupt-safe.
  • Loading branch information
jorisdral committed Oct 23, 2024
1 parent da4bdf1 commit 7d96c72
Showing 1 changed file with 75 additions and 37 deletions.
112 changes: 75 additions & 37 deletions src-control/Control/Concurrent/Class/MonadSTM/RWVar.hs
Original file line number Diff line number Diff line change
Expand Up @@ -75,47 +75,85 @@ withReadAccess rwvar k =
k

{-# SPECIALISE unsafeAcquireWriteAccess :: RWVar IO a -> IO a #-}
unsafeAcquireWriteAccess :: MonadSTM m => RWVar m a -> m a
unsafeAcquireWriteAccess rw@(RWVar !var) = do
may <- atomically $ readTVar var >>= \case
Reading n x
| n == 0 -> do
writeTVar var Writing
pure (Just x)
| otherwise -> do
writeTVar var (WaitingToWrite n x)
pure Nothing
WaitingToWrite n x
| n == 0 -> do
writeTVar var Writing
pure (Just x)
| otherwise -> retry
-- If we see this, someone else has already acquired a write lock, so we retry.
Writing -> retry
-- When the previous STM transaction returns Nothing, it signals that we set
-- the RWState to WaitingToWrite. All that remains is to try acquiring write
-- access again.
-- | Acquire write access. This function assumes that it runs in a masked
-- context, and that is properly paired with an 'unsafeReleaseWriteAccess'!
--
-- If multiple threads try to acquire write access concurrently, then they will
-- race for access. However, if a thread has set RWState to WaitingToWrite, then
-- it is guaranteed that the same thread will acquire write access when all
-- readers have finished. That is, other writes can not "jump the queue". When
-- the writer finishes, then all other waiting threads will race for write
-- access again.
--
-- TODO: unsafeReleaseWriteAccess will set RWState to Reading 0. In case we have
-- readers *and* writers waiting for a writer to finish, once the writer is
-- finished there will be a race. In this race, readers and writers are just as
-- likely to acquire access first. However, if we wanted to make RWVar even more
-- biased towards writers, then we could ensure that all waiting writers get
-- access before the readers get a chance. This would probably require us to
-- change RWState to represent the case where writers are waiting for a writer
-- to finish.
unsafeAcquireWriteAccess :: (MonadSTM m, MonadCatch m) => RWVar m a -> m a
unsafeAcquireWriteAccess (RWVar !var) =
-- trySetWriting is interruptible, but it is fine if it is interrupted
-- because the RWState can not be changed before the interruption.
--
-- If multiple threads try to acquire write access concurrently, then they
-- will race for access. Even if a thread has set RWState to WaitingToWrite,
-- there is no guarantee that the same thread will acquire write access
-- first when all readers have finished. However, if the writer has
-- finished, then all other waiting threads will try to get write again
-- until they succeed.
-- trySetWriting might update the RWState. There are interruptible
-- operations in the body of the bracketOnError (in waitToWrite), so async
-- exceptions can be delivered there. If an async exception happens because
-- of an interrupt, we undo the RWState change using undoWaitingToWrite.
--
-- Note that if waitToWrite is interrupted, that it is impossible for the
-- RWState to have changed from WaitingToWrite to either Reading or Writing.
-- Therefore, undoWaitingToWrite can assume that it will find WaitingToWrite
-- in the lock.
bracketOnError trySetWriting undoWaitingToWrite $
-- When Nothing is returned, it means that we set the RWState to
-- WaitingToWrite, and so we wait to acquire the final write access.
--
-- When Just is returned, we already have write access.
maybe waitToWrite pure
where
-- Try to acquire a write lock immediately, or otherwise set the internal
-- state to WaitingToWrite as soon as possible.
--
-- Note: this is interruptible
trySetWriting = atomically $ readTVar var >>= \case
Reading n x
| n == 0 -> do
writeTVar var Writing
pure (Just x)
| otherwise -> do
writeTVar var (WaitingToWrite n x)
pure Nothing
-- The following two branches are interruptible
WaitingToWrite _n _x -> retry
Writing -> retry

-- Note: this is uninterruptible
undoWaitingToWrite Nothing = atomically $ readTVar var >>= \case
Reading _n _x -> error "undoWaitingToWrite: found Reading but expected WaitingToWrite"
WaitingToWrite n x -> writeTVar var (Reading n x)
Writing -> error "undoWaitingToWrite: found Writing but expected WaitingToWrite"
undoWaitingToWrite (Just _) = error "undoWaitingToWrite: found Just but expected Nothing"

-- Wait for the number of readers to go to 0, and then finally acquire write
-- access.
--
-- TODO: unsafeReleaseWriteAccess will set RWState to Reading 0. In case we
-- have readers *and* writers waiting for a writer to finish, once the
-- writer is finished there will be a race. In this race, readers and
-- writers are just as likely to acquire access first. However, if we wanted
-- to make RWVar even more biased towards writers, then we could ensure that
-- all waiting writers get access before the readers get a chance. This
-- would probably require us to change RWState to represent the case where
-- writers are waiting for a writer to finish.
case may of
Nothing -> unsafeAcquireWriteAccess rw
Just x -> pure x
-- Note: this is interruptible
waitToWrite = atomically $ readTVar var >>= \case
Reading _n _x -> error "waitToWrite: found Reading but expected WaitingToWrite"
WaitingToWrite n x
| n == 0 -> do
writeTVar var Writing
pure x
-- This branch is interruptible
| otherwise -> retry
Writing -> error "waitToWrite: found Reading but expected Writing"

{-# SPECIALISE unsafeReleaseWriteAccess :: RWVar IO a -> a -> STM IO () #-}
-- | Release write access. This function assumes that it runs in a masked
-- context, and that is properly paired with an 'unsafeAcquireWriteAccess'!
unsafeReleaseWriteAccess :: MonadSTM m => RWVar m a -> a -> STM m ()
unsafeReleaseWriteAccess (RWVar !var) x = do
readTVar var >>= \case
Expand Down

0 comments on commit 7d96c72

Please sign in to comment.