Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

validation tools for solver settings #113

Merged
merged 2 commits into from
May 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/julia/ClarabelRs/src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,11 @@ mutable struct Solver{T <: Float64} <: Clarabel.AbstractSolver{Float64}
ptr:: Ptr{Cvoid}

function Solver{T}(ptr) where T

if ptr == C_NULL
throw(ErrorException("Solver constructor failed"))
end

obj = new(ptr)
finalizer(solver_drop_jlrs,obj)
return obj
Expand Down
11 changes: 9 additions & 2 deletions src/julia/interface.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,18 @@ pub(crate) extern "C" fn solver_new_jlrs(
let b = Vec::from(b);

let cones = ccall_arrays_to_cones(jlcones);

let settings = settings_from_json(json_settings);

let solver = DefaultSolver::new(&P, &q, &A, &b, &cones, settings);
// manually validate settings from Julia side
match settings.validate() {
Ok(_) => (),
Err(e) => {
println!("Invalid settings: {}", e);
return std::ptr::null_mut();
}
};

let solver = DefaultSolver::new(&P, &q, &A, &b, &cones, settings);
to_ptr(Box::new(solver))
}

Expand Down
15 changes: 12 additions & 3 deletions src/python/impl_default_py.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use crate::solver::{
};
use num_derive::ToPrimitive;
use num_traits::ToPrimitive;
use pyo3::exceptions::PyException;
use pyo3::prelude::*;
use std::fmt::Write;

Expand Down Expand Up @@ -401,12 +402,20 @@ impl PyDefaultSolver {
b: Vec<f64>,
cones: Vec<PySupportedCone>,
settings: PyDefaultSettings,
) -> Self {
) -> PyResult<Self> {
let cones = _py_to_native_cones(cones);
let settings = settings.to_internal();
let solver = DefaultSolver::new(&P, &q, &A, &b, &cones, settings);

Self { inner: solver }
//manually validate settings from Python side
match settings.validate() {
Ok(_) => (),
Err(e) => {
return Err(PyException::new_err(format!("Invalid settings: {}", e)));
}
}

let solver = DefaultSolver::new(&P, &q, &A, &b, &cones, settings);
Ok(Self { inner: solver })
}

fn solve(&mut self) -> PyDefaultSolution {
Expand Down
110 changes: 110 additions & 0 deletions src/solver/implementations/default/settings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use serde::{Deserialize, Serialize};
/// Standard-form solver type implementing the [`Settings`](crate::solver::core::traits::Settings) trait

#[derive(Builder, Debug, Clone)]
#[builder(build_fn(validate = "Self::validate"))]
#[cfg_attr(feature = "julia", derive(Serialize, Deserialize))]
pub struct DefaultSettings<T: FloatT> {
///maximum number of iterations
Expand Down Expand Up @@ -204,3 +205,112 @@ where
self
}
}

// pre build checker (for auto-validation when using the builder)

/// Automatic pre-build settings validation
impl<T> DefaultSettingsBuilder<T>
where
T: FloatT,
{
pub fn validate(&self) -> Result<(), String> {
// check that the direct solve method is valid
if let Some(ref direct_solve_method) = self.direct_solve_method {
validate_direct_solve_method(direct_solve_method.as_str())?;
}

// check that the chordal decomposition merge method is valid
#[cfg(feature = "sdp")]
if let Some(ref chordal_decomposition_merge_method) =
self.chordal_decomposition_merge_method
{
validate_chordal_decomposition_merge_method(
chordal_decomposition_merge_method.as_str(),
)?;
}

Ok(())
}
}

// post build checker (for ad-hoc validation, e.g. when passing from python/Julia)
// this is not used directly in the solver, but can be called manually by the user

/// Manual post-build settings validation
impl<T> DefaultSettings<T>
where
T: FloatT,
{
pub fn validate(&self) -> Result<(), String> {
validate_direct_solve_method(&self.direct_solve_method)?;

// check that the chordal decomposition merge method is valid
#[cfg(feature = "sdp")]
validate_chordal_decomposition_merge_method(&self.chordal_decomposition_merge_method)?;

Ok(())
}
}

// ---------------------------------------------------------
// individual validation functions go here
// ---------------------------------------------------------

fn validate_direct_solve_method(direct_solve_method: &str) -> Result<(), String> {
match direct_solve_method {
"qdldl" => Ok(()),
#[cfg(feature = "faer-sparse")]
"faer" => Ok(()),
_ => Err(format!(
"Invalid direct_solve_method: {:?}",
direct_solve_method
)),
}
}

#[cfg(feature = "sdp")]
fn validate_chordal_decomposition_merge_method(
chordal_decomposition_merge_method: &str,
) -> Result<(), String> {
match chordal_decomposition_merge_method {
"none" => Ok(()),
"parent_child" => Ok(()),
"clique_graph" => Ok(()),
_ => Err(format!(
"Invalid chordal_decomposition_merge_method: {}",
chordal_decomposition_merge_method
)),
}
}

#[test]
fn test_settings_validate() {
// all standard settings
DefaultSettingsBuilder::<f64>::default().build().unwrap();

// fail on unknown direct solve method
assert!(DefaultSettingsBuilder::<f64>::default()
.direct_solve_method("foo".to_string())
.build()
.is_err());

// fail on solve options in disabled feature
let builder = DefaultSettingsBuilder::<f64>::default()
.direct_solve_method("faer".to_string())
.build();
cfg_if::cfg_if! {
if #[cfg(feature = "faer-sparse")] {
assert!(build.is_ok());
}
else {
assert!(builder.is_err());
}
}

#[cfg(feature = "sdp")]
// fail on unknown chordal decomposition merge method
assert!(DefaultSettingsBuilder::<f64>::default()
.chordal_decomposition_merge_method("foo".to_string())
.build()
.is_err());
}