3 The core function: css

The main inputs that css() accepts are the following:

  • a design matrix X
  • a response y
  • a selection method fitfun with specified tuning parameter lambda. This specifies \(\hat{S}^\lambda(A)\). In particular:
###"param-fitfun"###
#' @param fitfun A function; the feature selection function used on each
#' subsample by cluster stability selection. This can be any feature selection
#' method; the only requirement is that it accepts the arguments (and only the
#' arguments) `X`, `y`, and `lambda` and returns an integer vector that is a 
#' subset of `1:p`. For example, `fitfun` could be best subset selection or 
#' forward stepwise selection or LARS and `lambda` could be the desired model 
#' size; or `fitfun` could be the elastic net and `lambda` could be a length-two
#' vector specifying lambda and alpha. Default is `cssLasso`, an implementation 
#' of lasso (relying on the R package `glmnet`), where `lambda` must be a 
#' positive numeric specifying the L1 penalty for the `lasso`.

and

###"param-lambda"###
#' @param lambda A tuning parameter or set of tuning parameters that may be used
#' by the feature selection method `fitfun`. In the default case when
#' `fitfun = cssLasso`, lambda should be a numeric: the penalty to use for each
#' lasso fit. (`css()` does not require lambda to be any particular object because
#' for a user-specified feature selection method `fitfun`, lambda can be an
#' arbitrary object. See the description of `fitfun` below.)
  • a specification of which features belong together in highly correlated clusters. This specifies the clusters \(C_1,\ldots, C_K\). In particular:
###"param-clusters"###
#' @param clusters A list of integer vectors; each vector should contain the 
#' indices of a cluster of features (a subset of `1:p`). (If there is only one
#' cluster, clusters can either be a list of length 1 or an integer vector.)
#' All of the provided clusters must be non-overlapping. Every feature not
#' appearing in any cluster will be assumed to be unclustered (that is, they
#' will be treated as if they are in a "cluster" containing only themselves). If
#' clusters is a list of length 0 (or a list only containing clusters of length
#' 1), then `css()` returns the same results as stability selection (so the
#' returned `feat_sel_mat` will be identical to `clus_sel_mat`). Names for the
#' clusters will be needed later; any clusters that are not given names in the
#' provided list will be given names automatically by `css()`. #' CAUTION: if
#' the provided X is a data.frame that contains a categorical feature with more
#' than two levels, then the resulting matrix made from model.matrix will have a
#' different number of columns than the provided data.frame, some of the feature
#' numbers will change, and the clusters argument will not work properly (in the
#' current version of the package). To get correct results in this case, please
#' use model.matrix to convert the data.frame to a numeric matrix on your own,
#' then provide this matrix and cluster assignments with respect to this matrix.
#' Default is `list()` (so no clusters are specified).
###"param-sampling_type"###
#' @param sampling_type A character vector; either "SS" or "MB". For "MB",
#' all B subsamples are drawn randomly (as proposed by Meinshausen and Bühlmann
#' 2010). For "SS", in addition to these B subsamples, the B complementary pair
#' subsamples will be drawn as well (see Faletto and Bien 2022 or Shah and
#' Samworth 2013 for details). Default is "SS", and "MB" is not supported yet.

The number of subsamples is also up to the user:

###"param-B"###
#' @param B Integer or numeric; the number of subsamples. Note: For
#' `sampling_type=="MB"` the total number of subsamples will be `B`; for
#' `sampling_type="SS"` the number of subsamples will be `2*B`. Default is 100
#' for `sampling_type="MB"` and 50 for `sampling_type="SS"`.

In this section, we will build the function css() step-by-step.

  • First, css() creates random subsamples of the data: \(A_1,\ldots, A_B\) (and \(\bar A_1,\ldots, \bar A_B\) for Shah-Samworth subsampling). This is implemented in the function createSubsamples().

  • Next, css() executes the specified feature selection method \(\hat S^\lambda(\cdot)\) on each subsample. This generates a binary matrix, feat_sel_mat, with one row for each subsample and one column for each feature in X. In each row, the entry corresponding to each feature equals 1 if the feature was selected on that subsample and 0 otherwise. feat_sel_mat is one of the outputs of css(); it is used to calculate the selection proportions for individual features. This is implemented in the function getSelMatrix().

  • So far this is the same as the original stability selection procedure. Now we take into account the clusters. The function getClusterSelMatrix() takes in the formatted clusters as well as feat_sel_mat and generates the binary matrix clus_sel_mat, which contains selection indicators for each cluster. clus_sel_mat has the same number of rows as feat_sel_mat and one column for each cluster rather than one for each feature. This is used to calculate the cluster selection proportions.

The function css() is specified below. (There are some extra details that I omitted in the above description for brevity; you can read the full details below.)

#' Cluster Stability Selection
#'
#' Executes cluster stability selection algorithm. Takes subsamples of data,
#' executes feature selection algorithm on each subsample, and returns matrices
#' of feature selection indicators as well as cluster selection indicators.
#'
#' @param X An n x p numeric matrix (preferably) or a data.frame (which will
#' be coerced internally to a matrix by the function model.matrix) containing
#' p >= 2 features/predictors.
#' @param y The response; can be anything that takes the form of an
#' n-dimensional vector, with the ith entry corresponding to the ith row of X.
#' Typically (and for default `fitfun = cssLasso`), `y` should be an n-dimensional
#' numeric vector.
<<param-lambda>>
<<param-clusters>>
<<param-fitfun>>
<<param-sampling_type>>
<<param-B>>
#' @param prop_feats_remove Numeric; if `prop_feats_remove` is greater than 0,
#' then on each subsample, each feature is randomly dropped from the design
#' matrix that is provided to `fitfun` with probability `prop_feats_remove`
#' (independently across features). That is, in a typical subsample,
#' `prop_feats_remove*p` features will be dropped (though this number will vary).
#' This is similar in spirit (but distinct from) extended stability selection
#' (Beinrucker et. al. 2016); see their paper for some of the benefits of
#' dropping features (besides increasing computational speed and decreasing
#' memory requirements). For `sampling_type="SS"`, the features dropped in
#' each complementary pair of subsamples are identical in order to ensure that
#' the theoretical guarantees of Faletto and Bien (2022) are retained within
#' each individual pair of subsamples. (Note that this feature is not
#' investigated either theoretically or in simulations by Faletto and Bien
#' 2022). Must be between 0 and 1. Default is 0.
#' @param train_inds Optional; an integer or numeric vector containing the
#' indices of observations in `X` and `y` to set aside for model training by the
#' function `getCssPreds()` after feature selection. (This will only work if `y` is
#' real-valued, because `getCssPreds()` using ordinary least squares regression to
#' generate predictions.) If `train_inds` is not provided, all of the observations
#' in the provided data set will be used for feature selection.
#' @param num_cores Optional; an integer. If using parallel processing, the
#' number of cores to use for parallel processing (`num_cores` will be supplied
#' internally as the `mc.cores` argument of `parallel::mclapply()`).
#' @return A list containing the following items:
#' \item{`feat_sel_mat`}{A `B` (or `2*B` for `sampling_type="SS"`) x `p` numeric (binary) matrix. `feat_sel_mat[i, j] = 1` if feature `j` was selected by the base feature selection method on subsample `i`, and 0 otherwise.}
#' \item{`clus_sel_mat`}{A `B` (or `2*B` for SS sampling) x `length(clusters)` numeric (binary) matrix. `clus_sel_mat[i, j] = 1` if at least one feature from cluster j was selected by the base feature selection method on subsample `i`, and 0 otherwise.}
#' \item{`X`}{The `X` matrix provided to `css()`, coerced from a data.frame to a 
#' matrix if needed.}
#' \item{`y`}{The `y` vector provided to `css()`.}
#' \item{`clusters`}{A named list of integer vectors containing all of the clusters provided to `css()`, as well as size 1 clusters of any features not listed in any
#' of the clusters provided to `css()`. All clusters will have names; any 
#' clusters not provided with a name in the input to `css()` will be given names
#' automatically by `css()` (of the form c1, etc.).}
#' \item{`train_inds`}{Identical to the `train_inds` provided to `css()`.}
#' @author Gregory Faletto, Jacob Bien
#' @references
<<faletto2022>>
#' 
<<shah2013>>
#' 
#' Meinshausen, N., & Bühlmann, P. (2010). Stability Selection. \emph{Journal of the Royal
#' Statistical Society. Series B: Statistical Methodology}, 72(4), 417–473.
#' \url{https://rss.onlinelibrary.wiley.com/doi/full/10.1111/j.1467-9868.2010.00740.x}.
#' 
#' Beinrucker, A., Dogan, Ü., &
#' Blanchard, G. (2016). Extensions of stability selection using subsamples of
#' observations and covariates. \emph{Statistics and Computing}, 26(5), 1059-
#' 1077. \url{https://doi.org/10.1007/s11222-015-9589-y}.
#' 
#' Jerome Friedman, Trevor Hastie, Robert Tibshirani (2010). Regularization Paths for Generalized
#' Linear Models via Coordinate Descent. \emph{Journal of Statistical Software},
#' 33(1), 1-22. URL \url{https://www.jstatsoft.org/v33/i01/}.
#' @export
css <- function(X, y, lambda, clusters = list(), fitfun = cssLasso,
    sampling_type = "SS", B = ifelse(sampling_type == "MB", 100L, 50L),
    prop_feats_remove = 0, train_inds = integer(), num_cores = 1L
    ){

    # Check inputs

    check_list <- checkCssInputs(X, y, lambda, clusters, fitfun, sampling_type,
        B, prop_feats_remove, train_inds, num_cores)

    feat_names <- check_list$feat_names
    X <- check_list$X
    clusters <- check_list$clusters

    rm(check_list)

    n <- nrow(X)
    p <- ncol(X)

    train_inds <- as.integer(train_inds)

    ### Create subsamples

    sel_inds <- setdiff(1:n, train_inds)
    n_sel <- length(sel_inds)
    if(n_sel < 4){
        stop("Too many training indices provided (must be at least 4 observations left for feature selection, and ideally many more)")
    }

    subsamps_object <- createSubsamples(n_sel, p, B, sampling_type,
        prop_feats_remove)

    ### Get matrix of selected feature sets from subsamples

    stopifnot(!is.matrix(y))

    feat_sel_mat <- getSelMatrix(X[sel_inds, ], y[sel_inds], lambda, B,
        sampling_type, subsamps_object, num_cores, fitfun)

    if(any(!is.na(feat_names))){
        colnames(feat_sel_mat) <- feat_names
        colnames(X) <- feat_names
    }

    ### Get selection proportions for clusters corresponding to each feature

    clus_sel_mat <- getClusterSelMatrix(clusters, feat_sel_mat)

    # Check outputs
    stopifnot(!is.null(colnames(clus_sel_mat)))
    stopifnot(all(colnames(clus_sel_mat) == names(clusters)))

    ret <- list(feat_sel_mat = feat_sel_mat,
        clus_sel_mat = clus_sel_mat,
        X = X,
        y = y,
        clusters = clusters,
        train_inds = train_inds
        )

    class(ret) <- "cssr"

    return(ret)
}

Tests for css() and the other functions in this section are located at the end of this document. Next, we’ll build the individual functions one at a time. (The one function above that I didn’t mention at all is checkCssInputs(), which confirms that the inputs to css() are as expected and does some basic formatting on the inputs. To streamline presentation, this and other helper functions are specified later.)

3.1 Main components of css

3.1.1 Generating subsamples

The function createSubsamples() is responsible for generating the subsamples \(A_b\), where \(b\) ranges from 1 to \(B\) or 1 to \(2B\) depending on the type of subsampling:

#' Creates lists of subsamples for stability selection.
#'
#' @param n Integer or numeric; sample size of the data set.
#' @param p Integer or numeric; number of features.
<<param-B>>
<<param-sampling_type>>
#' @param num_feats_remove Integer; number of features select automatically on
#' every iteration. Determined earlier from input prop_feats_remove to css.
#' @return A list of length `B` (or `2*B` for `sampling_type = "SS"`). If
#' `prop_feats_remove = 0`, each list element is an integer vector of length
#' `floor(n/2)` containing the indices of a subsample of `1:n`. (For
#' `sampling_type=="SS"`, the last `B` subsamples will be complementary pairs of
#' the first `B` subsamples; see Faletto and Bien 2022 or Shah and Samworth 2013
#' for details.) If `prop_feats_remove > 0`, each element is a named list with
#' members "subsample" (same as above) and "feats_to_keep", a logical vector
#' of length `p`; `feats_to_keep[j] = TRUE` if feature `j` is chosen for this
#' subsample, and false otherwise.
#' @author Gregory Faletto, Jacob Bien
#' @references
<<faletto2022>>
#' 
<<shah2013>>
createSubsamples <- function(n, p, B, sampling_type, prop_feats_remove=0){

    # Check inputs

    stopifnot(length(n) == 1)
    stopifnot(is.numeric(n) | is.integer(n))
    stopifnot(n == round(n))
    stopifnot(n > 0)

    stopifnot(length(p) == 1)
    stopifnot(is.numeric(p) | is.integer(p))
    stopifnot(p == round(p))
    stopifnot(p > 0)

    checkSamplingType(sampling_type)
    checkPropFeatsRemove(prop_feats_remove, p)

    if(prop_feats_remove == 0){
        subsamples <- getSubsamps(n, B, sampling_type)
        return(subsamples)
    } else{
        # In this case, we generate subsamples as well as logical vectors
        # of features to keep
        subsamps_and_feats <- list()
        subsamples <- getSubsamps(n, B, sampling_type)
        for(i in 1:B){
            # Logical p-vector, where each entry is TRUE with probability
            # 1 - prop_feats_remove
            feats_to_keep_i <- as.logical(stats::rbinom(n=p, size=1,
                prob=1 - prop_feats_remove))
            # Make sure at least two entries are equal to TRUE (so that at
            # least two features are present for every subsample)--if not,
            # randomly choose features to add
            while(sum(feats_to_keep_i) < 2){
                false_inds <- which(!feats_to_keep_i)
                sel_feat <- sample(false_inds, size=1)
                feats_to_keep_i[sel_feat] <- TRUE
            }
            subsamps_and_feats[[i]] <- list(subsample=subsamples[[i]],
                feats_to_keep=feats_to_keep_i)
        }

        if(sampling_type=="SS"){
            stopifnot(length(subsamples) == 2*B)
            for(i in 1:B){
                # Keep the same features as in the other subsample (this
                # ensures that the theoretical guarantee of Shah and Samworth
                # 2013 remains valid on every individual pair of subsamples)
                subsamps_and_feats[[B + i]] <- list(subsample=subsamples[[B + i]],
                    feats_to_keep=subsamps_and_feats[[i]]$feats_to_keep)
            }
        }

        # Check output
        stopifnot(all(names(subsamps_and_feats) == c("subsample",
            "feats_to_keep")))

        return(subsamps_and_feats)
    }
    # Shouldn't be possible to reach this part of function
    stop("createSubsamples failed to return anything")
}

Notice that createSubsamples() calls some helper functions to check the inputs (again, these are specified later in the helper functions section). createSubsamples() also calls the workhorse function getSubsamps() to generate a list of subsamples.

#' Generate list of subsamples
#'
#` Generate list of `B` (or `2*B` for sampling_type="SS") subsamples of size
#` `n/2`
#' @param n Integer or numeric; sample size of the data set.
<<param-B>>
<<param-sampling_type>>
#' @return A list of length `B` (or `2*B` for `sampling_type="SS"`), where each
#' element is an integer vector of length `floor(n/2)` containing the indices
#' of a subsample of `1:n`. For `sampling_type=="SS"`, the last `B` subsamples
#' will be complementary pairs of the first `B` subsamples (see Faletto and
#' Bien 2022 or Shah and Samworth 2013 for details).
#' @author Gregory Faletto, Jacob Bien
#' @references
#' 
<<faletto2022>>
#' 
<<shah2013>>
getSubsamps <- function(n, B, sampling_type){
    subsamples <- list()
    for(i in 1:B){
        subsamples[[i]] <- sort(sample.int(n=n, size=floor(n/2), replace=FALSE))
    }
    stopifnot(length(subsamples) == B)
    # TODO(gregfaletto): add support for sampling_type="MB"
    if(sampling_type=="SS"){
        for(i in 1:B){
            # For the ith entry, take a subsample of size floor(n/2) from the
            # remaining n - floor(n/2) observations. (Only necessary to actually
            # take the subsample if n is odd; if not, the set we want is
            # setdiff(1:n, subsamples[[i]]), so skip the sample function.)
            if(n/2 == floor(n/2)){
                subsamples[[B + i]] <- sort(setdiff(1:n, subsamples[[i]]))
            } else{
                subsamples[[B + i]] <- sort(sample(x=setdiff(1:n,
                    subsamples[[i]]), size=floor(n/2), replace=FALSE))
            }

            # Check output

            stopifnot(is.integer(subsamples[[B + i]]))
            stopifnot(all(subsamples[[B + i]] ==
                round(subsamples[[B + i]])))
            stopifnot(floor(n/2) == length(subsamples[[B + i]]))
            stopifnot(length(subsamples[[B + i]]) ==
                length(unique(subsamples[[B + i]])))
        }
        stopifnot(length(subsamples) == 2*B)
    }
    return(subsamples)
}

3.1.2 Forming selection matrices

The next function called in the body of css() is getSelMatrix(), which records for each feature \(j\) and subsample \(b\) whether \(j\in \hat S^\lambda(A_b)\):

#' Generates matrix of selection indicators from stability selection.
#'
#' @param x an n x p numeric matrix or a data.frame containing the predictors.
#' @param y A response vector; can be any response that takes the form of a
#' length n vector and is used (or not used) by fitfun. Typically (and for
#' default fitfun = cssLasso), y should be an n-dimensional numeric vector
#' containing the response.
<<param-lambda>>
<<param-B>>
<<param-sampling_type>>
#' @param subsamps_object A list of length `B` (or `2*B` if `sampling_type="SS"`),
#' where each element is one of the following: \item{subsample}{An integer
#' vector of size `n/2` containing the indices of the observations in the
#' subsample.} \item{drop_var_input}{A named list containing two elements: one
#' named "subsample", matching the previous description, and a logical vector
#' named "feats_to_keep" containing the indices of the features to be
#' automatically selected.} (The first object is the output of the function
#' createSubsamples when the provided prop_feats_remove is 0, the default, and
#' the second object is the output of createSubsamples when prop_feats_remove >
#' 0.)
#' @param num_cores Optional; an integer. If using parallel processing, the
#' number of cores to use for parallel processing (num_cores will be supplied
#' internally as the mc.cores argument of parallel::mclapply).
<<param-fitfun>>
#' @return A binary integer matrix of dimension `B` x `p` (if sampling_type ==
#' "MB") or `2*B` x `p` (if sampling_type == "SS"). res[i, j] = 1 if feature j
#' was selected on subsample i and equals 0 otherwise. (That is, each row is a
#' selected set.)
#' @author Gregory Faletto, Jacob Bien
getSelMatrix <- function(x, y, lambda, B, sampling_type, subsamps_object,
    num_cores, fitfun=cssLasso){

    # Check inputs

    stopifnot(is.matrix(x))
    stopifnot(all(!is.na(x)))

    n <- nrow(x)
    p <- ncol(x)

    stopifnot(length(y) == n)
    stopifnot(!is.matrix(y))
    # Intentionally don't check y or lambda further to allow for flexibility--these
    # inputs should be checked within fitfun.

    checkSamplingType(sampling_type)

    # Get list of selected feature sets from subsamples

    res_list <- parallel::mclapply(X=subsamps_object, FUN=cssLoop, x=x, y=y,
        lambda=lambda, fitfun=fitfun, mc.cores=num_cores)

    # Store selected sets in B x p (or `2*B` x p for "SS") binary matrix
    if(sampling_type=="SS"){
        res <- matrix(0L, 2*B, p)
    } else if(sampling_type=="MB"){
        res <- matrix(0L, B, p)
    } else{
        stop("!(sampling_type %in% c(SS, MB)) (don't know how to set dimensions of res")
    }

    stopifnot(length(res_list) == nrow(res))

    # Get selected sets into matrix res

    for(i in 1:nrow(res)){
        if(length(res_list[[i]]) == 0){
            # If no features are selected, don't fill in anything in this row
            next
        }

        if(!is.integer(res_list[[i]])){
            print(paste("failed on iteration", i))
            print(res_list[[i]])
            stop("Something seems to be wrong with the feature selection method (fitfun failed to return an integer vector)")
        }
        stopifnot(length(res_list[[i]]) <= p & length(res_list[[i]]) > 0)
        stopifnot(all(!is.na(res_list[[i]])))
        stopifnot(max(res_list[[i]]) <= p)
        stopifnot(min(res_list[[i]]) >= 1)
        stopifnot(length(res_list[[i]]) == length(unique(res_list[[i]])))
        stopifnot(!("try-error" %in% class(res_list[[i]]) |
            "error" %in% class(res_list[[i]]) |
            "simpleError" %in% class(res_list[[i]]) |
            "condition" %in% class(res_list[[i]])))

        # Store selected variables in the ith row of res
        res[i, res_list[[i]]] <- 1L
    }

    # Check output

    stopifnot(is.matrix(res))
    if(sampling_type=="SS"){
        stopifnot(nrow(res) == 2*B)
    } else{
        stopifnot(nrow(res) == B)
    }
    stopifnot(ncol(res) == p)
    stopifnot(all(res %in% c(0, 1)))

    return(res)
}

Again, getSelMatrix() uses some helper functions to check the inputs for safety. getSelMatrix() leverages the function parallel::mclapply() package, in order to allow for parallel processing if the user has set this up. This requires a helper function cssLoop(), which is also specified in the later section for helper functions. We add the parallel package to the DESCRIPTION file:

usethis::use_package("parallel")
## ✔ Adding 'parallel' to Imports field in DESCRIPTION
## • Refer to functions with `parallel::fun()`

getSelMatrix() also uses a feature selection function fitfun as an input. fitfun can be provided by the user as an input to css(). We provide one default fitfun, cssLasso(), which works on a real-valued response y given a specified penalty parameter lambda > 0. Both for its own importance and to show an example of a valid fitfun, we show cssLasso() here.

#' Provided fitfun implementing the lasso
#'
#' Function used to select features with the lasso on each subsample in cluster
#' stability selection. Uses glmnet implementation of the lasso.
#' @param X A design matrix containing the predictors. (In practice this will
#' be a subsample of the full design matrix provided to `css()`.)
#' @param y A numeric vector containing the response.
#' @param lambda Numeric; a nonnegative number for the lasso penalty to use
#' on each subsample. (For now, only one lambda value can be provided to
#' `cssLasso()`; in the future, we plan to allow for multiple lambda values to be
#' provided to `cssLasso()`, as described in Faletto and Bien 2022.)
#' @return An integer vector; the indices of the features selected by the lasso.
#' @author Gregory Faletto, Jacob Bien
#' @references 
#' 
<<faletto2022>>
#' 
#' Jerome Friedman, Trevor Hastie,
#' Robert Tibshirani (2010). Regularization Paths for Generalized Linear Models
#' via Coordinate Descent. \emph{Journal of Statistical Software}, 33(1), 1-22.
#' URL \url{https://www.jstatsoft.org/v33/i01/}.
#' @export
cssLasso <- function(X, y, lambda){
    # Check inputs

    # TODO(gregfaletto) allow cssLasso to accept a vector of lambda values rather
    # than just a single one.
    checkCssLassoInputs(X, y, lambda)

    n <- nrow(X)
    p <- ncol(X)

    # Fit a lasso path (full path for speed, per glmnet documentation)

    lasso_model <- glmnet::glmnet(X, y, family="gaussian")
    stopifnot(all.equal(class(lasso_model), c("elnet", "glmnet")))

    # Get coefficients at desired lambda

    pred <- glmnet::predict.glmnet(lasso_model, type="nonzero",
        s=lambda, exact=TRUE, newx=X, x=X, y=y)

    if(is.null(pred[[1]])){return(integer())}

    stopifnot(is.data.frame(pred))
    stopifnot(!("try-error" %in% class(pred) | "error" %in% class(pred) |
        "simpleError" %in% class(pred) | "condition" %in% class(pred)))

    if(length(dim(pred)) == 2){
        selected_glmnet <- pred[, 1]
    } else if(length(dim(pred)) == 3){
        selected_glmnet <- pred[, 1, 1]
    } else if(length(dim(pred)) == 1){
        selected_glmnet <- pred
    } else{
        stop("length(dim(pred)) not 1, 2, or 3")
    }

    stopifnot(length(selected_glmnet) >= 1)
    stopifnot(length(selected_glmnet) <= ncol(X))
    stopifnot(all(selected_glmnet == round(selected_glmnet)))
    stopifnot(length(selected_glmnet) == length(unique(selected_glmnet)))
    selected_glmnet <- as.integer(selected_glmnet)

    selected <- sort(selected_glmnet)

    return(selected)
}

Notice that cssLasso() depends on the package glmnet, and calls the function checkCssLassoInputs(), which verifies that the inputs to cssLasso() are as needed. (In particular, the function css() imposes very few requirements on lambda and y to allow the end user flexibility for any specified fitfun. checkCssLassoInputs() ensures that y is a real-valued response, lambda is a nonnegative real number, etc.) checkCssLassoInputs() is specified in the helper functions section later.

We add glmnet to the DESCRIPTION file as well:

usethis::use_package("glmnet")
## ✔ Adding 'glmnet' to Imports field in DESCRIPTION
## • Refer to functions with `glmnet::fun()`

Finally, getClusterSelMatrix() is the last significant function called within css(). It takes the information from getSelMatrix(), i.e. whether feature \(j\in\hat S^\lambda(A_b)\), and outputs for every cluster \(C\) whether \(C\cap S^\lambda(A_b)\neq\emptyset\).

#' Get cluster selection matrix
#'
#' Given a matrix of feature selection indicator variables and a list of
#' clusters of features, returns a matrix of cluster indicator variables.
#'
#' @param clusters A named list where each entry is an integer vector of indices
#' of features that are in a common cluster, as in the output of formatClusters.
#' (The length of list clusters is equal to the number of clusters.) All
#' identified clusters must be non-overlapping, and all features must appear in
#' exactly one cluster (any unclustered features should be in their own
#' "cluster" of size 1).
#' @param res A binary integer matrix. es[i, j] = 1 if feature j was selected
#' on subsample i and equals 0 otherwise, as in the output of `getSelMatrix()`.
#' (That is, each row is a selected set.)
#' @return A binary integer matrix with the same number of rows
#' as res and length(clusters) columns. Entry i, j is 1 if at least
#' one member of cluster j was selected on subsample i, and 0 otherwise.
#' @author Gregory Faletto, Jacob Bien
getClusterSelMatrix <- function(clusters, res){

    # Check input
    checkGetClusterSelMatrixInput(clusters, res)

    p <- ncol(res)

    n_clusters <- length(clusters)

    # Matrix of cluster selection proportions (with n_clusters columns)
    res_n_clusters <- matrix(rep(0L, nrow(res)*n_clusters), nrow=nrow(res),
        ncol=n_clusters)
    colnames(res_n_clusters) <- names(clusters)

    for(j in 1:n_clusters){
        # Identify rows of res where at least one feature from this cluster
        # was selected
        rows_j_sel <- apply(res, 1, function(x){any(x[clusters[[j]]] == 1)})

        # Put ones in these rows of res_n_clusters[, j]
        res_n_clusters[rows_j_sel, j] <- 1L

        if(length(clusters[[j]]) <= 1){
            next
        }
    }

    # Check output
    stopifnot(is.matrix(res_n_clusters))
    stopifnot(identical(colnames(res_n_clusters), names(clusters)))
    stopifnot(all(res_n_clusters %in% c(0, 1)))
    stopifnot(ncol(res_n_clusters) == length(clusters))

    return(res_n_clusters)
}

The helper function checkGetClusterSelMatrixInput() verifies the inputs to getClusterSelMatrix().

With the above and the helper functions below, the css() function is complete!

3.2 Helper functions

Below, we specify some of the helper functions for css(), which are less important for understanding what css() does.

  • The function checkCssInputs() is called at the beginning of the css() function. It confirms that the inputs to css() are as expected. Also, css() allows some flexibility in how the inputs are formatted, and checkCssInputs() converts the inputs into consistent formats for later use.
    • The function checkCssClustersInput() is called within checkCssInputs(), and specifically checks the clusters input.
    • formatClusters() modifies clusters, ensuring every feature appears in exactly one cluster (any unclustered features are put in a “cluster” by themselves).
      • checkFormatClustersInput() verifies the input to formatClusters() for safety (this might seem a little redundant here, but formatClusters() is also called by different functions in this package).
        • checkY() ensures (here and elsewhere) that the provided y is as expected. (In particular, in general the inputted y can be a vector of any type as long as the specified fitfun can handle y appropriately. For some function inputs, y must be a numeric vector. checkY() enforces this, among other things.)
      • checkClusters() verifies that the output of formatClusters() is as expected.
      • getPrototypes() identifies the prototypes of each cluster–the feature within each cluster that has the largest marginal sample correlation with y.
    • checkSamplingType() ensures that the input sampling_type is as expected.
    • checkB() ensures that the input B is as expected.
    • checkPropFeatsRemove() ensures that the input prop_feats_remove is as expected.
  • The function cssLoop() is called within css(). In particular, it is a helper function needed in order to allow for parallel processing of the feature selection method on each of the subsamples.
    • checkCssLoopOutput() verifies that the output of cssLoop() is as expected. This is particularly important because the user can specify their own feature selection method; checkCssLoopOutput() verifies that the output of the feature selection method is valid, and provides an informative error message if not.
  • checkCssLassoInputs() checks that the inputs to the provided fitfun, cssLasso(), are as needed.
  • checkGetClusterSelMatrixInput() verifies that the input to getClusterSelMatrix() is as expected.

Tests are written for all of these functions and appear as early as possible after the function is defined. (For functions that don’t depend on any other functions, tests appear immediately after the function; for functions with dependencies, the tests appear after all dependencies have been defined.)

checkCssInputs():

#' Helper function to confirm that inputs to the function `css()` are as expected,
#' and modify inputs if needed
#'
#' @param X An n x p numeric matrix (preferably) or a data.frame (which will
#' be coerced internally to a matrix by the function model.matrix) containing
#' p >= 2 features/predictors.
#' @param y The response; can be anything that takes the form of an
#' n-dimensional vector, with the ith entry corresponding to the ith row of X.
#' Typically (and for default fitfun = cssLasso), y should be an n-dimensional
#' numeric vector.
<<param-lambda>>
<<param-clusters>>
#' @param fitfun A function; the feature selection function used on each
#' subsample by cluster stability selection. This can be any feature selection
#' method; the only requirement is that it accepts the arguments (and only the
#' arguments) X, y, and lambda and returns an integer vector that is a subset of
#' 1:p. For example, fitfun could be best subset selection or forward stepwise
#' selection or LARS and lambda could be the desired model size; or fitfun could be the
#' elastic net and lambda could be a length-two vector specifying lambda and
#' alpha. Default is cssLasso, an implementation of lasso (relying on the R
#' package glmnet), where lambda must be a positive numeric specifying the L1
#' penalty for the lasso.
#' @param sampling_type A character vector; either "SS" or "MB". For "MB",
#' all B subsamples are drawn randomly (as proposed by Meinshausen and Bühlmann
#' 2010). For "SS", in addition to these B subsamples, the B complementary pair
#' subsamples will be drawn as well (see Faletto and Bien 2022 or Shah and
#' Samworth 2013 for details). Default is "SS", and "MB" is not supported yet.
#' @param B Integer or numeric; the number of subsamples. Note: For
#' sampling.type=="MB" the total number of subsamples will be `B`; for
#' sampling_type="SS" the number of subsamples will be `2*B`. Default is 100
#' for sampling_type="MB" and 50 for sampling_type="SS".
#' @param prop_feats_remove Numeric; if prop_feats_remove is greater than 0,
#' then on each subsample, each feature is randomly dropped from the design
#' matrix that is provided to fitfun with probability prop_feats_remove
#' (independently across features). That is, in a typical subsample,
#' prop_feats_remove*p features will be dropped (though this number will vary).
#' This is similar in spirit (but distinct from) extended stability selection
#' (Beinrucker et. al. 2016); see their paper for some of the benefits of
#' dropping features (besides increasing computational speed and decreasing
#' memory requirements). For sampling_type="SS", the features dropped in
#' each complementary pair of subsamples are identical in order to ensure that
#' the theoretical guarantees of Faletto and Bien (2022) are retained within
#' each individual pair of subsamples. (Note that this feature is not
#' investigated either theoretically or in simulations by Faletto and Bien
#' 2022). Must be between 0 and 1. Default is 0.
#' @param train_inds Optional; an integer or numeric vector containing the
#' indices of observations in X and y to set aside for model training by the
#' function getCssPreds after feature selection. (This will only work if y is
#' real-valued, because getCssPreds using ordinary least squares regression to
#' generate predictions.) If train_inds is not provided, all of the observations
#' in the provided data set will be used for feature selection.
#' @param num_cores Optional; an integer. If using parallel processing, the
#' number of cores to use for parallel processing (num_cores will be supplied
#' internally as the mc.cores argument of parallel::mclapply).
#' @return A named list with the following elements: \item{feat_names}{A 
#' character vector containing the column names of X (if the provided X
#' had column names). If the provided X did not have column names, feat_names
#' will be NA.} \item{X}{The provided X, converted to a matrix if it was
#' originally provided as a data.frame, and with feature names removed if they
#' had been provided.}\item{clusters}{A list of integer vectors; each vector
#' will contain the indices of a cluster of features. Any duplicated clusters
#' provided in the input will be removed.}
#' @author Gregory Faletto, Jacob Bien
checkCssInputs <- function(X, y, lambda, clusters, fitfun, sampling_type, B,
    prop_feats_remove, train_inds, num_cores){

    stopifnot(is.matrix(X) | is.data.frame(X))

    clust_names <- as.character(NA)
    if(!is.null(names(clusters)) & is.list(clusters)){
        clust_names <- names(clusters)
    }

    # Check if x is a matrix; if it's a data.frame, convert to matrix.
    if(is.data.frame(X)){
        p <- ncol(X)
        
        X <- stats::model.matrix(~ ., X)
        X <- X[, colnames(X) != "(Intercept)"]

        if(length(clusters) > 0 & (p != ncol(X))){
            stop("When stats::model.matrix converted the provided data.frame X to a matrix, the number of columns changed (probably because the provided data.frame contained a factor variable with at least three levels). Please convert X to a matrix yourself using model.matrix and provide cluster assignments according to the columns of the new matrix.")
        }
    }

    stopifnot(is.matrix(X))
    stopifnot(all(!is.na(X)))

    feat_names <- as.character(NA)
    if(!is.null(colnames(X))){
        feat_names <- colnames(X)
    }

    n <- nrow(X)
    p <- ncol(X)

    if(!is.null(colnames(X))){
        feat_names <- colnames(X)
    }

    stopifnot(p >= 2)
    if(length(feat_names) > 1){
        stopifnot(length(feat_names) == p)
    } else{
        stopifnot(is.na(feat_names))
    }

    colnames(X) <- character()

    stopifnot(length(y) == n)
    # Intentionally don't check y or lambda further to allow for flexibility--these
    # inputs should be checked within fitfun.

    # Check clusters argument
    clusters <- checkCssClustersInput(clusters)

    ### Format clusters into a list where all features are in exactly one
    # cluster (any unclustered features are put in their own "cluster" of size
    # 1).
    clusters <- formatClusters(clusters, p=p, clust_names=clust_names)$clusters

    stopifnot(class(fitfun) == "function")
    stopifnot(length(fitfun) == 1)
    if(!identical(formals(fitfun), formals(cssLasso))){
        err_mess <- paste("fitfun must accept arguments named X, y, and lambda. Detected arguments to fitfun:",
            paste(names(formals(fitfun)), collapse=", "))
        stop(err_mess)
    }

    checkSamplingType(sampling_type)
    checkB(B)
    checkPropFeatsRemove(prop_feats_remove, p)

    stopifnot(is.numeric(train_inds) | is.integer(train_inds))
    if(length(train_inds) > 0){
        stopifnot(all(!is.na(train_inds)))
        stopifnot(all(round(train_inds) == train_inds))
        stopifnot(length(train_inds) == length(unique(train_inds)))
        stopifnot(all(train_inds >= 1))
        stopifnot(all(train_inds) <= n)
        stopifnot(length(train_inds) <= n - 2)
        stopifnot(length(train_inds) >= 1)
    }

    stopifnot(length(num_cores) == 1)
    stopifnot(is.integer(num_cores) | is.numeric(num_cores))
    stopifnot(!is.na(num_cores))
    stopifnot(num_cores == round(num_cores))
    stopifnot(num_cores >= 1)
    stopifnot(num_cores <= parallel::detectCores())

    return(list(feat_names=feat_names, X=X, clusters=clusters))
}

checkCssClustersInput:

#' Helper function to confirm that clusters input to css is as expected
#'
<<param-clusters>>
#' @return Same as the input, but all of the clusters will be coerced to
#' integers.
#' @author Gregory Faletto, Jacob Bien
checkCssClustersInput <- function(clusters){
    stopifnot(!is.na(clusters))
    if(is.list(clusters)){
        stopifnot(all(!is.na(clusters)))
        stopifnot(length(clusters) == length(unique(clusters)))

        if(length(clusters) > 0){
            for(i in 1:length(clusters)){
                stopifnot(length(clusters[[i]]) == length(unique(clusters[[i]])))
                stopifnot(all(!is.na(clusters[[i]])))
                stopifnot(is.integer(clusters[[i]]) | is.numeric(clusters[[i]]))
                stopifnot(all(clusters[[i]] == round(clusters[[i]])))
                stopifnot(all(clusters[[i]] >= 1))
                clusters[[i]] <- as.integer(clusters[[i]])
            }

            if(length(clusters) >= 2){
                # Check that clusters are non-overlapping
                for(i in 1:(length(clusters) - 1)){
                    for(j in (i+1):length(clusters)){
                        if(length(intersect(clusters[[i]], clusters[[j]])) != 0){
                            error_mes <- paste("Overlapping clusters detected; clusters must be non-overlapping. Overlapping clusters: ",
                                i, ", ", j, ".", sep="")
                            stop(error_mes)
                        }
                    }
                }
            }
        }
    } else{
        # If clusters is not a list, it should be a vector of indices of
        # features that are in the (one) cluster
        stopifnot(is.numeric(clusters) | is.integer(clusters))
        stopifnot(length(clusters) == length(unique(clusters)))
        stopifnot(all(!is.na(clusters)))
        stopifnot(is.integer(clusters) | is.numeric(clusters))
        stopifnot(all(clusters == round(clusters)))
        stopifnot(all(clusters >= 1))
        clusters <- as.integer(clusters)
    }
    return(clusters)
}

tests for checkCssClustersInput:

testthat::test_that("checkCssClustersInput works", {
  
  # Intentionally don't provide clusters for all feature, mix up formatting,
  # etc.
  good_clusters <- list(red_cluster=1L:4L, green_cluster=5L:8L)
  
  res <- checkCssClustersInput(good_clusters)
  
  # clusters
  testthat::expect_true(is.list(res))
  testthat::expect_equal(length(res), length(names(res)))
  testthat::expect_equal(length(res), length(unique(names(res))))
  testthat::expect_true(all(!is.na(names(res))))
  testthat::expect_true(all(!is.null(names(res))))
  clust_feats <- integer()
  for(i in 1:length(res)){
    clust_feats <- c(clust_feats, res[[i]])
  }
  testthat::expect_equal(length(clust_feats), length(unique(clust_feats)))
  testthat::expect_equal(length(clust_feats), length(intersect(clust_feats,
                                                               1:8)))

  ## Trying other inputs
  
  unnamed_clusters <- list(1L:3L, 5L:8L)
  
  res <- checkCssClustersInput(unnamed_clusters)
  
  # clusters
  testthat::expect_true(is.list(res))
  clust_feats <- integer()
  for(i in 1:length(res)){
    clust_feats <- c(clust_feats, res[[i]])
  }
  testthat::expect_equal(length(clust_feats), length(unique(clust_feats)))
  testthat::expect_equal(length(clust_feats), length(intersect(clust_feats,
                                                               1:8)))
  
  testthat::expect_error(checkCssClustersInput(list(1:4, 4:6)),
                         "Overlapping clusters detected; clusters must be non-overlapping. Overlapping clusters: 1, 2.",
                         fixed=TRUE)
  
  testthat::expect_error(checkCssClustersInput(list(2:3, 2:3)),
                         "length(clusters) == length(unique(clusters)) is not TRUE",
                         fixed=TRUE)
  
  testthat::expect_error(checkCssClustersInput(list(2:3, as.integer(NA))),
                         "!is.na(clusters) are not all TRUE",
                         fixed=TRUE)
  
  testthat::expect_error(checkCssClustersInput(list(2:3, c(4, 4, 5))),
                         "length(clusters[[i]]) == length(unique(clusters[[i]])) is not TRUE",
                         fixed=TRUE)
  
  testthat::expect_error(checkCssClustersInput(list(2:3, -1)),
                         "all(clusters[[i]] >= 1) is not TRUE",
                         fixed=TRUE)
  
  testthat::expect_error(checkCssClustersInput(c(0.4, 0.6)),
                         "all(clusters == round(clusters)) is not TRUE",
                         fixed=TRUE)
  
  # Single cluster
  res_sing_clust <- checkCssClustersInput(2:5)
  testthat::expect_equal(length(res_sing_clust), 4)


})
## Test passed 😀

formatClusters():

#' Formats clusters in standardized way, optionally estimating cluster
#' prototypes
#'
#' @param clusters Either an integer vector of a list of integer vectors; each
#' vector should contain the indices of a cluster of features. (If there is only
#' one cluster, clusters can either be a list of length 1 or simply an integer
#' vector.) If clusters is specified then R is ignored.
#' @param p integer or numeric; the numbe of features in x (should match 
#' ncol(x), if x is provided)
#' @param clust_names A character vector of the names of the clusters in
#' clusters.
#' @param get_prototypes Logical: if TRUE, will identify prototype from each
#' cluster (the feature from each cluster that is most correlated with the
#' response) for the protolasso. In this case, x and y must be provided.
#' @param x n x p numeric matrix; design matrix. Only needs to be provided if
#' get_prototypes is TRUE.
#' @param y Numeric response vector; only needs to be provided if get_prototypes
#' is TRUE. Note: in general, the css function does not require y to be a
#' numeric vector, because the provided fitfun could use a different form of y
#' (for example, a categorical response variable). However, y must be numeric in
#' order to provide prototypes because the prototypes are determined using the
#' correlation between cluster members (columns of x) and y.
#' @param R Numeric p x p matrix; not currently used. Entry ij contains the 
#' "substitutive value" of feature i for feature j (diagonal must consist of
#' ones, all entries must be between 0 and 1, and matrix must be symmetric)
#' @return A named list with the following elements: \item{clusters}{A named
#' list where each entry is an integer vector of indices of features that are in
#' a common cluster. (The length of list clusters is equal to the number of
#' clusters.) All identified clusters are non-overlapping. All features appear
#' in exactly one cluster (any unclustered features will be put in their own
#' "cluster" of size 1).} \item{multiple}{Logical; TRUE if there is more than
#' one cluster of size greater than 1, FALSE otherwise.} \item{prototypes}{only
#' returned if get_prototypes=TRUE. An integer vector whose length is equal to
#' the number of clusters. Entry i is the index of the feature belonging to
#' cluster i that is most highly correlated with y (that is, the prototype for
#' the cluster, as in the protolasso; see Reid and Tibshirani 2016).}
#' @author Gregory Faletto, Jacob Bien
#' @references Reid, S., & Tibshirani, R. (2016). Sparse regression and marginal
#' testing using cluster prototypes. \emph{Biostatistics}, 17(2), 364–376.
#' \url{https://doi.org/10.1093/biostatistics/kxv049}.
formatClusters <- function(clusters=NA, p=-1, clust_names=NA, 
    get_prototypes=FALSE, x=NA, y=NA, R=NA){

    # Check inputs
    clusters <- checkFormatClustersInput(clusters, p, clust_names,
        get_prototypes, x, y, R)

    n <- nrow(x)

    multiple <- FALSE

    if(any(lengths(clusters) > 1)){ # & length(clusters) > 1
        # Only care about clusters with more than one element (only ones that
        # need to be treated differently)
        # keep track of whether there's more than one cluster or not
        multiple <- sum(lengths(clusters) > 1) > 1
    }

    # For any features not already in a cluster, add a cluster containing only
    # that feature
    orig_length_clusters <- length(clusters)

    stopifnot(p >= 1)
    for(i in 1:p){
        feat_i_found <- FALSE
        if(orig_length_clusters > 0){
            for(j in 1:orig_length_clusters){
                # If i is in cluster j, break out of this loop and consider the
                # next i
                if(i %in% clusters[[j]]){
                    feat_i_found <- TRUE
                    break
                }
            }
        }

        # If feature i wasn't found in any cluster, add a cluster containing
        # just feature i
        if(!feat_i_found){
            clusters[[length(clusters) + 1]] <- i
        }
    }

    n_clusters <- length(clusters)

    # Add names to clusters
    if(is.null(names(clusters))){
        names(clusters) <- paste("c", 1:n_clusters, sep="")
    } else{
        # What clusters need names?
        unnamed_clusts <- which(is.na(names(clusters)) | names(clusters) == "")
        if(length(unnamed_clusts) > 0){
            proposed_clust_names <- paste("c", unnamed_clusts, sep="")
            # Replace any proposed cluster names that are already in use
            if(any(proposed_clust_names %in% names(clusters))){
                proposed_clust_names[proposed_clust_names %in% names(clusters)] <- paste("c",
                    unnamed_clusts[proposed_clust_names %in% names(clusters)] + p,
                    sep="")
            }
            while_counter <- 0
            while(any(proposed_clust_names %in% names(clusters))){
                proposed_clust_names[proposed_clust_names %in% names(clusters)] <- paste(proposed_clust_names[proposed_clust_names %in% names(clusters)],
                    "_1", sep="")
                while_counter <- while_counter + 1
                if(while_counter >= 100){
                    stop("Function formatClusters stuck in an infinite while loop")
                }
            }
            stopifnot(length(unnamed_clusts) == length(proposed_clust_names))
            names(clusters)[unnamed_clusts] <- proposed_clust_names
        }
    }

    # Check output

    checkClusters(clusters, p)
    stopifnot(is.logical(multiple))
    stopifnot(length(multiple) == 1)
    stopifnot(!is.na(multiple))

    if(get_prototypes){
        prototypes <- getPrototypes(clusters, x, y)

        return(list(clusters=clusters, multiple=multiple,
            prototypes=prototypes))
    } else{
        return(list(clusters=clusters, multiple=multiple))
    }
}

checkFormatClustersInput():

#' Helper function to ensure that the inputs to formatClusters are as expected
#'
#' @param clusters Either an integer vector of a list of integer vectors; each
#' vector should contain the indices of a cluster of features. (If there is only
#' one cluster, clusters can either be a list of length 1 or simply an integer
#' vector.) If clusters is specified then R is ignored.
#' @param p integer or numeric; the numbe of features in x (should match 
#' ncol(x), if x is provided)
#' @param clust_names A character vector of the names of the clusters in clusters.
#' @param get_prototypes Logical: if TRUE, will identify prototype from each
#' cluster (the feature from each cluster that is most correlated with the
#' response) for the protolasso. In this case, x and y must be provided.
#' @param x n x p numeric matrix; design matrix. Only needs to be provided if
#' get_prototypes is TRUE.
#' @param y Numeric response vector; only needs to be provided if get_prototypes
#' is TRUE. Note: in general, the css function does not require y to be a
#' numeric vector, because the provided fitfun could use a different form of y
#' (for example, a categorical response variable). However, y must be numeric in
#' order to provide prototypes because the prototypes are determined using the
#' correlation between cluster members (columns of x) and y.
#' @param R Numeric p x p matrix; not currently used. Entry ij contains the 
#' "substitutive value" of feature i for feature j (diagonal must consist of
#' ones, all entries must be between 0 and 1, and matrix must be symmetric)
#' @return A list of integer vectors; each vector will contain the indices of a
#' cluster of features. Any duplicated clusters provided in the input will be
#' removed.
#' @author Gregory Faletto, Jacob Bien
checkFormatClustersInput <- function(clusters, p, clust_names, 
    get_prototypes, x, y, R){

    if(any(is.na(clusters)) & any(is.na(R))){
        stop("Must specify one of clusters or R (or does one of these provided inputs contain NA?)")
    }

    stopifnot(is.integer(p) | is.numeric(p))
    stopifnot(length(p) == 1)
    stopifnot(p == round(p))
    stopifnot(!is.na(p))
    if(p > 0){
        stopifnot(p >= 2)
    }

    use_R <- FALSE
    if(any(is.na(clusters)) | length(clusters) == 0){
        if(all(!is.na(R))){
            stopifnot(is.matrix(R))
            stopifnot(all(dim(R) == p))
            stopifnot(all(diag(R) == 1))
            stopifnot(identical(R, t(R)))
            stopifnot(all(!is.na(R)))
            stopifnot(all(R %in% c(0, 1)))
            use_R <- TRUE
        }
    } else{
        stopifnot(!is.list(clusters) | all(lengths(clusters) >= 1))
        stopifnot(is.list(clusters) | length(clusters) >= 1)
        stopifnot(all(!is.na(clusters)))

        if(is.list(clusters) & length(clusters) > 0){
            for(i in 1:length(clusters)){
                stopifnot(length(clusters[[i]]) == length(unique(clusters[[i]])))
                stopifnot(all(!is.na(clusters[[i]])))
                stopifnot(all(clusters[[i]] >= 1))
                stopifnot(is.integer(clusters[[i]]))
            }

            stopifnot(length(clusters) == length(unique(clusters)))

            clusters <- clusters[!duplicated(clusters)]

            if(length(clusters) >= 2){
                # Check that clusters are non-overlapping
                for(i in 1:(length(clusters) - 1)){
                    for(j in (i+1):length(clusters)){
                        stopifnot(length(intersect(clusters[[i]],
                            clusters[[j]])) == 0)
                    }
                }
            }

            if(any(!is.na(clust_names))){
                stopifnot(length(clust_names) == length(clusters))
            }
        } else if(!is.list(clusters)){
            clusters_temp <- clusters
            clusters <- list()
            clusters[[1]] <- clusters_temp
            rm(clusters_temp)
        }
    }

    stopifnot(length(get_prototypes) == 1)
    stopifnot(is.logical(get_prototypes))

    if(any(!is.na(clust_names))){
        stopifnot(is.character(clust_names))
    }

    if(get_prototypes){
        stopifnot(all(!is.na(x)))
        stopifnot(is.matrix(x))

        n <- nrow(x)

        checkY(y, n)
    }

    if(use_R){
        # Determine clusters from R
        clusters <- list()

        for(i in 1:nrow(R)){
            clusters[[i]] <- as.integer(which(R[i, ] > 0))
            stopifnot(length(clusters[[i]]) == length(unique(clusters[[i]])))
            stopifnot(all(!is.na(clusters[[i]])))
            stopifnot(is.integer(clusters[[i]]))
        }

        clusters <- unique(clusters)
        stopifnot(is.list(clusters))

        if(length(clusters) >= 2){
            # Check that clusters are non-overlapping
            for(i in 1:(length(clusters) - 1)){
                for(j in (i+1):length(clusters)){
                    if(length(intersect(clusters[[i]], clusters[[j]])) != 0){
                        stop("Invalid R matrix with overlapping clusters (clusters must not be overlapping)")
                    }
                }
            }
        }
    }

    stopifnot(is.list(clusters))

    return(clusters)
}

checkY():

#' Helper function to confirm that the argument y to several functions is
#' as expected
#'
#' @param y Numeric response vector.
#' @param n Number of observations of covariates; should match length of y.
#' @author Gregory Faletto, Jacob Bien
checkY <- function(y, n){
    stopifnot(all(!is.na(y)))
    stopifnot(is.numeric(y) | is.integer(y))
    stopifnot(length(unique(y)) > 1)
    stopifnot(length(n) == 1)
    stopifnot(!is.na(n))
    stopifnot(is.numeric(n) | is.integer(n))
    stopifnot(n == round(n))
    stopifnot(n > 0)
    stopifnot(n == length(y))
}

Tests for checkY():

testthat::test_that("checkY works", {
  testthat::expect_null(checkY(as.numeric(1:20)*.1, 20))
  testthat::expect_null(checkY(1L:15L, 15))
  testthat::expect_error(checkY(1:7, 8), "n == length(y) is not TRUE",
                         fixed=TRUE)
  testthat::expect_error(checkY(1:7, -7), "n > 0 is not TRUE", fixed=TRUE)
  testthat::expect_error(checkY(rep(as.numeric(NA), 13), 13),
                         "all(!is.na(y)) is not TRUE", fixed=TRUE)
  testthat::expect_error(checkY(rep(5.2, 9), 9),
                         "length(unique(y)) > 1 is not TRUE", fixed=TRUE)
  testthat::expect_error(checkY(c(TRUE, FALSE, TRUE), 3),
                         "is.numeric(y) | is.integer(y) is not TRUE",
                         fixed=TRUE)
})
## Test passed 🥇

Tests for checkFormatClustersInput():

testthat::test_that("checkFormatClustersInput works", {
  
  # Intentionally don't provide clusters for all feature, mix up formatting,
  # etc.
  good_clusters <- list(red_cluster=1L:4L, green_cluster=5L:8L)
  
  res <- checkFormatClustersInput(good_clusters, p=10,
                                  clust_names=c("red_cluster", "green_cluster"),
                                  get_prototypes=FALSE, x=NA, y=NA, R=NA)
  
  testthat::expect_true(is.list(res))
  clust_feats <- integer()
  for(i in 1:length(res)){
    clust_feats <- c(clust_feats, res[[i]])
  }
  testthat::expect_equal(length(clust_feats), length(unique(clust_feats)))
  testthat::expect_equal(length(clust_feats), length(intersect(clust_feats,
                                                               1:8)))

  ## Trying other inputs
  
  unnamed_clusters <- list(1L:3L, 5L:8L)

  res <- checkFormatClustersInput(unnamed_clusters, p=10, clust_names=NA,
                                  get_prototypes=FALSE, x=NA, y=NA, R=NA)

  # clusters
  testthat::expect_true(is.list(res))
  clust_feats <- integer()
  for(i in 1:length(res)){
    clust_feats <- c(clust_feats, res[[i]])
  }
  testthat::expect_equal(length(clust_feats), length(unique(clust_feats)))
  testthat::expect_equal(length(clust_feats), length(intersect(clust_feats,
                                                               1:8)))

  testthat::expect_error(checkFormatClustersInput(list(1:4, 4:6), p=10,
                                                  clust_names=NA,
                                                  get_prototypes=FALSE, x=NA,
                                                  y=NA, R=NA),
                         "length(intersect(clusters[[i]], clusters[[j]])) == 0 is not TRUE",
                         fixed=TRUE)
  testthat::expect_error(checkFormatClustersInput(list(2:3, 2:3), p=10,
                                  clust_names=NA, get_prototypes=FALSE, x=NA,
                                  y=NA, R=NA),
                         "length(clusters) == length(unique(clusters)) is not TRUE",
                         fixed=TRUE)
  
  testthat::expect_error(checkFormatClustersInput(list(2:3, as.integer(NA)),
                                                  p=10,
                                                  clust_names=NA,
                                                  get_prototypes=FALSE, x=NA,
                                                  y=NA, R=NA),
                         "Must specify one of clusters or R (or does one of these provided inputs contain NA?)",
                         fixed=TRUE)
  
  testthat::expect_error(checkFormatClustersInput(list(2:3, c(4, 4, 5)),
                                                  p=10,
                                                  clust_names=NA,
                                                  get_prototypes=FALSE, x=NA,
                                                  y=NA, R=NA),
                         "length(clusters[[i]]) == length(unique(clusters[[i]])) is not TRUE",
                         fixed=TRUE)
  
   testthat::expect_error(checkFormatClustersInput(list(1:4, -1),
                                                  p=10,
                                                  clust_names=NA,
                                                  get_prototypes=FALSE, x=NA,
                                                  y=NA, R=NA),
                         "all(clusters[[i]] >= 1) is not TRUE",
                         fixed=TRUE)
   
   testthat::expect_error(checkFormatClustersInput(list(1:4, c(2.3, 1.2)),
                                                  p=10,
                                                  clust_names=NA,
                                                  get_prototypes=FALSE, x=NA,
                                                  y=NA, R=NA),
                         "is.integer(clusters[[i]]) is not TRUE",
                         fixed=TRUE)

  # Single cluster
   testthat::expect_true(is.list(checkFormatClustersInput(c(1:5), p=10,
                                                          clust_names=NA,
                                                          get_prototypes=FALSE,
                                                          x=NA, y=NA, R=NA)))
})
## Test passed 😀

checkClusters():

#' Helper function to confirm that the argument clusters to several functions is
#' as expected
#'
#' @param clusters A named list where each entry is an integer vector of indices
#' of features that are in a common cluster, as in the output of css or
#' formatClusters. (The length of list clusters is equal to the number of
#' clusters.) All identified clusters must be non-overlapping, and all features
#' must appear in exactly one cluster (any unclustered features should be in
#' their own "cluster" of size 1).
#' @param p The number of features; must be at least as large as the number of
#' clusters.
#' @author Gregory Faletto, Jacob Bien
checkClusters <- function(clusters, p){
    stopifnot(is.list(clusters))
    stopifnot(all(lengths(clusters) >= 1))
    stopifnot(all(!is.na(clusters)))

    n_clusters <- length(clusters)

    stopifnot(n_clusters == length(unique(clusters)))
    stopifnot(n_clusters <= p)
    stopifnot(!is.null(names(clusters)))
    stopifnot(is.character(names(clusters)))
    stopifnot(all(!is.na(names(clusters)) & names(clusters) != ""))
    stopifnot(length(unique(names(clusters))) == n_clusters)

    all_clustered_feats <- integer()
    for(i in 1:n_clusters){
        stopifnot(is.integer(clusters[[i]]))
        all_clustered_feats <- c(all_clustered_feats, clusters[[i]])
    }

    stopifnot(length(all_clustered_feats) == p)
    stopifnot(length(unique(all_clustered_feats)) == p)
    stopifnot(all(all_clustered_feats <= p))
    stopifnot(all(all_clustered_feats >= 1))
}

Tests for checkClusters():

testthat::test_that("checkClusters works", {
  good_clusters <- list(c1=1L:5L, c2=6L:8L, c3=9L)
  
  testthat::expect_null(checkClusters(good_clusters, 9))
  testthat::expect_error(checkClusters(good_clusters, 10),
                         "length(all_clustered_feats) == p is not TRUE",
                         fixed=TRUE)
  testthat::expect_error(checkClusters(1L:10L, 10),
                         "is.list(clusters) is not TRUE",
                         fixed=TRUE)
  testthat::expect_error(checkClusters(list(c1=1L:5L, c2=6L:8L, c3=9L,
                                            c4=integer()), 9),
                         "all(lengths(clusters) >= 1) is not TRUE",
                         fixed=TRUE)
  testthat::expect_error(checkClusters(list(c1=1L:5L, c2=6L:8L, c3=9L,
                                            c4=as.integer(NA)), 9),
                         "all(!is.na(clusters)) is not TRUE",
                         fixed=TRUE)
  testthat::expect_error(checkClusters(list(c1=1L:5L, c2=6L:8L, c3=9L,
                                            c2=6L:8L), 9),
                         "n_clusters == length(unique(clusters)) is not TRUE",
                         fixed=TRUE)
  testthat::expect_error(checkClusters(list(c1=1L:5L, c2=6L:8L, c3=10L), 9),
                         "all(all_clustered_feats <= p) is not TRUE",
                         fixed=TRUE)
})
## Test passed 😸

getPrototypes():

#' Estimate prototypes from a list of clusters
#'
#' Takes in list of clusters, x, and y and returns an integer vector (of length
#' equal to the number of clusters) of the indices of the feature prototypes
#' (the features from each cluster most correlated with the response).
#'
#' @param clusters A list where each entry is an integer vector of indices of
#' features that are in a common cluster. (The length of list clusters must be
#' equal to the number of clusters.) All identified clusters must be
#' non-overlapping. Must only include clusters of size 2 or larger.
#' @param x n x p numeric matrix; design matrix.
#' @param y Numeric response vector. Note: in general, the css function does not
#' require y to be a numeric vector, because the provided fitfun could use a
#' different form of y (for example, a categorical response variable). However,
#' y must be numeric in order to provide prototypes because the prototypes are
#' determined using the correlation between cluster members (columns of x) and
#' y.
#' @return An integer vector of the same length as clusters. Entry j is the
#' index of the feature identified as the prototype for cluster j.
#' @author Gregory Faletto, Jacob Bien
getPrototypes <- function(clusters, x, y){
    # Check inputs

    stopifnot(!is.list(clusters) | all(lengths(clusters) >= 1))
    stopifnot(is.list(clusters) | length(clusters) >= 1)

    stopifnot(all(!is.na(x)))
    stopifnot(is.matrix(x))

    n <- nrow(x)
    p <- ncol(x)

    checkY(y, n)

    # Identify prototypes
    if(length(clusters) > 0){
        if(is.list(clusters)){
            prototypes <- rep(as.integer(NA), length(clusters))
            for(i in 1:length(clusters)){
                prototypes[i] <- identifyPrototype(clusters[[i]], x, y)
            }
        } else{
            # If clusters is not a list, then there is only one cluster.
            prototypes <- identifyPrototype(clusters, x, y)
        }
    }  else{
        prototypes <- integer()
    }

    # Check output

    stopifnot(is.integer(prototypes))
    if(length(clusters) > 0){
        if(is.list(clusters)){
            stopifnot(length(prototypes) == length(clusters))
        } else {
            stopifnot(length(prototypes) == 1)
        }
    } else{
        stopifnot(length(prototypes) == 0)
    }

    stopifnot(all(!is.na(prototypes)))
    stopifnot(length(prototypes) == length(unique(prototypes)))

    return(prototypes)
}

identifyPrototype():

#' Estimate prototypes from a single cluster
#'
#' Takes in a single cluster, x, and y and returns an integer of the index of
#' the feature prototype (the feature from the cluster most correlated with the
#' response).
#'
#' @param cluster_members_i An integer vector of indices of features that are in
#' a common cluster. Must have length at least 2.
#' @param x n x p numeric matrix; design matrix.
#' @param y Numeric response vector. Note: in general, the css function does not
#' require y to be a numeric vector, because the provided fitfun could use a
#' different form of y (for example, a categorical response variable). However,
#' y must be numeric in order to provide prototypes because the prototypes are
#' determined using the correlation between cluster members (columns of x) and
#' y.
#' @return integer; the index of the feature identified as the prototype for
#' the cluster.
#' @author Gregory Faletto, Jacob Bien
identifyPrototype <- function(cluster_members_i, x, y){
    # Check input
    stopifnot(is.integer(cluster_members_i))
    # If cluster only has one member, that member is the prototype
    if(length(cluster_members_i) == 1){
        return(cluster_members_i)
    }

    # Choose which cluster member to represent cluster for stability
    # metric purposes by choosing the one most highly correlated
    # with y

    cors_i <- apply(x[, cluster_members_i], 2, corFunction, y=y)
    max_index_i <- which.max(cors_i)[1]

    stopifnot(length(max_index_i) == 1)
    stopifnot(max_index_i %in% 1:length(cluster_members_i))

    ret <- cluster_members_i[max_index_i]

    # Check output

    stopifnot(is.integer(ret))
    stopifnot(length(ret) == 1)
    stopifnot(ret %in% cluster_members_i)
    stopifnot(identical(x[, cluster_members_i][, max_index_i],
        x[, ret]))

    return(ret)
}

corFunction():

#' Absolute value of sample correlation between two vectors
#'
#' Calculates the absolute value of correlation of t and y. If either input has
#' only one unique value, returns 0 by definition.
#' @param t A numeric or integer vector.
#' @param y A numeric or integer vector; must have the same length as t.
#' @return A numeric vector of the same length as cluster_i containing the
#' weights corresponding to each of the features in cluster_i. The weights
#' will all be nonnegative and sum to 1.
#' @author Gregory Faletto, Jacob Bien
corFunction <- function(t, y){
    # Check inputs
    stopifnot(is.numeric(t) | is.integer(t))
    stopifnot(is.numeric(y) | is.integer(y))
    stopifnot(length(t) == length(y))
    if(length(unique(t)) == 1){
        return(0)
    }
    if(length(unique(y)) == 1){
        warning("The second argument to corFunction only had one unique entry")
        return(0)
    }
    return(abs(stats::cor(t, y)))
}

Tests for corFunction()

testthat::test_that("corFunction works", {
  testthat::expect_identical(corFunction(rep(1, 10), 1:10), 0)
  testthat::expect_identical(corFunction(rep(1.2, 5), 1:5), 0)
  
  set.seed(23451)
  
  x <- stats::rnorm(8)
  y <- stats::rnorm(8)
  testthat::expect_identical(corFunction(x, y), abs(stats::cor(x, y)))
  testthat::expect_warning(corFunction(1:5, rep(1.2, 5)),
                           "The second argument to corFunction only had one unique entry",
                           fixed=TRUE)
  testthat::expect_error(corFunction("1", "2"),
                         "is.numeric(t) | is.integer(t) is not TRUE",
                         fixed=TRUE)
  testthat::expect_error(corFunction(3:8, "2"),
                         "is.numeric(y) | is.integer(y) is not TRUE",
                         fixed=TRUE)
  testthat::expect_error(corFunction(3:8, 1:2),
                         "length(t) == length(y) is not TRUE",
                         fixed=TRUE)
})
## Test passed 😸

Tests for identifyPrototype() (requires corFunction()):

testthat::test_that("identifyPrototype works", {
  testthat::expect_identical(identifyPrototype(10L, "a", 5), 10L)
  n <- 10
  p <- 5
  
  set.seed(9834)
  
  X <- matrix(stats::rnorm(n*p), nrow=n, ncol=p)
  y <- X[, p]
  testthat::expect_equal(identifyPrototype(as.integer(p), X, y), p)
  testthat::expect_equal(identifyPrototype(2L, X, y), 2)
  testthat::expect_equal(identifyPrototype(as.integer(2:p), X, y), p)
  testthat::expect_error(identifyPrototype(as.integer(2:p), y, X),
                         "incorrect number of dimensions",
                         fixed=TRUE)
  
  y2 <- rnorm(n)

  res <- identifyPrototype(c(2L, 3L), X, y2)

  testthat::expect_true(is.integer(res))

  testthat::expect_equal(length(res), 1)

  testthat::expect_true(res %in% c(2L, 3L))

})
## Test passed 🎊

Tests for getPrototypes()

testthat::test_that("getPrototypes works", {
  n <- 10
  p <- 5
  
  set.seed(902689)
  
  X <- matrix(stats::rnorm(n*p), nrow=n, ncol=p)
  y <- X[, p]

  testthat::expect_identical(getPrototypes(list(1L, 2L, 3L, 4L, 5L), X, y), 1:5)

  testthat::expect_identical(getPrototypes(list(1L:5L), X, y), 5L)

  testthat::expect_identical(getPrototypes(list(1L, 2L:5L), X, y), c(1L, 5L))

  testthat::expect_identical(getPrototypes(list(3L:5L), X, y), 5L)

  y2 <- rnorm(n)

  res <- getPrototypes(list(1L, c(2L, 3L), c(4L, 5L)), X, y2)

  testthat::expect_true(is.integer(res))

  testthat::expect_equal(length(res), 3)

  testthat::expect_identical(res[1], 1L)

  testthat::expect_true(res[2] %in% c(2L, 3L))

  testthat::expect_true(res[3] %in% c(4L, 5L))

  testthat::expect_error(getPrototypes(list(1L, 2L, 3L, 4L, 5L), y, X),
                          "is.matrix(x) is not TRUE",
                          fixed=TRUE)

  testthat::expect_error(getPrototypes(list(1L, 2L, 3L, 4L, 5L), X, y[1:9]),
                         "n == length(y) is not TRUE",
                         fixed=TRUE)

})
## Test passed 🎊

Finally, tests for formatClusters():

testthat::test_that("formatClusters works", {
  
  # Intentionally don't provide clusters for all feature, mix up formatting,
  # etc.
  good_clusters <- list(red_cluster=1:3, 5:8)
  
  res <- formatClusters(good_clusters, p=10)
  
  testthat::expect_true(is.list(res))
  testthat::expect_identical(names(res), c("clusters", "multiple"))
  
  # Clusters
  testthat::expect_true(is.list(res$clusters))
  testthat::expect_equal(length(res$clusters), 5)
  testthat::expect_equal(5, length(names(res$clusters)))
  testthat::expect_equal(5, length(unique(names(res$clusters))))
  testthat::expect_true("red_cluster" %in% names(res$clusters))
  testthat::expect_true(all(!is.na(names(res$clusters))))
  testthat::expect_true(all(!is.null(names(res$clusters))))
  testthat::expect_true(all(names(res$clusters) != ""))

  clust_feats <- integer()
  true_list <- list(1:3, 5:8, 4, 9, 10)
  for(i in 1:length(res$clusters)){
    testthat::expect_true(is.integer(res$clusters[[i]]))
    testthat::expect_equal(length(intersect(clust_feats, res$clusters[[i]])), 0)
    testthat::expect_true(all(res$clusters[[i]] %in% 1:10))
    testthat::expect_equal(length(res$clusters[[i]]),
                           length(unique(res$clusters[[i]])))
    testthat::expect_true(all(res$clusters[[i]] == true_list[[i]]))
    clust_feats <- c(clust_feats, res$clusters[[i]])
  }

  testthat::expect_equal(length(clust_feats), 10)
  testthat::expect_equal(10, length(unique(clust_feats)))
  testthat::expect_equal(10, length(intersect(clust_feats, 1:10)))
  
  # Multiple
  testthat::expect_true(res$multiple)
  testthat::expect_false(formatClusters(3:5, p=10)$multiple)

  ## Trying other inputs

  testthat::expect_error(formatClusters(list(3:7, 7:10), p=15),
                         "length(intersect(clusters[[i]], clusters[[j]])) == 0 is not TRUE",
                         fixed=TRUE)
  testthat::expect_error(formatClusters(list(5:8, 5:8), p=9),
                         "length(clusters) == length(unique(clusters)) is not TRUE",
                         fixed=TRUE)

  testthat::expect_error(formatClusters(list(5:8), p=7),
                         "length(all_clustered_feats) == p is not TRUE",
                         fixed=TRUE)

  testthat::expect_error(formatClusters(list(2:3, as.integer(NA)), p=10),
                         "Must specify one of clusters or R (or does one of these provided inputs contain NA?)",
                         fixed=TRUE)

  testthat::expect_error(formatClusters(list(2:3, c(4, 4, 5)), p=8),
                         "length(clusters[[i]]) == length(unique(clusters[[i]])) is not TRUE",
                         fixed=TRUE)

   testthat::expect_error(formatClusters(list(1:4, -1), p=10),
                         "all(clusters[[i]] >= 1) is not TRUE",
                         fixed=TRUE)

   testthat::expect_error(formatClusters(list(1:4, c(2.3, 1.2))),
                         "is.integer(clusters[[i]]) is not TRUE",
                         fixed=TRUE)
   
   ### Test prototypes feature
   
    n <- 8
    p <- 6
    
    set.seed(690289)
    
    X <- matrix(stats::rnorm(n*p), nrow=n, ncol=p)
    y <- X[, p]
    
    res <- formatClusters(clusters=list(), p=p, get_prototypes=TRUE, x=X, y=y)
    
    testthat::expect_true(is.list(res))
    testthat::expect_identical(names(res), c("clusters", "multiple",
                                             "prototypes"))
    testthat::expect_true(is.integer(res$prototypes))
    testthat::expect_identical(res$prototypes, 1:p)

    testthat::expect_equal(formatClusters(clusters=1:p, p=p,
                                              get_prototypes=TRUE, x=X,
                                              y=y)$prototypes, p)
    
    testthat::expect_identical(formatClusters(clusters=list(1L, 2L:p), p=p,
                                              get_prototypes=TRUE, x=X,
                                              y=y)$prototypes,
                               as.integer(c(1, p)))
    
    testthat::expect_identical(formatClusters(clusters=3L:p, p=p,
                                              get_prototypes=TRUE, x=X,
                                              y=y)$prototypes,
                               as.integer(c(p, 1, 2)))
    
    y2 <- rnorm(n)

    res <- formatClusters(clusters=list(2:3, 4:5), p=p, get_prototypes=TRUE,
                          x=X, y=y2)$prototypes

    testthat::expect_true(is.integer(res))

    testthat::expect_equal(length(res), 4)

    testthat::expect_true(res[1] %in% c(2L, 3L))

    testthat::expect_true(res[2] %in% c(4L, 5L))
    
    testthat::expect_equal(res[3], 1L)
    
    testthat::expect_equal(res[4], p)

    testthat::expect_error(formatClusters(clusters=list(2:3, 4:5), p=p,
                                          get_prototypes=TRUE, x=y2, y=X),
                           "is.matrix(x) is not TRUE", fixed=TRUE)

    testthat::expect_error(formatClusters(clusters=list(2:3, 4:5), p=p,
                                          get_prototypes=TRUE, x=X,
                                          y=y2[1:(n-1)]),
                           "n == length(y) is not TRUE", fixed=TRUE)
})
## Test passed 😀

checkSamplingType():

#' Helper function to confirm that the argument sampling_type to several 
#' functions is as expected
#'
#' @param sampling_type A character vector; either "SS" or "MB". "MB" is not
#' supported yet.
#' @author Gregory Faletto, Jacob Bien
checkSamplingType <- function(sampling_type){
    stopifnot(is.character(sampling_type))
    stopifnot(length(sampling_type) == 1)
    stopifnot(!is.na(sampling_type))
    stopifnot(sampling_type %in% c("SS", "MB"))
    if(sampling_type == "MB"){
        stop("sampling_type MB is not yet supported (and isn't recommended anyway)")
    }
}

Tests for checkSamplingType():

testthat::test_that("checkSamplingType works", {
  testthat::expect_null(checkSamplingType("SS"))
  testthat::expect_error(checkSamplingType("MB"),
                         "sampling_type MB is not yet supported (and isn't recommended anyway)",
                         fixed=TRUE)
  testthat::expect_error(checkSamplingType(c("SS", "SS")),
                         "length(sampling_type) == 1 is not TRUE", fixed=TRUE)
  testthat::expect_error(checkSamplingType(1),
                         "is.character(sampling_type) is not TRUE", fixed=TRUE)
  testthat::expect_error(checkSamplingType(as.character(NA)),
                         "!is.na(sampling_type) is not TRUE", fixed=TRUE)
})
## Test passed 🥇

checkB():

#' Helper function to confirm that the argument B to several functions is as
#' expected
#'
#' @param B Integer or numeric; the number of subsamples. Note: For
#' sampling.type=="MB" the total number of subsamples will be `B`; for
#' sampling_type="SS" the number of subsamples will be `2*B`.
#' @author Gregory Faletto, Jacob Bien
checkB <- function(B){
    stopifnot(length(B) == 1)
    stopifnot(is.numeric(B) | is.integer(B))
    stopifnot(!is.na(B))
    stopifnot(B == round(B))
    stopifnot(B > 0)
    if(B < 10){
        warning("Small values of B may lead to poor results.")
    } else if (B > 2000){
        warning("Large values of B may require long computation times.")
    }
}

Tests for checkB():

testthat::test_that("checkB works", {
  testthat::expect_null(checkB(1500))
  testthat::expect_null(checkB(15))
  testthat::expect_error(checkB("B"),
                         "is.numeric(B) | is.integer(B) is not TRUE",
                         fixed=TRUE)
  testthat::expect_error(checkB(20:25), "length(B) == 1 is not TRUE",
                         fixed=TRUE)
  testthat::expect_error(checkB(as.integer(NA)), "!is.na(B) is not TRUE",
                         fixed=TRUE)
  testthat::expect_error(checkB(1.2), "B == round(B) is not TRUE",
                         fixed=TRUE)
  testthat::expect_error(checkB(-100), "B > 0 is not TRUE",
                         fixed=TRUE)
  testthat::expect_warning(checkB(5),
                           "Small values of B may lead to poor results.",
                           fixed=TRUE)
  testthat::expect_warning(checkB(2200),
                           "Large values of B may require long computation times.",
                           fixed=TRUE)
})
## Test passed 🎉

checkPropFeatsRemove():

#' Helper function to confirm that the argument prop_feats_remove to several 
#' functions is as expected
#'
#' @param prop_feats_remove Numeric; proportion of features that are dropped on
#' each subsample. Must be between 0 and 1.
#' @param p The number of features; must be greater than 2 if prop_feats_remove
#' is greater than 0.
#' @author Gregory Faletto, Jacob Bien
checkPropFeatsRemove <- function(prop_feats_remove, p){
    stopifnot(length(prop_feats_remove) == 1)
    stopifnot(is.numeric(prop_feats_remove) | is.integer(prop_feats_remove))
    stopifnot(!is.na(prop_feats_remove))
    stopifnot(prop_feats_remove >= 0 & prop_feats_remove < 1)
    if(prop_feats_remove > 0){
        # Make sure p is at least 2 or else this doesn't make sense
        stopifnot(p >= 2)
    }
}

Tests for checkPropFeatsRemove():

testthat::test_that("checkPropFeatsRemove works", {
  testthat::expect_null(checkPropFeatsRemove(0, 5))
  testthat::expect_null(checkPropFeatsRemove(.3, 10))
  testthat::expect_error(checkPropFeatsRemove(1, 3),
                         "prop_feats_remove >= 0 & prop_feats_remove < 1 is not TRUE",
                         fixed=TRUE)
  testthat::expect_error(checkPropFeatsRemove(c(.5, .6), 17),
                         "length(prop_feats_remove) == 1 is not TRUE",
                         fixed=TRUE)
  testthat::expect_error(checkPropFeatsRemove(".3", 99),
                         "is.numeric(prop_feats_remove) | is.integer(prop_feats_remove) is not TRUE",
                         fixed=TRUE)
  testthat::expect_error(checkPropFeatsRemove(as.numeric(NA), 172),
                         "!is.na(prop_feats_remove) is not TRUE",
                         fixed=TRUE)
  testthat::expect_error(checkPropFeatsRemove(.1, 1),
                         "p >= 2 is not TRUE",
                         fixed=TRUE)
})
## Test passed 🎉

Finally, tests for checkCssInputs():

testthat::test_that("checkCssInputs works", {
  set.seed(80526)
  
  x <- matrix(stats::rnorm(15*11), nrow=15, ncol=11)
  y <- stats::rnorm(15)
  
  # Intentionally don't provide clusters for all feature, mix up formatting,
  # etc.
  good_clusters <- list(red_cluster=1L:5L,
                        green_cluster=6L:8L
                        # , c4=10:11
                        )
  
  res <- checkCssInputs(X=x, y=y, lambda=0.01, clusters=good_clusters,
                        fitfun = cssLasso, sampling_type = "SS", B = 13,
                        prop_feats_remove = 0, train_inds = integer(),
                        num_cores = 1L)
  
  # Basic output
  testthat::expect_true(is.list(res))
  testthat::expect_identical(names(res), c("feat_names", "X", "clusters"))
  
  # feat_names
  testthat::expect_true(is.character(res$feat_names))
  testthat::expect_true(is.na(res$feat_names))
  testthat::expect_equal(length(res$feat_names), 1)

  # X
  testthat::expect_true(is.matrix(res$X))
  testthat::expect_true(all(!is.na(res$X)))
  testthat::expect_true(is.numeric(res$X))
  testthat::expect_equal(ncol(res$X), 11)
  testthat::expect_equal(nrow(res$X), 15)

  # clusters
  testthat::expect_true(is.list(res$clusters))
  testthat::expect_equal(length(res$clusters), length(names(res$clusters)))
  testthat::expect_equal(length(res$clusters),
                         length(unique(names(res$clusters))))
  testthat::expect_true(all(!is.na(names(res$clusters))))
  testthat::expect_true(all(!is.null(names(res$clusters))))
  clust_feats <- integer()
  for(i in 1:length(res$clusters)){
    clust_feats <- c(clust_feats, res$clusters[[i]])
  }
  testthat::expect_equal(length(clust_feats), length(unique(clust_feats)))
  testthat::expect_equal(length(clust_feats), length(intersect(clust_feats,
                                                               1:11)))

  ## Trying other inputs

  # Custom fitfun with nonsense lambda (which will be ignored by fitfun, and
  # shouldn't throw any error, because the acceptable input for lambda should be
  # enforced only by fitfun)

  testFitfun <- function(X, y, lambda){
    p <- ncol(X)
    stopifnot(p >= 2)
    # Choose p/2 features randomly
    selected <- sample.int(p, size=floor(p/2))
    return(selected)
  }

  res_fitfun <- checkCssInputs(X=x, y=y, lambda=x, clusters=1:3,
                               fitfun = testFitfun, sampling_type = "SS",
                               B = 13, prop_feats_remove = 0,
                               train_inds = integer(), num_cores = 1L)
  testthat::expect_true(is.list(res_fitfun))

  # Single cluster
  res_sing_clust <- checkCssInputs(X=x, y=y,
                                   lambda=c("foo", as.character(NA), "bar"),
                                   clusters=1:3, fitfun = testFitfun,
                                   sampling_type = "SS", B = 13,
                                   prop_feats_remove = 0,
                                   train_inds = integer(), num_cores = 1L)
  testthat::expect_true(is.list(res_sing_clust))
  testthat::expect_equal(length(res_sing_clust$clusters), 11 - 3 + 1)
  testthat::expect_true(length(unique(names(res_sing_clust$clusters))) == 11 -
                          3 + 1)
  testthat::expect_true(all(!is.na(names(res_sing_clust$clusters))))
  testthat::expect_true(all(!is.null(names(res_sing_clust$clusters))))

  # Other sampling types
  testthat::expect_error(checkCssInputs(X=x, y=y, lambda=c("foo",
                                                           as.character(NA),
                                                           "bar"), clusters=1:3,
                                        fitfun = testFitfun,
                                        sampling_type = "MB", B = 13,
                                        prop_feats_remove = 0,
                                        train_inds = integer(), num_cores = 1L),
                         "sampling_type MB is not yet supported (and isn't recommended anyway)",
                         fixed=TRUE)

  # Error has quotation marks in it
  testthat::expect_error(checkCssInputs(X=x, y=y, lambda=c("foo",
                                                           as.character(NA),
                                                           "bar"), clusters=1:3,
                                        fitfun = testFitfun,
                                        sampling_type = "S", B = 13,
                                        prop_feats_remove = 0,
                                        train_inds = integer(), num_cores = 1L))
  
  testthat::expect_error(checkCssInputs(X=x, y=y, lambda=c("foo", "bar",
                                                           as.character(NA)),
                                        clusters=1:3, fitfun = testFitfun,
                                        sampling_type = 2, B = 13,
                                        prop_feats_remove = 0,
                                        train_inds = integer(), num_cores = 1L),
                         "is.character(sampling_type) is not TRUE",
                         fixed=TRUE)

  # B
  testthat::expect_warning(checkCssInputs(X=x, y=y, lambda=c("foo", "bar",
                                                           as.character(NA)),
                                        clusters=1:3, fitfun = testFitfun,
                                        sampling_type = "SS", B = 5,
                                        prop_feats_remove = 0,
                                        train_inds = integer(), num_cores = 1L),
                           "Small values of B may lead to poor results.",
                           fixed=TRUE)

  testthat::expect_error(checkCssInputs(X=x, y=y, lambda=c("foo", "bar",
                                                           as.character(NA)),
                                        clusters=1:3, fitfun = testFitfun,
                                        sampling_type = "SS", B = "foo",
                                        prop_feats_remove = 0,
                                        train_inds = integer(), num_cores = 1L),
                           "is.numeric(B) | is.integer(B) is not TRUE",
                           fixed=TRUE)

  # prop_feats_remove
  testthat::expect_true(is.list(checkCssInputs(X=x, y=y,
                                               lambda=c("foo", "bar",
                                                        as.character(NA)),
                                               clusters=1:3, fitfun=testFitfun,
                                               sampling_type = "SS", B = 12,
                                               prop_feats_remove = 0.3,
                                               train_inds = integer(),
                                               num_cores = 1L)))

  # Use train_inds argument
  testthat::expect_true(is.list(checkCssInputs(X=x, y=y,
                                               lambda=c("foo", "bar",
                                                        as.character(NA)),
                                               clusters=1:3, fitfun=testFitfun,
                                               sampling_type = "SS", B = 12,
                                               prop_feats_remove = 0.3,
                                               train_inds = 11:15,
                                               num_cores = 1L)))

})
## Test passed 🥳

cssLoop():

#' Helper function run on each subsample
#' 
#' Runs provided feature selection method `fitfun` on each subsample for cluster
#' stability selection (this function is called within `mclapply`).
#' @param input Could be one of two things: \item{subsample}{An integer vector
#' of size `n/2` containing the indices of the observations in the subsample.}
#' \item{drop_var_input}{A named list containing two elements: one named
#' "subsample" and the same as the previous description, and a logical vector
#' named "feats_to_keep" containing the indices of the features to be
#' automatically selected.} (The first object is the output of the function
#' `createSubsamples()` when the provided `prop_feats_remove` is 0, the default, and
#' the second object is the output of `createSubsamples()` when `prop_feats_remove >
#' 0`.)
#' @param x an n x p numeric matrix containing the predictors. (This should be
#' the full design matrix provided to css.)
#' @param y A response; can be any response that takes the form of a length n
#' vector and is used (or not used) by `fitfun`. Typically (and for default
#' `fitfun = cssLasso`), `y` should be an n-dimensional numeric vector containing the
#' response. This should be the full response provided to css.
#' @param lambda A tuning parameter or set of tuning parameters that may be used
#' by the feature selection method. For example, in the default case when
#' `fitfun = cssLasso`, `lambda` is a numeric: the penalty to use for each lasso
#' fit.
#' @param fitfun A function that takes in arguments X, y, and lambda and returns
#' a vector of indices of the columns of X (selected features).
#' @return An integer vector; the indices of the features selected by `fitfun`.
#' @author Gregory Faletto, Jacob Bien
cssLoop <- function(input, x, y, lambda, fitfun){
    # Check inputs
    stopifnot(is.matrix(x))
    stopifnot(all(!is.na(x)))

    colnames(x) <- character()
    n <- nrow(x)
    p <- ncol(x)

    stopifnot(length(y) == n)
    stopifnot(!is.matrix(y))
    # Intentionally don't check y or lambda further to allow for flexibility--these
    # inputs should be checked within fitfun.

    if(!is.list(input)){
        subsample <- input
        feats_to_keep <- rep(TRUE, p)
    } else{
        stopifnot(all(names(input) == c("subsample", "feats_to_keep")))
        subsample <- input$subsample
        feats_to_keep <- input$feats_to_keep
    }

    stopifnot(is.integer(subsample))
    stopifnot(all(subsample == round(subsample)))
    stopifnot(floor(n/2) == length(subsample))
    stopifnot(length(subsample) == length(unique(subsample)))

    stopifnot(is.logical(feats_to_keep))
    stopifnot(length(feats_to_keep) == p)

    selected <- do.call(fitfun, list(X=x[subsample, feats_to_keep],
        y=y[subsample], lambda=lambda))

    selected <- which(feats_to_keep)[selected]

    # Check output
    checkCssLoopOutput(selected, p, as.integer(which(feats_to_keep)))

    return(as.integer(selected))
}

checkCssLoopOutput():

#' Helper function to confirm that the outputs of the provided feature selection
#' method are as required. 
#'
#' @param selected An integer vector; the indices of the features selected by
#' the lasso.
#' @param p The total number of observed features; all selected features must be
#' in 1:p.
#' @param feats_on_subsamp Integer; the indices of the features considered by
#' the feature selection method. All selected features must be among these
#' features.
#' @author Gregory Faletto, Jacob Bien
checkCssLoopOutput <- function(selected, p, feats_on_subsamp){
    if(!exists("selected")){
        stop("The provided feature selection method fitfun failed to return anything on (at least) one subsample")
    }
    if(!is.integer(selected) & !is.numeric(selected)){
        stop("The provided feature selection method fitfun failed to return an integer or numeric vector on (at least) one subsample")
    }
    if(any(is.na(selected))){
        stop("The provided feature selection method fitfun returned a vector containing NA values on (at least) one subsample")
    }
    if(!all(selected == round(selected))){
        stop("The provided feature selection method fitfun failed to return a vector of valid (integer) indices on (at least) one subsample")
    }
    if(length(selected) != length(unique(selected))){
        stop("The provided feature selection method fitfun returned a vector of selected features containing repeated indices on (at least) one subsample")
    }
    if(length(selected) > p){
        stop("The provided feature selection method fitfun returned a vector of selected features longer than p on (at least) one subsample")
    }
    if(length(selected) > 0){
        if(max(selected) > p){
            stop("The provided feature selection method fitfun returned a vector of selected features containing an index greater than ncol(X) on (at least) one subsample")
        }
        if(min(selected) <= 0){
            stop("The provided feature selection method fitfun returned a vector of selected features containing a non-positive index on (at least) one subsample")
        }
    }
    if("try-error" %in% class(selected) |
        "error" %in% class(selected) | "simpleError" %in% class(selected) |
        "condition" %in% class(selected)){
        stop("The provided feature selection method fitfun returned an error on (at least) one subsample")
    }
    if(!all(selected %in% feats_on_subsamp)){
        stop("The provided feature selection method somehow selected features that were not provided for it to consider.")
    }
}

Tests for checkCssLoopOutput():

testthat::test_that("checkCssLoopOutput works", {
  testthat::expect_null(checkCssLoopOutput(selected=1:5, p=6,
                                           feats_on_subsamp=1:6))
  
  testthat::expect_error(checkCssLoopOutput(selected=1:5, p=4,
                                            feats_on_subsamp=1:6),
                         "The provided feature selection method fitfun returned a vector of selected features longer than p on (at least) one subsample",
                         fixed=TRUE)
  
  testthat::expect_error(checkCssLoopOutput(selected=1:5, p=7,
                                            feats_on_subsamp=1:4),
                         "The provided feature selection method somehow selected features that were not provided for it to consider.",
                         fixed=TRUE)
  
  testthat::expect_error(checkCssLoopOutput(selected=c(1, 2, 3, 4.4, 5), p=7,
                                            feats_on_subsamp=1:7),
                         "The provided feature selection method fitfun failed to return a vector of valid (integer) indices on (at least) one subsample",
                         fixed=TRUE)
  
  testthat::expect_error(checkCssLoopOutput(selected=rep(1, 3), p=7,
                                            feats_on_subsamp=1:7),
                         "The provided feature selection method fitfun returned a vector of selected features containing repeated indices on (at least) one subsample",
                         fixed=TRUE)
  
  testthat::expect_error(checkCssLoopOutput(selected=c(-1, 5), p=7,
                                            feats_on_subsamp=1:7),
                         "The provided feature selection method fitfun returned a vector of selected features containing a non-positive index on (at least) one subsample",
                         fixed=TRUE)
  
  testthat::expect_error(checkCssLoopOutput(selected=c(0, 5), p=7,
                                            feats_on_subsamp=1:7),
                         "The provided feature selection method fitfun returned a vector of selected features containing a non-positive index on (at least) one subsample",
                         fixed=TRUE)
  
  testthat::expect_error(checkCssLoopOutput(selected=as.integer(NA), p=7,
                                            feats_on_subsamp=1:7),
                         "The provided feature selection method fitfun returned a vector containing NA values on (at least) one subsample",
                         fixed=TRUE)
  
  testthat::expect_error(checkCssLoopOutput(selected=c("1", "2", "3"), p=7,
                                            feats_on_subsamp=1:7),
                         "The provided feature selection method fitfun failed to return an integer or numeric vector on (at least) one subsample",
                         fixed=TRUE)

})
## Test passed 🌈

checkCssLassoInputs():

#' Helper function to confirm that the inputs to `cssLasso()` are as expected. 
#'
#' @param X A design matrix containing the predictors. (Note that we don't need
#' to check X very much, because X will have already been checked by the
#' function `checkCssInputs()` when it was provided to `css()`.)
#' @param y A numeric vector containing the response.
#' @param lambda Numeric; a nonnegative number for the lasso penalty to use
#' on each subsample. (For now, only one lambda value can be provided to
#' `cssLasso()`; in the future, we plan to allow for multiple lambda values to be
#' provided to `cssLasso()`, as described in Faletto and Bien 2022.)
#' @author Gregory Faletto, Jacob Bien
checkCssLassoInputs <- function(X, y, lambda){

    n <- nrow(X)
    p <- ncol(X)

    if(!is.numeric(y)){
        stop("For method cssLasso, y must be a numeric vector.")
    }
    if(is.matrix(y)){
        stop("For method cssLasso, y must be a numeric vector (inputted y was a matrix).")
    }
    if(n != length(y)){
        stop("For method cssLasso, y must be a vector of length equal to nrow(X).")
    }
    if(length(unique(y)) <= 1){
        stop("Subsample with only one unique value of y detected--for method cssLasso, all subsamples of y of size floor(n/2) must have more than one unique value.")
    }
    if(!is.numeric(lambda) & !is.integer(lambda)){
        stop("For method cssLasso, lambda must be a numeric.")
    }
    if(any(is.na(lambda))){
        stop("NA detected in provided lambda input to cssLasso")
    }
    if(length(lambda) != 1){
        stop("For method cssLasso, lambda must be a numeric of length 1.")
    }
    if(lambda < 0){
        stop("For method cssLasso, lambda must be nonnegative.")
    }
}

Tests for checkCssLassoInputs():

testthat::test_that("checkCssLassoInputs works", {
  set.seed(761)
  
  x <- matrix(stats::rnorm(15*4), nrow=15, ncol=4)
  y <- stats::rnorm(15)
  
  testthat::expect_null(checkCssLassoInputs(X=x, y=y, lambda=0.01))
  
  testthat::expect_error(checkCssLassoInputs(X=x, y=logical(15), lambda=0.05),
                         "For method cssLasso, y must be a numeric vector.",
                         fixed=TRUE)
  
  testthat::expect_error(checkCssLassoInputs(X=x[1:13, ], y=y, lambda=0.01),
                         "For method cssLasso, y must be a vector of length equal to nrow(X).",
                         fixed=TRUE)
  
  testthat::expect_error(checkCssLassoInputs(X=x, y=rep(1.2, 15), lambda=0.05),
                         "Subsample with only one unique value of y detected--for method cssLasso, all subsamples of y of size floor(n/2) must have more than one unique value.",
                         fixed=TRUE)
  
  testthat::expect_error(checkCssLassoInputs(X=x, y=y, lambda=TRUE),
                         "For method cssLasso, lambda must be a numeric.",
                         fixed=TRUE)
  
  testthat::expect_error(checkCssLassoInputs(X=x, y=y, lambda=as.numeric(NA)),
                         "NA detected in provided lambda input to cssLasso",
                         fixed=TRUE)
  
  testthat::expect_error(checkCssLassoInputs(X=x, y=y, lambda=-0.01),
                         "For method cssLasso, lambda must be nonnegative.",
                         fixed=TRUE)

  testthat::expect_error(checkCssLassoInputs(X=x, y=y, lambda=x),
                         "For method cssLasso, lambda must be a numeric of length 1.",
                         fixed=TRUE)
  
  testthat::expect_error(checkCssLassoInputs(X=x, y=y, lambda=numeric()),
                         "For method cssLasso, lambda must be a numeric of length 1.",
                         fixed=TRUE)
  
  testthat::expect_error(checkCssLassoInputs(X=x, y=y, lambda=-0.01),
                         "For method cssLasso, lambda must be nonnegative.",
                         fixed=TRUE)
  
})
## Test passed 🌈

Tests for cssLoop() (had to define checkCssLassoInputs() first):

testthat::test_that("cssLoop works", {
  set.seed(89134)

  x <- matrix(stats::rnorm(9*8), nrow=9, ncol=8)
  y <- stats::rnorm(9)

  output <- cssLoop(input=1L:4L, x=x, y=y, lambda=0.05,
                    fitfun=cssLasso)

  testthat::expect_true(is.integer(output))

  testthat::expect_equal(length(output), length(unique(output)))

  testthat::expect_true(length(output) <= 8)

  testthat::expect_true(all(output >= 1))

  testthat::expect_true(all(output <= 8))

  testthat::expect_error(cssLoop(input=1L:6L, x=x, y=y, lambda=0.05,
                                 fitfun=cssLasso),
                         "floor(n/2) == length(subsample) is not TRUE",
                         fixed=TRUE)

  testthat::expect_error(cssLoop(input=1L:4L, x=x, y=y[1:8],
                                 lambda=0.05, fitfun=cssLasso),
                         "length(y) == n is not TRUE",
                         fixed=TRUE)

  testthat::expect_error(cssLoop(input=1L:4L, x=x, y=logical(9),
                                 lambda=0.05, fitfun=cssLasso),
                         "For method cssLasso, y must be a numeric vector.",
                         fixed=TRUE)

  testthat::expect_error(cssLoop(input=1L:4L, x=x, y=y,
                                 lambda=x, fitfun=cssLasso),
                         "For method cssLasso, lambda must be a numeric of length 1.",
                         fixed=TRUE)

  # Test other input format

  alt_input <- list("subsample"=2:5, "feats_to_keep"=c(FALSE, rep(TRUE, 4),
                                                       rep(FALSE, 2), TRUE))

  output2 <- cssLoop(input=alt_input, x=x, y=y, lambda=0.08, fitfun=cssLasso)

  testthat::expect_true(is.integer(output2))

  testthat::expect_equal(length(output2), length(unique(output2)))

  testthat::expect_true(length(output2) <= 8)

  testthat::expect_true(all(output2 %in% c(2, 3, 4, 5, 8)))

  testthat::expect_error(cssLoop(input= list("subsample"=2:5,
                                             "feats_to_keep"=c(FALSE,
                                                               rep(TRUE, 4),
                                                               rep(FALSE, 2))),
                                 x=x, y=y, lambda=0.08, fitfun=cssLasso),
                         "length(feats_to_keep) == p is not TRUE",
                         fixed=TRUE)

  # Custom fitfun with nonsense lambda (which will be ignored by fitfun, and
  # shouldn't throw any error, because the acceptable input for lambda should be
  # enforced only by fitfun) and nonsense y

  testFitfun <- function(X, y, lambda){
    p <- ncol(X)
    stopifnot(p >= 2)
    # Choose p/2 features randomly
    selected <- sample.int(p, size=floor(p/2))
    return(selected)
  }

  testthat::expect_true(is.integer(cssLoop(input=1L:4L, x=x, y=y,
                                           lambda=TRUE, fitfun=testFitfun)))

  testthat::expect_true(is.integer(cssLoop(input=1L:4L, x=x,
                                           y=character(9), lambda=.05,
                                           fitfun=testFitfun)))

})
## Test passed 🎊

checkGetClusterSelMatrixInput():

#' Helper function to check inputs to getClusterSelMatrix function
#'
#' @param clusters A named list where each entry is an integer vector of indices
#' of features that are in a common cluster, as in the output of formatClusters.
#' (The length of list clusters is equal to the number of clusters.) All
#' identified clusters must be non-overlapping, and all features must appear in
#' exactly one cluster (any unclustered features should be in their own
#' "cluster" of size 1).
#' @param res A binary integer matrix. res[i, j] = 1 if feature j was selected
#' on subsample i and equals 0 otherwise, as in the output of getSelMatrix.
#' (That is, each row is a selected set.)
#' @return The parameter B, corresponding to half of the subsamples for 
#' sampling_type "SS".
#' @author Gregory Faletto, Jacob Bien
checkGetClusterSelMatrixInput <- function(clusters, res){
    stopifnot(is.matrix(res))
    stopifnot(all(res %in% c(0, 1)))
    p <- ncol(res)
    stopifnot(nrow(res) > 0)
    checkClusters(clusters, p)
}

Tests for checkGetClusterSelMatrixInput():

testthat::test_that("checkGetClusterSelMatrixInput works", {
  
  good_clusters <- list(happy=1L:8L, sad=9L:10L, med=11L)
  
  res <- matrix(sample(c(0, 1), size=6*11, replace=TRUE), nrow=6, ncol=11)
  
  testthat::expect_null(checkGetClusterSelMatrixInput(good_clusters, res))
  
  testthat::expect_error(checkGetClusterSelMatrixInput(list(happy=1L:8L,
                                                            med=11L), res),
                         "length(all_clustered_feats) == p is not TRUE",
                         fixed=TRUE)
  
  testthat::expect_error(checkGetClusterSelMatrixInput(good_clusters, 1:9),
                         "is.matrix(res) is not TRUE", fixed=TRUE)
  
  testthat::expect_error(checkGetClusterSelMatrixInput(good_clusters, res + .3),
                         "all(res %in% c(0, 1)) is not TRUE", fixed=TRUE)
  
  testthat::expect_error(checkGetClusterSelMatrixInput(good_clusters,
                                                       res[, 1:9]),
                         "length(all_clustered_feats) == p is not TRUE",
                         fixed=TRUE)
  
  testthat::expect_error(checkGetClusterSelMatrixInput(1L:10L, res),
                         "is.list(clusters) is not TRUE", fixed=TRUE)
  
  testthat::expect_error(checkGetClusterSelMatrixInput(list(c1=1L:5L,
                                                            c2=6L:8L,
                                                            c3=9L,
                                                            c4=integer()), res),
                         "all(lengths(clusters) >= 1) is not TRUE",
                         fixed=TRUE)
                         
  testthat::expect_error(checkGetClusterSelMatrixInput(list(c1=1L:5L,
                                            c2=6L:8L, c3=9L,
                                            c4=as.integer(NA)), res),
                         "all(!is.na(clusters)) is not TRUE",
                         fixed=TRUE)
  
  testthat::expect_error(checkGetClusterSelMatrixInput(list(c1=1L:5L,
                                            c2=6L:8L, c3=9L,
                                            c2=6L:8L), res),
                         "n_clusters == length(unique(clusters)) is not TRUE",
                         fixed=TRUE)
  
  testthat::expect_error(checkGetClusterSelMatrixInput(list(c1=1L:5L,
                                            c2=6L:8L, c3=14L), res),
                         "length(all_clustered_feats) == p is not TRUE",
                         fixed=TRUE)
  
})
## Test passed 🎉