Browse Source

Change to new `!!` and `!!!` unquote (splicing) operators / Avoid joining messages

pull/3/head
Philipp Baumann 6 years ago
parent
commit
f47c334678
  1. 29
      R/pls-modeling.R

29
R/pls-modeling.R

@ -19,7 +19,7 @@ split_data_q <- function(
# Slice based on sample_id if spectral data is in tibble class
if (tibble::is_tibble(spec_chem)) {
spec_chem <- spec_chem %>%
dplyr::group_by(rlang::UQ(rlang::sym("sample_id"))) %>%
dplyr::group_by(!!rlang::sym("sample_id")) %>%
dplyr::slice(1L)
}
@ -305,16 +305,16 @@ train_rf_q <- function(x,
transform_cvpredictions <- function(cal_index, predobs_cv) {
predobs_cv <- dplyr::full_join(cal_index, predobs_cv, by = "rowIndex") %>%
dplyr::group_by(rlang::UQ(rlang::sym("sample_id"))) %>%
dplyr::group_by(!!rlang::sym("sample_id")) %>%
# Average observed and predicted values
dplyr::mutate("obs" = mean(rlang::UQ(rlang::sym("obs"))),
"pred_sd" = sd(rlang::UQ(rlang::sym("pred")))) %>%
dplyr::mutate("obs" = mean(!!rlang::sym("obs")),
"pred_sd" = sd(!!rlang::sym("pred"))) %>%
# Add 95% confidence interval for mean hold-out predictions from
# repeated k-fold cross-validation
dplyr::mutate_at(.vars = dplyr::vars(rlang::UQ(rlang::sym("pred"))),
dplyr::mutate_at(.vars = dplyr::vars(!!rlang::sym("pred")),
.funs = dplyr::funs("pred_sem_ci" = sem_ci)) %>%
# Add mean hold-out predictions from repeated k-fold cross-validation
dplyr::mutate("pred" = mean(rlang::UQ(rlang::sym("pred")))) %>%
dplyr::mutate("pred" = mean(!!rlang::sym("pred"))) %>%
# Slice data set to only have one row per sample_id
dplyr::slice(1L)
}
@ -387,7 +387,7 @@ evaluate_model_q <- function(x, model, response,
# Alternative solution for one model: conformal::GetCVPreds(model) function
# see https://github.com/cran/conformal/blob/master/R/misc.R
predobs_cv <- plyr::ldply(list_models,
function(x) plyr::match_df(x$pred, x$bestTune),
function(x) dplyr::anti_join(x$pred, x$bestTune, by = "ncomp"),
.id = "model"
)
# Extract auto-prediction
@ -407,19 +407,20 @@ evaluate_model_q <- function(x, model, response,
# train object; select only rowIndex and sample_id of calibration tibble
vars_indexing <- c("rowIndex", "sample_id")
cal_index <- dplyr::select(x$calibration,
rlang::UQS(rlang::syms(vars_indexing)))
!!!rlang::syms(vars_indexing))
# Transform cross-validation hold-out predictions --------------------------
predobs_cv <- transform_cvpredictions(cal_index = cal_index,
predobs_cv = predobs_cv)
predobs_cv$object <- predobs_cv$model
predobs_cv$dataType <- "Cross-validation"
predobs_cv$model <- factor(predobs_cv$model)
predobs_cv$dataType <- factor("Cross-validation")
vars_keep <- c("obs", "pred", "pred_sd", "pred_sem_ci",
"model", "dataType", "object")
predobs_cv <- dplyr::select(predobs_cv,
# !!! sample_id newly added
rlang::UQS(rlang::syms(vars_keep))
!!!rlang::syms(vars_keep)
)
# Add column pred_sd to predobs data frame (assign values to 0) so that
# column pred_sd is retained in predobs_cv after dplyr::bind_rows
@ -433,8 +434,8 @@ evaluate_model_q <- function(x, model, response,
# predobs <- dplyr::bind_rows(predobs, predobs_cv)
# Calculate model performance indexes by model and dataType
# uses package plyr and function summary.df of SPECmisc.R
stats <- plyr::ddply(predobs, c("model", "dataType"),
function(x) summary_df(x, "obs", "pred")
stats <- suppressWarnings(plyr::ddply(predobs, c("model", "dataType"),
function(x) summary_df(x, "obs", "pred"))
)
}
# Add number of components to stats; from finalModel list item
@ -469,7 +470,7 @@ evaluate_model_q <- function(x, model, response,
obs_val <- subset(predobs, dataType == "Cross-validation")$obs
df_range <- data.frame(
response = rep(response_name, 2),
dataType = c("Calibration", "Cross-validation"),
dataType = factor(c("Calibration", "Cross-validation")),
min_obs = c(range(obs_cal)[1], range(obs_val)[1]),
median_obs = c(median(obs_cal), median(obs_val)),
max_obs = c(range(obs_cal)[2], range(obs_val)[2]),
@ -480,7 +481,7 @@ evaluate_model_q <- function(x, model, response,
}
# Join stats with range data frame (df_range)
stats <- plyr::join(stats, df_range, type = "inner")
stats <- suppressWarnings(dplyr::inner_join(stats, df_range, by = "dataType"))
annotation <- plyr::mutate(stats,
rmse = as.character(as.expression(paste0("RMSE == ",
round(rmse, 2)))),

Loading…
Cancel
Save