Skip to content

Commit

Permalink
agent, smp server: PostgreSQL connection pool (#1468)
Browse files Browse the repository at this point in the history
* agent, smp server: PostgreSQL connection pool

* simplify, create all connections before start

* rename

* remove imports
  • Loading branch information
epoberezkin authored Feb 25, 2025
1 parent 4dc40bd commit 1725409
Show file tree
Hide file tree
Showing 9 changed files with 100 additions and 47 deletions.
56 changes: 29 additions & 27 deletions src/Simplex/Messaging/Agent/Store/Postgres.hs
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ module Simplex.Messaging.Agent.Store.Postgres
where

import Control.Concurrent.STM
import Control.Exception (bracketOnError, finally, onException, throwIO)
import Control.Exception (finally, onException, throwIO, uninterruptibleMask_)
import Control.Logger.Simple (logError)
import Control.Monad (void, when)
import Control.Monad
import Data.ByteString (ByteString)
import Data.Functor (($>))
import Data.Text (Text)
Expand Down Expand Up @@ -53,11 +53,24 @@ createDBStore opts migrations confirmMigrations = do
in sharedMigrateSchema dbm (dbNew st) migrations confirmMigrations

connectPostgresStore :: DBOpts -> IO DBStore
connectPostgresStore DBOpts {connstr, schema, createSchema} = do
(dbConn, dbNew) <- connectDB connstr schema createSchema -- TODO [postgres] analogue for dbBusyLoop?
dbConnection <- newMVar dbConn
dbClosed <- newTVarIO False
pure DBStore {dbConnstr = connstr, dbSchema = schema, dbConnection, dbNew, dbClosed}
connectPostgresStore DBOpts {connstr, schema, poolSize, createSchema} = do
dbSem <- newMVar ()
dbPool <- newTBQueueIO poolSize
dbClosed <- newTVarIO True
let st = DBStore {dbConnstr = connstr, dbSchema = schema, dbPoolSize = fromIntegral poolSize, dbPool, dbSem, dbNew = False, dbClosed}
dbNew <- connectPool st createSchema
pure st {dbNew}

-- uninterruptibleMask_ here and below is used here so that it is not interrupted half-way,
-- it relies on the assumption that when dbClosed = True, the queue is empty,
-- and when it is False, the queue is full (or will have connections returned to it by the threads that use them).
connectPool :: DBStore -> Bool -> IO Bool
connectPool DBStore {dbConnstr, dbSchema, dbPoolSize, dbPool, dbClosed} createSchema = uninterruptibleMask_ $ do
(conn, dbNew) <- connectDB dbConnstr dbSchema createSchema -- TODO [postgres] analogue for dbBusyLoop?
conns <- replicateM (dbPoolSize - 1) $ fst <$> connectDB dbConnstr dbSchema False
mapM_ (atomically . writeTBQueue dbPool) (conn : conns)
atomically $ writeTVar dbClosed False
pure dbNew

connectDB :: ByteString -> ByteString -> Bool -> IO (DB.Connection, Bool)
connectDB connstr schema createSchema = do
Expand Down Expand Up @@ -97,29 +110,18 @@ doesSchemaExist db schema = do
(Only schema)
pure schemaExists

-- can share with SQLite
closeDBStore :: DBStore -> IO ()
closeDBStore st@DBStore {dbClosed} =
ifM (readTVarIO dbClosed) (putStrLn "closeDBStore: already closed") $
withConnection st $ \conn -> do
DB.close conn
atomically $ writeTVar dbClosed True

openPostgresStore_ :: DBStore -> IO ()
openPostgresStore_ DBStore {dbConnstr, dbSchema, dbConnection, dbClosed} =
bracketOnError
(takeMVar dbConnection)
(tryPutMVar dbConnection)
$ \_dbConn -> do
(dbConn, _dbNew) <- connectDB dbConnstr dbSchema False
atomically $ writeTVar dbClosed False
putMVar dbConnection dbConn
closeDBStore DBStore {dbPool, dbPoolSize, dbClosed} =
ifM (readTVarIO dbClosed) (putStrLn "closeDBStore: already closed") $ uninterruptibleMask_ $ do
replicateM_ dbPoolSize $ atomically $ readTBQueue dbPool
atomically $ writeTVar dbClosed True

reopenDBStore :: DBStore -> IO ()
reopenDBStore st@DBStore {dbClosed} =
ifM (readTVarIO dbClosed) open (putStrLn "reopenDBStore: already opened")
where
open = openPostgresStore_ st
reopenDBStore st =
ifM
(readTVarIO $ dbClosed st)
(void $ connectPool st False)
(putStrLn "reopenDBStore: already opened")

-- not used with postgres client (used for ExecAgentStoreSQL, ExecChatStoreSQL)
execSQL :: PSQL.Connection -> Text -> IO [Text]
Expand Down
25 changes: 18 additions & 7 deletions src/Simplex/Messaging/Agent/Store/Postgres/Common.hs
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE TupleSections #-}

module Simplex.Messaging.Agent.Store.Postgres.Common
( DBStore (..),
Expand All @@ -11,32 +15,39 @@ module Simplex.Messaging.Agent.Store.Postgres.Common
)
where

import Control.Concurrent.MVar
import Control.Concurrent.STM
import Control.Exception (bracket)
import Data.ByteString (ByteString)
import qualified Database.PostgreSQL.Simple as PSQL
import UnliftIO.MVar
import UnliftIO.STM
import Numeric.Natural

-- TODO [postgres] use log_min_duration_statement instead of custom slow queries (SQLite's Connection type)
data DBStore = DBStore
{ dbConnstr :: ByteString,
dbSchema :: ByteString,
dbConnection :: MVar PSQL.Connection,
dbPoolSize :: Int,
dbPool :: TBQueue PSQL.Connection,
-- MVar is needed for fair pool distribution, without STM retry contention.
-- Only one thread can be blocked on STM read.
dbSem :: MVar (),
dbClosed :: TVar Bool,
dbNew :: Bool
}

data DBOpts = DBOpts
{ connstr :: ByteString,
schema :: ByteString,
poolSize :: Natural,
createSchema :: Bool
}
deriving (Show)

-- TODO [postgres] connection pool
withConnectionPriority :: DBStore -> Bool -> (PSQL.Connection -> IO a) -> IO a
withConnectionPriority DBStore {dbConnection} _priority action =
withMVar dbConnection action
{-# INLINE withConnectionPriority #-}
withConnectionPriority DBStore {dbPool, dbSem} _priority =
bracket
(withMVar dbSem $ \_ -> atomically $ readTBQueue dbPool)
(atomically . writeTBQueue dbPool)

withConnection :: DBStore -> (PSQL.Connection -> IO a) -> IO a
withConnection st = withConnectionPriority st False
Expand Down
4 changes: 2 additions & 2 deletions src/Simplex/Messaging/Client.hs
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ import qualified Simplex.Messaging.TMap as TM
import Simplex.Messaging.Transport
import Simplex.Messaging.Transport.Client (SocksAuth (..), SocksProxyWithAuth (..), TransportClientConfig (..), TransportHost (..), defaultSMPPort, defaultTcpConnectTimeout, runTransportClient)
import Simplex.Messaging.Transport.KeepAlive
import Simplex.Messaging.Util (bshow, diffToMicroseconds, ifM, liftEitherWith, raceAny_, threadDelay', tshow, whenM)
import Simplex.Messaging.Util (bshow, diffToMicroseconds, ifM, liftEitherWith, raceAny_, threadDelay', tryWriteTBQueue, tshow, whenM)
import Simplex.Messaging.Version
import System.Mem.Weak (Weak, deRefWeak)
import System.Timeout (timeout)
Expand Down Expand Up @@ -1121,7 +1121,7 @@ sendProtocolCommand_ c@ProtocolClient {client_ = PClient {sndQ}, thParams = THan

nonBlockingWriteTBQueue :: TBQueue a -> a -> IO ()
nonBlockingWriteTBQueue q x = do
sent <- atomically $ ifM (isFullTBQueue q) (pure False) (writeTBQueue q x $> True)
sent <- atomically $ tryWriteTBQueue q x
unless sent $ void $ forkIO $ atomically $ writeTBQueue q x

getResponse :: ProtocolClient v err msg -> Maybe Int -> Request err msg -> IO (Response err msg)
Expand Down
2 changes: 1 addition & 1 deletion src/Simplex/Messaging/Server.hs
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ smpServer started cfg@ServerConfig {transports, transportConfig = tCfg, startOpt
mapM_ (queueEvts qEvts) . join . IM.lookup cId =<< readTVarIO cls
queueEvts qEvts (AClient _ _ c@Client {connected, sndQ = q}) =
whenM (readTVarIO connected) $ do
sent <- atomically $ ifM (isFullTBQueue q) (pure False) (writeTBQueue q ts $> True)
sent <- atomically $ tryWriteTBQueue q ts
if sent
then updateEndStats
else -- if queue is full it can block
Expand Down
36 changes: 30 additions & 6 deletions src/Simplex/Messaging/Server/Main.hs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import qualified Data.Text as T
import Data.Text.Encoding (decodeLatin1, encodeUtf8)
import qualified Data.Text.IO as T
import Network.Socket (HostName)
import Numeric.Natural (Natural)
import Options.Applicative
import Simplex.Messaging.Agent.Protocol (connReqUriP')
import Simplex.Messaging.Agent.Store.Postgres (checkSchemaExists)
Expand Down Expand Up @@ -236,8 +237,14 @@ smpServerCLI_ generateSite serveStaticFiles attachStaticFiles cfgPath logPath =
DBOpts
{ connstr = either (const defaultDBConnStr) encodeUtf8 $ lookupValue "STORE_LOG" "db_connection" ini,
schema = either (const defaultDBSchema) encodeUtf8 $ lookupValue "STORE_LOG" "db_schema" ini,
poolSize = either (const defaultDBPoolSize) (read . T.unpack) $ lookupValue "STORE_LOG" "db_pool_size" ini,
createSchema = False
}
dbOptsIniContent :: DBOpts -> Text
dbOptsIniContent DBOpts {connstr, schema, poolSize } =
(optDisabled' (connstr == defaultDBConnStr) <> "db_connection: " <> safeDecodeUtf8 connstr <> "\n")
<> (optDisabled' (schema == defaultDBSchema) <> "db_schema: " <> safeDecodeUtf8 schema <> "\n")
<> (optDisabled' (poolSize == defaultDBPoolSize) <> "db_pool_size: " <> tshow poolSize <> "\n\n")
httpsCertFile = combine cfgPath "web.crt"
httpsKeyFile = combine cfgPath "web.key"
defaultStaticPath = combine logPath "www"
Expand Down Expand Up @@ -320,8 +327,7 @@ smpServerCLI_ generateSite serveStaticFiles attachStaticFiles cfgPath logPath =
\# `database`- PostgreSQL databass (requires `store_messages: journal`).\n\
\store_queues: memory\n\n\
\# Database connection settings for PostgreSQL database (`store_queues: database`).\n"
<> (optDisabled dbOptions <> "db_connection: " <> safeDecodeUtf8 (maybe defaultDBConnStr connstr dbOptions) <> "\n")
<> (optDisabled dbOptions <> "db_schema: " <> safeDecodeUtf8 (maybe defaultDBSchema schema dbOptions) <> "\n\n")
<> dbOptsIniContent dbOptions
<> "# Message storage mode: `memory` or `journal`.\n\
\store_messages: memory\n\n\
\# When store_messages is `memory`, undelivered messages are optionally saved and restored\n\
Expand Down Expand Up @@ -626,6 +632,9 @@ defaultDBConnStr = "postgresql://smp@/smp_server_store"
defaultDBSchema :: ByteString
defaultDBSchema = "smp_server"

defaultDBPoolSize :: Natural
defaultDBPoolSize = 10

defaultControlPort :: Int
defaultControlPort = 5224

Expand Down Expand Up @@ -712,7 +721,12 @@ serverPublicInfo ini = serverInfo <$!> infoValue "source_code"
(_, _, pkURI, pkFingerprint) -> Just ServerContactAddress {simplex, email, pgp = PGPKey <$> pkURI <*> pkFingerprint}

optDisabled :: Maybe a -> Text
optDisabled p = if isNothing p then "# " else ""
optDisabled = optDisabled' . isNothing
{-# INLINE optDisabled #-}

optDisabled' :: Bool -> Text
optDisabled' cond = if cond then "# " else ""
{-# INLINE optDisabled' #-}

validCountryValue :: String -> String -> Either String Text
validCountryValue field s
Expand All @@ -738,7 +752,7 @@ data StoreCmd = SCImport | SCExport | SCDelete

data InitOptions = InitOptions
{ enableStoreLog :: Bool,
dbOptions :: Maybe DBOpts,
dbOptions :: DBOpts,
dbMigrateUp :: Bool,
logStats :: Bool,
signAlgorithm :: SignAlgorithm,
Expand Down Expand Up @@ -780,7 +794,8 @@ cliCommandP cfgPath logPath iniFile =
<> short 'l'
<> help "Enable store log for persistence"
)
dbOptions <- optional dbOptsP
dbOptions <- dbOptsP
-- TODO [postgresql] remove
dbMigrateUp <-
switch
( long "db-migrate-up"
Expand Down Expand Up @@ -963,7 +978,16 @@ cliCommandP cfgPath logPath iniFile =
<> value defaultDBSchema
<> showDefault
)
pure DBOpts {connstr, schema, createSchema = False}
poolSize <-
option
auto
( long "pool-size"
<> metavar "POOL_SIZE"
<> help "Database pool size"
<> value defaultDBPoolSize
<> showDefault
)
pure DBOpts {connstr, schema, poolSize, createSchema = False}
parseConfirmMigrations :: ReadM MigrationConfirmation
parseConfirmMigrations = eitherReader $ \case
"up" -> Right MCYesUp
Expand Down
12 changes: 9 additions & 3 deletions src/Simplex/Messaging/Server/QueueStore/Postgres.hs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
{-# LANGUAGE CPP #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DerivingStrategies #-}
Expand Down Expand Up @@ -37,11 +38,9 @@ import Simplex.Messaging.Agent.Client (withLockMap)
import Simplex.Messaging.Agent.Lock (Lock)
import Simplex.Messaging.Agent.Store.Postgres (createDBStore)
import Simplex.Messaging.Agent.Store.Postgres.Common
import Simplex.Messaging.Agent.Store.Postgres.DB (FromField (..), ToField (..), blobFieldDecoder)
import Simplex.Messaging.Agent.Store.Postgres.DB (FromField (..), ToField (..))
import qualified Simplex.Messaging.Agent.Store.Postgres.DB as DB
import Simplex.Messaging.Agent.Store.Shared (MigrationConfirmation)
import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Encoding.String
import Simplex.Messaging.Protocol
import Simplex.Messaging.Server.QueueStore
import Simplex.Messaging.Server.QueueStore.Postgres.Migrations (serverMigrations)
Expand All @@ -52,6 +51,11 @@ import qualified Simplex.Messaging.TMap as TM
import Simplex.Messaging.Util (firstRow, ifM, tshow, ($>>), ($>>=), (<$$), (<$$>))
import System.Exit (exitFailure)
import System.IO (hFlush, stdout)
#if !defined(dbPostgres)
import Simplex.Messaging.Agent.Store.Postgres.DB (blobFieldDecoder)
import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Encoding.String
#endif

data PostgresQueueStore q = PostgresQueueStore
{ dbStore :: DBStore,
Expand Down Expand Up @@ -367,10 +371,12 @@ instance ToField EntityId where toField (EntityId s) = toField $ Binary s

deriving newtype instance FromField EntityId

#if !defined(dbPostgres)
instance ToField (C.DhSecret 'C.X25519) where toField = toField . Binary . C.dhBytes'

instance FromField (C.DhSecret 'C.X25519) where fromField = blobFieldDecoder strDecode

instance ToField C.APublicAuthKey where toField = toField . Binary . C.encodePubKey

instance FromField C.APublicAuthKey where fromField = blobFieldDecoder C.decodePubKey
#endif
7 changes: 7 additions & 0 deletions src/Simplex/Messaging/Util.hs
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,13 @@ mapAccumLM_NonEmpty
mapAccumLM_NonEmpty f s (x :| xs) =
[(s2, x' :| xs') | (s1, x') <- f s x, (s2, xs') <- mapAccumLM_List f s1 xs]

tryWriteTBQueue :: TBQueue a -> a -> STM Bool
tryWriteTBQueue q a = do
full <- isFullTBQueue q
unless full $ writeTBQueue q a
pure $ not full
{-# INLINE tryWriteTBQueue #-}

catchAll :: IO a -> (E.SomeException -> IO a) -> IO a
catchAll = E.catch
{-# INLINE catchAll #-}
Expand Down
1 change: 1 addition & 0 deletions tests/SMPClient.hs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ testStoreDBOpts =
DBOpts
{ connstr = testServerDBConnstr,
schema = "smp_server",
poolSize = 3,
createSchema = True
}

Expand Down
4 changes: 3 additions & 1 deletion tests/ServerTests.hs
Original file line number Diff line number Diff line change
Expand Up @@ -1019,7 +1019,8 @@ testMsgNOTExpireOnInterval =

testBlockMessageQueue :: SpecWith (ATransport, AStoreType)
testBlockMessageQueue =
it "should return BLOCKED error when queue is blocked" $ \(at@(ATransport (t :: TProxy c)), msType) -> do
-- TODO [postgres]
xit "should return BLOCKED error when queue is blocked" $ \(at@(ATransport (t :: TProxy c)), msType) -> do
g <- C.newRandom
(rId, sId) <- withSmpServerStoreLogOnMS at msType testPort $ runTest t $ \h -> do
(rPub, rKey) <- atomically $ C.generateAuthKeyPair C.SEd448 g
Expand All @@ -1028,6 +1029,7 @@ testBlockMessageQueue =
(rId1, NoEntity) #== "creates queue"
pure (rId, sId)

-- TODO [postgres] block via control port
withFile testStoreLogFile AppendMode $ \h -> B.hPutStrLn h $ strEncode $ BlockQueue rId $ BlockingInfo BRContent

withSmpServerStoreLogOnMS at msType testPort $ runTest t $ \h -> do
Expand Down

0 comments on commit 1725409

Please sign in to comment.