@@ -19,7 +19,7 @@ import Control.Concurrent.STM
19
19
import Control.Exception
20
20
import qualified Control.Exception as E
21
21
import Data.Foldable
22
- import Data.Map (Map )
22
+ import Data.Map.Strict (Map )
23
23
import qualified Data.Map.Strict as Map
24
24
import qualified System.TimeManager as T
25
25
@@ -28,9 +28,14 @@ import Imports
28
28
----------------------------------------------------------------
29
29
30
30
-- | Manager to manage the thread and the timer.
31
- data Manager = Manager T. Manager ( TVar ManagedThreads )
31
+ data Manager = Manager T. Manager ManagedThreads
32
32
33
- type ManagedThreads = Map ThreadId TimeoutHandle
33
+ -- | The set of managed threads
34
+ --
35
+ -- This is a newtype to ensure that this is always updated strictly.
36
+ newtype ManagedThreads = WrapManagedThreads
37
+ { unwrapManagedThreads :: TVar (Map ThreadId TimeoutHandle )
38
+ }
34
39
35
40
----------------------------------------------------------------
36
41
@@ -49,7 +54,7 @@ cancelTimeout ThreadWithoutTimeout = return ()
49
54
-- by 'setAction'. This allows that the action can include
50
55
-- the manager itself.
51
56
start :: T. Manager -> IO Manager
52
- start timmgr = Manager timmgr <$> newTVarIO Map. empty
57
+ start timmgr = Manager timmgr <$> newManagedThreads
53
58
54
59
----------------------------------------------------------------
55
60
@@ -70,10 +75,7 @@ stopAfter :: Manager -> IO a -> (Maybe SomeException -> IO ()) -> IO a
70
75
stopAfter (Manager _timmgr var) action cleanup = do
71
76
mask $ \ unmask -> do
72
77
ma <- try $ unmask action
73
- m <- atomically $ do
74
- m0 <- readTVar var
75
- writeTVar var Map. empty
76
- return m0
78
+ m <- atomically $ modifyManagedThreads var (\ ts -> (Map. empty, ts))
77
79
forM_ (Map. elems m) cancelTimeout
78
80
let er = either Just (const Nothing ) ma
79
81
forM_ (Map. keys m) $ \ tid ->
@@ -102,17 +104,17 @@ forkManagedUnmask (Manager _timmgr var) label io =
102
104
void $ mask_ $ forkIOWithUnmask $ \ unmask -> E. handle ignore $ do
103
105
labelMe label
104
106
tid <- myThreadId
105
- atomically $ modifyTVar var $ Map. insert tid ThreadWithoutTimeout
107
+ atomically $ modifyManagedThreads_ var $ Map. insert tid ThreadWithoutTimeout
106
108
-- We catch the exception and do not rethrow it: we don't want the
107
109
-- exception printed to stderr.
108
110
io unmask `catch` ignore
109
- atomically $ modifyTVar var $ Map. delete tid
111
+ atomically $ modifyManagedThreads_ var $ Map. delete tid
110
112
where
111
113
ignore (E. SomeException _) = return ()
112
114
113
115
waitCounter0 :: Manager -> IO ()
114
116
waitCounter0 (Manager _timmgr var) = atomically $ do
115
- m <- readTVar var
117
+ m <- getManagedThreads var
116
118
check (Map. size m == 0 )
117
119
118
120
----------------------------------------------------------------
@@ -122,5 +124,29 @@ withTimeout (Manager timmgr var) action =
122
124
T. withHandleKillThread timmgr (return () ) $ \ th -> do
123
125
tid <- myThreadId
124
126
-- overriding ThreadWithoutTimeout
125
- atomically $ modifyTVar var $ Map. insert tid $ ThreadWithTimeout th
127
+ atomically $ modifyManagedThreads_ var $ Map. insert tid $ ThreadWithTimeout th
126
128
action th
129
+
130
+ ----------------------------------------------------------------
131
+
132
+ newManagedThreads :: IO ManagedThreads
133
+ newManagedThreads = WrapManagedThreads <$> newTVarIO Map. empty
134
+
135
+ getManagedThreads :: ManagedThreads -> STM (Map ThreadId TimeoutHandle )
136
+ getManagedThreads = readTVar . unwrapManagedThreads
137
+
138
+ modifyManagedThreads
139
+ :: ManagedThreads
140
+ -> (Map ThreadId TimeoutHandle -> (Map ThreadId TimeoutHandle , a ))
141
+ -> STM a
142
+ modifyManagedThreads (WrapManagedThreads var) f = do
143
+ threads <- readTVar var
144
+ let (threads', result) = f threads
145
+ writeTVar var $! threads' -- strict update
146
+ return result
147
+
148
+ modifyManagedThreads_
149
+ :: ManagedThreads
150
+ -> (Map ThreadId TimeoutHandle -> Map ThreadId TimeoutHandle )
151
+ -> STM ()
152
+ modifyManagedThreads_ var f = modifyManagedThreads var (\ ts -> (f ts, () ))
0 commit comments