Skip to content

Commit

Permalink
Streamlined MCMC prior calculations further
Browse files Browse the repository at this point in the history
  • Loading branch information
KeithJF82 committed Jun 26, 2024
1 parent 4398f7a commit b763fc1
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 77 deletions.
10 changes: 3 additions & 7 deletions R/main.R
Original file line number Diff line number Diff line change
Expand Up @@ -465,15 +465,11 @@ param_calc_enviro <- function(enviro_coeffs = c(), enviro_covar_values = c()){

assert_that(all(enviro_coeffs >= 0), msg = "All environmental coefficients must have positive values")
n_env_vars = length(enviro_covar_values)
assert_that(length(enviro_coeffs) %in% c(n_env_vars, 2*n_env_vars), msg = "Wrong number of environmental coefficients")
assert_that(length(enviro_coeffs)==2*n_env_vars, msg = "Wrong number of environmental coefficients")

output = list(FOI = NA, R0 = NA)
#output$FOI = max(0.0, sum(enviro_coeffs[c(1:n_env_vars)]*enviro_covar_values)) #Zero-maximum removed due to MCMC convergence issues
output = list()
output$FOI = sum(enviro_coeffs[c(1:n_env_vars)]*enviro_covar_values)
if(length(enviro_coeffs) == 2*n_env_vars){
#output$R0 = max(0.0, sum(enviro_coeffs[c(1:n_env_vars)+n_env_vars]*enviro_covar_values))
output$R0 = sum(enviro_coeffs[c(1:n_env_vars)+n_env_vars]*enviro_covar_values)
}
output$R0 = sum(enviro_coeffs[c(1:n_env_vars)+n_env_vars]*enviro_covar_values)

return(output)
}
Expand Down
97 changes: 44 additions & 53 deletions R/mcmc.R
Original file line number Diff line number Diff line change
Expand Up @@ -172,11 +172,6 @@ MCMC <- function(log_params_ini = c(), input_data = list(), obs_sero_data = NULL
}

return(NULL)
#Get final parameter values
# param_out = exp(log_params)
# names(param_out) = names(log_params_ini)
#
# return(param_out)
}
#-------------------------------------------------------------------------------
#' @title single_posterior_calc
Expand Down Expand Up @@ -204,48 +199,60 @@ single_posterior_calc <- function(log_params_prop = c(), input_data = list(), ob

consts=list(...)

#Get additional values and calculate associated priors
#Get additional values, calculate initial prior
vaccine_efficacy = p_rep_severe = p_rep_death = m_FOI_Brazil = 1.0
prior_add = 0
prior_like = 0
for(var_name in names(consts$add_values)){
if(var_name %in% consts$extra_estimated_params){
i = match(var_name, names(log_params_prop))
value = exp(as.numeric(log_params_prop[i]))
assign(var_name, value)
if(consts$prior_settings$type == "norm"){
# prior_add = prior_add+log(dtrunc(value, "norm", a = 0, b = 1, mean = consts$prior_settings$norm_params_mean[i],
# sd = consts$prior_settings$norm_params_sd[i]))
prior_add = prior_add+log(dtrunc(value, "norm", a = 1.0e-3, b = 1, mean = consts$prior_settings$norm_params_mean[i],
sd = consts$prior_settings$norm_params_sd[i])) #a set higher than 0 to prevent uninformative probs
prior_like = log(dtrunc(value, "norm", a = 1.0e-3, b = 1, mean = consts$prior_settings$norm_params_mean[i],
sd = consts$prior_settings$norm_params_sd[i])) #a set higher than 0 to prevent uninformative probs
} else {
if(consts$prior_settings$type == "flat"){
if(value<consts$prior_settings$log_params_min[i] || value>consts$prior_settings$log_params_max[i]){prior_add = -Inf}
if(value<consts$prior_settings$log_params_min[i] || value>consts$prior_settings$log_params_max[i]){prior_like = -Inf}
}
}
} else {assign(var_name, consts$add_values[[var_name]])}
}

n_values=2*(ncol(consts$enviro_dataenviro_data)-1)
if(consts$prior_settings$type == "flat"){
for(i in 1:n_values){
if(log_params_prop[i]<consts$prior_settings$log_params_min[i]){prior_like = -Inf}
if(log_params_prop[i]>consts$prior_settings$log_params_max[i]){prior_like = -Inf}
}
} else {
if(consts$prior_settings$type == "norm"){
values=c(1:n_values)
prior_like = prior_like + sum(dnorm(log_params_prop[values], mean = consts$prior_settings$norm_params_mean[values],
sd = consts$prior_settings$norm_params_sd[values], log = TRUE))

}
}


#If prior is finite so far, get FOI and R0 values and calculate associated prior
if(is.finite(prior_add)){
if(is.finite(prior_like)){
regions = input_data$region_labels
n_regions = length(regions)

FOI_R0_data = mcmc_FOI_R0_setup(consts$prior_settings, regions, log_params_prop, consts$enviro_data)
FOI_R0_data = mcmc_FOI_R0_setup(regions, log_params_prop, consts$enviro_data)
FOI_values = FOI_R0_data$FOI_values

for(n_region in 1:n_regions){if(substr(regions[n_region], 1, 3) == "BRA"){FOI_values[n_region] = FOI_values[n_region]*m_FOI_Brazil}}
R0_values = FOI_R0_data$R0_values
if(consts$prior_settings$type == "norm"){
prior_prop = FOI_R0_data$prior + prior_add +
prior_like = prior_like +
sum(log(dtrunc(R0_values, "norm", a = 0, b = Inf, mean = consts$prior_settings$R0_mean, sd = consts$prior_settings$R0_sd))) +
sum(log(dtrunc(FOI_values, "norm", a = 0, b = 1, mean = consts$prior_settings$FOI_mean, sd = consts$prior_settings$FOI_sd)))
} else {
prior_prop = FOI_R0_data$prior+prior_add
}
} else {prior_prop = -Inf}
}

### If prior finite, evaluate likelihood ###
if (is.finite(prior_prop)) {
if (is.finite(prior_like)) {

#Generate modelled data over all regions
dataset <- Generate_Dataset(input_data, FOI_values, R0_values, obs_sero_data, obs_case_data, vaccine_efficacy,
Expand All @@ -266,9 +273,7 @@ single_posterior_calc <- function(log_params_prop = c(), input_data = list(), ob
} else {deaths_like_values = 0}
} else {cases_like_values = deaths_like_values = 0}

# posterior = prior_prop+mean(c(sum(sero_like_values, na.rm = TRUE), sum(cases_like_values, na.rm = TRUE),
# sum(deaths_like_values, na.rm = TRUE)), na.rm = TRUE)
posterior = prior_prop+sum(sero_like_values, na.rm = TRUE)+sum(cases_like_values, na.rm = TRUE)+sum(deaths_like_values, na.rm = TRUE)
posterior = prior_like+sum(sero_like_values, na.rm = TRUE)+sum(cases_like_values, na.rm = TRUE)+sum(deaths_like_values, na.rm = TRUE)

} else {posterior = -Inf}

Expand Down Expand Up @@ -391,52 +396,41 @@ param_prop_setup <- function(log_params = c(), chain_cov = 1, adapt = 0){
#' infection and (optionally) reproduction number values either directly or from environmental covariates. Also
#' calculates related components of prior probability.
#'
#' @param prior_settings List containing settings for priors: must contain text named "type": \cr
#' If type = "zero", prior probability is always zero \cr
#' If type = "flat", prior probability is zero if log parameter values in designated ranges log_params_min and log_params_max,
#' -Inf otherwise; log_params_min and log_params_max included in prior_settings as vectors of same length as log_params_ini \cr
#' If type = "norm", prior probability is given by dnorm calculation on parameter values with settings based on vectors of values
#' in prior_settings: \cr
#' norm_params_mean and norm_params_sd (vectors of mean and standard deviation values applied to log FOI/R0
#' parameters and to actual values of additional parameters) \cr
#' + FOI_mean + FOI_sd (mean + standard deviation of computed FOI, single values) \cr
#' + R0_mean + R0_sd (mean + standard deviation of computed R0, single values) \cr
# @param prior_settings List containing settings for priors: must contain text named "type": \cr
# If type = "zero", prior probability is always zero \cr
# If type = "flat", prior probability is zero if log parameter values in designated ranges log_params_min and log_params_max,
# -Inf otherwise; log_params_min and log_params_max included in prior_settings as vectors of same length as log_params_ini \cr
# If type = "norm", prior probability is given by dnorm calculation on parameter values with settings based on vectors of values
# in prior_settings: \cr
# norm_params_mean and norm_params_sd (vectors of mean and standard deviation values applied to log FOI/R0
# parameters and to actual values of additional parameters) \cr
# + FOI_mean + FOI_sd (mean + standard deviation of computed FOI, single values) \cr
# + R0_mean + R0_sd (mean + standard deviation of computed R0, single values) \cr
#' @param regions Vector of region names
#' @param log_params_prop Proposed values of parameters (natural logarithm of actual parameters)
#' @param enviro_data Environmental data frame, containing only relevant environmental covariate values
#' '
#' @export
#'
mcmc_FOI_R0_setup <- function(prior_settings = list(type = "zero"), regions = "", log_params_prop = c(), enviro_data = list()){
#mcmc_FOI_R0_setup <- function(prior_settings = list(type = "zero"), regions = "", log_params_prop = c(), enviro_data = list()){
mcmc_FOI_R0_setup <- function(regions = "", log_params_prop = c(), enviro_data = list()){

n_regions = length(regions)
FOI_values = R0_values = rep(0, n_regions)

n_env_vars = ncol(enviro_data)-1
n_values = 2*n_env_vars
enviro_coeffs = exp(log_params_prop[c(1:n_values)])
# n_values = 2*n_env_vars
# enviro_coeffs = exp(log_params_prop[c(1:n_values)])
enviro_coeffs = exp(log_params_prop[c(1:(2*n_env_vars))])

for(i in 1:n_regions){
for(i in 1:n_regions){ #TODO - Streamline to calculate all at once
model_params = param_calc_enviro(enviro_coeffs,
as.numeric(enviro_data[enviro_data$region == regions[i], 1+c(1:n_env_vars)]))
FOI_values[i] = model_params$FOI
R0_values[i] = model_params$R0
}

prior = 0
if(prior_settings$type == "norm"){
prior = sum(dnorm(log_params_prop[c(1:n_values)], mean = prior_settings$norm_params_mean[c(1:n_values)],
sd = prior_settings$norm_params_sd[c(1:n_values)], log = TRUE))
} else {
if(prior_settings$type == "flat"){
for(i in 1:n_values){
if(log_params_prop[i]<prior_settings$log_params_min[i]){prior = -Inf}
if(log_params_prop[i]>prior_settings$log_params_max[i]){prior = -Inf}
}
}
}

return(list(FOI_values = FOI_values, R0_values = R0_values, prior = prior))
return(list(FOI_values = FOI_values, R0_values = R0_values))#, prior = prior))
}
#-------------------------------------------------------------------------------
#' @title mcmc_prelim_fit
Expand Down Expand Up @@ -528,9 +522,6 @@ mcmc_prelim_fit <- function(n_iterations = 1, n_param_sets = 1, n_bounds = 1, lo
cat("\nIteration: ", iteration, "\n", sep = "")
all_param_sets <- lhs(n = n_param_sets, rect = cbind(log_params_min, log_params_max))
results = data.frame()
# consts = list(mode_start = mode_start, prior_settings = prior_settings, dt = dt, n_reps = n_reps, enviro_data = enviro_data,
# p_severe_inf = p_severe_inf, p_death_severe_inf = p_death_severe_inf, add_values = add_values,
# deterministic = deterministic, mode_parallel = mode_parallel, cluster = cluster)

for(set in 1:n_param_sets){
cat("\n\tSet: ", set, sep = "")
Expand Down
18 changes: 1 addition & 17 deletions man/mcmc_FOI_R0_setup.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit b763fc1

Please sign in to comment.