Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature-gate embedding,autotune,external-indexing jobs in daemon #349

Merged
merged 1 commit into from
Nov 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 20 additions & 21 deletions lantern_cli/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@ path = "src/main.rs"
[dependencies]
clap = { version = "4.5.20", features = ["derive"] }
anyhow = "1.0.91"
postgres = "0.19.9"
postgres = { version = "0.19.9", optional = true }
rand = "0.8.5"
linfa-clustering = { version = "0.7.0", features = ["ndarray-linalg"], optional = true }
linfa = {version = "0.7.0", optional = true}
ndarray = { version = "0.15.6", features = ["rayon"] }
ndarray = { version = "0.15.6", features = ["rayon"], optional = true }
rayon = { version="1.10.0", optional = true }
md5 = {version="0.7.0", optional = true }
serde = { version = "1.0", features = ["derive"] }
Expand All @@ -27,19 +27,18 @@ futures = "0.3.31"
tokio = { version = "1.41.0", features = ["full"] }
lazy_static = "1.5.0"
itertools = "0.13.0"
csv = "1.3.0"
sysinfo = "0.32.0"
tiktoken-rs = "0.6.0"
url = "2.5"
num_cpus = "1.16.0"
ort = { version = "1.16.3", features = ["load-dynamic", "cuda", "openvino"] }
tokenizers = { version = "0.20.1", features = ["default"] }
image = { version = "0.25.4", features = ["jpeg", "png", "webp" ]}
nvml-wrapper = "0.10.0"
strum = { version = "0.26", features = ["derive"] }
regex = "1.11.1"
postgres-types = { version = "0.2.8", features = ["derive"] }
usearch = { git = "https://github.com/Ngalstyan4/usearch.git", rev = "aa4f91d21230fd611b6c7741fa06be8c20acc9a9" }
sysinfo = { version = "0.32.0", optional = true }
tiktoken-rs = { version = "0.6.0", optional = true }
url = { version = "2.5", optional = true }
num_cpus = { version = "1.16.0", optional = true }
ort = { version = "1.16.3", features = ["load-dynamic", "cuda", "openvino"], optional = true }
tokenizers = { version = "0.20.1", features = ["default"], optional = true }
image = { version = "0.25.4", features = ["jpeg", "png", "webp" ], optional = true }
nvml-wrapper = { version = "0.10.0", optional = true }
strum = { version = "0.26", features = ["derive"], optional = true }
regex = { version = "1.11.1", optional = true }
postgres-types = { version = "0.2.8", features = ["derive"], optional = true }
usearch = { git = "https://github.com/Ngalstyan4/usearch.git", rev = "aa4f91d21230fd611b6c7741fa06be8c20acc9a9", optional = true }
actix-web = { version = "4.9.0", optional = true }
env_logger = { version = "0.11.5", optional = true }
deadpool-postgres = { version = "0.14.0", optional = true }
Expand All @@ -53,18 +52,18 @@ bitvec = { version="1.0.1", optional=true }
rustls = { version="0.23.16", optional=true }
rustls-pemfile = { version="2.2.0", optional=true }
glob = { version="0.3.1", optional=true }
reqwest = { version = "0.12.9", default-features = false, features = ["json", "blocking", "rustls-tls"] }
reqwest = { version = "0.12.9", default-features = false, features = ["json", "blocking", "rustls-tls"], optional = true }

[features]
default = ["cli", "daemon", "http-server", "autotune", "pq", "external-index", "external-index-server", "embeddings"]
daemon = ["dep:tokio-postgres"]
http-server = ["dep:deadpool-postgres", "dep:deadpool", "dep:bytes", "dep:utoipa", "dep:utoipa-swagger-ui", "dep:actix-web", "dep:tokio-postgres", "dep:env_logger", "dep:actix-web-httpauth"]
http-server = ["dep:deadpool-postgres", "dep:deadpool", "dep:bytes", "dep:utoipa", "dep:utoipa-swagger-ui", "dep:actix-web", "dep:tokio-postgres", "dep:env_logger", "dep:actix-web-httpauth", "dep:regex"]
autotune = []
pq = ["dep:gcp_auth", "dep:linfa", "dep:linfa-clustering", "dep:md5", "dep:rayon"]
pq = ["dep:gcp_auth", "dep:linfa", "dep:linfa-clustering", "dep:md5", "dep:rayon", "dep:reqwest", "dep:postgres", "dep:ndarray"]
cli = []
external-index = []
external-index-server = ["dep:bitvec", "dep:rustls", "dep:rustls-pemfile", "dep:glob"]
embeddings = ["dep:bytes"]
external-index = ["dep:postgres-types", "dep:usearch", "dep:postgres"]
external-index-server = ["dep:bitvec", "dep:rustls", "dep:rustls-pemfile", "dep:glob", "dep:usearch"]
embeddings = ["dep:bytes", "dep:sysinfo", "dep:tiktoken-rs", "dep:url", "dep:num_cpus", "dep:ort", "dep:tokenizers", "dep:image", "dep:nvml-wrapper", "dep:strum", "dep:regex", "dep:reqwest", "dep:ndarray"]

[lib]
doctest = false
12 changes: 6 additions & 6 deletions lantern_cli/src/cli.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
use super::daemon::cli::DaemonArgs;
use super::embeddings::cli::{EmbeddingArgs, MeasureModelSpeedArgs, ShowModelsArgs};
use super::external_index::cli::CreateIndexArgs;
use super::http_server::cli::HttpServerArgs;
use super::index_autotune::cli::IndexAutotuneArgs;
use super::pq::cli::PQArgs;
use clap::{Parser, Subcommand};
use lantern_cli::daemon::cli::DaemonArgs;
use lantern_cli::embeddings::cli::{EmbeddingArgs, MeasureModelSpeedArgs, ShowModelsArgs};
use lantern_cli::external_index::cli::CreateIndexArgs;
use lantern_cli::external_index::cli::IndexServerArgs;
use lantern_cli::http_server::cli::HttpServerArgs;
use lantern_cli::index_autotune::cli::IndexAutotuneArgs;
use lantern_cli::pq::cli::PQArgs;

#[derive(Subcommand, Debug)]
pub enum Commands {
Expand Down
41 changes: 38 additions & 3 deletions lantern_cli/src/daemon/autotune_jobs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ use super::helpers::{
index_job_update_processor, remove_job_handle, set_job_handle, startup_hook,
};
use super::types::{
AutotuneJob, AutotuneProcessorArgs, JobEvent, JobEventHandlersMap, JobInsertNotification,
JobRunArgs, JobUpdateNotification,
AutotuneProcessorArgs, JobEvent, JobEventHandlersMap, JobInsertNotification, JobRunArgs,
JobUpdateNotification,
};
use crate::daemon::helpers::anyhow_wrap_connection;
use crate::external_index::cli::UMetricKind;
Expand All @@ -20,7 +20,7 @@ use tokio::sync::{
mpsc,
mpsc::{Receiver, Sender, UnboundedReceiver, UnboundedSender},
};
use tokio_postgres::{Client, NoTls};
use tokio_postgres::{Client, NoTls, Row};
use tokio_util::sync::CancellationToken;

pub const JOB_TABLE_DEFINITION: &'static str = r#"
Expand Down Expand Up @@ -55,6 +55,41 @@ latency DOUBLE PRECISION NOT NULL,
build_time DOUBLE PRECISION NULL
"#;

#[derive(Debug)]
pub struct AutotuneJob {
pub id: i32,
pub is_init: bool,
pub db_uri: String,
pub schema: String,
pub table: String,
pub column: String,
pub metric_kind: String,
pub model_name: Option<String>,
pub recall: f64,
pub k: u16,
pub sample_size: usize,
pub create_index: bool,
}

impl AutotuneJob {
pub fn new(row: Row, db_uri: &str) -> AutotuneJob {
Self {
id: row.get::<&str, i32>("id"),
db_uri: db_uri.to_owned(),
schema: row.get::<&str, String>("schema"),
table: row.get::<&str, String>("table"),
column: row.get::<&str, String>("column"),
metric_kind: row.get::<&str, String>("metric_kind"),
model_name: row.get::<&str, Option<String>>("model"),
recall: row.get::<&str, f64>("target_recall"),
k: row.get::<&str, i32>("k") as u16,
sample_size: row.get::<&str, i32>("sample_size") as usize,
create_index: row.get::<&str, bool>("create_index"),
is_init: true,
}
}
}

pub async fn autotune_job_processor(
mut rx: Receiver<AutotuneProcessorArgs>,
cancel_token: CancellationToken,
Expand Down
113 changes: 107 additions & 6 deletions lantern_cli/src/daemon/embedding_jobs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,19 @@ use super::helpers::{
remove_job_handle, schedule_job_retry, set_job_handle, startup_hook,
};
use super::types::{
ClientJobsMap, EmbeddingJob, EmbeddingProcessorArgs, JobBatchingHashMap, JobEvent,
JobEventHandlersMap, JobInsertNotification, JobRunArgs, JobUpdateNotification,
ClientJobsMap, EmbeddingProcessorArgs, JobBatchingHashMap, JobEvent, JobEventHandlersMap,
JobInsertNotification, JobRunArgs, JobUpdateNotification,
};
use crate::daemon::helpers::anyhow_wrap_connection;
use crate::embeddings::cli::{EmbeddingArgs, EmbeddingJobType};
use crate::embeddings::cli::{EmbeddingArgs, EmbeddingJobType, Runtime};
use crate::embeddings::core::utils::get_clean_model_name;
use crate::embeddings::get_default_batch_size;
use crate::logger::Logger;
use crate::utils::{get_common_embedding_ignore_filters, get_full_table_name, quote_ident};
use crate::utils::{
get_common_embedding_ignore_filters, get_full_table_name, quote_ident, quote_literal,
};
use crate::{embeddings, types::*};
use itertools::Itertools;
use std::collections::HashMap;
use std::ops::Deref;
use std::path::Path;
Expand All @@ -22,8 +26,7 @@ use std::time::SystemTime;
use tokio::fs;
use tokio::sync::mpsc::{Receiver, Sender, UnboundedReceiver, UnboundedSender};
use tokio::sync::{mpsc, Mutex, RwLock};
use tokio_postgres::types::ToSql;
use tokio_postgres::{Client, NoTls};
use tokio_postgres::{types::ToSql, Client, NoTls, Row};
use tokio_util::sync::CancellationToken;

pub const JOB_TABLE_DEFINITION: &'static str = r#"
Expand Down Expand Up @@ -71,6 +74,104 @@ const EMB_USAGE_TABLE_NAME: &'static str = "embedding_usage_info";
const EMB_FAILURE_TABLE_NAME: &'static str = "embedding_failure_info";
const EMB_LOCK_TABLE_NAME: &'static str = "_lantern_emb_job_locks";

#[derive(Debug, Clone)]
pub struct EmbeddingJob {
pub id: i32,
pub is_init: bool,
pub db_uri: String,
pub schema: String,
pub table: String,
pub column: String,
pub pk: String,
pub filter: Option<String>,
pub label: Option<String>,
pub job_type: EmbeddingJobType,
pub column_type: String,
pub out_column: String,
pub model: String,
pub runtime_params: String,
pub runtime: Runtime,
pub batch_size: Option<usize>,
pub row_ids: Option<Vec<String>>,
}

impl EmbeddingJob {
pub fn new(row: Row, data_path: &str, db_uri: &str) -> Result<EmbeddingJob, anyhow::Error> {
let runtime = Runtime::try_from(row.get::<&str, Option<&str>>("runtime").unwrap_or("ort"))?;
let runtime_params = if runtime == Runtime::Ort {
format!(r#"{{ "data_path": "{data_path}" }}"#)
} else {
row.get::<&str, Option<String>>("runtime_params")
.unwrap_or("{}".to_owned())
};

let batch_size = if let Some(batch_size) = row.get::<&str, Option<i32>>("batch_size") {
Some(batch_size as usize)
} else {
None
};

Ok(Self {
id: row.get::<&str, i32>("id"),
pk: row.get::<&str, String>("pk"),
label: row.get::<&str, Option<String>>("label"),
db_uri: db_uri.to_owned(),
schema: row.get::<&str, String>("schema"),
table: row.get::<&str, String>("table"),
column: row.get::<&str, String>("column"),
out_column: row.get::<&str, String>("dst_column"),
model: get_clean_model_name(row.get::<&str, &str>("model"), runtime),
runtime,
runtime_params,
filter: None,
row_ids: None,
is_init: true,
batch_size,
job_type: EmbeddingJobType::try_from(
row.get::<&str, Option<&str>>("job_type")
.unwrap_or("embedding"),
)?,
column_type: row
.get::<&str, Option<String>>("column_type")
.unwrap_or("REAL[]".to_owned()),
})
}

pub fn set_filter(&mut self, filter: &str) {
self.filter = Some(filter.to_owned());
}

pub fn set_is_init(&mut self, is_init: bool) {
self.is_init = is_init;
}

pub fn set_row_ids(&mut self, row_ids: Vec<String>) {
self.row_ids = Some(row_ids);
}

#[allow(dead_code)]
pub fn set_ctid_filter(&mut self, row_ids: &Vec<String>) {
let row_ctids_str = row_ids
.iter()
.map(|r| {
format!(
"currtid2('{table_name}','{r}'::tid)",
table_name = &self.table
)
})
.join(",");
self.set_filter(&format!("ctid IN ({row_ctids_str})"));
}

pub fn set_id_filter(&mut self, row_ids: &Vec<String>) {
let row_ctids_str = row_ids.iter().map(|s| quote_literal(s)).join(",");
self.set_filter(&format!(
"id IN ({row_ctids_str}) AND {common_filter}",
common_filter = get_common_embedding_ignore_filters(&self.column)
));
}
}

async fn lock_row(
client: Arc<Client>,
lock_table_name: &str,
Expand Down
37 changes: 34 additions & 3 deletions lantern_cli/src/daemon/external_index_jobs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ use super::helpers::{
index_job_update_processor, startup_hook,
};
use super::types::{
ExternalIndexJob, ExternalIndexProcessorArgs, JobEvent, JobEventHandlersMap,
JobInsertNotification, JobRunArgs, JobTaskEventTx, JobUpdateNotification,
ExternalIndexProcessorArgs, JobEvent, JobEventHandlersMap, JobInsertNotification, JobRunArgs,
JobTaskEventTx, JobUpdateNotification,
};
use crate::daemon::helpers::anyhow_wrap_connection;
use crate::external_index::cli::{CreateIndexArgs, UMetricKind};
Expand All @@ -16,7 +16,7 @@ use std::sync::Arc;
use std::time::SystemTime;
use tokio::sync::mpsc::{self, Receiver, Sender, UnboundedReceiver, UnboundedSender};
use tokio::sync::RwLock;
use tokio_postgres::{Client, NoTls};
use tokio_postgres::{Client, NoTls, Row};
use tokio_util::sync::CancellationToken;

pub const JOB_TABLE_DEFINITION: &'static str = r#"
Expand All @@ -39,6 +39,37 @@ pub const JOB_TABLE_DEFINITION: &'static str = r#"
"progress" INT2 DEFAULT 0
"#;

#[derive(Debug)]
pub struct ExternalIndexJob {
pub id: i32,
pub db_uri: String,
pub schema: String,
pub table: String,
pub column: String,
pub operator_class: String,
pub index_name: Option<String>,
pub ef: usize,
pub efc: usize,
pub m: usize,
}

impl ExternalIndexJob {
pub fn new(row: Row, db_uri: &str) -> ExternalIndexJob {
Self {
id: row.get::<&str, i32>("id"),
db_uri: db_uri.to_owned(),
schema: row.get::<&str, String>("schema"),
table: row.get::<&str, String>("table"),
column: row.get::<&str, String>("column"),
operator_class: row.get::<&str, String>("operator"),
index_name: row.get::<&str, Option<String>>("index"),
ef: row.get::<&str, i32>("ef") as usize,
efc: row.get::<&str, i32>("efc") as usize,
m: row.get::<&str, i32>("m") as usize,
}
}
}

async fn set_job_handle(
jobs_map: Arc<JobEventHandlersMap>,
job_id: i32,
Expand Down
Loading
Loading