Skip to content

Commit 3a79c2f

Browse files
Lorak-mmkwprzytula
andcommitted
Use shard from query plan during execution
Co-authored-by: Wojciech Przytuła <[email protected]>
1 parent 536d0a2 commit 3a79c2f

File tree

4 files changed

+33
-79
lines changed

4 files changed

+33
-79
lines changed

scylla/src/transport/connection_pool.rs

+15-12
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#[cfg(feature = "cloud")]
22
use crate::cloud::set_ssl_config_for_scylla_cloud_host;
33

4-
use crate::routing::{Shard, ShardCount, Sharder, Token};
4+
use crate::routing::{Shard, ShardCount, Sharder};
55
use crate::transport::errors::QueryError;
66
use crate::transport::{
77
connection,
@@ -28,7 +28,7 @@ use std::time::Duration;
2828

2929
use tokio::sync::{broadcast, mpsc, Notify};
3030
use tracing::instrument::WithSubscriber;
31-
use tracing::{debug, trace, warn};
31+
use tracing::{debug, error, trace, warn};
3232

3333
/// The target size of a per-node connection pool.
3434
#[derive(Debug, Clone, Copy)]
@@ -235,22 +235,25 @@ impl NodeConnectionPool {
235235
.unwrap_or(None)
236236
}
237237

238-
pub(crate) fn connection_for_token(&self, token: Token) -> Result<Arc<Connection>, QueryError> {
239-
trace!(token = token.value, "Selecting connection for token");
238+
pub(crate) fn connection_for_shard(&self, shard: Shard) -> Result<Arc<Connection>, QueryError> {
239+
trace!(shard = shard, "Selecting connection for shard");
240240
self.with_connections(|pool_conns| match pool_conns {
241241
PoolConnections::NotSharded(conns) => {
242242
Self::choose_random_connection_from_slice(conns).unwrap()
243243
}
244244
PoolConnections::Sharded {
245-
sharder,
246245
connections,
246+
sharder
247247
} => {
248-
let shard: u16 = sharder
249-
.shard_of(token)
248+
let shard = shard
250249
.try_into()
251-
.expect("Shard number doesn't fit in u16");
252-
trace!(shard = shard, "Selecting connection for token");
253-
Self::connection_for_shard(shard, sharder.nr_shards, connections.as_slice())
250+
// It's safer to use 0 rather that panic here, as shards are returned by `LoadBalancingPolicy`
251+
// now, which can be implemented by a user in an arbitrary way.
252+
.unwrap_or_else(|_| {
253+
error!("The provided shard number: {} does not fit u16! Using 0 as the shard number.", shard);
254+
0
255+
});
256+
Self::connection_for_shard_helper(shard, sharder.nr_shards, connections.as_slice())
254257
}
255258
})
256259
}
@@ -266,13 +269,13 @@ impl NodeConnectionPool {
266269
connections,
267270
} => {
268271
let shard: u16 = rand::thread_rng().gen_range(0..sharder.nr_shards.get());
269-
Self::connection_for_shard(shard, sharder.nr_shards, connections.as_slice())
272+
Self::connection_for_shard_helper(shard, sharder.nr_shards, connections.as_slice())
270273
}
271274
})
272275
}
273276

274277
// Tries to get a connection to given shard, if it's broken returns any working connection
275-
fn connection_for_shard(
278+
fn connection_for_shard_helper(
276279
shard: u16,
277280
nr_shards: ShardCount,
278281
shard_conns: &[Vec<Arc<Connection>>],

scylla/src/transport/iterator.rs

+6-23
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ use crate::transport::connection::{Connection, NonErrorQueryResponse, QueryRespo
3535
use crate::transport::load_balancing::{self, RoutingInfo};
3636
use crate::transport::metrics::Metrics;
3737
use crate::transport::retry_policy::{QueryInfo, RetryDecision, RetrySession};
38-
use crate::transport::{Node, NodeRef};
38+
use crate::transport::NodeRef;
3939
use tracing::{trace, trace_span, warn, Instrument};
4040
use uuid::Uuid;
4141

@@ -160,8 +160,6 @@ impl RowIterator {
160160
let worker_task = async move {
161161
let query_ref = &query;
162162

163-
let choose_connection = |node: Arc<Node>| async move { node.random_connection().await };
164-
165163
let page_query = |connection: Arc<Connection>,
166164
consistency: Consistency,
167165
paging_state: Option<Bytes>| {
@@ -187,7 +185,6 @@ impl RowIterator {
187185

188186
let worker = RowIteratorWorker {
189187
sender: sender.into(),
190-
choose_connection,
191188
page_query,
192189
statement_info: routing_info,
193190
query_is_idempotent: query.config.is_idempotent,
@@ -259,13 +256,6 @@ impl RowIterator {
259256
is_confirmed_lwt: config.prepared.is_confirmed_lwt(),
260257
};
261258

262-
let choose_connection = |node: Arc<Node>| async move {
263-
match token {
264-
Some(token) => node.connection_for_token(token).await,
265-
None => node.random_connection().await,
266-
}
267-
};
268-
269259
let page_query = |connection: Arc<Connection>,
270260
consistency: Consistency,
271261
paging_state: Option<Bytes>| async move {
@@ -311,7 +301,6 @@ impl RowIterator {
311301

312302
let worker = RowIteratorWorker {
313303
sender: sender.into(),
314-
choose_connection,
315304
page_query,
316305
statement_info,
317306
query_is_idempotent: config.prepared.config.is_idempotent,
@@ -496,13 +485,9 @@ type PageSendAttemptedProof = SendAttemptedProof<Result<ReceivedPage, QueryError
496485

497486
// RowIteratorWorker works in the background to fetch pages
498487
// RowIterator receives them through a channel
499-
struct RowIteratorWorker<'a, ConnFunc, QueryFunc, SpanCreatorFunc> {
488+
struct RowIteratorWorker<'a, QueryFunc, SpanCreatorFunc> {
500489
sender: ProvingSender<Result<ReceivedPage, QueryError>>,
501490

502-
// Closure used to choose a connection from a node
503-
// AsyncFn(Arc<Node>) -> Result<Arc<Connection>, QueryError>
504-
choose_connection: ConnFunc,
505-
506491
// Closure used to perform a single page query
507492
// AsyncFn(Arc<Connection>, Option<Bytes>) -> Result<QueryResponse, QueryError>
508493
page_query: QueryFunc,
@@ -524,11 +509,8 @@ struct RowIteratorWorker<'a, ConnFunc, QueryFunc, SpanCreatorFunc> {
524509
span_creator: SpanCreatorFunc,
525510
}
526511

527-
impl<ConnFunc, ConnFut, QueryFunc, QueryFut, SpanCreator>
528-
RowIteratorWorker<'_, ConnFunc, QueryFunc, SpanCreator>
512+
impl<QueryFunc, QueryFut, SpanCreator> RowIteratorWorker<'_, QueryFunc, SpanCreator>
529513
where
530-
ConnFunc: Fn(Arc<Node>) -> ConnFut,
531-
ConnFut: Future<Output = Result<Arc<Connection>, QueryError>>,
532514
QueryFunc: Fn(Arc<Connection>, Consistency, Option<Bytes>) -> QueryFut,
533515
QueryFut: Future<Output = Result<QueryResponse, QueryError>>,
534516
SpanCreator: Fn() -> RequestSpan,
@@ -546,12 +528,13 @@ where
546528

547529
self.log_query_start();
548530

549-
'nodes_in_plan: for (node, _shard) in query_plan {
531+
'nodes_in_plan: for (node, shard) in query_plan {
550532
let span =
551533
trace_span!(parent: &self.parent_span, "Executing query", node = %node.address);
552534
// For each node in the plan choose a connection to use
553535
// This connection will be reused for same node retries to preserve paging cache on the shard
554-
let connection: Arc<Connection> = match (self.choose_connection)(node.clone())
536+
let connection: Arc<Connection> = match node
537+
.connection_for_shard(shard)
555538
.instrument(span.clone())
556539
.await
557540
{

scylla/src/transport/node.rs

+6-11
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use tracing::warn;
33
use uuid::Uuid;
44

55
/// Node represents a cluster node along with it's data and connections
6-
use crate::routing::{Sharder, Token};
6+
use crate::routing::{Shard, Sharder};
77
use crate::transport::connection::Connection;
88
use crate::transport::connection::VerifiedKeyspaceName;
99
use crate::transport::connection_pool::{NodeConnectionPool, PoolConfig};
@@ -152,18 +152,13 @@ impl Node {
152152
self.pool.as_ref()?.sharder()
153153
}
154154

155-
/// Get connection which should be used to connect using given token
156-
/// If this connection is broken get any random connection to this Node
157-
pub(crate) async fn connection_for_token(
155+
/// Get a connection targetting the given shard
156+
/// If such connection is broken, get any random connection to this `Node`
157+
pub(crate) async fn connection_for_shard(
158158
&self,
159-
token: Token,
159+
shard: Shard,
160160
) -> Result<Arc<Connection>, QueryError> {
161-
self.get_pool()?.connection_for_token(token)
162-
}
163-
164-
/// Get random connection
165-
pub(crate) async fn random_connection(&self) -> Result<Arc<Connection>, QueryError> {
166-
self.get_pool()?.random_connection()
161+
self.get_pool()?.connection_for_shard(shard)
167162
}
168163

169164
pub fn is_down(&self) -> bool {

scylla/src/transport/session.rs

+6-33
Original file line numberDiff line numberDiff line change
@@ -655,7 +655,6 @@ impl Session {
655655
statement_info,
656656
&query.config,
657657
execution_profile,
658-
|node: Arc<Node>| async move { node.random_connection().await },
659658
|connection: Arc<Connection>,
660659
consistency: Consistency,
661660
execution_profile: &ExecutionProfileInner| {
@@ -1024,12 +1023,6 @@ impl Session {
10241023
statement_info,
10251024
&prepared.config,
10261025
execution_profile,
1027-
|node: Arc<Node>| async move {
1028-
match token {
1029-
Some(token) => node.connection_for_token(token).await,
1030-
None => node.random_connection().await,
1031-
}
1032-
},
10331026
|connection: Arc<Connection>,
10341027
consistency: Consistency,
10351028
execution_profile: &ExecutionProfileInner| {
@@ -1236,14 +1229,6 @@ impl Session {
12361229
statement_info,
12371230
&batch.config,
12381231
execution_profile,
1239-
|node: Arc<Node>| async move {
1240-
match first_value_token {
1241-
Some(first_value_token) => {
1242-
node.connection_for_token(first_value_token).await
1243-
}
1244-
None => node.random_connection().await,
1245-
}
1246-
},
12471232
|connection: Arc<Connection>,
12481233
consistency: Consistency,
12491234
execution_profile: &ExecutionProfileInner| {
@@ -1507,28 +1492,23 @@ impl Session {
15071492
}
15081493

15091494
// This method allows to easily run a query using load balancing, retry policy etc.
1510-
// Requires some information about the query and two closures
1511-
// First closure is used to choose a connection
1512-
// - query will use node.random_connection()
1513-
// - execute will use node.connection_for_token()
1514-
// The second closure is used to do the query itself on a connection
1495+
// Requires some information about the query and a closure.
1496+
// The closure is used to do the query itself on a connection
15151497
// - query will use connection.query()
15161498
// - execute will use connection.execute()
15171499
// If this query closure fails with some errors retry policy is used to perform retries
15181500
// On success this query's result is returned
15191501
// I tried to make this closures take a reference instead of an Arc but failed
15201502
// maybe once async closures get stabilized this can be fixed
1521-
async fn run_query<'a, ConnFut, QueryFut, ResT>(
1503+
async fn run_query<'a, QueryFut, ResT>(
15221504
&'a self,
15231505
statement_info: RoutingInfo<'a>,
15241506
statement_config: &'a StatementConfig,
15251507
execution_profile: Arc<ExecutionProfileInner>,
1526-
choose_connection: impl Fn(Arc<Node>) -> ConnFut,
15271508
do_query: impl Fn(Arc<Connection>, Consistency, &ExecutionProfileInner) -> QueryFut,
15281509
request_span: &'a RequestSpan,
15291510
) -> Result<RunQueryResult<ResT>, QueryError>
15301511
where
1531-
ConnFut: Future<Output = Result<Arc<Connection>, QueryError>>,
15321512
QueryFut: Future<Output = Result<ResT, QueryError>>,
15331513
ResT: AllowedRunQueryResTType,
15341514
{
@@ -1602,7 +1582,6 @@ impl Session {
16021582

16031583
self.execute_query(
16041584
&shared_query_plan,
1605-
&choose_connection,
16061585
&do_query,
16071586
&execution_profile,
16081587
ExecuteQueryContext {
@@ -1638,7 +1617,6 @@ impl Session {
16381617
});
16391618
self.execute_query(
16401619
query_plan,
1641-
&choose_connection,
16421620
&do_query,
16431621
&execution_profile,
16441622
ExecuteQueryContext {
@@ -1684,16 +1662,14 @@ impl Session {
16841662
result
16851663
}
16861664

1687-
async fn execute_query<'a, ConnFut, QueryFut, ResT>(
1665+
async fn execute_query<'a, QueryFut, ResT>(
16881666
&'a self,
16891667
query_plan: impl Iterator<Item = (NodeRef<'a>, Shard)>,
1690-
choose_connection: impl Fn(Arc<Node>) -> ConnFut,
16911668
do_query: impl Fn(Arc<Connection>, Consistency, &ExecutionProfileInner) -> QueryFut,
16921669
execution_profile: &ExecutionProfileInner,
16931670
mut context: ExecuteQueryContext<'a>,
16941671
) -> Option<Result<RunQueryResult<ResT>, QueryError>>
16951672
where
1696-
ConnFut: Future<Output = Result<Arc<Connection>, QueryError>>,
16971673
QueryFut: Future<Output = Result<ResT, QueryError>>,
16981674
ResT: AllowedRunQueryResTType,
16991675
{
@@ -1702,14 +1678,11 @@ impl Session {
17021678
.consistency_set_on_statement
17031679
.unwrap_or(execution_profile.consistency);
17041680

1705-
'nodes_in_plan: for (node, _shard) in query_plan {
1681+
'nodes_in_plan: for (node, shard) in query_plan {
17061682
let span = trace_span!("Executing query", node = %node.address);
17071683
'same_node_retries: loop {
17081684
trace!(parent: &span, "Execution started");
1709-
let connection: Arc<Connection> = match choose_connection(node.clone())
1710-
.instrument(span.clone())
1711-
.await
1712-
{
1685+
let connection = match node.connection_for_shard(shard).await {
17131686
Ok(connection) => connection,
17141687
Err(e) => {
17151688
trace!(

0 commit comments

Comments
 (0)