From 47129ff0d5cf4ada77bcfccd804e1754187ec7f2 Mon Sep 17 00:00:00 2001 From: Varik Matevosyan Date: Thu, 31 Oct 2024 12:40:17 +0400 Subject: [PATCH] feature-gate embedding,autotune,external-indexing jobs in daemon, so no additional dependencies will be compiled if the code is not used. Review all dependencies for lantern-cli in Cargo.toml and mark optional based on features that needs them --- lantern_cli/Cargo.toml | 41 ++-- lantern_cli/src/cli.rs | 12 +- lantern_cli/src/daemon/autotune_jobs.rs | 41 +++- lantern_cli/src/daemon/embedding_jobs.rs | 113 +++++++++- lantern_cli/src/daemon/external_index_jobs.rs | 37 +++- lantern_cli/src/daemon/helpers.rs | 47 ++-- lantern_cli/src/daemon/mod.rs | 63 ++++-- lantern_cli/src/daemon/types.rs | 209 ++---------------- lantern_cli/src/embeddings/measure_speed.rs | 2 +- lantern_cli/src/main.rs | 11 +- lantern_cli/src/utils/test_utils.rs | 24 +- lantern_cli/tests/embedding_test_with_db.rs | 2 +- lantern_extras/Cargo.toml | 1 - 13 files changed, 330 insertions(+), 273 deletions(-) diff --git a/lantern_cli/Cargo.toml b/lantern_cli/Cargo.toml index cd3e8928..a9a40b22 100644 --- a/lantern_cli/Cargo.toml +++ b/lantern_cli/Cargo.toml @@ -12,11 +12,11 @@ path = "src/main.rs" [dependencies] clap = { version = "4.5.20", features = ["derive"] } anyhow = "1.0.91" -postgres = "0.19.9" +postgres = { version = "0.19.9", optional = true } rand = "0.8.5" linfa-clustering = { version = "0.7.0", features = ["ndarray-linalg"], optional = true } linfa = {version = "0.7.0", optional = true} -ndarray = { version = "0.15.6", features = ["rayon"] } +ndarray = { version = "0.15.6", features = ["rayon"], optional = true } rayon = { version="1.10.0", optional = true } md5 = {version="0.7.0", optional = true } serde = { version = "1.0", features = ["derive"] } @@ -27,19 +27,18 @@ futures = "0.3.31" tokio = { version = "1.41.0", features = ["full"] } lazy_static = "1.5.0" itertools = "0.13.0" -csv = "1.3.0" -sysinfo = "0.32.0" -tiktoken-rs = "0.6.0" -url = "2.5" -num_cpus = "1.16.0" -ort = { version = "1.16.3", features = ["load-dynamic", "cuda", "openvino"] } -tokenizers = { version = "0.20.1", features = ["default"] } -image = { version = "0.25.4", features = ["jpeg", "png", "webp" ]} -nvml-wrapper = "0.10.0" -strum = { version = "0.26", features = ["derive"] } -regex = "1.11.1" -postgres-types = { version = "0.2.8", features = ["derive"] } -usearch = { git = "https://github.com/Ngalstyan4/usearch.git", rev = "aa4f91d21230fd611b6c7741fa06be8c20acc9a9" } +sysinfo = { version = "0.32.0", optional = true } +tiktoken-rs = { version = "0.6.0", optional = true } +url = { version = "2.5", optional = true } +num_cpus = { version = "1.16.0", optional = true } +ort = { version = "1.16.3", features = ["load-dynamic", "cuda", "openvino"], optional = true } +tokenizers = { version = "0.20.1", features = ["default"], optional = true } +image = { version = "0.25.4", features = ["jpeg", "png", "webp" ], optional = true } +nvml-wrapper = { version = "0.10.0", optional = true } +strum = { version = "0.26", features = ["derive"], optional = true } +regex = { version = "1.11.1", optional = true } +postgres-types = { version = "0.2.8", features = ["derive"], optional = true } +usearch = { git = "https://github.com/Ngalstyan4/usearch.git", rev = "aa4f91d21230fd611b6c7741fa06be8c20acc9a9", optional = true } actix-web = { version = "4.9.0", optional = true } env_logger = { version = "0.11.5", optional = true } deadpool-postgres = { version = "0.14.0", optional = true } @@ -53,18 +52,18 @@ bitvec = { version="1.0.1", optional=true } rustls = { version="0.23.16", optional=true } rustls-pemfile = { version="2.2.0", optional=true } glob = { version="0.3.1", optional=true } -reqwest = { version = "0.12.9", default-features = false, features = ["json", "blocking", "rustls-tls"] } +reqwest = { version = "0.12.9", default-features = false, features = ["json", "blocking", "rustls-tls"], optional = true } [features] default = ["cli", "daemon", "http-server", "autotune", "pq", "external-index", "external-index-server", "embeddings"] daemon = ["dep:tokio-postgres"] -http-server = ["dep:deadpool-postgres", "dep:deadpool", "dep:bytes", "dep:utoipa", "dep:utoipa-swagger-ui", "dep:actix-web", "dep:tokio-postgres", "dep:env_logger", "dep:actix-web-httpauth"] +http-server = ["dep:deadpool-postgres", "dep:deadpool", "dep:bytes", "dep:utoipa", "dep:utoipa-swagger-ui", "dep:actix-web", "dep:tokio-postgres", "dep:env_logger", "dep:actix-web-httpauth", "dep:regex"] autotune = [] -pq = ["dep:gcp_auth", "dep:linfa", "dep:linfa-clustering", "dep:md5", "dep:rayon"] +pq = ["dep:gcp_auth", "dep:linfa", "dep:linfa-clustering", "dep:md5", "dep:rayon", "dep:reqwest", "dep:postgres", "dep:ndarray"] cli = [] -external-index = [] -external-index-server = ["dep:bitvec", "dep:rustls", "dep:rustls-pemfile", "dep:glob"] -embeddings = ["dep:bytes"] +external-index = ["dep:postgres-types", "dep:usearch", "dep:postgres"] +external-index-server = ["dep:bitvec", "dep:rustls", "dep:rustls-pemfile", "dep:glob", "dep:usearch"] +embeddings = ["dep:bytes", "dep:sysinfo", "dep:tiktoken-rs", "dep:url", "dep:num_cpus", "dep:ort", "dep:tokenizers", "dep:image", "dep:nvml-wrapper", "dep:strum", "dep:regex", "dep:reqwest", "dep:ndarray"] [lib] doctest = false diff --git a/lantern_cli/src/cli.rs b/lantern_cli/src/cli.rs index d764194a..c16aad02 100644 --- a/lantern_cli/src/cli.rs +++ b/lantern_cli/src/cli.rs @@ -1,11 +1,11 @@ -use super::daemon::cli::DaemonArgs; -use super::embeddings::cli::{EmbeddingArgs, MeasureModelSpeedArgs, ShowModelsArgs}; -use super::external_index::cli::CreateIndexArgs; -use super::http_server::cli::HttpServerArgs; -use super::index_autotune::cli::IndexAutotuneArgs; -use super::pq::cli::PQArgs; use clap::{Parser, Subcommand}; +use lantern_cli::daemon::cli::DaemonArgs; +use lantern_cli::embeddings::cli::{EmbeddingArgs, MeasureModelSpeedArgs, ShowModelsArgs}; +use lantern_cli::external_index::cli::CreateIndexArgs; use lantern_cli::external_index::cli::IndexServerArgs; +use lantern_cli::http_server::cli::HttpServerArgs; +use lantern_cli::index_autotune::cli::IndexAutotuneArgs; +use lantern_cli::pq::cli::PQArgs; #[derive(Subcommand, Debug)] pub enum Commands { diff --git a/lantern_cli/src/daemon/autotune_jobs.rs b/lantern_cli/src/daemon/autotune_jobs.rs index 83ba4884..eaaee82b 100644 --- a/lantern_cli/src/daemon/autotune_jobs.rs +++ b/lantern_cli/src/daemon/autotune_jobs.rs @@ -3,8 +3,8 @@ use super::helpers::{ index_job_update_processor, remove_job_handle, set_job_handle, startup_hook, }; use super::types::{ - AutotuneJob, AutotuneProcessorArgs, JobEvent, JobEventHandlersMap, JobInsertNotification, - JobRunArgs, JobUpdateNotification, + AutotuneProcessorArgs, JobEvent, JobEventHandlersMap, JobInsertNotification, JobRunArgs, + JobUpdateNotification, }; use crate::daemon::helpers::anyhow_wrap_connection; use crate::external_index::cli::UMetricKind; @@ -20,7 +20,7 @@ use tokio::sync::{ mpsc, mpsc::{Receiver, Sender, UnboundedReceiver, UnboundedSender}, }; -use tokio_postgres::{Client, NoTls}; +use tokio_postgres::{Client, NoTls, Row}; use tokio_util::sync::CancellationToken; pub const JOB_TABLE_DEFINITION: &'static str = r#" @@ -55,6 +55,41 @@ latency DOUBLE PRECISION NOT NULL, build_time DOUBLE PRECISION NULL "#; +#[derive(Debug)] +pub struct AutotuneJob { + pub id: i32, + pub is_init: bool, + pub db_uri: String, + pub schema: String, + pub table: String, + pub column: String, + pub metric_kind: String, + pub model_name: Option, + pub recall: f64, + pub k: u16, + pub sample_size: usize, + pub create_index: bool, +} + +impl AutotuneJob { + pub fn new(row: Row, db_uri: &str) -> AutotuneJob { + Self { + id: row.get::<&str, i32>("id"), + db_uri: db_uri.to_owned(), + schema: row.get::<&str, String>("schema"), + table: row.get::<&str, String>("table"), + column: row.get::<&str, String>("column"), + metric_kind: row.get::<&str, String>("metric_kind"), + model_name: row.get::<&str, Option>("model"), + recall: row.get::<&str, f64>("target_recall"), + k: row.get::<&str, i32>("k") as u16, + sample_size: row.get::<&str, i32>("sample_size") as usize, + create_index: row.get::<&str, bool>("create_index"), + is_init: true, + } + } +} + pub async fn autotune_job_processor( mut rx: Receiver, cancel_token: CancellationToken, diff --git a/lantern_cli/src/daemon/embedding_jobs.rs b/lantern_cli/src/daemon/embedding_jobs.rs index c179ff73..e6348ba4 100644 --- a/lantern_cli/src/daemon/embedding_jobs.rs +++ b/lantern_cli/src/daemon/embedding_jobs.rs @@ -4,15 +4,19 @@ use super::helpers::{ remove_job_handle, schedule_job_retry, set_job_handle, startup_hook, }; use super::types::{ - ClientJobsMap, EmbeddingJob, EmbeddingProcessorArgs, JobBatchingHashMap, JobEvent, - JobEventHandlersMap, JobInsertNotification, JobRunArgs, JobUpdateNotification, + ClientJobsMap, EmbeddingProcessorArgs, JobBatchingHashMap, JobEvent, JobEventHandlersMap, + JobInsertNotification, JobRunArgs, JobUpdateNotification, }; use crate::daemon::helpers::anyhow_wrap_connection; -use crate::embeddings::cli::{EmbeddingArgs, EmbeddingJobType}; +use crate::embeddings::cli::{EmbeddingArgs, EmbeddingJobType, Runtime}; +use crate::embeddings::core::utils::get_clean_model_name; use crate::embeddings::get_default_batch_size; use crate::logger::Logger; -use crate::utils::{get_common_embedding_ignore_filters, get_full_table_name, quote_ident}; +use crate::utils::{ + get_common_embedding_ignore_filters, get_full_table_name, quote_ident, quote_literal, +}; use crate::{embeddings, types::*}; +use itertools::Itertools; use std::collections::HashMap; use std::ops::Deref; use std::path::Path; @@ -22,8 +26,7 @@ use std::time::SystemTime; use tokio::fs; use tokio::sync::mpsc::{Receiver, Sender, UnboundedReceiver, UnboundedSender}; use tokio::sync::{mpsc, Mutex, RwLock}; -use tokio_postgres::types::ToSql; -use tokio_postgres::{Client, NoTls}; +use tokio_postgres::{types::ToSql, Client, NoTls, Row}; use tokio_util::sync::CancellationToken; pub const JOB_TABLE_DEFINITION: &'static str = r#" @@ -71,6 +74,104 @@ const EMB_USAGE_TABLE_NAME: &'static str = "embedding_usage_info"; const EMB_FAILURE_TABLE_NAME: &'static str = "embedding_failure_info"; const EMB_LOCK_TABLE_NAME: &'static str = "_lantern_emb_job_locks"; +#[derive(Debug, Clone)] +pub struct EmbeddingJob { + pub id: i32, + pub is_init: bool, + pub db_uri: String, + pub schema: String, + pub table: String, + pub column: String, + pub pk: String, + pub filter: Option, + pub label: Option, + pub job_type: EmbeddingJobType, + pub column_type: String, + pub out_column: String, + pub model: String, + pub runtime_params: String, + pub runtime: Runtime, + pub batch_size: Option, + pub row_ids: Option>, +} + +impl EmbeddingJob { + pub fn new(row: Row, data_path: &str, db_uri: &str) -> Result { + let runtime = Runtime::try_from(row.get::<&str, Option<&str>>("runtime").unwrap_or("ort"))?; + let runtime_params = if runtime == Runtime::Ort { + format!(r#"{{ "data_path": "{data_path}" }}"#) + } else { + row.get::<&str, Option>("runtime_params") + .unwrap_or("{}".to_owned()) + }; + + let batch_size = if let Some(batch_size) = row.get::<&str, Option>("batch_size") { + Some(batch_size as usize) + } else { + None + }; + + Ok(Self { + id: row.get::<&str, i32>("id"), + pk: row.get::<&str, String>("pk"), + label: row.get::<&str, Option>("label"), + db_uri: db_uri.to_owned(), + schema: row.get::<&str, String>("schema"), + table: row.get::<&str, String>("table"), + column: row.get::<&str, String>("column"), + out_column: row.get::<&str, String>("dst_column"), + model: get_clean_model_name(row.get::<&str, &str>("model"), runtime), + runtime, + runtime_params, + filter: None, + row_ids: None, + is_init: true, + batch_size, + job_type: EmbeddingJobType::try_from( + row.get::<&str, Option<&str>>("job_type") + .unwrap_or("embedding"), + )?, + column_type: row + .get::<&str, Option>("column_type") + .unwrap_or("REAL[]".to_owned()), + }) + } + + pub fn set_filter(&mut self, filter: &str) { + self.filter = Some(filter.to_owned()); + } + + pub fn set_is_init(&mut self, is_init: bool) { + self.is_init = is_init; + } + + pub fn set_row_ids(&mut self, row_ids: Vec) { + self.row_ids = Some(row_ids); + } + + #[allow(dead_code)] + pub fn set_ctid_filter(&mut self, row_ids: &Vec) { + let row_ctids_str = row_ids + .iter() + .map(|r| { + format!( + "currtid2('{table_name}','{r}'::tid)", + table_name = &self.table + ) + }) + .join(","); + self.set_filter(&format!("ctid IN ({row_ctids_str})")); + } + + pub fn set_id_filter(&mut self, row_ids: &Vec) { + let row_ctids_str = row_ids.iter().map(|s| quote_literal(s)).join(","); + self.set_filter(&format!( + "id IN ({row_ctids_str}) AND {common_filter}", + common_filter = get_common_embedding_ignore_filters(&self.column) + )); + } +} + async fn lock_row( client: Arc, lock_table_name: &str, diff --git a/lantern_cli/src/daemon/external_index_jobs.rs b/lantern_cli/src/daemon/external_index_jobs.rs index f6a998bc..9dda982e 100644 --- a/lantern_cli/src/daemon/external_index_jobs.rs +++ b/lantern_cli/src/daemon/external_index_jobs.rs @@ -3,8 +3,8 @@ use super::helpers::{ index_job_update_processor, startup_hook, }; use super::types::{ - ExternalIndexJob, ExternalIndexProcessorArgs, JobEvent, JobEventHandlersMap, - JobInsertNotification, JobRunArgs, JobTaskEventTx, JobUpdateNotification, + ExternalIndexProcessorArgs, JobEvent, JobEventHandlersMap, JobInsertNotification, JobRunArgs, + JobTaskEventTx, JobUpdateNotification, }; use crate::daemon::helpers::anyhow_wrap_connection; use crate::external_index::cli::{CreateIndexArgs, UMetricKind}; @@ -16,7 +16,7 @@ use std::sync::Arc; use std::time::SystemTime; use tokio::sync::mpsc::{self, Receiver, Sender, UnboundedReceiver, UnboundedSender}; use tokio::sync::RwLock; -use tokio_postgres::{Client, NoTls}; +use tokio_postgres::{Client, NoTls, Row}; use tokio_util::sync::CancellationToken; pub const JOB_TABLE_DEFINITION: &'static str = r#" @@ -39,6 +39,37 @@ pub const JOB_TABLE_DEFINITION: &'static str = r#" "progress" INT2 DEFAULT 0 "#; +#[derive(Debug)] +pub struct ExternalIndexJob { + pub id: i32, + pub db_uri: String, + pub schema: String, + pub table: String, + pub column: String, + pub operator_class: String, + pub index_name: Option, + pub ef: usize, + pub efc: usize, + pub m: usize, +} + +impl ExternalIndexJob { + pub fn new(row: Row, db_uri: &str) -> ExternalIndexJob { + Self { + id: row.get::<&str, i32>("id"), + db_uri: db_uri.to_owned(), + schema: row.get::<&str, String>("schema"), + table: row.get::<&str, String>("table"), + column: row.get::<&str, String>("column"), + operator_class: row.get::<&str, String>("operator"), + index_name: row.get::<&str, Option>("index"), + ef: row.get::<&str, i32>("ef") as usize, + efc: row.get::<&str, i32>("efc") as usize, + m: row.get::<&str, i32>("m") as usize, + } + } +} + async fn set_job_handle( jobs_map: Arc, job_id: i32, diff --git a/lantern_cli/src/daemon/helpers.rs b/lantern_cli/src/daemon/helpers.rs index 1875fe1a..65cdb10b 100644 --- a/lantern_cli/src/daemon/helpers.rs +++ b/lantern_cli/src/daemon/helpers.rs @@ -1,19 +1,17 @@ use super::types::{ - EmbeddingJob, JobEvent, JobEventHandlersMap, JobInsertNotification, JobTaskEventTx, - JobUpdateNotification, + JobEvent, JobEventHandlersMap, JobInsertNotification, JobTaskEventTx, JobUpdateNotification, }; -use crate::embeddings::get_try_cast_fn_sql; use crate::logger::Logger; -use crate::types::{AnyhowVoidResult, JOB_CANCELLED_MESSAGE}; +use crate::types::AnyhowVoidResult; use crate::utils::{get_common_embedding_ignore_filters, get_full_table_name, quote_ident}; use futures::StreamExt; -use postgres::tls::MakeTlsConnect; -use postgres::{IsolationLevel, Socket}; use std::sync::Arc; -use std::time::{Duration, SystemTime}; -use tokio::sync::mpsc::{Sender, UnboundedReceiver, UnboundedSender}; +use std::time::Duration; +use tokio::sync::mpsc::UnboundedSender; +use tokio_postgres::tls::MakeTlsConnect; use tokio_postgres::Client; use tokio_postgres::{AsyncMessage, Connection}; +use tokio_postgres::{IsolationLevel, Socket}; use tokio_util::sync::CancellationToken; pub async fn check_table_exists(client: Arc, table: &str) -> AnyhowVoidResult { @@ -283,6 +281,11 @@ pub async fn startup_hook( if failure_table_name.is_some() && failure_table_def.is_some() { let failure_table_name = get_full_table_name(schema, failure_table_name.unwrap()); let failure_table_def = failure_table_def.unwrap(); + #[cfg(not(feature = "embeddings"))] + let try_cast_fn = ""; + #[cfg(feature = "embeddings")] + let try_cast_fn = crate::embeddings::get_try_cast_fn_sql(&schema); + transaction .batch_execute(&format!( " @@ -292,7 +295,7 @@ pub async fn startup_hook( CREATE INDEX IF NOT EXISTS embedding_failures_job_id_row_id ON {failure_table_name}(job_id, row_id); GRANT SELECT ON {failure_table_name} TO PUBLIC; ", - ldb_try_cast_fn = get_try_cast_fn_sql(&schema) + ldb_try_cast_fn = try_cast_fn )) .await?; } @@ -302,6 +305,7 @@ pub async fn startup_hook( Ok(()) } +#[cfg(any(feature = "autotune", feature = "external-index"))] pub async fn collect_pending_index_jobs( client: Arc, insert_notification_tx: UnboundedSender, @@ -325,16 +329,17 @@ pub async fn collect_pending_index_jobs( // and some job will be terminated while running // on next start of daemon the job will not be picked as // it will already have started_at set - generate_missing: row.get::>(1).is_some(), + generate_missing: row.get::>(1).is_some(), })?; } Ok(()) } +#[cfg(any(feature = "autotune", feature = "external-index"))] pub async fn index_job_update_processor( client: Arc, - mut update_queue_rx: UnboundedReceiver, + mut update_queue_rx: tokio::sync::mpsc::UnboundedReceiver, schema: String, table: String, job_cancelleation_handlers: Arc, @@ -350,7 +355,7 @@ pub async fn index_job_update_processor( ) .await?; - let canceled_at: Option = row.get("canceled_at"); + let canceled_at: Option = row.get("canceled_at"); if canceled_at.is_some() { // Cancel ongoing job @@ -358,8 +363,10 @@ pub async fn index_job_update_processor( let job = jobs.get(&id); if let Some(tx) = job { - tx.send(JobEvent::Errored(JOB_CANCELLED_MESSAGE.to_owned())) - .await?; + tx.send(JobEvent::Errored( + crate::types::JOB_CANCELLED_MESSAGE.to_owned(), + )) + .await?; } drop(jobs); } @@ -370,13 +377,16 @@ pub async fn index_job_update_processor( Ok(()) } +#[cfg(any(feature = "autotune", feature = "external-index"))] pub async fn cancel_all_jobs(map: Arc) -> AnyhowVoidResult { let mut jobs_map = map.write().await; let jobs: Vec<(i32, JobTaskEventTx)> = jobs_map.drain().collect(); for (_, tx) in jobs { - tx.send(JobEvent::Errored(JOB_CANCELLED_MESSAGE.to_owned())) - .await?; + tx.send(JobEvent::Errored( + crate::types::JOB_CANCELLED_MESSAGE.to_owned(), + )) + .await?; } Ok(()) @@ -406,10 +416,11 @@ pub fn get_missing_rows_filter(src_column: &str, out_column: &str) -> String { ) } +#[cfg(feature = "embeddings")] pub async fn schedule_job_retry( logger: Arc, - job: EmbeddingJob, - tx: Sender, + job: super::embedding_jobs::EmbeddingJob, + tx: tokio::sync::mpsc::Sender, retry_after: Duration, ) { tokio::spawn(async move { diff --git a/lantern_cli/src/daemon/mod.rs b/lantern_cli/src/daemon/mod.rs index 115da864..da73acc8 100644 --- a/lantern_cli/src/daemon/mod.rs +++ b/lantern_cli/src/daemon/mod.rs @@ -1,7 +1,11 @@ +#[cfg(feature = "autotune")] pub mod autotune_jobs; pub mod cli; +#[cfg(feature = "embeddings")] mod client_embedding_jobs; +#[cfg(feature = "embeddings")] pub mod embedding_jobs; +#[cfg(feature = "external-index")] pub mod external_index_jobs; mod helpers; mod types; @@ -21,9 +25,6 @@ use crate::types::AnyhowVoidResult; use crate::{logger::Logger, utils::get_full_table_name}; use types::{DaemonJobHandlerMap, JobRunArgs, TargetDB}; -use autotune_jobs::autotune_job_processor; -use embedding_jobs::embedding_job_processor; -use external_index_jobs::external_index_job_processor; use types::{AutotuneProcessorArgs, EmbeddingProcessorArgs, ExternalIndexProcessorArgs, JobType}; lazy_static! { @@ -92,7 +93,7 @@ async fn spawn_job( args: Arc, job_type: JobType, parent_cancel_token: CancellationToken, -) { +) -> AnyhowVoidResult { let mut retry_interval = 5; let log_label = match job_type { @@ -122,7 +123,8 @@ async fn spawn_job( jobs.insert(target_db.name.clone(), cancel_token.clone()); drop(jobs); - let result = match &job_type { + let result: Result<(), anyhow::Error> = match &job_type { + #[cfg(feature = "embeddings")] JobType::Embeddings(processor_tx) => { embedding_jobs::start( JobRunArgs { @@ -139,6 +141,11 @@ async fn spawn_job( ) .await } + #[cfg(not(feature = "embeddings"))] + JobType::Embeddings(_) => { + anyhow::bail!("Embedding jobs are not enabled"); + } + #[cfg(feature = "external-index")] JobType::ExternalIndex(processor_tx) => { external_index_jobs::start( JobRunArgs { @@ -155,6 +162,11 @@ async fn spawn_job( ) .await } + #[cfg(not(feature = "external-index"))] + JobType::ExternalIndex(_) => { + anyhow::bail!("External Index jobs are not enabled"); + } + #[cfg(feature = "autotune")] JobType::Autotune(processor_tx) => { autotune_jobs::start( JobRunArgs { @@ -171,6 +183,10 @@ async fn spawn_job( ) .await } + #[cfg(not(feature = "autotune"))] + JobType::Autotune(_) => { + anyhow::bail!("Autotune jobs are not enabled"); + } }; cancel_token.cancel(); @@ -190,6 +206,8 @@ async fn spawn_job( break; } + + Ok(()) } async fn spawn_jobs( @@ -371,30 +389,39 @@ pub async fn start( let args_arc = Arc::new(args); let args_arc_clone = args_arc.clone(); - let (embedding_tx, embedding_rx): ( + let embedding_channel: ( Sender, Receiver, ) = mpsc::channel(1); - let (autotune_tx, autotune_rx): ( + let autotune_channel: ( Sender, Receiver, ) = mpsc::channel(1); - let (external_index_tx, external_index_rx): ( + let external_index_channel: ( Sender, Receiver, ) = mpsc::channel(1); + #[cfg(feature = "embeddings")] if args_arc.embeddings { - tokio::spawn(embedding_job_processor(embedding_rx, cancel_token.clone())); + tokio::spawn(embedding_jobs::embedding_job_processor( + embedding_channel.1, + cancel_token.clone(), + )); } + #[cfg(feature = "autotune")] if args_arc.autotune { - tokio::spawn(autotune_job_processor(autotune_rx, cancel_token.clone())); + tokio::spawn(autotune_jobs::autotune_job_processor( + autotune_channel.1, + cancel_token.clone(), + )); } + #[cfg(feature = "external-index")] if args_arc.external_index { - tokio::spawn(external_index_job_processor( - external_index_rx, + tokio::spawn(external_index_jobs::external_index_job_processor( + external_index_channel.1, cancel_token.clone(), )); } @@ -403,9 +430,9 @@ pub async fn start( spawn_jobs( target_db, args_arc_clone.clone(), - embedding_tx.clone(), - autotune_tx.clone(), - external_index_tx.clone(), + embedding_channel.0.clone(), + autotune_channel.0.clone(), + external_index_channel.0.clone(), cancel_token.clone(), ) .await; @@ -414,9 +441,9 @@ pub async fn start( if args_arc.master_db.is_some() { db_change_listener( args_arc.clone(), - embedding_tx.clone(), - autotune_tx.clone(), - external_index_tx.clone(), + embedding_channel.0.clone(), + autotune_channel.0.clone(), + external_index_channel.0.clone(), logger.clone(), cancel_token.clone(), ) diff --git a/lantern_cli/src/daemon/types.rs b/lantern_cli/src/daemon/types.rs index 913f337d..4bda413b 100644 --- a/lantern_cli/src/daemon/types.rs +++ b/lantern_cli/src/daemon/types.rs @@ -1,18 +1,8 @@ -use crate::embeddings::cli::{EmbeddingArgs, EmbeddingJobType, Runtime}; -use crate::embeddings::core::utils::get_clean_model_name; -use crate::external_index::cli::CreateIndexArgs; -use crate::index_autotune::cli::IndexAutotuneArgs; -use crate::logger::Logger; -use crate::types::{AnyhowVoidResult, ProgressCbFn}; -use crate::utils::{get_common_embedding_ignore_filters, quote_literal}; -use itertools::Itertools; use std::collections::HashMap; -use std::sync::Arc; use tokio::sync::{ mpsc::{Sender, UnboundedSender}, Mutex, RwLock, }; -use tokio_postgres::Row; use tokio_util::sync::CancellationToken; #[derive(Clone)] @@ -54,170 +44,6 @@ impl TargetDB { } } -#[derive(Debug, Clone)] -pub struct EmbeddingJob { - pub id: i32, - pub is_init: bool, - pub db_uri: String, - pub schema: String, - pub table: String, - pub column: String, - pub pk: String, - pub filter: Option, - pub label: Option, - pub job_type: EmbeddingJobType, - pub column_type: String, - pub out_column: String, - pub model: String, - pub runtime_params: String, - pub runtime: Runtime, - pub batch_size: Option, - pub row_ids: Option>, -} - -impl EmbeddingJob { - pub fn new(row: Row, data_path: &str, db_uri: &str) -> Result { - let runtime = Runtime::try_from(row.get::<&str, Option<&str>>("runtime").unwrap_or("ort"))?; - let runtime_params = if runtime == Runtime::Ort { - format!(r#"{{ "data_path": "{data_path}" }}"#) - } else { - row.get::<&str, Option>("runtime_params") - .unwrap_or("{}".to_owned()) - }; - - let batch_size = if let Some(batch_size) = row.get::<&str, Option>("batch_size") { - Some(batch_size as usize) - } else { - None - }; - - Ok(Self { - id: row.get::<&str, i32>("id"), - pk: row.get::<&str, String>("pk"), - label: row.get::<&str, Option>("label"), - db_uri: db_uri.to_owned(), - schema: row.get::<&str, String>("schema"), - table: row.get::<&str, String>("table"), - column: row.get::<&str, String>("column"), - out_column: row.get::<&str, String>("dst_column"), - model: get_clean_model_name(row.get::<&str, &str>("model"), runtime), - runtime, - runtime_params, - filter: None, - row_ids: None, - is_init: true, - batch_size, - job_type: EmbeddingJobType::try_from( - row.get::<&str, Option<&str>>("job_type") - .unwrap_or("embedding"), - )?, - column_type: row - .get::<&str, Option>("column_type") - .unwrap_or("REAL[]".to_owned()), - }) - } - - pub fn set_filter(&mut self, filter: &str) { - self.filter = Some(filter.to_owned()); - } - - pub fn set_is_init(&mut self, is_init: bool) { - self.is_init = is_init; - } - - pub fn set_row_ids(&mut self, row_ids: Vec) { - self.row_ids = Some(row_ids); - } - - #[allow(dead_code)] - pub fn set_ctid_filter(&mut self, row_ids: &Vec) { - let row_ctids_str = row_ids - .iter() - .map(|r| { - format!( - "currtid2('{table_name}','{r}'::tid)", - table_name = &self.table - ) - }) - .join(","); - self.set_filter(&format!("ctid IN ({row_ctids_str})")); - } - - pub fn set_id_filter(&mut self, row_ids: &Vec) { - let row_ctids_str = row_ids.iter().map(|s| quote_literal(s)).join(","); - self.set_filter(&format!( - "id IN ({row_ctids_str}) AND {common_filter}", - common_filter = get_common_embedding_ignore_filters(&self.column) - )); - } -} - -#[derive(Debug)] -pub struct AutotuneJob { - pub id: i32, - pub is_init: bool, - pub db_uri: String, - pub schema: String, - pub table: String, - pub column: String, - pub metric_kind: String, - pub model_name: Option, - pub recall: f64, - pub k: u16, - pub sample_size: usize, - pub create_index: bool, -} - -impl AutotuneJob { - pub fn new(row: Row, db_uri: &str) -> AutotuneJob { - Self { - id: row.get::<&str, i32>("id"), - db_uri: db_uri.to_owned(), - schema: row.get::<&str, String>("schema"), - table: row.get::<&str, String>("table"), - column: row.get::<&str, String>("column"), - metric_kind: row.get::<&str, String>("metric_kind"), - model_name: row.get::<&str, Option>("model"), - recall: row.get::<&str, f64>("target_recall"), - k: row.get::<&str, i32>("k") as u16, - sample_size: row.get::<&str, i32>("sample_size") as usize, - create_index: row.get::<&str, bool>("create_index"), - is_init: true, - } - } -} - -#[derive(Debug)] -pub struct ExternalIndexJob { - pub id: i32, - pub db_uri: String, - pub schema: String, - pub table: String, - pub column: String, - pub operator_class: String, - pub index_name: Option, - pub ef: usize, - pub efc: usize, - pub m: usize, -} - -impl ExternalIndexJob { - pub fn new(row: Row, db_uri: &str) -> ExternalIndexJob { - Self { - id: row.get::<&str, i32>("id"), - db_uri: db_uri.to_owned(), - schema: row.get::<&str, String>("schema"), - table: row.get::<&str, String>("table"), - column: row.get::<&str, String>("column"), - operator_class: row.get::<&str, String>("operator"), - index_name: row.get::<&str, Option>("index"), - ef: row.get::<&str, i32>("ef") as usize, - efc: row.get::<&str, i32>("efc") as usize, - m: row.get::<&str, i32>("m") as usize, - } - } -} - #[derive(Debug)] pub struct JobInsertNotification { pub id: i32, @@ -249,32 +75,43 @@ pub enum ClientJobSignal { Restart, } +#[cfg(feature = "embeddings")] pub type EmbeddingProcessorArgs = ( - EmbeddingArgs, + crate::embeddings::cli::EmbeddingArgs, Sender>, - Logger, + crate::logger::Logger, ); +#[cfg(not(feature = "embeddings"))] +pub type EmbeddingProcessorArgs = (); +#[cfg(feature = "autotune")] pub type AutotuneProcessorArgs = ( - IndexAutotuneArgs, - Sender, + crate::index_autotune::cli::IndexAutotuneArgs, + Sender, JobTaskEventTx, - Option, - Arc>, - Logger, + Option, + std::sync::Arc>, + crate::logger::Logger, ); +#[cfg(not(feature = "autotune"))] +pub type AutotuneProcessorArgs = (); +#[cfg(feature = "external-index")] pub type ExternalIndexProcessorArgs = ( - CreateIndexArgs, - Sender, + crate::external_index::cli::CreateIndexArgs, + Sender, JobTaskEventTx, - Option, - Arc>, - Logger, + Option, + std::sync::Arc>, + crate::logger::Logger, ); +#[cfg(not(feature = "external-index"))] +pub type ExternalIndexProcessorArgs = (); pub enum JobType { Embeddings(Sender), + #[allow(dead_code)] ExternalIndex(Sender), + #[allow(dead_code)] Autotune(Sender), } diff --git a/lantern_cli/src/embeddings/measure_speed.rs b/lantern_cli/src/embeddings/measure_speed.rs index 85b42b1e..291d6179 100644 --- a/lantern_cli/src/embeddings/measure_speed.rs +++ b/lantern_cli/src/embeddings/measure_speed.rs @@ -5,7 +5,7 @@ use super::{ core::{EmbeddingRuntime, Runtime}, }; use crate::logger::{LogLevel, Logger}; -use postgres::NoTls; +use tokio_postgres::NoTls; use tokio_util::sync::CancellationToken; use super::cli::MeasureModelSpeedArgs; diff --git a/lantern_cli/src/main.rs b/lantern_cli/src/main.rs index 77b09fd3..d9229242 100644 --- a/lantern_cli/src/main.rs +++ b/lantern_cli/src/main.rs @@ -1,13 +1,14 @@ -use std::process; - -use crate::logger::{LogLevel, Logger}; -use clap::Parser; -use lantern_cli::*; +#[cfg(feature = "cli")] mod cli; #[cfg(feature = "cli")] #[tokio::main] async fn main() { + use std::process; + + use clap::Parser; + use lantern_cli::logger::{LogLevel, Logger}; + use lantern_cli::*; use tokio_util::sync::CancellationToken; rustls::crypto::aws_lc_rs::default_provider() diff --git a/lantern_cli/src/utils/test_utils.rs b/lantern_cli/src/utils/test_utils.rs index 94ec1889..ba282e67 100644 --- a/lantern_cli/src/utils/test_utils.rs +++ b/lantern_cli/src/utils/test_utils.rs @@ -1,11 +1,27 @@ pub mod daemon_test_utils { - use crate::{daemon, types::AnyhowVoidResult}; + use crate::types::AnyhowVoidResult; use std::{env, time::Duration}; use tokio_postgres::{Client, NoTls}; pub static CLIENT_TABLE_NAME: &'static str = "_lantern_cloud_client1"; pub static CLIENT_TABLE_NAME_2: &'static str = "_lantern_cloud_client2"; + #[cfg(not(feature = "autotune"))] + pub static AUTOTUNE_JOB_TABLE_DEF: &'static str = "(id INT)"; + #[cfg(feature = "autotune")] + pub static AUTOTUNE_JOB_TABLE_DEF: &'static str = + crate::daemon::autotune_jobs::JOB_TABLE_DEFINITION; + #[cfg(not(feature = "external-index"))] + pub static EXTERNAL_INDEX_JOB_TABLE_DEF: &'static str = "(id INT)"; + #[cfg(feature = "external-index")] + pub static EXTERNAL_INDEX_JOB_TABLE_DEF: &'static str = + crate::daemon::external_index_jobs::JOB_TABLE_DEFINITION; + #[cfg(not(feature = "embeddings"))] + pub static EMBEDDING_JOB_TABLE_DEF: &'static str = "(id INT)"; + #[cfg(feature = "embeddings")] + pub static EMBEDDING_JOB_TABLE_DEF: &'static str = + crate::daemon::embedding_jobs::JOB_TABLE_DEFINITION; + async fn drop_db(client: &mut Client, name: &str) -> AnyhowVoidResult { client .execute( @@ -58,9 +74,9 @@ pub mod daemon_test_utils { CREATE TABLE _lantern_extras_internal.external_index_jobs ({indexing_job_table_def}); "#, - embedding_job_table_def = daemon::embedding_jobs::JOB_TABLE_DEFINITION, - autotune_job_table_def = daemon::autotune_jobs::JOB_TABLE_DEFINITION, - indexing_job_table_def = daemon::external_index_jobs::JOB_TABLE_DEFINITION, + embedding_job_table_def = EMBEDDING_JOB_TABLE_DEF, + autotune_job_table_def = AUTOTUNE_JOB_TABLE_DEF, + indexing_job_table_def = EXTERNAL_INDEX_JOB_TABLE_DEF, )) .await?; diff --git a/lantern_cli/tests/embedding_test_with_db.rs b/lantern_cli/tests/embedding_test_with_db.rs index 5f9fc4f7..62dd1c0b 100644 --- a/lantern_cli/tests/embedding_test_with_db.rs +++ b/lantern_cli/tests/embedding_test_with_db.rs @@ -9,7 +9,7 @@ use std::{ use lantern_cli::embeddings::{self, cli::EmbeddingJobType}; use lantern_cli::embeddings::{core::Runtime, get_try_cast_fn_sql}; use lantern_cli::{daemon::embedding_jobs::FAILURE_TABLE_DEFINITION, embeddings::cli}; -use postgres::IsolationLevel; +use tokio_postgres::IsolationLevel; use tokio_postgres::{Client, NoTls}; use tokio_util::sync::CancellationToken; diff --git a/lantern_extras/Cargo.toml b/lantern_extras/Cargo.toml index ccd47c04..3cd6289b 100644 --- a/lantern_extras/Cargo.toml +++ b/lantern_extras/Cargo.toml @@ -27,7 +27,6 @@ url = "2.5" lantern_cli = { path = "../lantern_cli", default-features = false, features = [ "external-index", "embeddings", - "autotune", "daemon", ] } anyhow = "1.0.91"