Skip to content

Commit 8e845e7

Browse files
authored
Merge pull request #944 from Lorak-mmk/shard-selecting-lb-v2
Shard selecting load balancing
2 parents 07df198 + 28ae015 commit 8e845e7

18 files changed

+368
-296
lines changed

Cargo.lock.msrv

+1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

docs/source/load-balancing/load-balancing.md

+13-12
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22

33
## Introduction
44

5-
The driver uses a load balancing policy to determine which node(s) to contact
6-
when executing a query. Load balancing policies implement the
5+
The driver uses a load balancing policy to determine which node(s) and shard(s)
6+
to contact when executing a query. Load balancing policies implement the
77
`LoadBalancingPolicy` trait, which contains methods to generate a load
88
balancing plan based on the query information and the state of the cluster.
99

@@ -12,12 +12,14 @@ being opened. For a node connection blacklist configuration refer to
1212
`scylla::transport::host_filter::HostFilter`, which can be set session-wide
1313
using `SessionBuilder::host_filter` method.
1414

15+
In this chapter, "target" will refer to a pair `<node, optional shard>`.
16+
1517
## Plan
1618

1719
When a query is prepared to be sent to the database, the load balancing policy
18-
constructs a load balancing plan. This plan is essentially a list of nodes to
20+
constructs a load balancing plan. This plan is essentially a list of targets to
1921
which the driver will try to send the query. The first elements of the plan are
20-
the nodes which are the best to contact (e.g. they might be replicas for the
22+
the targets which are the best to contact (e.g. they might be replicas for the
2123
requested data or have the best latency).
2224

2325
## Policy
@@ -84,17 +86,16 @@ first element of the load balancing plan is needed, so it's usually unnecessary
8486
to compute entire load balancing plan. To optimize this common case, the
8587
`LoadBalancingPolicy` trait provides two methods: `pick` and `fallback`.
8688

87-
`pick` returns the first node to contact for a given query, which is usually
88-
the best based on a particular load balancing policy. If `pick` returns `None`,
89-
then `fallback` will not be called.
89+
`pick` returns the first target to contact for a given query, which is usually
90+
the best based on a particular load balancing policy.
9091

91-
`fallback`, returns an iterator that provides the rest of the nodes in the load
92-
balancing plan. `fallback` is called only when using the initial picked node
93-
fails (or when executing speculatively).
92+
`fallback`, returns an iterator that provides the rest of the targets in the
93+
load balancing plan. `fallback` is called when using the initial picked
94+
target fails (or when executing speculatively) or when `pick` returned `None`.
9495

95-
It's possible for the `fallback` method to include the same node that was
96+
It's possible for the `fallback` method to include the same target that was
9697
returned by the `pick` method. In such cases, the query execution layer filters
97-
out the picked node from the iterator returned by `fallback`.
98+
out the picked target from the iterator returned by `fallback`.
9899

99100
### `on_query_success` and `on_query_failure`:
100101

examples/Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ uuid = "1.0"
2020
tower = "0.4"
2121
stats_alloc = "0.1"
2222
clap = { version = "3.2.4", features = ["derive"] }
23+
rand = "0.8.5"
2324

2425
[[example]]
2526
name = "auth"

examples/compare-tokens.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ async fn main() -> Result<()> {
4141
.get_cluster_data()
4242
.get_token_endpoints("examples_ks", Token { value: t })
4343
.iter()
44-
.map(|n| n.address)
44+
.map(|(node, _shard)| node.address)
4545
.collect::<Vec<NodeAddr>>()
4646
);
4747

examples/custom_load_balancing_policy.rs

+17-3
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,37 @@
11
use anyhow::Result;
2+
use rand::thread_rng;
3+
use rand::Rng;
4+
use scylla::transport::NodeRef;
25
use scylla::{
36
load_balancing::{LoadBalancingPolicy, RoutingInfo},
7+
routing::Shard,
48
transport::{ClusterData, ExecutionProfile},
59
Session, SessionBuilder,
610
};
711
use std::{env, sync::Arc};
812

913
/// Example load balancing policy that prefers nodes from favorite datacenter
14+
/// This is, of course, very naive, as it is completely non token-aware.
15+
/// For more realistic implementation, see [`DefaultPolicy`](scylla::load_balancing::DefaultPolicy).
1016
#[derive(Debug)]
1117
struct CustomLoadBalancingPolicy {
1218
fav_datacenter_name: String,
1319
}
1420

21+
fn with_random_shard(node: NodeRef) -> (NodeRef, Shard) {
22+
let nr_shards = node
23+
.sharder()
24+
.map(|sharder| sharder.nr_shards.get())
25+
.unwrap_or(1);
26+
(node, thread_rng().gen_range(0..nr_shards) as Shard)
27+
}
28+
1529
impl LoadBalancingPolicy for CustomLoadBalancingPolicy {
1630
fn pick<'a>(
1731
&'a self,
1832
_info: &'a RoutingInfo,
1933
cluster: &'a ClusterData,
20-
) -> Option<scylla::transport::NodeRef<'a>> {
34+
) -> Option<(NodeRef<'a>, Shard)> {
2135
self.fallback(_info, cluster).next()
2236
}
2337

@@ -31,9 +45,9 @@ impl LoadBalancingPolicy for CustomLoadBalancingPolicy {
3145
.unique_nodes_in_datacenter_ring(&self.fav_datacenter_name);
3246

3347
match fav_dc_nodes {
34-
Some(nodes) => Box::new(nodes.iter()),
48+
Some(nodes) => Box::new(nodes.iter().map(with_random_shard)),
3549
// If there is no dc with provided name, fallback to other datacenters
36-
None => Box::new(cluster.get_nodes_info().iter()),
50+
None => Box::new(cluster.get_nodes_info().iter().map(with_random_shard)),
3751
}
3852
}
3953

scylla/src/transport/cluster.rs

+6-5
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
/// Cluster manages up to date information and connections to database nodes
22
use crate::frame::response::event::{Event, StatusChangeEvent};
33
use crate::prepared_statement::TokenCalculationError;
4-
use crate::routing::Token;
4+
use crate::routing::{Shard, Token};
55
use crate::transport::host_filter::HostFilter;
66
use crate::transport::{
77
connection::{Connection, VerifiedKeyspaceName},
@@ -27,6 +27,7 @@ use tracing::{debug, warn};
2727
use uuid::Uuid;
2828

2929
use super::node::{KnownNode, NodeAddr};
30+
use super::NodeRef;
3031

3132
use super::locator::ReplicaLocator;
3233
use super::partitioner::calculate_token_for_partition_key;
@@ -408,17 +409,17 @@ impl ClusterData {
408409
}
409410

410411
/// Access to replicas owning a given token
411-
pub fn get_token_endpoints(&self, keyspace: &str, token: Token) -> Vec<Arc<Node>> {
412+
pub fn get_token_endpoints(&self, keyspace: &str, token: Token) -> Vec<(Arc<Node>, Shard)> {
412413
self.get_token_endpoints_iter(keyspace, token)
413-
.cloned()
414+
.map(|(node, shard)| (node.clone(), shard))
414415
.collect()
415416
}
416417

417418
pub(crate) fn get_token_endpoints_iter(
418419
&self,
419420
keyspace: &str,
420421
token: Token,
421-
) -> impl Iterator<Item = &Arc<Node>> {
422+
) -> impl Iterator<Item = (NodeRef<'_>, Shard)> {
422423
let keyspace = self.keyspaces.get(keyspace);
423424
let strategy = keyspace
424425
.map(|k| &k.strategy)
@@ -436,7 +437,7 @@ impl ClusterData {
436437
keyspace: &str,
437438
table: &str,
438439
partition_key: &SerializedValues,
439-
) -> Result<Vec<Arc<Node>>, BadQuery> {
440+
) -> Result<Vec<(Arc<Node>, Shard)>, BadQuery> {
440441
Ok(self.get_token_endpoints(
441442
keyspace,
442443
self.compute_token(keyspace, table, partition_key)?,

scylla/src/transport/connection_pool.rs

+15-12
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#[cfg(feature = "cloud")]
22
use crate::cloud::set_ssl_config_for_scylla_cloud_host;
33

4-
use crate::routing::{Shard, ShardCount, Sharder, Token};
4+
use crate::routing::{Shard, ShardCount, Sharder};
55
use crate::transport::errors::QueryError;
66
use crate::transport::{
77
connection,
@@ -28,7 +28,7 @@ use std::time::Duration;
2828

2929
use tokio::sync::{broadcast, mpsc, Notify};
3030
use tracing::instrument::WithSubscriber;
31-
use tracing::{debug, trace, warn};
31+
use tracing::{debug, error, trace, warn};
3232

3333
/// The target size of a per-node connection pool.
3434
#[derive(Debug, Clone, Copy)]
@@ -235,22 +235,25 @@ impl NodeConnectionPool {
235235
.unwrap_or(None)
236236
}
237237

238-
pub(crate) fn connection_for_token(&self, token: Token) -> Result<Arc<Connection>, QueryError> {
239-
trace!(token = token.value, "Selecting connection for token");
238+
pub(crate) fn connection_for_shard(&self, shard: Shard) -> Result<Arc<Connection>, QueryError> {
239+
trace!(shard = shard, "Selecting connection for shard");
240240
self.with_connections(|pool_conns| match pool_conns {
241241
PoolConnections::NotSharded(conns) => {
242242
Self::choose_random_connection_from_slice(conns).unwrap()
243243
}
244244
PoolConnections::Sharded {
245-
sharder,
246245
connections,
246+
sharder
247247
} => {
248-
let shard: u16 = sharder
249-
.shard_of(token)
248+
let shard = shard
250249
.try_into()
251-
.expect("Shard number doesn't fit in u16");
252-
trace!(shard = shard, "Selecting connection for token");
253-
Self::connection_for_shard(shard, sharder.nr_shards, connections.as_slice())
250+
// It's safer to use 0 rather that panic here, as shards are returned by `LoadBalancingPolicy`
251+
// now, which can be implemented by a user in an arbitrary way.
252+
.unwrap_or_else(|_| {
253+
error!("The provided shard number: {} does not fit u16! Using 0 as the shard number. Check your LoadBalancingPolicy implementation.", shard);
254+
0
255+
});
256+
Self::connection_for_shard_helper(shard, sharder.nr_shards, connections.as_slice())
254257
}
255258
})
256259
}
@@ -266,13 +269,13 @@ impl NodeConnectionPool {
266269
connections,
267270
} => {
268271
let shard: u16 = rand::thread_rng().gen_range(0..sharder.nr_shards.get());
269-
Self::connection_for_shard(shard, sharder.nr_shards, connections.as_slice())
272+
Self::connection_for_shard_helper(shard, sharder.nr_shards, connections.as_slice())
270273
}
271274
})
272275
}
273276

274277
// Tries to get a connection to given shard, if it's broken returns any working connection
275-
fn connection_for_shard(
278+
fn connection_for_shard_helper(
276279
shard: u16,
277280
nr_shards: ShardCount,
278281
shard_conns: &[Vec<Arc<Connection>>],

scylla/src/transport/iterator.rs

+7-24
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ use crate::transport::connection::{Connection, NonErrorQueryResponse, QueryRespo
3535
use crate::transport::load_balancing::{self, RoutingInfo};
3636
use crate::transport::metrics::Metrics;
3737
use crate::transport::retry_policy::{QueryInfo, RetryDecision, RetrySession};
38-
use crate::transport::{Node, NodeRef};
38+
use crate::transport::NodeRef;
3939
use tracing::{trace, trace_span, warn, Instrument};
4040
use uuid::Uuid;
4141

@@ -160,8 +160,6 @@ impl RowIterator {
160160
let worker_task = async move {
161161
let query_ref = &query;
162162

163-
let choose_connection = |node: Arc<Node>| async move { node.random_connection().await };
164-
165163
let page_query = |connection: Arc<Connection>,
166164
consistency: Consistency,
167165
paging_state: Option<Bytes>| {
@@ -187,7 +185,6 @@ impl RowIterator {
187185

188186
let worker = RowIteratorWorker {
189187
sender: sender.into(),
190-
choose_connection,
191188
page_query,
192189
statement_info: routing_info,
193190
query_is_idempotent: query.config.is_idempotent,
@@ -259,13 +256,6 @@ impl RowIterator {
259256
is_confirmed_lwt: config.prepared.is_confirmed_lwt(),
260257
};
261258

262-
let choose_connection = |node: Arc<Node>| async move {
263-
match token {
264-
Some(token) => node.connection_for_token(token).await,
265-
None => node.random_connection().await,
266-
}
267-
};
268-
269259
let page_query = |connection: Arc<Connection>,
270260
consistency: Consistency,
271261
paging_state: Option<Bytes>| async move {
@@ -290,7 +280,7 @@ impl RowIterator {
290280
config
291281
.cluster_data
292282
.get_token_endpoints_iter(keyspace, token)
293-
.cloned()
283+
.map(|(node, shard)| (node.clone(), shard))
294284
.collect(),
295285
)
296286
} else {
@@ -311,7 +301,6 @@ impl RowIterator {
311301

312302
let worker = RowIteratorWorker {
313303
sender: sender.into(),
314-
choose_connection,
315304
page_query,
316305
statement_info,
317306
query_is_idempotent: config.prepared.config.is_idempotent,
@@ -496,13 +485,9 @@ type PageSendAttemptedProof = SendAttemptedProof<Result<ReceivedPage, QueryError
496485

497486
// RowIteratorWorker works in the background to fetch pages
498487
// RowIterator receives them through a channel
499-
struct RowIteratorWorker<'a, ConnFunc, QueryFunc, SpanCreatorFunc> {
488+
struct RowIteratorWorker<'a, QueryFunc, SpanCreatorFunc> {
500489
sender: ProvingSender<Result<ReceivedPage, QueryError>>,
501490

502-
// Closure used to choose a connection from a node
503-
// AsyncFn(Arc<Node>) -> Result<Arc<Connection>, QueryError>
504-
choose_connection: ConnFunc,
505-
506491
// Closure used to perform a single page query
507492
// AsyncFn(Arc<Connection>, Option<Bytes>) -> Result<QueryResponse, QueryError>
508493
page_query: QueryFunc,
@@ -524,11 +509,8 @@ struct RowIteratorWorker<'a, ConnFunc, QueryFunc, SpanCreatorFunc> {
524509
span_creator: SpanCreatorFunc,
525510
}
526511

527-
impl<ConnFunc, ConnFut, QueryFunc, QueryFut, SpanCreator>
528-
RowIteratorWorker<'_, ConnFunc, QueryFunc, SpanCreator>
512+
impl<QueryFunc, QueryFut, SpanCreator> RowIteratorWorker<'_, QueryFunc, SpanCreator>
529513
where
530-
ConnFunc: Fn(Arc<Node>) -> ConnFut,
531-
ConnFut: Future<Output = Result<Arc<Connection>, QueryError>>,
532514
QueryFunc: Fn(Arc<Connection>, Consistency, Option<Bytes>) -> QueryFut,
533515
QueryFut: Future<Output = Result<QueryResponse, QueryError>>,
534516
SpanCreator: Fn() -> RequestSpan,
@@ -546,12 +528,13 @@ where
546528

547529
self.log_query_start();
548530

549-
'nodes_in_plan: for node in query_plan {
531+
'nodes_in_plan: for (node, shard) in query_plan {
550532
let span =
551533
trace_span!(parent: &self.parent_span, "Executing query", node = %node.address);
552534
// For each node in the plan choose a connection to use
553535
// This connection will be reused for same node retries to preserve paging cache on the shard
554-
let connection: Arc<Connection> = match (self.choose_connection)(node.clone())
536+
let connection: Arc<Connection> = match node
537+
.connection_for_shard(shard)
555538
.instrument(span.clone())
556539
.await
557540
{

0 commit comments

Comments
 (0)