@@ -6,12 +6,13 @@ use std::collections::HashMap;
6
6
use std:: error:: Error ;
7
7
use std:: fmt;
8
8
use std:: hash:: { BuildHasher , Hash , Hasher } ;
9
- use std:: sync:: atomic:: { AtomicBool , Ordering } ;
9
+ use std:: sync:: atomic:: { AtomicBool , AtomicUsize , Ordering } ;
10
10
use std:: sync:: Arc ;
11
11
use std:: time:: Duration ;
12
12
13
13
use tokio:: sync:: { Mutex , RwLock } ;
14
14
use tokio:: task:: JoinHandle ;
15
+ use tokio:: time:: Instant ;
15
16
use tracing:: info;
16
17
17
18
type InnerLockTable < K > = HashMap < K , Arc < tokio:: sync:: Mutex < ( ) > > > ;
@@ -22,6 +23,7 @@ pub struct MutexTable<K: Hash> {
22
23
_k : std:: marker:: PhantomData < K > ,
23
24
_cleaner : JoinHandle < ( ) > ,
24
25
stop : Arc < AtomicBool > ,
26
+ size : Arc < AtomicUsize > ,
25
27
}
26
28
27
29
#[ derive( Debug ) ]
@@ -46,6 +48,7 @@ impl<K: Hash + std::cmp::Eq + Send + Sync + 'static> MutexTable<K> {
46
48
shard_size : usize ,
47
49
cleanup_period : Duration ,
48
50
cleanup_initial_delay : Duration ,
51
+ cleanup_entries_threshold : usize ,
49
52
) -> Self {
50
53
let lock_table: Arc < Vec < RwLock < InnerLockTable < K > > > > = Arc :: new (
51
54
( 0 ..num_shards)
@@ -56,19 +59,29 @@ impl<K: Hash + std::cmp::Eq + Send + Sync + 'static> MutexTable<K> {
56
59
let cloned = lock_table. clone ( ) ;
57
60
let stop = Arc :: new ( AtomicBool :: new ( false ) ) ;
58
61
let stop_cloned = stop. clone ( ) ;
62
+ let size: Arc < AtomicUsize > = Arc :: new ( AtomicUsize :: new ( 0 ) ) ;
63
+ let size_cloned = size. clone ( ) ;
59
64
Self {
60
65
random_state : RandomState :: new ( ) ,
61
66
lock_table,
62
67
_k : std:: marker:: PhantomData { } ,
63
68
_cleaner : tokio:: spawn ( async move {
64
69
tokio:: time:: sleep ( cleanup_initial_delay) . await ;
70
+ let mut previous_cleanup_instant = Instant :: now ( ) ;
65
71
while !stop_cloned. load ( Ordering :: SeqCst ) {
66
- Self :: cleanup ( cloned. clone ( ) ) ;
67
- tokio:: time:: sleep ( cleanup_period) . await ;
72
+ if size_cloned. load ( Ordering :: SeqCst ) >= cleanup_entries_threshold
73
+ || previous_cleanup_instant. elapsed ( ) >= cleanup_period
74
+ {
75
+ let num_removed = Self :: cleanup ( cloned. clone ( ) ) ;
76
+ size_cloned. fetch_sub ( num_removed, Ordering :: SeqCst ) ;
77
+ previous_cleanup_instant = Instant :: now ( ) ;
78
+ }
79
+ tokio:: time:: sleep ( Duration :: from_secs ( 1 ) ) . await ;
68
80
}
69
81
info ! ( "Stopping mutex table cleanup!" ) ;
70
82
} ) ,
71
83
stop,
84
+ size,
72
85
}
73
86
}
74
87
@@ -78,10 +91,16 @@ impl<K: Hash + std::cmp::Eq + Send + Sync + 'static> MutexTable<K> {
78
91
shard_size,
79
92
Duration :: from_secs ( 10 ) ,
80
93
Duration :: from_secs ( 10 ) ,
94
+ 10_000 ,
81
95
)
82
96
}
83
97
84
- pub fn cleanup ( lock_table : Arc < Vec < RwLock < InnerLockTable < K > > > > ) {
98
+ pub fn size ( & self ) -> usize {
99
+ self . size . load ( Ordering :: SeqCst )
100
+ }
101
+
102
+ pub fn cleanup ( lock_table : Arc < Vec < RwLock < InnerLockTable < K > > > > ) -> usize {
103
+ let mut num_removed: usize = 0 ;
85
104
for shard in lock_table. iter ( ) {
86
105
let map = shard. try_write ( ) ;
87
106
if map. is_err ( ) {
@@ -93,12 +112,18 @@ impl<K: Hash + std::cmp::Eq + Send + Sync + 'static> MutexTable<K> {
93
112
// This check is also likely sufficient e.g. you don't even need try_lock below, but keeping it just in case
94
113
if Arc :: strong_count ( v) == 1 {
95
114
let mutex_guard = v. try_lock ( ) ;
96
- mutex_guard. is_err ( )
115
+ if mutex_guard. is_ok ( ) {
116
+ num_removed += 1 ;
117
+ false
118
+ } else {
119
+ true
120
+ }
97
121
} else {
98
122
true
99
123
}
100
124
} ) ;
101
125
}
126
+ num_removed
102
127
}
103
128
104
129
fn get_lock_idx ( & self , key : & K ) -> usize {
@@ -144,7 +169,10 @@ impl<K: Hash + std::cmp::Eq + Send + Sync + 'static> MutexTable<K> {
144
169
let element = {
145
170
let mut map = self . lock_table [ lock_idx] . write ( ) . await ;
146
171
map. entry ( k)
147
- . or_insert_with ( || Arc :: new ( Mutex :: new ( ( ) ) ) )
172
+ . or_insert_with ( || {
173
+ self . size . fetch_add ( 1 , Ordering :: SeqCst ) ;
174
+ Arc :: new ( Mutex :: new ( ( ) ) )
175
+ } )
148
176
. clone ( )
149
177
} ;
150
178
LockGuard ( element. lock_owned ( ) . await )
@@ -171,7 +199,10 @@ impl<K: Hash + std::cmp::Eq + Send + Sync + 'static> MutexTable<K> {
171
199
. try_write ( )
172
200
. map_err ( |_| TryAcquireLockError :: LockTableLocked ) ?;
173
201
map. entry ( k)
174
- . or_insert_with ( || Arc :: new ( Mutex :: new ( ( ) ) ) )
202
+ . or_insert_with ( || {
203
+ self . size . fetch_add ( 1 , Ordering :: SeqCst ) ;
204
+ Arc :: new ( Mutex :: new ( ( ) ) )
205
+ } )
175
206
. clone ( )
176
207
} ;
177
208
let lock = element. try_lock_owned ( ) ;
@@ -225,8 +256,13 @@ async fn test_mutex_table_concurrent_in_same_bucket() {
225
256
#[ tokio:: test]
226
257
async fn test_mutex_table ( ) {
227
258
// Disable bg cleanup with Duration.MAX for initial delay
228
- let mutex_table =
229
- MutexTable :: < String > :: new_with_cleanup ( 1 , 128 , Duration :: from_secs ( 10 ) , Duration :: MAX ) ;
259
+ let mutex_table = MutexTable :: < String > :: new_with_cleanup (
260
+ 1 ,
261
+ 128 ,
262
+ Duration :: from_secs ( 10 ) ,
263
+ Duration :: MAX ,
264
+ 1000 ,
265
+ ) ;
230
266
let john1 = mutex_table. try_acquire_lock ( "john" . to_string ( ) ) ;
231
267
assert ! ( john1. is_ok( ) ) ;
232
268
let john2 = mutex_table. try_acquire_lock ( "john" . to_string ( ) ) ;
@@ -259,6 +295,7 @@ async fn test_mutex_table_bg_cleanup() {
259
295
128 ,
260
296
Duration :: from_secs ( 5 ) ,
261
297
Duration :: from_secs ( 1 ) ,
298
+ 1000 ,
262
299
) ;
263
300
let lock1 = mutex_table. try_acquire_lock ( "lock1" . to_string ( ) ) ;
264
301
let lock2 = mutex_table. try_acquire_lock ( "lock2" . to_string ( ) ) ;
@@ -296,3 +333,49 @@ async fn test_mutex_table_bg_cleanup() {
296
333
assert ! ( locked. is_empty( ) ) ;
297
334
}
298
335
}
336
+
337
+ #[ tokio:: test( flavor = "current_thread" , start_paused = true ) ]
338
+ async fn test_mutex_table_bg_cleanup_with_size_threshold ( ) {
339
+ // set up the table to never trigger cleanup because of time period but only size threshold
340
+ let mutex_table =
341
+ MutexTable :: < String > :: new_with_cleanup ( 1 , 128 , Duration :: MAX , Duration :: from_secs ( 1 ) , 5 ) ;
342
+ let lock1 = mutex_table. try_acquire_lock ( "lock1" . to_string ( ) ) ;
343
+ let lock2 = mutex_table. try_acquire_lock ( "lock2" . to_string ( ) ) ;
344
+ let lock3 = mutex_table. try_acquire_lock ( "lock3" . to_string ( ) ) ;
345
+ let lock4 = mutex_table. try_acquire_lock ( "lock4" . to_string ( ) ) ;
346
+ let lock5 = mutex_table. try_acquire_lock ( "lock5" . to_string ( ) ) ;
347
+ assert ! ( lock1. is_ok( ) ) ;
348
+ assert ! ( lock2. is_ok( ) ) ;
349
+ assert ! ( lock3. is_ok( ) ) ;
350
+ assert ! ( lock4. is_ok( ) ) ;
351
+ assert ! ( lock5. is_ok( ) ) ;
352
+ // Trigger cleanup
353
+ MutexTable :: cleanup ( mutex_table. lock_table . clone ( ) ) ;
354
+ // Try acquiring locks again, these should still fail because locks have not been released
355
+ let lock11 = mutex_table. try_acquire_lock ( "lock1" . to_string ( ) ) ;
356
+ let lock22 = mutex_table. try_acquire_lock ( "lock2" . to_string ( ) ) ;
357
+ let lock33 = mutex_table. try_acquire_lock ( "lock3" . to_string ( ) ) ;
358
+ let lock44 = mutex_table. try_acquire_lock ( "lock4" . to_string ( ) ) ;
359
+ let lock55 = mutex_table. try_acquire_lock ( "lock5" . to_string ( ) ) ;
360
+ assert ! ( lock11. is_err( ) ) ;
361
+ assert ! ( lock22. is_err( ) ) ;
362
+ assert ! ( lock33. is_err( ) ) ;
363
+ assert ! ( lock44. is_err( ) ) ;
364
+ assert ! ( lock55. is_err( ) ) ;
365
+ assert_eq ! ( mutex_table. size( ) , 5 ) ;
366
+ // drop all locks
367
+ drop ( lock1) ;
368
+ drop ( lock2) ;
369
+ drop ( lock3) ;
370
+ drop ( lock4) ;
371
+ drop ( lock5) ;
372
+ tokio:: task:: yield_now ( ) . await ;
373
+ // Wait for bg cleanup to be triggered because of size threshold
374
+ tokio:: time:: advance ( Duration :: from_secs ( 5 ) ) . await ;
375
+ tokio:: task:: yield_now ( ) . await ;
376
+ assert_eq ! ( mutex_table. size( ) , 0 ) ;
377
+ for entry in mutex_table. lock_table . iter ( ) {
378
+ let locked = entry. read ( ) . await ;
379
+ assert ! ( locked. is_empty( ) ) ;
380
+ }
381
+ }
0 commit comments