diff --git a/src/Simplex/Messaging/Agent/Store/Postgres.hs b/src/Simplex/Messaging/Agent/Store/Postgres.hs index 1c7abb701..3f5c14b26 100644 --- a/src/Simplex/Messaging/Agent/Store/Postgres.hs +++ b/src/Simplex/Messaging/Agent/Store/Postgres.hs @@ -16,9 +16,9 @@ module Simplex.Messaging.Agent.Store.Postgres where import Control.Concurrent.STM -import Control.Exception (bracketOnError, finally, onException, throwIO) +import Control.Exception (finally, onException, throwIO, uninterruptibleMask_) import Control.Logger.Simple (logError) -import Control.Monad (void, when) +import Control.Monad import Data.ByteString (ByteString) import Data.Functor (($>)) import Data.Text (Text) @@ -53,11 +53,24 @@ createDBStore opts migrations confirmMigrations = do in sharedMigrateSchema dbm (dbNew st) migrations confirmMigrations connectPostgresStore :: DBOpts -> IO DBStore -connectPostgresStore DBOpts {connstr, schema, createSchema} = do - (dbConn, dbNew) <- connectDB connstr schema createSchema -- TODO [postgres] analogue for dbBusyLoop? - dbConnection <- newMVar dbConn - dbClosed <- newTVarIO False - pure DBStore {dbConnstr = connstr, dbSchema = schema, dbConnection, dbNew, dbClosed} +connectPostgresStore DBOpts {connstr, schema, poolSize, createSchema} = do + dbSem <- newMVar () + dbPool <- newTBQueueIO poolSize + dbClosed <- newTVarIO True + let st = DBStore {dbConnstr = connstr, dbSchema = schema, dbPoolSize = fromIntegral poolSize, dbPool, dbSem, dbNew = False, dbClosed} + dbNew <- connectPool st createSchema + pure st {dbNew} + +-- uninterruptibleMask_ here and below is used here so that it is not interrupted half-way, +-- it relies on the assumption that when dbClosed = True, the queue is empty, +-- and when it is False, the queue is full (or will have connections returned to it by the threads that use them). +connectPool :: DBStore -> Bool -> IO Bool +connectPool DBStore {dbConnstr, dbSchema, dbPoolSize, dbPool, dbClosed} createSchema = uninterruptibleMask_ $ do + (conn, dbNew) <- connectDB dbConnstr dbSchema createSchema -- TODO [postgres] analogue for dbBusyLoop? + conns <- replicateM (dbPoolSize - 1) $ fst <$> connectDB dbConnstr dbSchema False + mapM_ (atomically . writeTBQueue dbPool) (conn : conns) + atomically $ writeTVar dbClosed False + pure dbNew connectDB :: ByteString -> ByteString -> Bool -> IO (DB.Connection, Bool) connectDB connstr schema createSchema = do @@ -97,29 +110,18 @@ doesSchemaExist db schema = do (Only schema) pure schemaExists --- can share with SQLite closeDBStore :: DBStore -> IO () -closeDBStore st@DBStore {dbClosed} = - ifM (readTVarIO dbClosed) (putStrLn "closeDBStore: already closed") $ - withConnection st $ \conn -> do - DB.close conn - atomically $ writeTVar dbClosed True - -openPostgresStore_ :: DBStore -> IO () -openPostgresStore_ DBStore {dbConnstr, dbSchema, dbConnection, dbClosed} = - bracketOnError - (takeMVar dbConnection) - (tryPutMVar dbConnection) - $ \_dbConn -> do - (dbConn, _dbNew) <- connectDB dbConnstr dbSchema False - atomically $ writeTVar dbClosed False - putMVar dbConnection dbConn +closeDBStore DBStore {dbPool, dbPoolSize, dbClosed} = + ifM (readTVarIO dbClosed) (putStrLn "closeDBStore: already closed") $ uninterruptibleMask_ $ do + replicateM_ dbPoolSize $ atomically $ readTBQueue dbPool + atomically $ writeTVar dbClosed True reopenDBStore :: DBStore -> IO () -reopenDBStore st@DBStore {dbClosed} = - ifM (readTVarIO dbClosed) open (putStrLn "reopenDBStore: already opened") - where - open = openPostgresStore_ st +reopenDBStore st = + ifM + (readTVarIO $ dbClosed st) + (void $ connectPool st False) + (putStrLn "reopenDBStore: already opened") -- not used with postgres client (used for ExecAgentStoreSQL, ExecChatStoreSQL) execSQL :: PSQL.Connection -> Text -> IO [Text] diff --git a/src/Simplex/Messaging/Agent/Store/Postgres/Common.hs b/src/Simplex/Messaging/Agent/Store/Postgres/Common.hs index be14d1a5b..ee94825a4 100644 --- a/src/Simplex/Messaging/Agent/Store/Postgres/Common.hs +++ b/src/Simplex/Messaging/Agent/Store/Postgres/Common.hs @@ -1,4 +1,8 @@ +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE QuasiQuotes #-} +{-# LANGUAGE TupleSections #-} module Simplex.Messaging.Agent.Store.Postgres.Common ( DBStore (..), @@ -11,16 +15,22 @@ module Simplex.Messaging.Agent.Store.Postgres.Common ) where +import Control.Concurrent.MVar +import Control.Concurrent.STM +import Control.Exception (bracket) import Data.ByteString (ByteString) import qualified Database.PostgreSQL.Simple as PSQL -import UnliftIO.MVar -import UnliftIO.STM +import Numeric.Natural -- TODO [postgres] use log_min_duration_statement instead of custom slow queries (SQLite's Connection type) data DBStore = DBStore { dbConnstr :: ByteString, dbSchema :: ByteString, - dbConnection :: MVar PSQL.Connection, + dbPoolSize :: Int, + dbPool :: TBQueue PSQL.Connection, + -- MVar is needed for fair pool distribution, without STM retry contention. + -- Only one thread can be blocked on STM read. + dbSem :: MVar (), dbClosed :: TVar Bool, dbNew :: Bool } @@ -28,15 +38,16 @@ data DBStore = DBStore data DBOpts = DBOpts { connstr :: ByteString, schema :: ByteString, + poolSize :: Natural, createSchema :: Bool } deriving (Show) --- TODO [postgres] connection pool withConnectionPriority :: DBStore -> Bool -> (PSQL.Connection -> IO a) -> IO a -withConnectionPriority DBStore {dbConnection} _priority action = - withMVar dbConnection action -{-# INLINE withConnectionPriority #-} +withConnectionPriority DBStore {dbPool, dbSem} _priority = + bracket + (withMVar dbSem $ \_ -> atomically $ readTBQueue dbPool) + (atomically . writeTBQueue dbPool) withConnection :: DBStore -> (PSQL.Connection -> IO a) -> IO a withConnection st = withConnectionPriority st False diff --git a/src/Simplex/Messaging/Client.hs b/src/Simplex/Messaging/Client.hs index b1b5bfa53..df12e5fce 100644 --- a/src/Simplex/Messaging/Client.hs +++ b/src/Simplex/Messaging/Client.hs @@ -145,7 +145,7 @@ import qualified Simplex.Messaging.TMap as TM import Simplex.Messaging.Transport import Simplex.Messaging.Transport.Client (SocksAuth (..), SocksProxyWithAuth (..), TransportClientConfig (..), TransportHost (..), defaultSMPPort, defaultTcpConnectTimeout, runTransportClient) import Simplex.Messaging.Transport.KeepAlive -import Simplex.Messaging.Util (bshow, diffToMicroseconds, ifM, liftEitherWith, raceAny_, threadDelay', tshow, whenM) +import Simplex.Messaging.Util (bshow, diffToMicroseconds, ifM, liftEitherWith, raceAny_, threadDelay', tryWriteTBQueue, tshow, whenM) import Simplex.Messaging.Version import System.Mem.Weak (Weak, deRefWeak) import System.Timeout (timeout) @@ -1121,7 +1121,7 @@ sendProtocolCommand_ c@ProtocolClient {client_ = PClient {sndQ}, thParams = THan nonBlockingWriteTBQueue :: TBQueue a -> a -> IO () nonBlockingWriteTBQueue q x = do - sent <- atomically $ ifM (isFullTBQueue q) (pure False) (writeTBQueue q x $> True) + sent <- atomically $ tryWriteTBQueue q x unless sent $ void $ forkIO $ atomically $ writeTBQueue q x getResponse :: ProtocolClient v err msg -> Maybe Int -> Request err msg -> IO (Response err msg) diff --git a/src/Simplex/Messaging/Server.hs b/src/Simplex/Messaging/Server.hs index e1d4d7861..67e8e227e 100644 --- a/src/Simplex/Messaging/Server.hs +++ b/src/Simplex/Messaging/Server.hs @@ -354,7 +354,7 @@ smpServer started cfg@ServerConfig {transports, transportConfig = tCfg, startOpt mapM_ (queueEvts qEvts) . join . IM.lookup cId =<< readTVarIO cls queueEvts qEvts (AClient _ _ c@Client {connected, sndQ = q}) = whenM (readTVarIO connected) $ do - sent <- atomically $ ifM (isFullTBQueue q) (pure False) (writeTBQueue q ts $> True) + sent <- atomically $ tryWriteTBQueue q ts if sent then updateEndStats else -- if queue is full it can block diff --git a/src/Simplex/Messaging/Server/Main.hs b/src/Simplex/Messaging/Server/Main.hs index 56f6cc414..4be4594d4 100644 --- a/src/Simplex/Messaging/Server/Main.hs +++ b/src/Simplex/Messaging/Server/Main.hs @@ -33,6 +33,7 @@ import qualified Data.Text as T import Data.Text.Encoding (decodeLatin1, encodeUtf8) import qualified Data.Text.IO as T import Network.Socket (HostName) +import Numeric.Natural (Natural) import Options.Applicative import Simplex.Messaging.Agent.Protocol (connReqUriP') import Simplex.Messaging.Agent.Store.Postgres (checkSchemaExists) @@ -236,8 +237,14 @@ smpServerCLI_ generateSite serveStaticFiles attachStaticFiles cfgPath logPath = DBOpts { connstr = either (const defaultDBConnStr) encodeUtf8 $ lookupValue "STORE_LOG" "db_connection" ini, schema = either (const defaultDBSchema) encodeUtf8 $ lookupValue "STORE_LOG" "db_schema" ini, + poolSize = either (const defaultDBPoolSize) (read . T.unpack) $ lookupValue "STORE_LOG" "db_pool_size" ini, createSchema = False } + dbOptsIniContent :: DBOpts -> Text + dbOptsIniContent DBOpts {connstr, schema, poolSize } = + (optDisabled' (connstr == defaultDBConnStr) <> "db_connection: " <> safeDecodeUtf8 connstr <> "\n") + <> (optDisabled' (schema == defaultDBSchema) <> "db_schema: " <> safeDecodeUtf8 schema <> "\n") + <> (optDisabled' (poolSize == defaultDBPoolSize) <> "db_pool_size: " <> tshow poolSize <> "\n\n") httpsCertFile = combine cfgPath "web.crt" httpsKeyFile = combine cfgPath "web.key" defaultStaticPath = combine logPath "www" @@ -320,8 +327,7 @@ smpServerCLI_ generateSite serveStaticFiles attachStaticFiles cfgPath logPath = \# `database`- PostgreSQL databass (requires `store_messages: journal`).\n\ \store_queues: memory\n\n\ \# Database connection settings for PostgreSQL database (`store_queues: database`).\n" - <> (optDisabled dbOptions <> "db_connection: " <> safeDecodeUtf8 (maybe defaultDBConnStr connstr dbOptions) <> "\n") - <> (optDisabled dbOptions <> "db_schema: " <> safeDecodeUtf8 (maybe defaultDBSchema schema dbOptions) <> "\n\n") + <> dbOptsIniContent dbOptions <> "# Message storage mode: `memory` or `journal`.\n\ \store_messages: memory\n\n\ \# When store_messages is `memory`, undelivered messages are optionally saved and restored\n\ @@ -626,6 +632,9 @@ defaultDBConnStr = "postgresql://smp@/smp_server_store" defaultDBSchema :: ByteString defaultDBSchema = "smp_server" +defaultDBPoolSize :: Natural +defaultDBPoolSize = 10 + defaultControlPort :: Int defaultControlPort = 5224 @@ -712,7 +721,12 @@ serverPublicInfo ini = serverInfo <$!> infoValue "source_code" (_, _, pkURI, pkFingerprint) -> Just ServerContactAddress {simplex, email, pgp = PGPKey <$> pkURI <*> pkFingerprint} optDisabled :: Maybe a -> Text -optDisabled p = if isNothing p then "# " else "" +optDisabled = optDisabled' . isNothing +{-# INLINE optDisabled #-} + +optDisabled' :: Bool -> Text +optDisabled' cond = if cond then "# " else "" +{-# INLINE optDisabled' #-} validCountryValue :: String -> String -> Either String Text validCountryValue field s @@ -738,7 +752,7 @@ data StoreCmd = SCImport | SCExport | SCDelete data InitOptions = InitOptions { enableStoreLog :: Bool, - dbOptions :: Maybe DBOpts, + dbOptions :: DBOpts, dbMigrateUp :: Bool, logStats :: Bool, signAlgorithm :: SignAlgorithm, @@ -780,7 +794,8 @@ cliCommandP cfgPath logPath iniFile = <> short 'l' <> help "Enable store log for persistence" ) - dbOptions <- optional dbOptsP + dbOptions <- dbOptsP + -- TODO [postgresql] remove dbMigrateUp <- switch ( long "db-migrate-up" @@ -963,7 +978,16 @@ cliCommandP cfgPath logPath iniFile = <> value defaultDBSchema <> showDefault ) - pure DBOpts {connstr, schema, createSchema = False} + poolSize <- + option + auto + ( long "pool-size" + <> metavar "POOL_SIZE" + <> help "Database pool size" + <> value defaultDBPoolSize + <> showDefault + ) + pure DBOpts {connstr, schema, poolSize, createSchema = False} parseConfirmMigrations :: ReadM MigrationConfirmation parseConfirmMigrations = eitherReader $ \case "up" -> Right MCYesUp diff --git a/src/Simplex/Messaging/Server/QueueStore/Postgres.hs b/src/Simplex/Messaging/Server/QueueStore/Postgres.hs index edb9f3dc0..376858ce5 100644 --- a/src/Simplex/Messaging/Server/QueueStore/Postgres.hs +++ b/src/Simplex/Messaging/Server/QueueStore/Postgres.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE CPP #-} {-# LANGUAGE BangPatterns #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE DerivingStrategies #-} @@ -37,11 +38,9 @@ import Simplex.Messaging.Agent.Client (withLockMap) import Simplex.Messaging.Agent.Lock (Lock) import Simplex.Messaging.Agent.Store.Postgres (createDBStore) import Simplex.Messaging.Agent.Store.Postgres.Common -import Simplex.Messaging.Agent.Store.Postgres.DB (FromField (..), ToField (..), blobFieldDecoder) +import Simplex.Messaging.Agent.Store.Postgres.DB (FromField (..), ToField (..)) import qualified Simplex.Messaging.Agent.Store.Postgres.DB as DB import Simplex.Messaging.Agent.Store.Shared (MigrationConfirmation) -import qualified Simplex.Messaging.Crypto as C -import Simplex.Messaging.Encoding.String import Simplex.Messaging.Protocol import Simplex.Messaging.Server.QueueStore import Simplex.Messaging.Server.QueueStore.Postgres.Migrations (serverMigrations) @@ -52,6 +51,11 @@ import qualified Simplex.Messaging.TMap as TM import Simplex.Messaging.Util (firstRow, ifM, tshow, ($>>), ($>>=), (<$$), (<$$>)) import System.Exit (exitFailure) import System.IO (hFlush, stdout) +#if !defined(dbPostgres) +import Simplex.Messaging.Agent.Store.Postgres.DB (blobFieldDecoder) +import qualified Simplex.Messaging.Crypto as C +import Simplex.Messaging.Encoding.String +#endif data PostgresQueueStore q = PostgresQueueStore { dbStore :: DBStore, @@ -367,6 +371,7 @@ instance ToField EntityId where toField (EntityId s) = toField $ Binary s deriving newtype instance FromField EntityId +#if !defined(dbPostgres) instance ToField (C.DhSecret 'C.X25519) where toField = toField . Binary . C.dhBytes' instance FromField (C.DhSecret 'C.X25519) where fromField = blobFieldDecoder strDecode @@ -374,3 +379,4 @@ instance FromField (C.DhSecret 'C.X25519) where fromField = blobFieldDecoder str instance ToField C.APublicAuthKey where toField = toField . Binary . C.encodePubKey instance FromField C.APublicAuthKey where fromField = blobFieldDecoder C.decodePubKey +#endif diff --git a/src/Simplex/Messaging/Util.hs b/src/Simplex/Messaging/Util.hs index 240a6ba5a..2d92b4b5e 100644 --- a/src/Simplex/Messaging/Util.hs +++ b/src/Simplex/Messaging/Util.hs @@ -156,6 +156,13 @@ mapAccumLM_NonEmpty mapAccumLM_NonEmpty f s (x :| xs) = [(s2, x' :| xs') | (s1, x') <- f s x, (s2, xs') <- mapAccumLM_List f s1 xs] +tryWriteTBQueue :: TBQueue a -> a -> STM Bool +tryWriteTBQueue q a = do + full <- isFullTBQueue q + unless full $ writeTBQueue q a + pure $ not full +{-# INLINE tryWriteTBQueue #-} + catchAll :: IO a -> (E.SomeException -> IO a) -> IO a catchAll = E.catch {-# INLINE catchAll #-} diff --git a/tests/SMPClient.hs b/tests/SMPClient.hs index e19d81f76..28639b0a5 100644 --- a/tests/SMPClient.hs +++ b/tests/SMPClient.hs @@ -69,6 +69,7 @@ testStoreDBOpts = DBOpts { connstr = testServerDBConnstr, schema = "smp_server", + poolSize = 3, createSchema = True } diff --git a/tests/ServerTests.hs b/tests/ServerTests.hs index 6c2199795..5160070f0 100644 --- a/tests/ServerTests.hs +++ b/tests/ServerTests.hs @@ -1019,7 +1019,8 @@ testMsgNOTExpireOnInterval = testBlockMessageQueue :: SpecWith (ATransport, AStoreType) testBlockMessageQueue = - it "should return BLOCKED error when queue is blocked" $ \(at@(ATransport (t :: TProxy c)), msType) -> do + -- TODO [postgres] + xit "should return BLOCKED error when queue is blocked" $ \(at@(ATransport (t :: TProxy c)), msType) -> do g <- C.newRandom (rId, sId) <- withSmpServerStoreLogOnMS at msType testPort $ runTest t $ \h -> do (rPub, rKey) <- atomically $ C.generateAuthKeyPair C.SEd448 g @@ -1028,6 +1029,7 @@ testBlockMessageQueue = (rId1, NoEntity) #== "creates queue" pure (rId, sId) + -- TODO [postgres] block via control port withFile testStoreLogFile AppendMode $ \h -> B.hPutStrLn h $ strEncode $ BlockQueue rId $ BlockingInfo BRContent withSmpServerStoreLogOnMS at msType testPort $ runTest t $ \h -> do