Skip to content

Commit d84007c

Browse files
-Fix moment matrix calculation with only numeric
1 parent cb57d2e commit d84007c

File tree

2 files changed

+45
-30
lines changed

2 files changed

+45
-30
lines changed

R/calculate_convex_hull_moment_matrix.R

+41-29
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ interpolate_convex_hull = function(points, ch_halfspace, n_samples_per_dimension
7474
#' the model terms, weighted by whether the point
7575
#' is on the edge.
7676
#'
77-
#' @param data Candidate set
77+
#' @param candidate_set_full Candidate set
7878
#' @param formula Default `~ .`. Model formula specifying the terms.
7979
#' @param n_samples_per_dimension Default `100`. Number of samples to take per dimension when interpolating inside
8080
#' the convex hull.
@@ -99,17 +99,28 @@ gen_momentsmatrix_continuous = function(
9999
# Detect any disallowed combinations
100100
unique_vals = prod(vapply(candidate_set, \(x) {length(unique(x))}, FUN.VALUE = integer(1)))
101101
any_disallowed = unique_vals != nrow(candidate_set)
102+
M_acc = NA
103+
total_weight = 0
102104

103105
# Simple if all numeric: just integrate over the region.
104106
if(length(factor_cols) == 0) {
105107
sub_candidate_set = as.matrix(candidate_set)
106-
ch = convhull_halfspace(sub_candidate_set)
107-
if (ch$volume <= 0) {
108-
next
108+
if(ncol(sub_candidate_set) == 1) {
109+
new_pts_ch = matrix(seq(min(sub_candidate_set),
110+
max(sub_candidate_set),
111+
length.out = n_samples_per_dimension),ncol=1)
112+
interp_ch = list()
113+
interp_ch$on_edge = rep(FALSE, nrow(new_pts_ch))
114+
vol = max(sub_candidate_set) - min(sub_candidate_set)
115+
} else {
116+
ch = convhull_halfspace(sub_candidate_set)
117+
if (ch$volume <= 0) {
118+
next
119+
}
120+
vol = ch$volume
121+
interp_ch = interpolate_convex_hull(as.matrix(sub_candidate_set), ch, n_samples_per_dimension = n_samples_per_dimension)
122+
new_pts_ch = interp_ch$data
109123
}
110-
vol = ch$volume
111-
interp_ch = interpolate_convex_hull(as.matrix(sub_candidate_set), ch, n_samples_per_dimension = n_samples_per_dimension)
112-
new_pts_ch = interp_ch$data
113124

114125
colnames(new_pts_ch) = numeric_cols
115126
interp_df = as.data.frame(new_pts_ch)
@@ -121,23 +132,16 @@ gen_momentsmatrix_continuous = function(
121132
w[interp_ch$on_edge] = 0.5
122133
# average subregion moment
123134
Xsub_w = apply(Xsub,2,\(x) x*sqrt(w))
124-
# M_sub = crossprod(Xsub) / sum(w)
125135

126-
M_sub = crossprod(Xsub_w) / sum(w)
136+
M = crossprod(Xsub_w) / sum(w)
127137

128-
# Weighted accumulation
129-
if (is.null(M_acc)) {
130-
M_acc = vol * M_sub
131-
} else {
132-
M_acc = M_acc + vol * M_sub
133-
}
134-
total_weight = total_weight + vol
135138
#Scale by the intercept
136139
if(colnames(M)[1] == "(Intercept)") {
137140
M = M / M[1,1]
138141
}
139142
return(M)
140143
} else {
144+
M_acc = NA
141145
# For categorical factors with disallowed combinations, we need to account for the
142146
# reduced domain of the integral. We'll calculate a moment matrix as above for each
143147
# factor level combination, weigh it by the total number of points, and sum it. That
@@ -155,7 +159,6 @@ gen_momentsmatrix_continuous = function(
155159
}
156160

157161
# We'll accumulate a weighted sum of sub-matrices
158-
M_acc = NULL
159162
total_weight = 0
160163

161164
for (r in seq_len(nrow(unique_combos))) {
@@ -167,20 +170,28 @@ gen_momentsmatrix_continuous = function(
167170
is_match = is_match & (candidate_set[[fc]] == combo_row[[fc]])
168171
}
169172
sub_candidate_set = candidate_set[is_match, , drop=FALSE]
170-
sub_candidate_set = sub_candidate_set[,is_numeric_col]
173+
sub_candidate_set = sub_candidate_set[,is_numeric_col, drop = FALSE]
171174
# If no rows => disallowed or doesn't appear => skip
172-
if (!nrow(sub_candidate_set)) {
175+
if (nrow(sub_candidate_set) == 0) {
173176
next
174177
}
175-
176-
# Calculate the convex hull and sample points
177-
ch = convhull_halfspace(sub_candidate_set)
178-
if (ch$volume <= 0) {
179-
next
178+
if(ncol(sub_candidate_set) == 1) {
179+
new_pts_ch = matrix(seq(min(sub_candidate_set),
180+
max(sub_candidate_set),
181+
length.out = n_samples_per_dimension),ncol=1)
182+
interp_ch = list()
183+
interp_ch$on_edge = rep(FALSE, nrow(new_pts_ch))
184+
vol = max(sub_candidate_set) - min(sub_candidate_set)
185+
} else {
186+
ch = convhull_halfspace(sub_candidate_set)
187+
if (ch$volume <= 0) {
188+
next
189+
}
190+
vol = ch$volume
191+
interp_ch = interpolate_convex_hull(as.matrix(sub_candidate_set), ch,
192+
n_samples_per_dimension = n_samples_per_dimension)
193+
new_pts_ch = interp_ch$data
180194
}
181-
vol = ch$volume
182-
interp_ch = interpolate_convex_hull(as.matrix(sub_candidate_set), ch, n_samples_per_dimension = n_samples_per_dimension)
183-
new_pts_ch = interp_ch$data
184195

185196
colnames(new_pts_ch) = numeric_cols
186197
interp_df = as.data.frame(new_pts_ch)
@@ -191,7 +202,8 @@ gen_momentsmatrix_continuous = function(
191202
}
192203

193204
# Now build model matrix
194-
Xsub = model.matrix(formula, data = interp_df, contrasts.arg = get_contrasts_from_candset(candidate_set))
205+
Xsub = model.matrix(formula, data = interp_df,
206+
contrasts.arg = get_contrasts_from_candset(candidate_set))
195207

196208
w = rep(1, nrow(Xsub))
197209
w[interp_ch$on_edge] = 0.5
@@ -202,7 +214,7 @@ gen_momentsmatrix_continuous = function(
202214
M_sub = crossprod(Xsub_w) / sum(w)
203215

204216
# Weighted accumulation
205-
if (is.null(M_acc)) {
217+
if (all(is.na(M_acc))) {
206218
M_acc = vol * M_sub
207219
} else {
208220
M_acc = M_acc + vol * M_sub

R/gen_design.R

+4-1
Original file line numberDiff line numberDiff line change
@@ -924,8 +924,11 @@ gen_design = function(candidateset, model, trials,
924924
if(all(classvector)) {
925925
mm = gen_momentsmatrix(factors, levelvector, classvector)
926926
} else {
927+
model_terms_cs = rownames(attr(terms.formula(model),"factors"))
928+
col_in_model = colnames(candidateset) %in% model_terms_cs
929+
candidate_set_mm = candidatesetnormalized[,col_in_model,drop=FALSE]
927930
mm = gen_momentsmatrix_continuous(formula = model,
928-
data = candidatesetnormalized,
931+
candidate_set = candidate_set_mm,
929932
n_samples_per_dimension = 100)
930933
}
931934
if (!parallel) {

0 commit comments

Comments
 (0)