Skip to content

Commit a38646d

Browse files
committed
Fix leak in H2 manager
See `ManagedThreads`. Closes kazu-yamamoto#154.
1 parent f7c0701 commit a38646d

File tree

1 file changed

+38
-12
lines changed

1 file changed

+38
-12
lines changed

Network/HTTP2/H2/Manager.hs

+38-12
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ import Control.Concurrent.STM
1919
import Control.Exception
2020
import qualified Control.Exception as E
2121
import Data.Foldable
22-
import Data.Map (Map)
22+
import Data.Map.Strict (Map)
2323
import qualified Data.Map.Strict as Map
2424
import qualified System.TimeManager as T
2525

@@ -28,9 +28,14 @@ import Imports
2828
----------------------------------------------------------------
2929

3030
-- | Manager to manage the thread and the timer.
31-
data Manager = Manager T.Manager (TVar ManagedThreads)
31+
data Manager = Manager T.Manager ManagedThreads
3232

33-
type ManagedThreads = Map ThreadId TimeoutHandle
33+
-- | The set of managed threads
34+
--
35+
-- This is a newtype to ensure that this is always updated strictly.
36+
newtype ManagedThreads = WrapManagedThreads
37+
{ unwrapManagedThreads :: TVar (Map ThreadId TimeoutHandle)
38+
}
3439

3540
----------------------------------------------------------------
3641

@@ -49,7 +54,7 @@ cancelTimeout ThreadWithoutTimeout = return ()
4954
-- by 'setAction'. This allows that the action can include
5055
-- the manager itself.
5156
start :: T.Manager -> IO Manager
52-
start timmgr = Manager timmgr <$> newTVarIO Map.empty
57+
start timmgr = Manager timmgr <$> newManagedThreads
5358

5459
----------------------------------------------------------------
5560

@@ -70,10 +75,7 @@ stopAfter :: Manager -> IO a -> (Maybe SomeException -> IO ()) -> IO a
7075
stopAfter (Manager _timmgr var) action cleanup = do
7176
mask $ \unmask -> do
7277
ma <- try $ unmask action
73-
m <- atomically $ do
74-
m0 <- readTVar var
75-
writeTVar var Map.empty
76-
return m0
78+
m <- atomically $ modifyManagedThreads var (\ts -> (Map.empty, ts))
7779
forM_ (Map.elems m) cancelTimeout
7880
let er = either Just (const Nothing) ma
7981
forM_ (Map.keys m) $ \tid ->
@@ -102,17 +104,17 @@ forkManagedUnmask (Manager _timmgr var) label io =
102104
void $ mask_ $ forkIOWithUnmask $ \unmask -> E.handle ignore $ do
103105
labelMe label
104106
tid <- myThreadId
105-
atomically $ modifyTVar var $ Map.insert tid ThreadWithoutTimeout
107+
atomically $ modifyManagedThreads_ var $ Map.insert tid ThreadWithoutTimeout
106108
-- We catch the exception and do not rethrow it: we don't want the
107109
-- exception printed to stderr.
108110
io unmask `catch` ignore
109-
atomically $ modifyTVar var $ Map.delete tid
111+
atomically $ modifyManagedThreads_ var $ Map.delete tid
110112
where
111113
ignore (E.SomeException _) = return ()
112114

113115
waitCounter0 :: Manager -> IO ()
114116
waitCounter0 (Manager _timmgr var) = atomically $ do
115-
m <- readTVar var
117+
m <- getManagedThreads var
116118
check (Map.size m == 0)
117119

118120
----------------------------------------------------------------
@@ -122,5 +124,29 @@ withTimeout (Manager timmgr var) action =
122124
T.withHandleKillThread timmgr (return ()) $ \th -> do
123125
tid <- myThreadId
124126
-- overriding ThreadWithoutTimeout
125-
atomically $ modifyTVar var $ Map.insert tid $ ThreadWithTimeout th
127+
atomically $ modifyManagedThreads_ var $ Map.insert tid $ ThreadWithTimeout th
126128
action th
129+
130+
----------------------------------------------------------------
131+
132+
newManagedThreads :: IO ManagedThreads
133+
newManagedThreads = WrapManagedThreads <$> newTVarIO Map.empty
134+
135+
getManagedThreads :: ManagedThreads -> STM (Map ThreadId TimeoutHandle)
136+
getManagedThreads = readTVar . unwrapManagedThreads
137+
138+
modifyManagedThreads
139+
:: ManagedThreads
140+
-> (Map ThreadId TimeoutHandle -> (Map ThreadId TimeoutHandle, a))
141+
-> STM a
142+
modifyManagedThreads (WrapManagedThreads var) f = do
143+
threads <- readTVar var
144+
let (threads', result) = f threads
145+
writeTVar var $! threads' -- strict update
146+
return result
147+
148+
modifyManagedThreads_
149+
:: ManagedThreads
150+
-> (Map ThreadId TimeoutHandle -> Map ThreadId TimeoutHandle)
151+
-> STM ()
152+
modifyManagedThreads_ var f = modifyManagedThreads var (\ts -> (f ts, ()))

0 commit comments

Comments
 (0)