Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

python data updates. Fixes #75 #156

Merged
merged 1 commit into from
Jan 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions examples/python/example_qp_update.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import clarabel
import numpy as np
from scipy import sparse

# Define problem data
P = sparse.csc_matrix([[6., 0.], [0., 4.]])
P = sparse.triu(P).tocsc()

q = np.array([-1., -4.])

A = sparse.csc_matrix(
[[1., -2.], # <-- LHS of equality constraint (lower bound)
[1., 0.], # <-- LHS of inequality constraint (upper bound)
[0., 1.], # <-- LHS of inequality constraint (upper bound)
[-1., 0.], # <-- LHS of inequality constraint (lower bound)
[0., -1.]]) # <-- LHS of inequality constraint (lower bound)

b = np.array([0., 1., 1., 1., 1.])

cones = [clarabel.ZeroConeT(1), clarabel.NonnegativeConeT(4)]
settings = clarabel.DefaultSettings()
settings.presolve_enable = False

solver = clarabel.DefaultSolver(P, q, A, b, cones, settings)

# complete vector data overwrite
qnew = np.array([0., 0.])

# partial vector data update
bv = np.array([0., 1.])
bi = np.array([1, 2])
bnew = (bi, bv)

# complete matrix data overwrite
Pnew = sparse.csc_matrix([[3., 0.], [0., 4.]]).tocsc()

# complete matrix data update (vector of nonzero values)
# NB: tuple of partial updates also works
Anew = A.data.copy()
Anew[1] = 2.

solver.update(q=qnew, P=Pnew, b=bnew, A=Anew)
3 changes: 2 additions & 1 deletion src/python/cscmatrix_py.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use pyo3::prelude::*;

//We can't implement the foreign trait FromPyObject directly on CscMatrix
//since it is outside the crate, so put a dummy wrapper around it here.
#[pyclass]
pub struct PyCscMatrix(CscMatrix<f64>);

impl Deref for PyCscMatrix {
Expand All @@ -29,7 +30,7 @@ impl<'a> FromPyObject<'a> for PyCscMatrix {

let mut mat = CscMatrix::new(shape[0], shape[1], colptr, rowval, nzval);

// if the python object was non in standard format, force the rust
// if the python object was not in standard format, force the rust
// object to still be nicely formatted
let is_canonical: bool = obj.getattr("has_canonical_format")?.extract()?;

Expand Down
135 changes: 127 additions & 8 deletions src/python/impl_default_py.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,18 @@
#![allow(non_snake_case)]

use super::*;
use crate::solver::{
core::{
traits::{InfoPrint, Settings},
IPSolver, SolverStatus,
use crate::{
algebra::CscMatrix,
solver::{
core::{
traits::{InfoPrint, Settings},
IPSolver, SolverStatus,
},
implementations::default::*,
SolverJSONReadWrite,
},
implementations::default::*,
SolverJSONReadWrite,
};
use pyo3::exceptions::PyException;
use pyo3::prelude::*;
use pyo3::{exceptions::PyException, prelude::*, types::PyDict};
use std::fmt::Write;

//Here we end up repeating several datatypes defined internally
Expand Down Expand Up @@ -389,6 +391,11 @@ impl PyDefaultSettings {
// ----------------------------------
// Solver
// ----------------------------------
impl From<DataUpdateError> for PyErr {
fn from(err: DataUpdateError) -> Self {
PyException::new_err(err.to_string())
}
}

#[pyclass(name = "DefaultSolver")]
pub struct PyDefaultSolver {
Expand Down Expand Up @@ -457,6 +464,118 @@ impl PyDefaultSolver {
self.inner.write_to_file(&mut file)?;
Ok(())
}

#[pyo3(signature = (**kwds))]
fn update(&mut self, kwds: Option<&Bound<'_, PyDict>>) -> PyResult<()> {
for (key, value) in kwds.unwrap().iter() {
let key = key.extract::<String>()?;

match key.as_str() {
"P" => match _py_to_matrix_update(value) {
Some(PyMatrixUpdateData::Matrix(M)) => {
self.inner.update_P(&M)?;
}
Some(PyMatrixUpdateData::Vector(v)) => {
self.inner.update_P(&v)?;
}
Some(PyMatrixUpdateData::Tuple((indices, values))) => {
self.inner.update_P(&(indices, values))?;
}
None => {
return Err(PyException::new_err("Invalid P update data"));
}
},
"A" => match _py_to_matrix_update(value) {
Some(PyMatrixUpdateData::Matrix(M)) => {
self.inner.update_A(&M)?;
}
Some(PyMatrixUpdateData::Vector(v)) => {
self.inner.update_A(&v)?;
}
Some(PyMatrixUpdateData::Tuple((indices, values))) => {
self.inner.update_A(&(indices, values))?;
}
None => {
return Err(PyException::new_err("Invalid A update data"));
}
},
"q" => match _py_to_vector_update(value) {
Some(PyVectorUpdateData::Vector(v)) => {
self.inner.update_q(&v)?;
}
Some(PyVectorUpdateData::Tuple((indices, values))) => {
self.inner.update_q(&(indices, values))?;
}
None => {
return Err(PyException::new_err("Invalid q update data"));
}
},
"b" => match _py_to_vector_update(value) {
Some(PyVectorUpdateData::Vector(v)) => {
self.inner.update_b(&v)?;
}
Some(PyVectorUpdateData::Tuple((indices, values))) => {
self.inner.update_b(&(indices, values))?;
}
None => {
return Err(PyException::new_err("Invalid b update data"));
}
},
_ => {
println!("unrecognized key: {}", key);
}
}
}
Ok(())
}
}

enum PyMatrixUpdateData {
Matrix(CscMatrix<f64>),
Vector(Vec<f64>),
Tuple((Vec<usize>, Vec<f64>)),
}

enum PyVectorUpdateData {
Vector(Vec<f64>),
Tuple((Vec<usize>, Vec<f64>)),
}

impl From<PyVectorUpdateData> for PyMatrixUpdateData {
fn from(val: PyVectorUpdateData) -> Self {
match val {
PyVectorUpdateData::Vector(v) => PyMatrixUpdateData::Vector(v),
PyVectorUpdateData::Tuple((indices, values)) => {
PyMatrixUpdateData::Tuple((indices, values))
}
}
}
}

fn _py_to_matrix_update(arg: Bound<'_, PyAny>) -> Option<PyMatrixUpdateData> {
// try converting to a csc matrix
let csc: Option<CscMatrix<f64>> = arg.extract::<PyCscMatrix>().ok().map(|x| x.into());
if let Some(csc) = csc {
return Some(PyMatrixUpdateData::Matrix(csc));
}

// try as vector data
_py_to_vector_update(arg).map(|x| x.into())
}

fn _py_to_vector_update(arg: Bound<'_, PyAny>) -> Option<PyVectorUpdateData> {
// try converting to a complete data vector
let values: Option<Vec<f64>> = arg.extract().ok();
if let Some(values) = values {
return Some(PyVectorUpdateData::Vector(values));
}

// try converting to a tuple of data and index vectors
let tuple = arg.extract::<(Vec<usize>, Vec<f64>)>().ok();
if let Some(tuple) = tuple {
return Some(PyVectorUpdateData::Tuple(tuple));
}
None
}

#[pyfunction(name = "read_from_file")]
Expand Down
33 changes: 32 additions & 1 deletion src/solver/implementations/default/data_updating.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#![allow(non_snake_case)]
use super::DefaultSolver;
use crate::algebra::*;
use core::iter::Zip;
use core::iter::{zip, Zip};
use core::slice::Iter;
use thiserror::Error;

Expand Down Expand Up @@ -250,6 +250,22 @@ where
}
}

impl<T> MatrixProblemDataUpdate<T> for (Vec<usize>, Vec<T>)
where
T: FloatT,
{
fn update_matrix(
&self,
M: &mut CscMatrix<T>,
lscale: &[T],
rscale: &[T],
cscale: Option<T>,
) -> Result<(), SparseFormatError> {
let z = zip(self.0.iter(), self.1.iter());
z.update_matrix(M, lscale, rscale, cscale)
}
}

impl<T> VectorProblemDataUpdate<T> for [T]
where
T: FloatT,
Expand Down Expand Up @@ -331,3 +347,18 @@ where
Ok(())
}
}

impl<T> VectorProblemDataUpdate<T> for (Vec<usize>, Vec<T>)
where
T: FloatT,
{
fn update_vector(
&self,
v: &mut [T],
vscale: &[T],
cscale: Option<T>,
) -> Result<(), SparseFormatError> {
let z = zip(self.0.iter(), self.1.iter());
z.update_vector(v, vscale, cscale)
}
}
2 changes: 1 addition & 1 deletion tests/data_updating.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ fn test_update_P_matrix_form() {
// original problem
let (P, q, A, b, cones, settings) = updating_test_data();
let mut solver1 = DefaultSolver::new(&P, &q, &A, &b, &cones, settings.clone());
//solver1.solve();
solver1.solve();

// change P and re-solve
let mut P2 = P.to_triu();
Expand Down