5 Selection and prediction functions

Now that css() has been defined and tested, we write functions to work with the output of css(). The output of these functions will be of more direct interest for most end users than the output of css(). These functions are defined separately from css() because the most computationally intensive steps happen within css(). css() can be called only once on a data set, and then the functions that follow can be explored relatively quickly (one can try different parameters, etc.).

  • getCssSelections() takes in the results of css() along with user-defined parameters on how to select clusters (a minimum or maximum number of clusters to select, along with a cutoff for cluster selection proportions) and selects clusters as well as features from those clusters.
  • getCssDesign() takes in the same inputs as getCssSelections() along with an unlabeled test matrix of features X (containing the same features as the X matrix provided originally to css()). It uses the results from css() to select clusters like getCssSelections(), then it uses a user-selected weighting scheme to compute weighted averages of the cluster members. It returns a test matrix of cluster representatives, which can be used for downstream predictive tasks.
  • Finally, getCssPreds() has the same inputs as getCssPreds() FIX THIS TYPO, except it also accepts a set of labeled training data (where the response must be real-valued). getCssPreds() selects clusters, forms matrices of cluster representatives on the training and test data, uses the training matrix of cluster representatives (along with the vector of responses for the training data) to estimate a linear model via ordinary least squares, and finally generates predictions on the test data using this linear model.

As in the previous section, we first define each function and then define the helper functions called by that function. Tests are written for each function as soon as all of its dependencies have been defined.

getCssSelections():

#' Obtain a selected set of clusters and features
#'
#' Generate sets of selected clusters and features from cluster stability
#' selection.
#' @param css_results An object of class "cssr" (the output of the function
#' css).
#' @param weighting Character; determines how to calculate the weights for
#' individual features within the selected clusters. Only those features with
#' nonzero weight within the selected clusters will be returned. Must be one of
#' "sparse", "weighted_avg", or "simple_avg'. For "sparse", all the weight is
#' put on the most frequently
#' selected individual cluster member (or divided equally among all the clusters
#' that are tied for the top selection proportion if there is a tie). For
#' "weighted_avg", only the features within a selected cluster that were
#' themselves selected on at least one subsample will have nonzero weight. For
#' "simple_avg", each cluster member gets equal weight regardless of the
#' individual feature selection proportions (that is, all cluster members within
#' each selected cluster will be returned.). See Faletto and Bien (2022) for
#' details. Default is "sparse".
#' @param cutoff Numeric; getCssSelections will select and return only of those
#' clusters with selection proportions equal to at least cutoff. Must be between
#' 0 and 1. Default is 0 (in which case either all clusters are selected, or
#' max_num_clusts are selected, if max_num_clusts is specified).
#' @param min_num_clusts Integer or numeric; the minimum number of clusters to
#' use regardless of cutoff. (That is, if the chosen cutoff returns fewer than
#' min_num_clusts clusters, the cutoff will be increased until at least
#' min_num_clusts clusters are selected.) Default is 1.
#' @param max_num_clusts Integer or numeric; the maximum number of clusters to
#' use regardless of cutoff. (That is, if the chosen cutoff returns more than
#' max_num_clusts clusters, the cutoff will be decreased until at most
#' max_num_clusts clusters are selected.) Default is NA (in which case
#' max_num_clusts is ignored).
#' @return A named list with two items. \item{selected_clusts}{A named list of
#' integer vectors; each vector contains the indices of the features in one of
#' the selected clusters.} \item{selected_feats}{A named integer vector; the
#' indices of the features with nonzero weights from all of the selected
#' clusters.}
#' @author Gregory Faletto, Jacob Bien
#' @references 
<<faletto2022>>
#' @export
getCssSelections <- function(css_results, weighting="sparse", cutoff=0,
    min_num_clusts=1, max_num_clusts=NA){
    # Check inputs
    stopifnot(class(css_results) == "cssr")
    checkCutoff(cutoff)
    checkWeighting(weighting)

    p <- ncol(css_results$feat_sel_mat)

    checkMinNumClusts(min_num_clusts, p, length(css_results$clusters))

    max_num_clusts <- checkMaxNumClusts(max_num_clusts, min_num_clusts, p,
        length(css_results$clusters))

    sel_results <- getSelectedClusters(css_results, weighting, cutoff,
        min_num_clusts, max_num_clusts)

    # sel_results$selected_clusts is guaranteed to have length at least 1 by
    # getSelectedClusters
    sel_clust_names <- names(sel_results$selected_clusts)

    stopifnot(length(sel_clust_names) >= 1)
    stopifnot(all(sel_clust_names %in% names(css_results$clusters)))

    sel_clusts <- list()
    for(i in 1:length(sel_clust_names)){
        sel_clusts[[i]] <- css_results$clusters[[sel_clust_names[i]]]
        names(sel_clusts)[i] <- sel_clust_names[i]
    }

    stopifnot(is.list(sel_clusts))
    stopifnot(length(sel_clusts) == length(sel_clust_names))

    # sel_results$selected_feats is guaranteed to have length at least as long
    # as sel_results$selected_clusts by getSelectedClusters
    return(list(selected_clusts=sel_clusts,
        selected_feats=sel_results$selected_feats))
}

checkCutoff():

#' Helper function to confirm that the argument cutoff to several functions is
#' as expected
#'
#' @param cutoff Numeric; only those clusters with selection proportions equal
#' to at least cutoff will be selected by cluster stability selection. Must be
#' between 0 and 1.
#' @author Gregory Faletto, Jacob Bien
checkCutoff <- function(cutoff){
    stopifnot(is.numeric(cutoff) | is.integer(cutoff))
    stopifnot(length(cutoff) == 1)
    stopifnot(!is.na(cutoff))
    stopifnot(cutoff >= 0)
    stopifnot(cutoff <= 1)
}

Tests for checkCutoff():

testthat::test_that("checkCutoff works", {
  testthat::expect_null(checkCutoff(0))
  testthat::expect_null(checkCutoff(0.2))
  testthat::expect_null(checkCutoff(1))
  
  testthat::expect_error(checkCutoff(-.2), "cutoff >= 0 is not TRUE", fixed=TRUE)
  testthat::expect_error(checkCutoff(2), "cutoff <= 1 is not TRUE", fixed=TRUE)
  testthat::expect_error(checkCutoff(".3"),
                        "is.numeric(cutoff) | is.integer(cutoff) is not TRUE",
                        fixed=TRUE)
  testthat::expect_error(checkCutoff(matrix(1:12, nrow=4, ncol=3)),
                         "length(cutoff) == 1 is not TRUE", fixed=TRUE)
  testthat::expect_error(checkCutoff(numeric()),
                         "length(cutoff) == 1 is not TRUE", fixed=TRUE)
  testthat::expect_error(checkCutoff(as.numeric(NA)),
                         "!is.na(cutoff) is not TRUE", fixed=TRUE)

})
## Test passed 😀

checkWeighting():

#' Helper function to confirm that the argument weighting to several 
#' functions is as expected
#'
#' @param weighting Character; determines how to calculate the weights to
#' combine features from the selected clusters into weighted averages, called
#' cluster representatives. Must be one of "sparse", "weighted_avg", or
#' "simple_avg'.
#' @author Gregory Faletto, Jacob Bien
checkWeighting <- function(weighting){
    stopifnot(length(weighting)==1)
    stopifnot(!is.na(weighting))
    if(!is.character(weighting)){
        stop("Weighting must be a character")
    }
    if(!(weighting %in% c("sparse", "simple_avg", "weighted_avg"))){
        stop("Weighting must be a character and one of sparse, simple_avg, or weighted_avg")
    }
}

Tests for checkWeighting():

testthat::test_that("checkWeighting works", {
  testthat::expect_null(checkWeighting("sparse"))
  testthat::expect_null(checkWeighting("simple_avg"))
  testthat::expect_null(checkWeighting("weighted_avg"))
  
  testthat::expect_error(checkWeighting(c("sparse", "simple_avg")),
                         "length(weighting) == 1 is not TRUE", fixed=TRUE)
  testthat::expect_error(checkWeighting(NA), "!is.na(weighting) is not TRUE",
                         fixed=TRUE)
  testthat::expect_error(checkWeighting(1), "Weighting must be a character",
                         fixed=TRUE)
  testthat::expect_error(checkWeighting("spasre"),
                         "Weighting must be a character and one of sparse, simple_avg, or weighted_avg",
                         fixed=TRUE)
})
## Test passed 🌈

checkMinNumClusts():

#' Helper function to confirm that the argument min_num_clusts to several 
#' functions is as expected
#'
#' @param min_num_clusts Integer or numeric; the minimum number of clusters to
#' use regardless of cutoff. (That is, if the chosen cutoff returns fewer than
#' min_num_clusts clusters, the cutoff will be increased until at least
#' min_num_clusts clusters are selected.)
#' @param p The number of features; since this is an upper bound on the number
#' of clusters of features, it is also an upper bound on min_num_clusts.
#' @param n_clusters The number of clusters; note that this is an upper bound
#' on min_num_clusts
#' @author Gregory Faletto, Jacob Bien
checkMinNumClusts <- function(min_num_clusts, p, n_clusters){
    stopifnot(length(min_num_clusts) == 1)
    stopifnot(is.numeric(min_num_clusts) | is.integer(min_num_clusts))
    stopifnot(!is.na(min_num_clusts))
    stopifnot(min_num_clusts == round(min_num_clusts))
    stopifnot(min_num_clusts >= 1)
    stopifnot(min_num_clusts <= p)
    stopifnot(min_num_clusts <= n_clusters)
}

Tests for checkMinNumClusts():

testthat::test_that("checkMinNumClusts works", {
  testthat::expect_null(checkMinNumClusts(1, 5, 4))
  testthat::expect_null(checkMinNumClusts(6, 6, 6))
  testthat::expect_null(checkMinNumClusts(3, 1932, 3))
  
  testthat::expect_error(checkMinNumClusts(c(2, 4), 5, 4),
                         "length(min_num_clusts) == 1 is not TRUE", fixed=TRUE)
  testthat::expect_error(checkMinNumClusts("3", "1932", "3"),
                         "is.numeric(min_num_clusts) | is.integer(min_num_clusts) is not TRUE",
                         fixed=TRUE)
  testthat::expect_error(checkMinNumClusts(NA, NA, NA),
                         "is.numeric(min_num_clusts) | is.integer(min_num_clusts) is not TRUE",
                         fixed=TRUE)
  testthat::expect_error(checkMinNumClusts(as.numeric(NA), as.numeric(NA),
                                           as.numeric(NA)),
                         "!is.na(min_num_clusts) is not TRUE", fixed=TRUE)
  testthat::expect_error(checkMinNumClusts(0, 13, 7),
                         "min_num_clusts >= 1 is not TRUE", fixed=TRUE)
  testthat::expect_error(checkMinNumClusts(-1, 9, 8),
                         "min_num_clusts >= 1 is not TRUE", fixed=TRUE)
  testthat::expect_error(checkMinNumClusts(6, 5, 5),
                         "min_num_clusts <= p is not TRUE", fixed=TRUE)
  testthat::expect_error(checkMinNumClusts(6, 7, 5),
                         "min_num_clusts <= n_clusters is not TRUE", fixed=TRUE)
})
## Test passed 🥇

checkMaxNumClusts():

#' Helper function to confirm that the argument max_num_clusts to several 
#' functions is as expected
#'
#' @param max_num_clusts Integer or numeric; the maximum number of clusters to
#' use regardless of cutoff. (That is, if the chosen cutoff returns more than
#' max_num_clusts clusters, the cutoff will be decreased until at most
#' max_num_clusts clusters are selected.) Can be NA, in which case
#' max_num_clusts will be ignored.
#' @param min_num_clusts Integer or numeric; the minimum number of clusters to
#' use regardless of cutoff. (That is, if the chosen cutoff returns fewer than
#' min_num_clusts clusters, the cutoff will be increased until at least
#' min_num_clusts clusters are selected.) max_num_clusts must be at least as
#' large as min_num_clusts.
#' @param p The number of features; since this is an upper bound on the number
#' of clusters of features, it is also an upper bound on max_num_clusts.
#' @param n_clusters The number of clusters; note that this is an upper bound
#' on max_num_clusts
#' @return The provided max_num_clusts, coerced to an integer if needed, and
#' coerced to be less than or equal to the total number of clusters.
#' @author Gregory Faletto, Jacob Bien
checkMaxNumClusts <- function(max_num_clusts, min_num_clusts, p, n_clusters){
    stopifnot(length(max_num_clusts) == 1)
    if(!is.na(max_num_clusts)){
        stopifnot(is.numeric(max_num_clusts) | is.integer(max_num_clusts))
        stopifnot(max_num_clusts == round(max_num_clusts))
        stopifnot(max_num_clusts >= 1)
        stopifnot(max_num_clusts <= p)
        max_num_clusts <- as.integer(min(n_clusters, max_num_clusts))
        stopifnot(max_num_clusts >= min_num_clusts)
    }
    return(max_num_clusts)
}

Tests for checkMaxNumClusts():

testthat::test_that("checkMaxNumClusts works", {
  testthat::expect_equal(checkMaxNumClusts(max_num_clusts=4, min_num_clusts=1,
                                           p=5, n_clusters=4), 4)
  testthat::expect_equal(checkMaxNumClusts(max_num_clusts=5, min_num_clusts=1,
                                           p=5, n_clusters=4), 4)
  testthat::expect_true(is.na(checkMaxNumClusts(max_num_clusts=NA,
                                                min_num_clusts=3, p=5,
                                                n_clusters=4)))
  
  testthat::expect_error(checkMaxNumClusts(max_num_clusts="4", min_num_clusts=1,
                                           p=5, n_clusters=4),
                         "is.numeric(max_num_clusts) | is.integer(max_num_clusts) is not TRUE",
                         fixed=TRUE)
  testthat::expect_error(checkMaxNumClusts(max_num_clusts=3.2, min_num_clusts=2,
                                           p=5, n_clusters=4),
                         "max_num_clusts == round(max_num_clusts) is not TRUE",
                         fixed=TRUE)
  testthat::expect_error(checkMaxNumClusts(max_num_clusts=1, min_num_clusts=2,
                                           p=5, n_clusters=4),
                         "max_num_clusts >= min_num_clusts is not TRUE",
                         fixed=TRUE)
  
  testthat::expect_error(checkMaxNumClusts(max_num_clusts=c(3, 4),
                                           min_num_clusts=2,
                                           p=5, n_clusters=4),
                         "length(max_num_clusts) == 1 is not TRUE",
                         fixed=TRUE)
  
  testthat::expect_error(checkMaxNumClusts(max_num_clusts="4",
                                           min_num_clusts="2",
                                           p="5", n_clusters="4"),
                         "is.numeric(max_num_clusts) | is.integer(max_num_clusts) is not TRUE",
                         fixed=TRUE)
  
  testthat::expect_error(checkMaxNumClusts(max_num_clusts=-1, min_num_clusts=2,
                                           p=5, n_clusters=4),
                         "max_num_clusts >= 1 is not TRUE", fixed=TRUE)
  
  testthat::expect_error(checkMaxNumClusts(max_num_clusts=6, min_num_clusts=2,
                                           p=5, n_clusters=4),
                         "max_num_clusts <= p is not TRUE", fixed=TRUE)
  
  testthat::expect_error(checkMaxNumClusts(max_num_clusts=1, min_num_clusts=2,
                                           p=5, n_clusters=4),
                         "max_num_clusts >= min_num_clusts is not TRUE",
                         fixed=TRUE)
})
## Test passed 😀

getSelectedClusters():

#' From css output, obtain names of selected clusters and selection proportions,
#' indices of all selected features, and weights of individual cluster members
#'
#' If cutoff is too high for at least min_num_clusts clusters to be selected,
#' then it will be lowered until min_num_clusts can be selected. After that, if
#' the cutoff is too low such that more than max_num_clusts are selected, then
#' the cutoff will be increased until no more than max_num_clusts are selected.
#' Note that because clusters can have tied selection proportions, it is
#' possible that the number of selected clusters will be strictly lower than
#' max_num_clusts or strictly greater than min_num_clusts. In fact, it is
#' possible that both cutoffs won't be able to be satisfied simulteaneously,
#' even if there is a strictly positive difference between max_num_clusts and
#' min_num_clusts. If this occurs, max_num_clusts will take precedence over
#' min_num_clusts. getSelectedClusters will throw an error if the provided
#' inputs don't allow it to select any clusters. 
#' 
#' @param css_results An object of class "cssr" (the output of the function
#' css).
#' @param weighting Character; determines how to calculate the weights for
#' individual features within the selected clusters. Only those features with
#' nonzero weight within the selected clusters will be returned. Must be one of
#' "sparse", "weighted_avg", or "simple_avg'. For "sparse", all the weight is
#' put on the most frequently
#' selected individual cluster member (or divided equally among all the clusters
#' that are tied for the top selection proportion if there is a tie). For
#' "weighted_avg", only the features within a selected cluster that were
#' themselves selected on at least one subsample will have nonzero weight. For
#' "simple_avg", each cluster member gets equal weight regardless of the
#' individual feature selection proportions (that is, all cluster members within
#' each selected cluster will be returned.). See Faletto and Bien (2022) for
#' details.
#' @param cutoff Numeric; getCssSelections will select and return only of those
#' clusters with selection proportions equal to at least cutoff. Must be between
#' 0 and 1.
#' @param min_num_clusts Integer or numeric; the minimum number of clusters to
#' use regardless of cutoff. (That is, if the chosen cutoff returns fewer than
#' min_num_clusts clusters, the cutoff will be increased until at least
#' min_num_clusts clusters are selected.)
#' @param max_num_clusts Integer or numeric; the maximum number of clusters to
#' use regardless of cutoff. (That is, if the chosen cutoff returns more than
#' max_num_clusts clusters, the cutoff will be decreased until at most
#' max_num_clusts clusters are selected.) If NA, max_num_clusts is ignored.
#' @return A named list with the following elements: \item{selected_clusts}{A
#' named numeric vector containing the selection proportions for the selected
#' clusters. The name of each entry is the name of the corresponding cluster.}
#' \item{selected_feats}{A named integer vector; the indices of the features
#' with nonzero weights from all of the selected clusters.} \item{weights}{A
#' named list of the same length as the number of selected clusters. Each list
#' element weights[[j]] is a numeric vector of the weights to use for the jth
#' selected cluster, and it has the same name as the cluster it corresponds
#' to.}
#' @author Gregory Faletto, Jacob Bien
getSelectedClusters <- function(css_results, weighting, cutoff, min_num_clusts,
    max_num_clusts){
    # Check input
    stopifnot(class(css_results) == "cssr")

    # Eliminate clusters with selection proportions below cutoff
    clus_sel_props <- colMeans(css_results$clus_sel_mat)

    # Get selected clusters
    selected_clusts <- clus_sel_props[clus_sel_props >= cutoff]
    B <- nrow(css_results$feat_sel_mat)

    # Check that selected_clusts has length at least min_num_clusts
    while(length(selected_clusts) < min_num_clusts){
        cutoff <- cutoff - 1/B
        selected_clusts <- clus_sel_props[clus_sel_props >= cutoff]
    }

    # Check that selected_clusts has length at most max_num_clusts
    if(!is.na(max_num_clusts)){
        n_clusters <- ncol(css_results$clus_sel_mat)
        while(length(selected_clusts) > max_num_clusts){
            cutoff <- cutoff + 1/B
            if(cutoff > 1){
                break
            }
            # Make sure we don't reduce to a selected set of size 0
            if(any(clus_sel_props >= cutoff)){
                selected_clusts <- clus_sel_props[clus_sel_props >= cutoff]
            } else{
                break
            }
        }
    }

    stopifnot(length(selected_clusts) >= 1)

    clust_names <- names(selected_clusts)

    n_sel_clusts <- length(selected_clusts)

    # Check that n_sel_clusts is as expected, and throw warnings or an error if
    # not
    checkSelectedClusters(n_sel_clusts, min_num_clusts, max_num_clusts,
        max(clus_sel_props))
    
    ### Get selected features from selected clusters
    clusters <- css_results$clusters
    stopifnot(all(clust_names %in% names(clusters)))

    # Get a list of weights for all of the selected clusters
    weights <- getAllClustWeights(css_results, selected_clusts, weighting)

    # Get selected features from each cluster (those features with nonzero
    # weights)
    selected_feats <- integer()
    for(i in 1:n_sel_clusts){
        clus_i_name <- clust_names[i]
        clust_i <- clusters[[clus_i_name]]
        weights_i <- weights[[i]]
        selected_feats <- c(selected_feats, clust_i[weights_i != 0])
    }

    feat_names <- colnames(css_results$feat_sel_mat)

    names(selected_feats) <- feat_names[selected_feats]

    # Check output (already checked weights wihin getAllClustWeights)

    checkGetSelectedClustersOutput(selected_clusts, selected_feats,
        weights, n_clusters=length(clusters), p=ncol(css_results$feat_sel_mat))

    return(list(selected_clusts=selected_clusts,
        selected_feats=selected_feats, weights=weights))
}

checkSelectedClusters():

#' Helper function to check operations within getSelectedClusters function
#'
#' @param n_sel_clusts The number of selected clusters; should be constrained
#' by min_num_clusts and max_num_clusts (though it may not be possible to
#' satisfy both constraints simulteneously, in which case a warning will be
#' thrown).
#' @param min_num_clusts Integer or numeric; the minimum number of clusters to
#' use regardless of cutoff. (That is, if the chosen cutoff returns fewer than
#' min_num_clusts clusters, the cutoff will be increased until at least
#' min_num_clusts clusters are selected.)
#' @param max_num_clusts Integer or numeric; the maximum number of clusters to
#' use regardless of cutoff. (That is, if the chosen cutoff returns more than
#' max_num_clusts clusters, the cutoff will be decreased until at most
#' max_num_clusts clusters are selected.) If NA, max_num_clusts is ignored.
#' @param max_sel_prop Numeric; the maximum selection proportion observed for 
#' any cluster.
#' @author Gregory Faletto, Jacob Bien
checkSelectedClusters <- function(n_sel_clusts, min_num_clusts, max_num_clusts,
    max_sel_prop){
    if(n_sel_clusts == 0){
        err <- paste("No clusters selected with this cutoff (try a cutoff below the maximum cluster selection proportion, ",
            max_sel_prop, ")", sep="")
        stop(err)
    }

    stopifnot(n_sel_clusts >= 1)

    # It may be impossible to get at least min_num_clusts or at most
    # max_num_clusts; if so, give a warning
    if(n_sel_clusts < min_num_clusts){
        warn <- paste("Returning fewer than min_num_clusts = ", min_num_clusts,
            " clusters because decreasing the cutoff any further would require returning more than max_num_clusts = ",
            max_num_clusts, " clusters", sep="")
        warning(warn)
    }
    if(!is.na(max_num_clusts)){
        if(n_sel_clusts > max_num_clusts){
            warn <- paste("Returning more than max_num_clusts = ",
                max_num_clusts,
                " clusters because increasing the cutoff any further would require returning 0 clusters",
                sep="")
            warning(warn)
        }
    }
}

Test for checkSelectedClusters():

testthat::test_that("checkSelectedClusters works", {
  testthat::expect_null(checkSelectedClusters(n_sel_clusts=5, min_num_clusts=1,
                                              max_num_clusts=NA, max_sel_prop=.8))
  testthat::expect_null(checkSelectedClusters(n_sel_clusts=5, min_num_clusts=2,
                                              max_num_clusts=5, max_sel_prop=.3))
  testthat::expect_null(checkSelectedClusters(n_sel_clusts=2, min_num_clusts=2,
                                              max_num_clusts=5, max_sel_prop=.3))
  

  testthat::expect_error(checkSelectedClusters(n_sel_clusts=0, min_num_clusts=2,
                                               max_num_clusts=5,
                                               max_sel_prop=.6),
                         "No clusters selected with this cutoff (try a cutoff below the maximum cluster selection proportion, 0.6)",
                         fixed=TRUE)
  
  testthat::expect_warning(checkSelectedClusters(n_sel_clusts=1,
                                                 min_num_clusts=2,
                                                 max_num_clusts=5,
                                                 max_sel_prop=.6),
                         "Returning fewer than min_num_clusts = 2 clusters because decreasing the cutoff any further would require returning more than max_num_clusts = 5 clusters",
                         fixed=TRUE)
  testthat::expect_warning(checkSelectedClusters(n_sel_clusts=6,
                                                 min_num_clusts=2,
                                                 max_num_clusts=5,
                                                 max_sel_prop=.6),
                         "Returning more than max_num_clusts = 5 clusters because increasing the cutoff any further would require returning 0 clusters",
                         fixed=TRUE)
  
})
## Test passed 🥇

getAllClustWeights():

#' Calculate weights for each cluster member of all of the selected clusters.
#' 
#' @param css_results An object of class "cssr" (the output of the function
#' css).
#' @param sel_clusters A named numeric vector containing the selection
#' proportions for the selected clusters. The name of each entry is the name
#' of the corresponding cluster.
#' @param weighting Character; determines how to calculate the weights for
#' individual features within the selected clusters. Only those features with
#' nonzero weight within the selected clusters will be returned. Must be one of
#' "sparse", "weighted_avg", or "simple_avg'. For "sparse", all the weight is
#' put on the most frequently
#' selected individual cluster member (or divided equally among all the clusters
#' that are tied for the top selection proportion if there is a tie). For
#' "weighted_avg", only the features within a selected cluster that were
#' themselves selected on at least one subsample will have nonzero weight. For
#' "simple_avg", each cluster member gets equal weight regardless of the
#' individual feature selection proportions (that is, all cluster members within
#' each selected cluster will be returned.). See Faletto and Bien (2022) for
#' details.
#' @return A named list of the same length as sel_clusters of numeric vectors.
#' weights[[j]] is the weights to use for the jth selected cluster, and it has
#' the same name as the cluster it corresponds to.
#' @author Gregory Faletto, Jacob Bien
getAllClustWeights <- function(css_results, sel_clusters, weighting){

    # Check inputs
    stopifnot(class(css_results) == "cssr")

    stopifnot(is.numeric(sel_clusters))
    p_ret <- length(sel_clusters)
    stopifnot(length(unique(names(sel_clusters))) == p_ret)
    stopifnot(p_ret > 0)

    checkWeighting(weighting)

    # Get selection proportions and clusters
    feat_sel_props <- colMeans(css_results$feat_sel_mat)

    p <- length(feat_sel_props)
    stopifnot(p >= p_ret)

    clusters <- css_results$clusters
    stopifnot(all(names(sel_clusters) %in% names(clusters)))

    # Identify weights
    weights <- list()

    for(j in 1:p_ret){
        # Find the members of the cluster feature j is a member of
        cluster_j <- clusters[[names(sel_clusters)[j]]]
        # Get the weights for this cluster and add them to the list
        weights[[j]] <- getClustWeights(cluster_j, weighting, feat_sel_props)
    }

    # Add names to weights
    names(weights) <- names(sel_clusters)

    # Check output

    stopifnot(length(weights) == p_ret)
    stopifnot(is.list(weights))

    for(i in 1:p_ret){
        stopifnot(length(clusters[[names(sel_clusters)[i]]]) ==
            length(weights[[i]]))
        stopifnot(all(weights[[i]] >= 0))
        stopifnot(all(weights[[i]] <= 1))
        stopifnot(abs(sum(weights[[i]]) - 1) < 10^(-6))
    }
    return(weights)
}

getClustWeights():

#' Calculate weights for members of a cluster using selection proportions
#'
#' Given a cluster of features, the selection proportions for each cluster
#' member, and a specified weighting scheme, calculate the appropriate weights
#' for the cluster.
#' @param cluster_i An integer vector containing the indices of the members
#' of a cluster.
#' @param weighting Character; determines how to calculate the weights for
#' individual features within the selected clusters. Only those features with
#' nonzero weight within the selected clusters will be returned. Must be one of
#' "sparse", "weighted_avg", or "simple_avg'. For "sparse", all the weight is
#' put on the most frequently
#' selected individual cluster member (or divided equally among all the clusters
#' that are tied for the top selection proportion if there is a tie). For
#' "weighted_avg", only the features within a selected cluster that were
#' themselves selected on at least one subsample will have nonzero weight. For
#' "simple_avg", each cluster member gets equal weight regardless of the
#' individual feature selection proportions (that is, all cluster members within
#' each selected cluster will be returned.). See Faletto and Bien (2022) for
#' details.
#' @param feat_sel_props A numeric vector of selection proportions corresponding
#' to each of the p features.
#' @return A numeric vector of the same length as cluster_i containing the
#' weights corresponding to each of the features in cluster_i. The weights
#' will all be nonnegative and sum to 1.
#' @author Gregory Faletto, Jacob Bien
getClustWeights <- function(cluster_i, weighting, feat_sel_props){

    stopifnot(is.integer(cluster_i) | is.numeric(cluster_i))
    stopifnot(all(cluster_i == round(cluster_i)))
    n_weights <- length(cluster_i)
    stopifnot(length(unique(cluster_i)) == n_weights)

    p <- length(feat_sel_props)
    stopifnot(all(cluster_i %in% 1:p))

    # Get the selection proportions of each cluster member
    sel_props <- feat_sel_props[cluster_i]

    stopifnot(all(sel_props >= 0))
    stopifnot(all(sel_props <= 1))

    weights_i <- rep(as.numeric(NA), n_weights)

    # Weighted or simple average?
    if(weighting == "sparse"){
        # Sparse cluster stability selection: All features in cluster with
        # selection proportion equal to the max
        # for the cluster get equal weight; rest of cluster gets 0 weight
        if(sum(sel_props) == 0){
            weights_i <- rep(1/n_weights, n_weights)
        } else{
            maxes <- sel_props==max(sel_props)

            stopifnot(sum(maxes) > 0)
            stopifnot(sum(maxes) <= n_weights)

            weights_i <- rep(0, n_weights)
            weights_i[maxes] <- 1/sum(maxes)
        }
    } else if(weighting == "weighted_avg"){
        # Get weights for weighted average
        if(sum(sel_props) == 0){
            weights_i <- rep(1/n_weights, n_weights)
        } else{
            weights_i <- sel_props/sum(sel_props)
        }
    } else if(weighting == "simple_avg"){
        weights_i <- rep(1/n_weights, n_weights)
    } else{
        stop("weighting must be one of sparse, simple_avg, or weighted_avg")
    }

    stopifnot(abs(sum(weights_i) - 1) < 10^(-6))
    stopifnot(length(weights_i) == n_weights)
    stopifnot(length(weights_i) >= 1)
    stopifnot(all(weights_i >= 0))
    stopifnot(all(weights_i <= 1))

    return(weights_i)
}

Tests for getClustWeights():

testthat::test_that("getClustWeights works", {
  sel_props <- c(0.1, 0.3, 0.5, 0.7, 0.9)
  
  # sparse
  testthat::expect_identical(getClustWeights(cluster_i=c(3L, 4L, 5L),
                                             weighting="sparse",
                                             feat_sel_props=sel_props),
                             c(0, 0, 1))
  
  # weighted_avg
  cluster=c(1L, 3L, 5L)
  true_weights <- sel_props[cluster]/sum(sel_props[cluster])
  
  testthat::expect_identical(getClustWeights(cluster_i=cluster,
                                             weighting="weighted_avg",
                                             feat_sel_props=sel_props),
                             true_weights)
  
  # simple_avg
  testthat::expect_identical(getClustWeights(cluster_i=c(2L, 3L, 4L, 5L),
                                             weighting="simple_avg",
                                             feat_sel_props=sel_props),
                             rep(0.25, 4))
})
## Test passed 🥳

Tests for getAllClustWeights():

testthat::test_that("getAllClustWeights works", {
  
  set.seed(1872)
  
  x <- matrix(stats::rnorm(10*5), nrow=10, ncol=5)
  y <- stats::rnorm(10)
  
  clust_names <- letters[1:3]
  
  good_clusters <- list(1:2, 3:4, 5)
  
  names(good_clusters) <- clust_names
  
  res <- css(X=x, y=y, lambda=0.01, clusters=good_clusters, fitfun = cssLasso,
    sampling_type = "SS", B = 10, prop_feats_remove = 0, train_inds = integer(),
    num_cores = 1L)
  
  sel_props <- colMeans(res$feat_sel_mat)
  
  sel_clusts <- list(1L:2L, 3L:4L)
  
  names(sel_clusts) <- clust_names[1:2]
  
  # sparse
  true_weights <- list()
  
  for(i in 1:2){
    weights_i <- sel_props[sel_clusts[[i]]]/sum(sel_props[sel_clusts[[i]]])
    true_weights[[i]] <- rep(0, length(weights_i))
    true_weights[[i]][weights_i == max(weights_i)] <- 1
  }
  
  names(true_weights) <- clust_names[1:2]
  
  testthat::expect_identical(getAllClustWeights(res,
                                                colMeans(res$clus_sel_mat[, 1:2]),
                                                "sparse"), true_weights)

  # weighted_avg
  true_weights <- list()

  for(i in 1:2){
    true_weights[[i]] <- sel_props[sel_clusts[[i]]]/sum(sel_props[sel_clusts[[i]]])
  }
  
  names(true_weights) <- clust_names[1:2]

  testthat::expect_identical(getAllClustWeights(res,
                                                colMeans(res$clus_sel_mat[, 1:2]),
                                                "weighted_avg"), true_weights)

  # simple_avg
  true_weights <- list()

  for(i in 1:2){
    n_weights_i <- length(sel_clusts[[i]])
    true_weights[[i]] <- rep(1/n_weights_i, n_weights_i)
  }
  
  names(true_weights) <- clust_names[1:2]

  testthat::expect_identical(getAllClustWeights(res,
                                                colMeans(res$clus_sel_mat[, 1:2]),
                                                "simple_avg"), true_weights)

  # Errors

  # css_results not correct (error has quotation marks)
  testthat::expect_error(getAllClustWeights(1:4, colMeans(res$clus_sel_mat[,
                                                                           1:2]),
                                            "simple_avg"))

  bad_sel_clusts <- colMeans(res$clus_sel_mat[, 1:2])
  names(bad_sel_clusts) <- c("apple", "banana")
  testthat::expect_error(getAllClustWeights(res, bad_sel_clusts, "sparse"),
                         "all(names(sel_clusters) %in% names(clusters)) is not TRUE",
                         fixed=TRUE)


  testthat::expect_error(getAllClustWeights(res, colMeans(res$clus_sel_mat[,
                                                                           1:2]),
                                            c("sparse", "simple_avg")),
                         "length(weighting) == 1 is not TRUE", fixed=TRUE)

  testthat::expect_error(getAllClustWeights(res, colMeans(res$clus_sel_mat[,
                                                                           1:2]),
                                            NA),
                         "!is.na(weighting) is not TRUE", fixed=TRUE)

  testthat::expect_error(getAllClustWeights(res, colMeans(res$clus_sel_mat[,
                                                                           1:2]),
                                            1),
                         "Weighting must be a character", fixed=TRUE)

  testthat::expect_error(getAllClustWeights(res, colMeans(res$clus_sel_mat[,
                                                                           1:2]),
                                            "spasre"),
                         "Weighting must be a character and one of sparse, simple_avg, or weighted_avg",
                         fixed=TRUE)

})
## Test passed 🥇

checkGetSelectedClustersOutput():

#' Helper function to check that output of getSelectedClusters is as expected
#'
#' @param selected_clusts A named numeric vector containing the selection
#' proportions for the selected clusters. The name of each entry is the name of
#' the corresponding cluster.
#' @param selected_feats A named integer vector; the indices of the features
#' with nonzero weights from all of the selected clusters.
#' @param weights A named list of the same length as the number of selected
#' clusters. Each list element weights[[j]] is a numeric vector of the weights
#' to use for the jth selected cluster, and it has the same name as the cluster
#' it corresponds to.
#' @param n_clusters Integer; the number of clusters in the data (upper bound
#' for the length of selected_clusts)
#' @param p Integer; number of features in the data (all selected_feats should
#' be in 1:p)
#' @author Gregory Faletto, Jacob Bien
checkGetSelectedClustersOutput <- function(selected_clusts, selected_feats,
    weights, n_clusters, p){
    stopifnot(is.numeric(selected_clusts))
    stopifnot(all(selected_clusts >= 0))
    stopifnot(all(selected_clusts <= 1))
    stopifnot(length(selected_clusts) >= 1)
    stopifnot(length(selected_clusts) <= n_clusters)
    stopifnot(length(names(selected_clusts)) ==
        length(unique(names(selected_clusts))))
    stopifnot(!is.null(names(selected_clusts)))
    stopifnot(all(!is.na(names(selected_clusts)) &
        names(selected_clusts) != ""))
    stopifnot(length(names(selected_clusts)) == length(selected_clusts))
    stopifnot(is.integer(selected_feats))
    stopifnot(length(selected_feats) == length(unique(selected_feats)))
    stopifnot(all(selected_feats %in% 1:p))
    stopifnot(length(selected_clusts) <= length(selected_feats))
    stopifnot(identical(names(weights), names(selected_clusts)))
    stopifnot(length(weights) == length(selected_clusts)) 
}

Tests for checkGetSelectedClustersOutput():

testthat::test_that("checkGetSelectedClustersOutput works", {
  
  sel_clusts <- 0.1*(1:9)
  names(sel_clusts) <- letters[1:9]
  
  weights <- list()
  
  for(i in 1:8){
    weights[[i]] <- c(0.2, 0.3)
  }
  weights[[9]] <- 0.4
  names(weights) <- letters[1:9]
  
  sel_feats <- 10:26
  names(sel_feats) <- LETTERS[10:26]
  
  testthat::expect_null(checkGetSelectedClustersOutput(selected_clusts=sel_clusts,
                                                       selected_feats=sel_feats,
                                                       weights=weights,
                                                       n_clusters=10, p=30))
  
  testthat::expect_error(checkGetSelectedClustersOutput(selected_clusts=letters[1:4],
                                                       selected_feats=sel_feats,
                                                       weights=weights,
                                                       n_clusters=10, p=30),
                         "is.numeric(selected_clusts) is not TRUE",
                         fixed=TRUE)
  
  testthat::expect_error(checkGetSelectedClustersOutput(selected_clusts=-sel_clusts,
                                                       selected_feats=sel_feats,
                                                       weights=weights,
                                                       n_clusters=10, p=30),
                         "all(selected_clusts >= 0) is not TRUE",
                         fixed=TRUE)
  
  testthat::expect_error(checkGetSelectedClustersOutput(selected_clusts=10*sel_clusts,
                                                       selected_feats=sel_feats,
                                                       weights=weights,
                                                       n_clusters=10, p=30),
                         "all(selected_clusts <= 1) is not TRUE",
                         fixed=TRUE)
  
  testthat::expect_error(checkGetSelectedClustersOutput(selected_clusts=numeric(),
                                                       selected_feats=sel_feats,
                                                       weights=weights,
                                                       n_clusters=10, p=30),
                         "length(selected_clusts) >= 1 is not TRUE",
                         fixed=TRUE)
  
  testthat::expect_error(checkGetSelectedClustersOutput(selected_clusts=sel_clusts,
                               selected_feats=sel_feats,
                               weights=weights,
                               n_clusters=8, p=30),
                         "length(selected_clusts) <= n_clusters is not TRUE",
                         fixed=TRUE)
  
  bad_clusts <- sel_clusts
  names(bad_clusts) <- rep("a", length(bad_clusts))
  
  testthat::expect_error(checkGetSelectedClustersOutput(selected_clusts=bad_clusts,
                               selected_feats=sel_feats,
                               weights=weights,
                               n_clusters=10, p=30),
                         "length(names(selected_clusts)) == length(unique(names(selected_clusts))) is not TRUE",
                         fixed=TRUE)
  
  testthat::expect_error(checkGetSelectedClustersOutput(selected_clusts=unname(sel_clusts),
                                                       selected_feats=sel_feats,
                                                       weights=weights,
                                                       n_clusters=10, p=30),
                         "!is.null(names(selected_clusts)) is not TRUE",
                         fixed=TRUE)
  
  bad_clusts <- sel_clusts
  names(bad_clusts)[1] <- ""
  
  testthat::expect_error(checkGetSelectedClustersOutput(selected_clusts=bad_clusts,
                               selected_feats=sel_feats, weights=weights,
                               n_clusters=10, p=30),
                         "all(!is.na(names(selected_clusts)) & names(selected_clusts) !=  .... is not TRUE",
                         fixed=TRUE)
  
  names(bad_clusts)[1] <- as.character(NA)
  
  testthat::expect_error(checkGetSelectedClustersOutput(selected_clusts=bad_clusts,
                               selected_feats=sel_feats, weights=weights,
                               n_clusters=10, p=30),
                         "all(!is.na(names(selected_clusts)) & names(selected_clusts) !=  .... is not TRUE",
                         fixed=TRUE)
  
  testthat::expect_error(checkGetSelectedClustersOutput(selected_clusts=sel_clusts,
                                                       selected_feats=0.1,
                                                       weights=weights,
                                                       n_clusters=10, p=30),
                         "is.integer(selected_feats) is not TRUE",
                         fixed=TRUE)
  
  testthat::expect_error(checkGetSelectedClustersOutput(selected_clusts=sel_clusts,
                                                       selected_feats=c(1L,
                                                                        rep(2L,
                                                                            2)),
                                                       weights=weights,
                                                       n_clusters=10, p=30),
                         "length(selected_feats) == length(unique(selected_feats)) is not TRUE",
                         fixed=TRUE)
  
  testthat::expect_error(checkGetSelectedClustersOutput(selected_clusts=sel_clusts,
                               selected_feats=sel_feats, weights=weights,
                               n_clusters=10, p=25),
                         "all(selected_feats %in% 1:p) is not TRUE",
                         fixed=TRUE)
  
  testthat::expect_error(checkGetSelectedClustersOutput(selected_clusts=sel_clusts,
                               selected_feats=sel_feats[1:8], weights=weights,
                               n_clusters=10, p=25),
                         "length(selected_clusts) <= length(selected_feats) is not TRUE",
                         fixed=TRUE)
  
})
## Test passed 🥇

Tests for getSelectedClusters()

testthat::test_that("getSelectedClusters works", {
  set.seed(26717)
  
  x <- matrix(stats::rnorm(10*5), nrow=10, ncol=5)
  y <- stats::rnorm(10)
  
  good_clusters <- list("apple"=1:2, "banana"=3:4, "cantaloupe"=5)
  
  css_res <- css(X=x, y=y, lambda=0.01, clusters=good_clusters, B = 10)

  res <- getSelectedClusters(css_res, weighting="sparse", cutoff=0.05,
                             min_num_clusts=1, max_num_clusts=NA)

  testthat::expect_true(is.list(res))
  testthat::expect_equal(length(res), 3)
  testthat::expect_identical(names(res), c("selected_clusts", "selected_feats",
                                           "weights"))
  testthat::expect_true(length(res$selected_clusts) <=
                          length(res$selected_feats))

  testthat::expect_true(is.numeric(res$selected_clusts))
  testthat::expect_true(length(res$selected_clusts) >= 1)
  testthat::expect_equal(length(names(res$selected_clusts)),
                         length(res$selected_clusts))
  testthat::expect_equal(length(names(res$selected_clusts)),
                         length(unique(names(res$selected_clusts))))
  testthat::expect_true(all(res$selected_clusts >= 0))
  testthat::expect_true(all(res$selected_clusts <= 1))

  testthat::expect_true(is.integer(res$selected_feats))
  testthat::expect_true(length(res$selected_feats) >= 1)
  testthat::expect_equal(length(names(res$selected_feats)),
                         length(unique(names(res$selected_feats))))
  testthat::expect_true(all(res$selected_feats >= 1))
  testthat::expect_true(all(res$selected_feats <= 5))
  testthat::expect_equal(length(res$selected_feats),
                             length(unique(res$selected_feats)))

  testthat::expect_equal(length(res$selected_clusts), length(res$weights))
  for(i in 1:length(res$weights)){
    weights_i <- res$weights[[i]]
    num_nonzero_weights <- sum(weights_i > 0)
    # For "sparse" weighting, tither there should only be one nonzero weight and
    # it should equal 1 (if there were no ties in selection proportions among
    # cluster members) or the nonzero weights should all be
    # 1/num_nonzero_weights
    testthat::expect_true(all(weights_i[weights_i > 0] == 1/num_nonzero_weights))
  }

  # weighted_avg
  res_weighted <- getSelectedClusters(css_res, weighting="weighted_avg",
                                      cutoff=0.05, min_num_clusts=1,
                                      max_num_clusts=NA)

  testthat::expect_equal(length(res_weighted$selected_clusts),
                         length(res_weighted$weights))
  for(i in 1:length(res_weighted$weights)){
    weights_i <- res_weighted$weights[[i]]
    testthat::expect_true(all(weights_i >= 0))
    testthat::expect_true(all(weights_i <= 1))
  }

  # simple_avg
  res_simple <- getSelectedClusters(css_res, weighting="simple_avg",
                                    cutoff=0.05, min_num_clusts=1,
                                    max_num_clusts=NA)

  testthat::expect_equal(length(res_simple$selected_clusts),
                         length(res_simple$weights))
  for(i in 1:length(res_simple$weights)){
    weights_i <- res_simple$weights[[i]]
    testthat::expect_equal(length(unique(weights_i)), 1)
    testthat::expect_equal(length(weights_i), sum(weights_i > 0))
  }

  # Test min_num_clusts
  res2 <- getSelectedClusters(css_res, weighting="weighted_avg", cutoff=1,
                             min_num_clusts=3, max_num_clusts=NA)
  testthat::expect_true(is.list(res2))
  testthat::expect_equal(length(res2$selected_clusts), 3)

  res3 <- getSelectedClusters(css_res, weighting="sparse", cutoff=1,
                             min_num_clusts=2, max_num_clusts=NA)
  testthat::expect_true(length(res3$selected_clusts) >= 2)

  # Test max_num_clusts
  # Ensure there is at least one relevant feature
  x2 <- x
  x2[, 5] <- y
  css_res2 <- css(X=x2, y=y, lambda=0.01, clusters=good_clusters, B = 10)
  res4 <- getSelectedClusters(css_res2, weighting="simple_avg", cutoff=0,
                             min_num_clusts=1, max_num_clusts=1)
  testthat::expect_true(is.list(res4))
  testthat::expect_equal(length(res4$selected_clusts), 1)

  res5 <- getSelectedClusters(css_res, weighting="weighted_avg", cutoff=0,
                             min_num_clusts=1, max_num_clusts=2)
  testthat::expect_true(length(res5$selected_clusts) <= 2)
  
  # Name features
  colnames(x) <- LETTERS[1:ncol(x)]
  css_res3 <- css(X=x, y=y, lambda=0.01, clusters=good_clusters, B = 10)
  res <- getSelectedClusters(css_res3, weighting="sparse", cutoff=0.05,
                             min_num_clusts=1, max_num_clusts=NA)

  testthat::expect_true(is.list(res))
  testthat::expect_equal(length(res), 3)
  testthat::expect_identical(names(res), c("selected_clusts", "selected_feats",
                                           "weights"))
  testthat::expect_true(length(res$selected_clusts) <=
                          length(res$selected_feats))

  testthat::expect_true(is.numeric(res$selected_clusts))
  testthat::expect_true(length(res$selected_clusts) >= 1)
  testthat::expect_equal(length(names(res$selected_clusts)),
                         length(res$selected_clusts))
  testthat::expect_equal(length(names(res$selected_clusts)),
                         length(unique(names(res$selected_clusts))))
  testthat::expect_true(all(res$selected_clusts >= 0))
  testthat::expect_true(all(res$selected_clusts <= 1))

  testthat::expect_true(is.integer(res$selected_feats))
  testthat::expect_true(length(res$selected_feats) >= 1)
  testthat::expect_equal(length(names(res$selected_feats)),
                         length(unique(names(res$selected_feats))))
  testthat::expect_true(all(res$selected_feats >= 1))
  testthat::expect_true(all(res$selected_feats <= 5))
  testthat::expect_equal(length(res$selected_feats),
                             length(unique(res$selected_feats)))
  testthat::expect_equal(length(names(res$selected_feats)),
                         length(res$selected_feats))
  testthat::expect_equal(length(names(res$selected_feats)),
                         length(unique(names(res$selected_feats))))
})
## Test passed 🎊

Finally, tests for getCssSelections()

testthat::test_that("getCssSelections works", {

  set.seed(26717)
  
  x <- matrix(stats::rnorm(10*7), nrow=10, ncol=7)
  y <- stats::rnorm(10)
  
  good_clusters <- list("apple"=1:2, "banana"=3:4, "cantaloupe"=5)
  
  css_res <- css(X=x, y=y, lambda=0.01, clusters=good_clusters, B = 10)

  res <- getCssSelections(css_res)

  testthat::expect_true(is.list(res))
  testthat::expect_equal(length(res), 2)
  testthat::expect_identical(names(res), c("selected_clusts", "selected_feats"))
  testthat::expect_true(length(res$selected_clusts) <=
                          length(res$selected_feats))

  testthat::expect_true(is.list(res$selected_clusts))
  testthat::expect_equal(length(names(res$selected_clusts)),
                           length(res$selected_clusts))
  testthat::expect_equal(length(names(res$selected_clusts)),
                           length(unique(names(res$selected_clusts))))
  already_used_feats <- integer()
  for(i in 1:length(res$selected_clusts)){
    sels_i <- res$selected_clusts[[i]]
    testthat::expect_true(length(sels_i) >= 1)
    testthat::expect_true(is.integer(sels_i))
    testthat::expect_true(all(sels_i %in% 1:11))
    testthat::expect_equal(length(sels_i), length(unique(sels_i)))
    testthat::expect_equal(length(intersect(already_used_feats, sels_i)), 0)
    already_used_feats <- c(already_used_feats, sels_i)
  }
  testthat::expect_true(length(already_used_feats) <= 11)
  testthat::expect_equal(length(already_used_feats),
                         length(unique(already_used_feats)))
  testthat::expect_true(all(already_used_feats %in% 1:11))

  testthat::expect_true(is.integer(res$selected_feats))
  testthat::expect_true(length(res$selected_feats) >= 1)
  testthat::expect_equal(length(names(res$selected_feats)),
                         length(unique(names(res$selected_feats))))
  testthat::expect_true(all(res$selected_feats >= 1))
  testthat::expect_true(all(res$selected_feats <= 7))
  testthat::expect_equal(length(res$selected_feats),
                             length(unique(res$selected_feats)))

  # Test min_num_clusts (should be 5 clusters--3 named ones, plus last two get
  # put in their own unnamed clusters automatically by css)
  res2 <- getCssSelections(css_res, weighting="weighted_avg", cutoff=1,
                             min_num_clusts=5, max_num_clusts=NA)
  testthat::expect_true(is.list(res2))
  testthat::expect_equal(length(res2$selected_clusts), 5)

  res3 <- getCssSelections(css_res, weighting="sparse", cutoff=1,
                             min_num_clusts=3, max_num_clusts=NA)
  testthat::expect_true(length(res3$selected_clusts) >= 3)

  # Test max_num_clusts
  # Ensure there is at least one relevant feature
  x2 <- x
  x2[, 5] <- y
  css_res2 <- css(X=x2, y=y, lambda=0.01, clusters=good_clusters, B = 10)
  res4 <- getCssSelections(css_res2, weighting="simple_avg", cutoff=0,
                             min_num_clusts=1, max_num_clusts=1)
  testthat::expect_true(is.list(res4))
  testthat::expect_equal(length(res4$selected_clusts), 1)

  res5 <- getCssSelections(css_res, weighting="weighted_avg", cutoff=0,
                             min_num_clusts=1, max_num_clusts=2)
  testthat::expect_true(length(res5$selected_clusts) <= 2)

  # Name features
  colnames(x) <- LETTERS[1:ncol(x)]
  css_res3 <- css(X=x, y=y, lambda=0.01, clusters=good_clusters, B = 10)
  res <- getCssSelections(css_res3, weighting="sparse", cutoff=0.05,
                             min_num_clusts=1, max_num_clusts=NA)

  testthat::expect_true(is.list(res))
  testthat::expect_equal(length(res), 2)
  testthat::expect_identical(names(res), c("selected_clusts", "selected_feats"))
  testthat::expect_true(length(res$selected_clusts) <=
                          length(res$selected_feats))

  testthat::expect_equal(length(names(res$selected_feats)),
                         length(res$selected_feats))
  testthat::expect_equal(length(names(res$selected_feats)),
                         length(unique(names(res$selected_feats))))

  # Bad inputs
  # Error has quotation marks in it
  testthat::expect_error(getCssSelections("css_results"))
  testthat::expect_error(getCssSelections(css_res, weighting="spasre"),
                         "Weighting must be a character and one of sparse, simple_avg, or weighted_avg",
                         fixed=TRUE)
  testthat::expect_error(getCssSelections(css_res, cutoff=-.5),
                         "cutoff >= 0 is not TRUE", fixed=TRUE)
  testthat::expect_error(getCssSelections(css_res, min_num_clusts=0),
                         "min_num_clusts >= 1 is not TRUE", fixed=TRUE)
  testthat::expect_error(getCssSelections(css_res, min_num_clusts=0),
                         "min_num_clusts >= 1 is not TRUE", fixed=TRUE)
  testthat::expect_error(getCssSelections(css_res, max_num_clusts=50),
                         "max_num_clusts <= p is not TRUE", fixed=TRUE)
  testthat::expect_error(getCssSelections(css_res, max_num_clusts=4.5),
                         "max_num_clusts == round(max_num_clusts) is not TRUE",
                         fixed=TRUE)
})
## Test passed 🥳

getCssDesign()

#' Obtain a design matrix of cluster representatives
#'
#' Takes a matrix of observations from the original feature space and returns
#' a matrix of representatives from the selected clusters based on the results
#' of cluster stability selection.
#' @param css_results An object of class "cssr" (the output of the function
#' css).
#' @param newX A numeric matrix (preferably) or a data.frame (which will
#' be coerced internally to a matrix by the function model.matrix) containing
#' the data that will be used to generate the design matrix of cluster
#' representatives. Must contain the same features (in the same
#' number of columns) as the X matrix provided to css, and if the columns of
#' newX are labeled, the names must match the variable names provided to css.
#' newX may be omitted if train_inds were provided to css to set aside
#' observations for model estimation. If this is the case, then when newX is
#' omitted getCssDesign will return a design matrix of cluster representatives
#' formed from the train_inds observations from the matrix X provided to css.
#' (If no train_inds were provided to css, newX must be provided to
#' getCssDesign.) Default is NA.
#' @param weighting Character; determines how to calculate the weights to
#' combine features from the selected clusters into weighted averages, called
#' cluster representatives. Must be one of "sparse", "weighted_avg", or
#' "simple_avg'. For "sparse", all the weight is put on the most frequently
#' selected individual cluster member (or divided equally among all the clusters
#' that are tied for the top selection proportion if there is a tie). For
#' "weighted_avg", the weight used for each cluster member is calculated in
#' proportion to the individual selection proportions of each feature. For
#' "simple_avg", each cluster member gets equal weight regardless of the
#' individual feature selection proportions (that is, the cluster representative
#' is just a simple average of all the cluster members). See Faletto and Bien
#' (2022) for details. Default is "weighted_avg".
#' @param cutoff Numeric; getCssDesign will only include those clusters with
#' selection proportions equal to at least cutoff. Must be between 0 and 1.
#' Default is 0 (in which case either all clusters are used, or max_num_clusts
#' are used, if max_num_clusts is specified).
#' @param min_num_clusts Integer or numeric; the minimum number of clusters to
#' use regardless of cutoff. (That is, if the chosen cutoff returns fewer than
#' min_num_clusts clusters, the cutoff will be increased until at least
#' min_num_clusts clusters are selected.) Default is 1.
#' @param max_num_clusts Integer or numeric; the maximum number of clusters to
#' use regardless of cutoff. (That is, if the chosen cutoff returns more than
#' max_num_clusts clusters, the cutoff will be decreased until at most
#' max_num_clusts clusters are selected.) Default is NA (in which case
#' max_num_clusts is ignored).
#' @return A design matrix with either nrow(newX) (or length(train_inds), if
#' train_inds was provided to css and newX was not provided to getCssDesign)
#' observations and number of columns equal to the number of selected clusters,
#' containing the cluster representatives for each cluster.
#' @author Gregory Faletto, Jacob Bien
#' @export
getCssDesign <- function(css_results, newX=NA, weighting="weighted_avg",
    cutoff=0, min_num_clusts=1, max_num_clusts=NA){
    # Check inputs
    stopifnot(class(css_results) == "cssr")

    check_results <- checkNewXProvided(newX, css_results)

    newX <- check_results$newX
    newXProvided <- check_results$newXProvided

    rm(check_results)

    n_train <- nrow(newX)

    results <- checkXInputResults(newX, css_results$X)

    newX <- results$newx
    feat_names <- results$feat_names

    rm(results)

    n <- nrow(newX)
    p <- ncol(newX)

    checkCutoff(cutoff)
    checkWeighting(weighting)
    checkMinNumClusts(min_num_clusts, p, length(css_results$clusters))

    max_num_clusts <- checkMaxNumClusts(max_num_clusts, min_num_clusts, p,
        length(css_results$clusters))

    # Take provided training design matrix and testX and turn them into
    # matrices of cluster representatives using information from css_results
    if(newXProvided){
        newX_clusters <- formCssDesign(css_results, weighting, cutoff,
            min_num_clusts, max_num_clusts, newx=newX)
    } else{
        newX_clusters <- formCssDesign(css_results, weighting, cutoff,
            min_num_clusts, max_num_clusts)
    }

    return(newX_clusters)
}

checkNewXProvided()

#' Helper function to confirm that the new X matrix provided to getCssDesign or
#' getCssPreds matches the characteristics of the X that was provided to css.
#'
#' @param trainX A numeric matrix (preferably) or a data.frame (which will
#' be coerced internally to a matrix by the function model.matrix). Must contain
#' the same features (in the same number of columns) as the X matrix provided to
#' css, and if the columns of trainX are labeled, the names must match the
#' variable names provided to css. trainX may be omitted if train_inds were
#' provided to css to set aside observations.
#' @param css_results An object of class "cssr" (the output of the function
#' css).
#' @return A named list with the following elements: \item{newX}{If trainX was
#' provided, this is the provided trainX matrix, coerced from a data.frame to a
#' matrix if the provided trainX was a data.frame. If trainX was not provided,
#' this is a matrix made up of the training indices provided to css in the
#' train_inds argument.} \item{newXProvided}{Logical; indicates whether a valid
#' trainX input was provided.}
#' @author Gregory Faletto, Jacob Bien
checkNewXProvided <- function(trainX, css_results){
    newXProvided <- FALSE

    if(all(!is.na(trainX)) & length(trainX) > 1){
        newXProvided <- TRUE
        trainX <- checkXInputResults(trainX, css_results$X)$newx
        
        n_train <- nrow(trainX)
        stopifnot(n_train > 1)
    } else{
        if(length(css_results$train_inds) == 0){
            stop("css was not provided with indices to set aside for model training (train_inds), so must provide new X in order to generate a design matrix")
        }
        trainX <- css_results$X[css_results$train_inds, ]
    } 
    stopifnot(is.matrix(trainX))
    stopifnot(is.numeric(trainX) | is.integer(trainX))
    stopifnot(all(!is.na(trainX)))
    stopifnot(ncol(trainX) >= 2)

    return(list(newX=trainX, newXProvided=newXProvided))
}

checkXInputResults()

#' Helper function to confirm that inputs to several functions are as expected,
#' and modify inputs if needed
#'
#' @param newx A numeric matrix (preferably) or a data.frame (which will
#' be coerced internally to a matrix by the function model.matrix) containing
#' the data that will be used to generate the design matrix of cluster
#' representatives. Must contain the same features (in the same
#' number of columns) as the X matrix provided to css, and if the columns of
#' newX are labeled, the names must match the variable names provided to css.
#' @param css_X The X matrix provided to css, as in the output of the css
#' function (after having been coerced from a data.frame to a matrix by css if
#' needed).
#' @return A named list with the following elements. \item{feat_names}{A 
#' character vector containing the column names of newx (if the provided newx
#' had column names). If the provided newx did not have column names, feat_names
#' will be NA.} \item{newx}{The provided newx matrix, coerced from a data.frame
#' to a matrix if the provided newx was a data.frame.}
#' @author Gregory Faletto, Jacob Bien
checkXInputResults <- function(newx, css_X){

    # Check if x is a matrix; if it's a data.frame, convert to matrix.
    if(is.data.frame(newx)){
        newx <- stats::model.matrix(~ ., newx)
        newx <- newx[, colnames(newx) != "(Intercept)"]
    }

    feat_names <- as.character(NA)
    if(!is.null(colnames(newx))){
        feat_names <- colnames(newx)
        stopifnot(identical(feat_names, colnames(css_X)))
    } else{
        # In this case, newx has no column names, so same better be true of
        # css_X
        if(!is.null(colnames(css_X))){
            warning("New X provided had no variable names (column names) even though the X provided to css did.")
        }
    }

    stopifnot(is.matrix(newx))
    stopifnot(all(!is.na(newx)))

    n <- nrow(newx)
    p <- ncol(newx)
    stopifnot(p >= 2)
    if(length(feat_names) > 1){
        stopifnot(length(feat_names) == p)
        stopifnot(!("(Intercept)" %in% feat_names))
    } else{
        stopifnot(is.na(feat_names))
    }

    colnames(newx) <- character()

    # Confirm that newx matches css_results$X
    if(p != ncol(css_X)){
        err <- paste("Number of columns in newx must match number of columns from matrix provided to css. Number of columns in new provided X: ",
            p, ". Number of columns in matrix provided to css: ", ncol(css_X),
            ".", sep="")
        stop(err)
    }
    if(length(feat_names) != 1 & all(!is.na(feat_names))){
        if(!identical(feat_names, colnames(css_X))){
            stop("Provided feature names for newx do not match feature names provided to css")
        }
    }

    return(list(feat_names=feat_names, newx=newx))
}

Tests for checkXInputResults()

testthat::test_that("checkXInputResults works", {
  set.seed(72617)

  x_select <- matrix(stats::rnorm(10*5), nrow=10, ncol=5)
  x_new <- matrix(stats::rnorm(8*5), nrow=8, ncol=5)
  y_select <- stats::rnorm(10)
  y_new <- stats::rnorm(8)

  good_clusters <- list("red"=1:2, "blue"=3:4, "green"=5)

  css_res <- css(X=x_select, y=y_select, lambda=0.01, clusters=good_clusters,
                 B = 10)
  
  res <- checkXInputResults(x_new, css_res$X)

  testthat::expect_true(is.list(res))
  testthat::expect_identical(names(res), c("feat_names", "newx"))
  
  testthat::expect_true(is.character(res$feat_names))
  testthat::expect_true(is.na(res$feat_names))

  testthat::expect_true(is.numeric(res$newx))
  testthat::expect_true(is.matrix(res$newx))
  testthat::expect_equal(nrow(res$newx), 8)
  testthat::expect_equal(ncol(res$newx), 5)
  testthat::expect_null(colnames(res$newx))

  # Try naming variables
  
  colnames(x_select) <- LETTERS[1:5]
  css_res_named <- css(X=x_select, y=y_select, lambda=0.01,
                       clusters=good_clusters, B = 10)
  
  # Named variables for css matrix but not new one--should get a warning
  testthat::expect_warning(checkXInputResults(x_new, css_res_named$X),
                           "New X provided had no variable names (column names) even though the X provided to css did.",
                           fixed=TRUE)
  
  # Try mismatching variable names
  colnames(x_new) <- LETTERS[2:6]
  testthat::expect_error(checkXInputResults(x_new, css_res_named$X),
                           "identical(feat_names, colnames(css_X)) is not TRUE",
                           fixed=TRUE)
  
  colnames(x_new) <- LETTERS[1:5]
  
  res_named <- checkXInputResults(x_new, css_res_named$X)
  
  testthat::expect_true(is.list(res_named))
  testthat::expect_identical(names(res_named), c("feat_names", "newx"))
  
  testthat::expect_true(is.character(res_named$feat_names))
  testthat::expect_identical(res_named$feat_names, LETTERS[1:5])

  # Try data.frame input to css and checkXInputResults

  X_df <- datasets::mtcars

  n <- nrow(X_df)
  y <- stats::rnorm(n)

  selec_inds <- 1:round(n/2)
  fit_inds <- setdiff(1:n, selec_inds)

  css_res_df <- css(X=X_df[selec_inds, ], y=y[selec_inds], lambda=0.01, B = 10)
  res_df <- checkXInputResults(X_df[fit_inds, ], css_res_df$X)

  testthat::expect_true(is.list(res_df))
  testthat::expect_identical(names(res_df), c("feat_names", "newx"))

  testthat::expect_true(is.character(res_df$feat_names))
  testthat::expect_identical(res_df$feat_names, colnames(css_res_df$X))
  testthat::expect_identical(res_df$feat_names, colnames(X_df))

  testthat::expect_true(is.numeric(res_df$newx))
  testthat::expect_true(is.matrix(res_df$newx))
  testthat::expect_null(colnames(res_df$newx))
  testthat::expect_equal(ncol(res_df$newx), ncol(css_res_df$X))

  # Try again with X as a dataframe with factors (number of columns of final
  # design matrix after one-hot encoding factors won't match number of columns
  # of X_df)
  # cyl, gear, and carb are factors with more than 2 levels
  X_df$cyl <- as.factor(X_df$cyl)
  X_df$vs <- as.factor(X_df$vs)
  X_df$am <- as.factor(X_df$am)
  X_df$gear <- as.factor(X_df$gear)
  X_df$carb <- as.factor(X_df$carb)

  css_res_df <- css(X=X_df[selec_inds, ], y=y[selec_inds], lambda=0.01, B = 10)
  res_df <- checkXInputResults(X_df[fit_inds, ], css_res_df$X)

  testthat::expect_true(is.list(res_df))
  testthat::expect_identical(names(res_df), c("feat_names", "newx"))

  testthat::expect_true(is.character(res_df$feat_names))
  testthat::expect_identical(res_df$feat_names, colnames(css_res_df$X))

  mat <- model.matrix( ~., X_df)
  mat <- mat[, colnames(mat) != "(Intercept)"]

  testthat::expect_identical(res_df$feat_names, colnames(mat))

  testthat::expect_true(is.numeric(res_df$newx))
  testthat::expect_true(is.matrix(res_df$newx))
  testthat::expect_null(colnames(res_df$newx))
  testthat::expect_equal(ncol(res_df$newx), ncol(css_res_df$X))
})
## Test passed 🌈

Tests for checkNewXProvided()

testthat::test_that("checkNewXProvided works", {
  set.seed(2673)

  x_select <- matrix(stats::rnorm(10*5), nrow=10, ncol=5)
  x_new <- matrix(stats::rnorm(8*5), nrow=8, ncol=5)
  y_select <- stats::rnorm(10)
  y_new <- stats::rnorm(8)

  good_clusters <- list("red"=1:2, "blue"=3:4, "green"=5)

  css_res <- css(X=x_select, y=y_select, lambda=0.01, clusters=good_clusters,
                 B = 10)
  
  res <- checkNewXProvided(x_new, css_res)

  testthat::expect_true(is.list(res))
  testthat::expect_identical(names(res), c("newX", "newXProvided"))
  
  testthat::expect_true(is.numeric(res$newX))
  testthat::expect_true(is.matrix(res$newX))
  testthat::expect_equal(nrow(res$newX), 8)
  testthat::expect_equal(ncol(res$newX), 5)
  testthat::expect_null(colnames(res$newX))
  
  testthat::expect_true(is.logical(res$newXProvided))
  testthat::expect_equal(length(res$newXProvided), 1)
  testthat::expect_true(!is.na(res$newXProvided))
  testthat::expect_true(res$newXProvided)
  
  # Add training indices
  css_res_train <- css(X=x_select, y=y_select, lambda=0.01,
                       clusters=good_clusters, B = 10, train_inds=6:10)
  
  # Training indices should be ignored if new x is provided
  
  res <- checkNewXProvided(x_new, css_res_train)

  testthat::expect_true(is.list(res))
  testthat::expect_identical(names(res), c("newX", "newXProvided"))
  
  testthat::expect_true(all(abs(x_new - res$newX) < 10^(-9)))
  testthat::expect_true(res$newXProvided)
  
  # Things should still work if new x is not provided
  
  res <- checkNewXProvided(NA, css_res_train)

  testthat::expect_true(is.list(res))
  testthat::expect_identical(names(res), c("newX", "newXProvided"))
  
  testthat::expect_true(is.numeric(res$newX))
  testthat::expect_true(is.matrix(res$newX))
  testthat::expect_equal(nrow(res$newX), 5)
  testthat::expect_equal(ncol(res$newX), 5)
  testthat::expect_null(colnames(res$newX))
  
  testthat::expect_false(res$newXProvided)
  
  # Try not providing training indices and omitting newx--should get error
  testthat::expect_error(checkNewXProvided(NA, css_res),
                         "css was not provided with indices to set aside for model training (train_inds), so must provide new X in order to generate a design matrix", fixed=TRUE)
  
  # Try naming variables

  colnames(x_select) <- LETTERS[1:5]
  css_res_named <- css(X=x_select, y=y_select, lambda=0.01,
                       clusters=good_clusters, B = 10)

  # Named variables for css matrix but not new one--should get a warning
  testthat::expect_warning(checkNewXProvided(x_new, css_res_named),
                           "New X provided had no variable names (column names) even though the X provided to css did.", fixed=TRUE)

  # Try mismatching variable names
  colnames(x_new) <- LETTERS[2:6]
  testthat::expect_error(checkNewXProvided(x_new, css_res_named),
                         "identical(feat_names, colnames(css_X)) is not TRUE",
                         fixed=TRUE)

  colnames(x_new) <- LETTERS[1:5]

  res_named <- checkNewXProvided(x_new, css_res_named)

  testthat::expect_true(is.list(res_named))
  testthat::expect_identical(names(res_named), c("newX", "newXProvided"))
  
  testthat::expect_true(all(abs(x_new - res_named$newX) < 10^(-9)))
  testthat::expect_true(res_named$newXProvided)

  # Try data.frame input to css and checkNewXProvided

  X_df <- datasets::mtcars

  n <- nrow(X_df)
  y <- stats::rnorm(n)

  selec_inds <- 1:round(n/2)
  fit_inds <- setdiff(1:n, selec_inds)

  css_res_df <- css(X=X_df[selec_inds, ], y=y[selec_inds], lambda=0.01, B = 10)
  res_df <- checkNewXProvided(X_df[fit_inds, ], css_res_df)

  testthat::expect_true(is.list(res_df))
  testthat::expect_identical(names(res_df), c("newX", "newXProvided"))
  
  testthat::expect_true(is.numeric(res_df$newX))
  testthat::expect_true(is.matrix(res_df$newX))
  testthat::expect_equal(nrow(res_df$newX), length(fit_inds))
  testthat::expect_equal(ncol(res_df$newX), ncol(css_res_df$X))
  testthat::expect_null(colnames(res_df$newX))
  
  testthat::expect_true(is.logical(res_df$newXProvided))
  testthat::expect_equal(length(res_df$newXProvided), 1)
  testthat::expect_true(!is.na(res_df$newXProvided))
  testthat::expect_true(res_df$newXProvided)
  
  # Try again with X as a dataframe with factors (number of columns of final
  # design matrix after one-hot encoding factors won't match number of columns
  # of X_df)
  X_df$cyl <- as.factor(X_df$cyl)
  X_df$vs <- as.factor(X_df$vs)
  X_df$am <- as.factor(X_df$am)
  X_df$gear <- as.factor(X_df$gear)
  X_df$carb <- as.factor(X_df$carb)

  css_res_df <- css(X=X_df[selec_inds, ], y=y[selec_inds], lambda=0.01, B = 10)
  res_df <- checkNewXProvided(X_df[fit_inds, ], css_res_df)

  testthat::expect_true(is.list(res_df))
  testthat::expect_identical(names(res_df), c("newX", "newXProvided"))
  
  testthat::expect_true(is.numeric(res_df$newX))
  testthat::expect_true(is.matrix(res_df$newX))
  testthat::expect_equal(nrow(res_df$newX), length(fit_inds))
  testthat::expect_equal(ncol(res_df$newX), ncol(css_res_df$X))
  testthat::expect_null(colnames(res_df$newX))
  
  testthat::expect_true(is.logical(res_df$newXProvided))
  testthat::expect_equal(length(res_df$newXProvided), 1)
  testthat::expect_true(!is.na(res_df$newXProvided))
  testthat::expect_true(res_df$newXProvided)
  
})
## Test passed 😀

formCssDesign():

#' Create design matrix of cluster representatives from matrix of raw features
#' using results of css function
#'
#' @param css_results An object of class "cssr" (the output of the function
#' css).
#' @param weighting Character; determines how to calculate the weights to
#' combine features from the selected clusters into weighted averages, called
#' cluster representatives. Must be one of "sparse", "weighted_avg", or
#' "simple_avg'. For "sparse", all the weight is put on the most frequently
#' selected individual cluster member (or divided equally among all the clusters
#' that are tied for the top selection proportion if there is a tie). For
#' "weighted_avg", the weight used for each cluster member is calculated in
#' proportion to the individual selection proportions of each feature. For
#' "simple_avg", each cluster member gets equal weight regardless of the
#' individual feature selection proportions (that is, the cluster representative
#' is just a simple average of all the cluster members). See Faletto and Bien
#' (2022) for details. Default is "weighted_avg".
#' @param cutoff Numeric; css will return only those clusters with selection
#' proportions equal to at least cutoff. Must be between 0 and 1. Default is 0
#' (in which case all clusters are returned in decreasing order of selection
#' proportion).
#' @param min_num_clusts Integer or numeric; the minimum number of clusters to
#' use regardless of cutoff. (That is, if the chosen cutoff returns fewer than
#' min_num_clusts clusters, the cutoff will be increased until at least
#' min_num_clusts clusters are selected.) Default is 1.
#' @param max_num_clusts Integer or numeric; the maximum number of clusters to
#' use regardless of cutoff. (That is, if the chosen cutoff returns more than
#' max_num_clusts clusters, the cutoff will be decreased until at most
#' max_num_clusts clusters are selected.) Default is NA (in which case
#' max_num_clusts is ignored).
#' @param newx A numeric matrix (preferably) or a data.frame (which will
#' be coerced internally to a matrix by the function model.matrix) containing
#' the data that will be used to generate the design matrix of cluster
#' representatives. Must contain the same features (in the same
#' number of columns) as the X matrix provided to css, and if the columns of
#' newx are labeled, the names must match the variable names provided to css.
#' newx may be omitted if train_inds were provided to css to set aside
#' observations for model estimation. If this is the case, then when newx is
#' omitted formCssDesign will return a design matrix of cluster representatives
#' formed from the train_inds observations from the matrix X provided to css.
#' (If no train_inds were provided to css, newX must be provided to
#' formCssDesign.) Default is NA.
#' @return A design matrix with the same number of rows as newx (or the 
#' train_inds provided to css) where the columns are the constructed cluster
#' representatives.
#' @author Gregory Faletto, Jacob Bien
#' @references
<<faletto2022>>
formCssDesign <- function(css_results, weighting="weighted_avg", cutoff=0,
    min_num_clusts=1, max_num_clusts=NA, newx=NA){

    # Check inputs
    ret <- checkFormCssDesignInputs(css_results, weighting, cutoff,
        min_num_clusts, max_num_clusts, newx)

    newx <- ret$newx
    max_num_clusts <- ret$max_num_clusts

    rm(ret)

    n <- nrow(newx)
    p <- ncol(newx)

    # Get the names of the selected clusters and the weights for the features
    # within each cluster, according to the provided weighting rule
    weights <- getSelectedClusters(css_results, weighting, cutoff,
        min_num_clusts, max_num_clusts)$weights

    n_sel_clusts <- length(weights)

    # Form matrix of cluster representatives of selected clusters
    X_clus_reps <- matrix(rep(as.numeric(NA), n*n_sel_clusts), nrow=n,
        ncol=n_sel_clusts)
    colnames(X_clus_reps) <- rep(as.character(NA), n_sel_clusts)

    for(i in 1:n_sel_clusts){
        clust_i_name <- names(weights)[i]

        stopifnot(length(clust_i_name) == 1)
        stopifnot(clust_i_name %in% names(weights))

        colnames(X_clus_reps)[i] <- clust_i_name

        clust_i <- css_results$clusters[[clust_i_name]]

        stopifnot(length(clust_i) >= 1)
        stopifnot(all(clust_i) %in% 1:p)

        weights_i <- weights[[clust_i_name]]

        stopifnot(length(clust_i) == length(weights_i))

        if(length(weights_i) > 1){
            X_clus_reps[, i] <- newx[, clust_i] %*% weights_i
        } else{
            X_clus_reps[, i] <- newx[, clust_i]*weights_i
        }
    }

    # Check output
    stopifnot(all(!is.na(X_clus_reps)))
    stopifnot(ncol(X_clus_reps) == n_sel_clusts)
    stopifnot(nrow(X_clus_reps) == n)

    return(X_clus_reps)
}

checkFormCssDesignInputs():

#' Helper function to check that the inputs to formCssDesign are as expected
#'
#' @param css_results An object of class "cssr" (the output of the function
#' css).
#' @param weighting Character; determines how to calculate the weights to
#' combine features from the selected clusters into weighted averages, called
#' cluster representatives. Must be one of "sparse", "weighted_avg", or
#' "simple_avg'.
#' @param cutoff Numeric; css will return only those clusters with selection
#' proportions equal to at least cutoff. Must be between 0 and 1.
#' @param min_num_clusts Integer or numeric; the minimum number of clusters to
#' use regardless of cutoff. (That is, if the chosen cutoff returns fewer than
#' min_num_clusts clusters, the cutoff will be increased until at least
#' min_num_clusts clusters are selected.)
#' @param max_num_clusts Integer or numeric; the maximum number of clusters to
#' use regardless of cutoff. (That is, if the chosen cutoff returns more than
#' max_num_clusts clusters, the cutoff will be decreased until at most
#' max_num_clusts clusters are selected.)
#' @param newx A numeric matrix (preferably) or a data.frame (which will
#' be coerced internally to a matrix by the function model.matrix) containing
#' the data that will be used to generate the design matrix of cluster
#' representatives. Must contain the same features (in the same
#' number of columns) as the X matrix provided to css, and if the columns of
#' newx are labeled, the names must match the variable names provided to css.
#' newx may be omitted if train_inds were provided to css to set aside
#' observations for model estimation. If this is the case, then when newx is
#' omitted formCssDesign will return a design matrix of cluster representatives
#' formed from the train_inds observations from the matrix X provided to css.
#' (If no train_inds were provided to css, newX must be provided to
#' formCssDesign.)
#' @return A named list with the following elements: \item{newx}{If newx was
#' provided, the provided newx matrix, coerced from a data.frame to a matrix if
#' needed. If newx was not provided, a matrix formed by the train_inds set
#' aside in the original function call to css.} \item{max_num_clusts}{The
#' provided max_num_clusts, coerced to an integer if needed, and coerced to be
#' less than or equal to the total number of clusters.}
#' @author Gregory Faletto, Jacob Bien
checkFormCssDesignInputs <- function(css_results, weighting, cutoff,
    min_num_clusts, max_num_clusts, newx){    
    stopifnot(class(css_results) == "cssr")

    if(length(newx) == 1){
        if(is.na(newx)){
            if(length(css_results$train_inds) == 0){
                stop("If css was not provided with indices to set aside for model training, then newx must be provided to formCssDesign")
            }
            newx <- css_results$X[css_results$train_inds, ]
            # feat_names <- colnames(newx)
        } else{
            results <- checkXInputResults(newx, css_results$X)

            newx <- results$newx
            # feat_names <- results$feat_names

            rm(results)
        }
    } else{
        results <- checkXInputResults(newx, css_results$X)

        newx <- results$newx
        # feat_names <- results$feat_names

        rm(results)
    }

    p <- ncol(newx)

    checkCutoff(cutoff)
    checkWeighting(weighting)
    checkMinNumClusts(min_num_clusts, p, length(css_results$clusters))
    max_num_clusts <- checkMaxNumClusts(max_num_clusts, min_num_clusts, p,
        length(css_results$clusters))

    return(list(newx=newx, max_num_clusts=max_num_clusts))
}

Tests for checkFormCssDesignInputs()

testthat::test_that("checkFormCssDesignInputs works", {
  set.seed(72617)

  x_select <- matrix(stats::rnorm(10*6), nrow=10, ncol=6)
  x_new <- matrix(stats::rnorm(8*6), nrow=8, ncol=6)
  y_select <- stats::rnorm(10)
  y_new <- stats::rnorm(8)

  good_clusters <- list("red"=1:2, "blue"=3:4, "green"=5)

  css_res <- css(X=x_select, y=y_select, lambda=0.01, clusters=good_clusters,
                 B = 10)
  
  res <- checkFormCssDesignInputs(css_results=css_res, weighting="sparse",
                                  cutoff=0.5, min_num_clusts=1,
                                  max_num_clusts=NA, newx=x_new)

  testthat::expect_true(is.list(res))
  testthat::expect_identical(names(res), c("newx", "max_num_clusts"))
  
  testthat::expect_true(is.numeric(res$newx))
  testthat::expect_true(is.matrix(res$newx))
  testthat::expect_equal(nrow(res$newx), 8)
  testthat::expect_equal(ncol(res$newx), 6)
  testthat::expect_null(colnames(res$newx))
  testthat::expect_true(all(abs(x_new - res$newX) < 10^(-9)))
  
  testthat::expect_equal(length(res$max_num_clusts), 1)
  testthat::expect_true(is.na(res$max_num_clusts))
  
  # Add training indices
  css_res_train <- css(X=x_select, y=y_select, lambda=0.01,
                       clusters=good_clusters, B=10, train_inds=6:10)
  
  # Training indices should be ignored if new x is provided
  
  res <- checkFormCssDesignInputs(css_results=css_res_train,
                                  weighting="weighted_avg", cutoff=0,
                                  min_num_clusts=2, max_num_clusts=NA,
                                  newx=x_new)

  testthat::expect_true(is.list(res))
  testthat::expect_identical(names(res), c("newx", "max_num_clusts"))
  
  testthat::expect_true(is.numeric(res$newx))
  testthat::expect_true(is.matrix(res$newx))
  testthat::expect_equal(nrow(res$newx), 8)
  testthat::expect_equal(ncol(res$newx), 6)
  testthat::expect_null(colnames(res$newx))
  testthat::expect_true(all(abs(x_new - res$newX) < 10^(-9)))
  
  # Things should still work if new x is not provided

  res <- checkFormCssDesignInputs(css_results=css_res_train, weighting="sparse",
                                  cutoff=1, min_num_clusts=3,
                                  max_num_clusts=NA, newx=NA)

  testthat::expect_true(is.list(res))
  testthat::expect_identical(names(res), c("newx", "max_num_clusts"))

  testthat::expect_true(is.numeric(res$newx))
  testthat::expect_true(is.matrix(res$newx))
  testthat::expect_equal(nrow(res$newx), length(6:10))
  testthat::expect_equal(ncol(res$newx), 6)
  testthat::expect_null(colnames(res$newx))
  testthat::expect_true(all(abs(x_select[1:5, ] - res$newX) < 10^(-9)))
  
  # Try not providing training indices and omitting newx--should get error
  testthat::expect_error(checkFormCssDesignInputs(css_results=css_res,
                                                  weighting="sparse",
                                                  cutoff=0.5, min_num_clusts=1,
                                                  max_num_clusts=5, newx=NA),
                         "If css was not provided with indices to set aside for model training, then newx must be provided to formCssDesign", fixed=TRUE)
  
  # Try naming variables

  colnames(x_select) <- LETTERS[1:6]
  css_res_named <- css(X=x_select, y=y_select, lambda=0.01,
                       clusters=good_clusters, B = 10)

  # Named variables for css matrix but not new one--should get a warning
  testthat::expect_warning(checkFormCssDesignInputs(css_results=css_res_named,
                                                    weighting="simple_avg",
                                                    cutoff=0.9,
                                                    min_num_clusts=1,
                                                    max_num_clusts=3,
                                                    newx=x_new),
                           "New X provided had no variable names (column names) even though the X provided to css did.", fixed=TRUE)

  # Try mismatching variable names
  colnames(x_new) <- LETTERS[2:7]
  testthat::expect_error(checkFormCssDesignInputs(css_results=css_res_named,
                                                  weighting="weighted_avg",
                                                  cutoff=0.2, min_num_clusts=1,
                                                  max_num_clusts=1,
                                                  newx=x_new),
                         "identical(feat_names, colnames(css_X)) is not TRUE",
                         fixed=TRUE)

  colnames(x_new) <- LETTERS[1:6]

  res_named <- checkFormCssDesignInputs(css_results=css_res_named,
                                        weighting="sparse", cutoff=0.5,
                                        min_num_clusts=2, max_num_clusts=NA,
                                        newx=x_new)

  testthat::expect_true(is.list(res_named))
  testthat::expect_identical(names(res_named), c("newx", "max_num_clusts"))

  testthat::expect_true(is.numeric(res_named$newx))
  testthat::expect_true(is.matrix(res_named$newx))
  testthat::expect_equal(nrow(res_named$newx), 8)
  testthat::expect_equal(ncol(res_named$newx), 6)
  testthat::expect_null(colnames(res_named$newx))
  testthat::expect_identical(colnames(css_res_named$X), LETTERS[1:6])
  testthat::expect_true(all(abs(x_new - res_named$newX) < 10^(-9)))

  # Try data.frame input to css and checkFormCssDesignInputs

  X_df <- datasets::mtcars

  n <- nrow(X_df)
  y <- stats::rnorm(n)

  selec_inds <- 1:round(n/2)
  fit_inds <- setdiff(1:n, selec_inds)

  css_res_df <- css(X=X_df[selec_inds, ], y=y[selec_inds], lambda=0.01, B = 10)
  res_df <- checkFormCssDesignInputs(css_results=css_res_df,
                                     weighting="simple_avg", cutoff=0.7,
                                     min_num_clusts=3, max_num_clusts=NA,
                                     newx=X_df[fit_inds, ])

  testthat::expect_true(is.list(res_df))
  testthat::expect_identical(names(res_df), c("newx", "max_num_clusts"))
  
  testthat::expect_true(is.numeric(res_df$newx))
  testthat::expect_true(is.matrix(res_df$newx))
  testthat::expect_null(colnames(res_df$newx))
  testthat::expect_equal(nrow(res_df$newx), length(fit_inds))
  testthat::expect_equal(ncol(res_df$newx), ncol(css_res_df$X))

  # Try again with X as a dataframe with factors (number of columns of final
  # design matrix after one-hot encoding factors won't match number of columns
  # of X_df)
  X_df$cyl <- as.factor(X_df$cyl)
  X_df$vs <- as.factor(X_df$vs)
  X_df$am <- as.factor(X_df$am)
  X_df$gear <- as.factor(X_df$gear)
  X_df$carb <- as.factor(X_df$carb)

  css_res_df <- css(X=X_df[selec_inds, ], y=y[selec_inds], lambda=0.01, B = 10)
  res_df <- checkFormCssDesignInputs(css_results=css_res_df,
                                     weighting="weighted_avg", cutoff=0.3,
                                     min_num_clusts=1, max_num_clusts=4,
                                     newx=X_df[fit_inds, ])

  testthat::expect_true(is.list(res_df))
  testthat::expect_identical(names(res_df), c("newx", "max_num_clusts"))
  
  testthat::expect_true(is.numeric(res_df$newx))
  testthat::expect_true(is.matrix(res_df$newx))
  testthat::expect_null(colnames(res_df$newx))
  testthat::expect_equal(nrow(res_df$newx), length(fit_inds))
  testthat::expect_equal(ncol(res_df$newx), ncol(css_res_df$X))
  
  ##### Try other bad inputs
  
  colnames(x_new) <- NULL
  
  testthat::expect_error(checkFormCssDesignInputs(css_results=css_res,
                                                  weighting="weighted_avg",
                                                  cutoff=-0.3, min_num_clusts=1,
                                                  max_num_clusts=4,
                                                  newx=x_new),
                         "cutoff >= 0 is not TRUE", fixed=TRUE)

  testthat::expect_error(checkFormCssDesignInputs(css_results=css_res,
                                                  weighting="sparse",
                                                  cutoff="0.5",
                                                  min_num_clusts=1,
                                                  max_num_clusts=NA, newx=x_new),
                        "is.numeric(cutoff) | is.integer(cutoff) is not TRUE",
                        fixed=TRUE)
  
  testthat::expect_error(checkFormCssDesignInputs(css_results=css_res,
                                                  weighting="sparse",
                                                  cutoff=as.numeric(NA),
                                                  min_num_clusts=1,
                                                  max_num_clusts=NA,
                                                  newx=x_new),
                        "!is.na(cutoff) is not TRUE", fixed=TRUE)
  
  testthat::expect_error(checkFormCssDesignInputs(css_results=css_res,
                                                  weighting=c("sparse",
                                                              "simple_avg"),
                                                  cutoff=0.2,
                                                  min_num_clusts=1,
                                                  max_num_clusts=NA,
                                                  newx=x_new),
                         "length(weighting) == 1 is not TRUE", fixed=TRUE)
  
  testthat::expect_error(checkFormCssDesignInputs(css_results=css_res,
                                                  weighting=1,
                                                  cutoff=0.2,
                                                  min_num_clusts=1,
                                                  max_num_clusts=NA,
                                                  newx=x_new),
                         "Weighting must be a character", fixed=TRUE)
  
  testthat::expect_error(checkFormCssDesignInputs(css_results=css_res,
                                                  weighting="spasre",
                                                  cutoff=0.2,
                                                  min_num_clusts=1,
                                                  max_num_clusts=NA,
                                                  newx=x_new),
                         "Weighting must be a character and one of sparse, simple_avg, or weighted_avg",
                         fixed=TRUE)
  
  testthat::expect_error(checkFormCssDesignInputs(css_results=css_res,
                                                  weighting="weighted_avg",
                                                  cutoff=0.2,
                                                  min_num_clusts=c(1, 2),
                                                  max_num_clusts=NA,
                                                  newx=x_new),
                         "length(min_num_clusts) == 1 is not TRUE", fixed=TRUE)
  
  testthat::expect_error(checkFormCssDesignInputs(css_results=css_res,
                                                  weighting="weighted_avg",
                                                  cutoff=0.2,
                                                  min_num_clusts="3",
                                                  max_num_clusts=NA,
                                                  newx=x_new),
                         "is.numeric(min_num_clusts) | is.integer(min_num_clusts) is not TRUE",
                         fixed=TRUE)

  testthat::expect_error(checkFormCssDesignInputs(css_results=css_res,
                                                  weighting="weighted_avg",
                                                  cutoff=0.2,
                                                  min_num_clusts=0,
                                                  max_num_clusts=NA,
                                                  newx=x_new),
                         "min_num_clusts >= 1 is not TRUE", fixed=TRUE)

  testthat::expect_error(checkFormCssDesignInputs(css_results=css_res,
                                                  weighting="weighted_avg",
                                                  cutoff=0.2,
                                                  min_num_clusts=6,
                                                  max_num_clusts=NA,
                                                  newx=x_new),
                         "min_num_clusts <= n_clusters is not TRUE", fixed=TRUE)
  
  
  testthat::expect_error(checkFormCssDesignInputs(css_results=css_res,
                                                  weighting="weighted_avg",
                                                  cutoff=0.2,
                                                  min_num_clusts=1,
                                                  max_num_clusts="4",
                                                  newx=x_new),
                         "is.numeric(max_num_clusts) | is.integer(max_num_clusts) is not TRUE",
                         fixed=TRUE)
  
  testthat::expect_error(checkFormCssDesignInputs(css_results=css_res,
                                                  weighting="weighted_avg",
                                                  cutoff=0.2,
                                                  min_num_clusts=1,
                                                  max_num_clusts=3.5,
                                                  newx=x_new),
                         "max_num_clusts == round(max_num_clusts) is not TRUE",
                         fixed=TRUE)
  
  testthat::expect_error(checkFormCssDesignInputs(css_results=css_res,
                                                  weighting="weighted_avg",
                                                  cutoff=0.2,
                                                  min_num_clusts=2,
                                                  max_num_clusts=1,
                                                  newx=x_new),
                         "max_num_clusts >= min_num_clusts is not TRUE",
                         fixed=TRUE)

  testthat::expect_error(checkFormCssDesignInputs(css_results=css_res,
                                                  weighting="weighted_avg",
                                                  cutoff=0.2,
                                                  min_num_clusts=2,
                                                  max_num_clusts=8,
                                                  newx=x_new),
                         "max_num_clusts <= p is not TRUE", fixed=TRUE)

  
})
## Test passed 🥇

Tests for formCssDesign()

testthat::test_that("formCssDesign works", {
  set.seed(17230)

  x_select <- matrix(stats::rnorm(10*6), nrow=10, ncol=6)
  x_new <- matrix(stats::rnorm(8*6), nrow=8, ncol=6)
  y_select <- stats::rnorm(10)
  y_new <- stats::rnorm(8)

  good_clusters <- list("red"=1:2, "blue"=3:4, "green"=5)

  css_res <- css(X=x_select, y=y_select, lambda=0.01, clusters=good_clusters,
                 B = 10)
  
  res <- formCssDesign(css_res, newx=x_new)

  testthat::expect_true(is.matrix(res))
  testthat::expect_true(is.numeric(res))
  testthat::expect_equal(nrow(res), 8)
  testthat::expect_equal(ncol(res), length(css_res$clusters))
  testthat::expect_true(all(colnames(res) %in% names(css_res$clusters)))
  testthat::expect_true(all(names(css_res$clusters) %in% colnames(res)))
  
  # Add training indices
  css_res_train <- css(X=x_select, y=y_select, lambda=0.01,
                       clusters=good_clusters, B=10, train_inds=6:10)

  # Training indices should be ignored if new x is provided

  res <- formCssDesign(css_results=css_res_train, weighting="weighted_avg",
                       cutoff=0, min_num_clusts=2, max_num_clusts=NA,
                       newx=x_new)

  testthat::expect_true(is.matrix(res))
  testthat::expect_true(is.numeric(res))
  testthat::expect_equal(nrow(res), 8)
  testthat::expect_equal(ncol(res), length(css_res_train$clusters))
  testthat::expect_true(all(colnames(res) %in% names(css_res_train$clusters)))
  testthat::expect_true(all(names(css_res_train$clusters) %in% colnames(res)))

  # Things should still work if new x is not provided

  res <- formCssDesign(css_results=css_res_train, weighting="weighted_avg",
                       cutoff=0, min_num_clusts=2, max_num_clusts=NA)

  testthat::expect_true(is.matrix(res))
  testthat::expect_true(is.numeric(res))
  testthat::expect_equal(nrow(res), 5)
  testthat::expect_equal(ncol(res), length(css_res_train$clusters))
  testthat::expect_true(all(colnames(res) %in% names(css_res_train$clusters)))
  testthat::expect_true(all(names(css_res_train$clusters) %in% colnames(res)))

  # Try not providing training indices and omitting newx--should get error
  testthat::expect_error(formCssDesign(css_results=css_res, weighting="sparse",
                                       cutoff=0.5, min_num_clusts=1,
                                       max_num_clusts=5, newx=NA),
                         "If css was not provided with indices to set aside for model training, then newx must be provided to formCssDesign", fixed=TRUE)

  # Try naming variables

  colnames(x_select) <- LETTERS[1:6]
  css_res_named <- css(X=x_select, y=y_select, lambda=0.01,
                       clusters=good_clusters, B = 10)

  # Named variables for css matrix but not new one--should get a warning
  testthat::expect_warning(formCssDesign(css_results=css_res_named,
                                         weighting="simple_avg", cutoff=0.9,
                                         min_num_clusts=1, max_num_clusts=3,
                                         newx=x_new),
                           "New X provided had no variable names (column names) even though the X provided to css did.", fixed=TRUE)

  # Try mismatching variable names
  colnames(x_new) <- LETTERS[2:7]
  testthat::expect_error(formCssDesign(css_results=css_res_named,
                                       weighting="weighted_avg", cutoff=0.2,
                                       min_num_clusts=1, max_num_clusts=1,
                                       newx=x_new),
                         "identical(feat_names, colnames(css_X)) is not TRUE",
                         fixed=TRUE)

  colnames(x_new) <- LETTERS[1:6]

  res_named <- formCssDesign(css_results=css_res_named,
                                        weighting="sparse", cutoff=0.5,
                                        min_num_clusts=2, max_num_clusts=NA,
                                        newx=x_new)
  
  testthat::expect_true(is.matrix(res_named))
  testthat::expect_true(is.numeric(res_named))
  testthat::expect_equal(nrow(res_named), 8)
  testthat::expect_true(ncol(res_named) <= length(css_res_named$clusters))
  testthat::expect_true(all(colnames(res_named) %in% names(css_res_named$clusters)))

  # Try data.frame input to css and formCssDesign

  X_df <- datasets::mtcars

  n <- nrow(X_df)
  y <- stats::rnorm(n)

  selec_inds <- 1:round(n/2)
  fit_inds <- setdiff(1:n, selec_inds)

  css_res_df <- css(X=X_df[selec_inds, ], y=y[selec_inds], lambda=0.01, B = 10)
  res_df <- formCssDesign(css_results=css_res_df, weighting="simple_avg",
                          cutoff=0.7, min_num_clusts=3, max_num_clusts=NA,
                          newx=X_df[fit_inds, ])

  testthat::expect_true(is.matrix(res_df))
  testthat::expect_true(is.numeric(res_df))
  testthat::expect_equal(nrow(res_df), length(fit_inds))
  testthat::expect_true(ncol(res_df) <= length(css_res_df$clusters))
  testthat::expect_true(all(colnames(res_df) %in% names(css_res_df$clusters)))

  # Try again with X as a dataframe with factors (number of columns of final
  # design matrix after one-hot encoding factors won't match number of columns
  # of X_df)
  X_df$cyl <- as.factor(X_df$cyl)
  X_df$vs <- as.factor(X_df$vs)
  X_df$am <- as.factor(X_df$am)
  X_df$gear <- as.factor(X_df$gear)
  X_df$carb <- as.factor(X_df$carb)

  css_res_df <- css(X=X_df[selec_inds, ], y=y[selec_inds], lambda=0.01, B = 10)
  res_df <- formCssDesign(css_results=css_res_df, weighting="weighted_avg",
                          cutoff=0.3, min_num_clusts=1, max_num_clusts=4,
                          newx=X_df[fit_inds, ])

  testthat::expect_true(is.matrix(res_df))
  testthat::expect_true(is.numeric(res_df))
  testthat::expect_equal(nrow(res_df), length(fit_inds))
  testthat::expect_true(ncol(res_df) <= length(css_res_df$clusters))
  testthat::expect_true(all(colnames(res_df) %in% names(css_res_df$clusters)))

  ##### Try other bad inputs

  colnames(x_new) <- NULL

  testthat::expect_error(formCssDesign(css_results=css_res, cutoff=-0.3,
                                       newx=x_new), "cutoff >= 0 is not TRUE",
                         fixed=TRUE)

  testthat::expect_error(formCssDesign(css_results=css_res, cutoff="0.5",
                                       newx=x_new),
                         "is.numeric(cutoff) | is.integer(cutoff) is not TRUE",
                        fixed=TRUE)

  testthat::expect_error(formCssDesign(css_results=css_res,
                                       cutoff=as.numeric(NA), newx=x_new),
                        "!is.na(cutoff) is not TRUE", fixed=TRUE)

  testthat::expect_error(formCssDesign(css_results=css_res,
                                       weighting=c("sparse", "simple_avg"),
                                       newx=x_new),
                         "length(weighting) == 1 is not TRUE", fixed=TRUE)

  testthat::expect_error(formCssDesign(css_results=css_res, weighting=1,
                                       newx=x_new),
                         "Weighting must be a character", fixed=TRUE)

  testthat::expect_error(formCssDesign(css_results=css_res, weighting="spasre",
                                       newx=x_new),
                         "Weighting must be a character and one of sparse, simple_avg, or weighted_avg",
                         fixed=TRUE)

  testthat::expect_error(formCssDesign(css_results=css_res,
                                       min_num_clusts=c(1, 2), newx=x_new),
                         "length(min_num_clusts) == 1 is not TRUE", fixed=TRUE)

  testthat::expect_error(formCssDesign(css_results=css_res, min_num_clusts="3",
                                       newx=x_new),
                         "is.numeric(min_num_clusts) | is.integer(min_num_clusts) is not TRUE",
                         fixed=TRUE)

  testthat::expect_error(formCssDesign(css_results=css_res, min_num_clusts=0,
                                       newx=x_new),
                         "min_num_clusts >= 1 is not TRUE", fixed=TRUE)

  testthat::expect_error(formCssDesign(css_results=css_res, min_num_clusts=6,
                                       newx=x_new),
                         "min_num_clusts <= n_clusters is not TRUE", fixed=TRUE)


  testthat::expect_error(formCssDesign(css_results=css_res, max_num_clusts="4",
                                       newx=x_new),
                         "is.numeric(max_num_clusts) | is.integer(max_num_clusts) is not TRUE",
                         fixed=TRUE)

  testthat::expect_error(formCssDesign(css_results=css_res, max_num_clusts=3.5,
                                       newx=x_new),
                         "max_num_clusts == round(max_num_clusts) is not TRUE",
                         fixed=TRUE)

  testthat::expect_error(formCssDesign(css_results=css_res, min_num_clusts=2,
                                       max_num_clusts=1, newx=x_new),
                         "max_num_clusts >= min_num_clusts is not TRUE",
                         fixed=TRUE)

  testthat::expect_error(formCssDesign(css_results=css_res, max_num_clusts=8,
                                       newx=x_new),
                         "max_num_clusts <= p is not TRUE", fixed=TRUE)

  
})
## Test passed 🥇

Finally, tests for getCssDesign()

testthat::test_that("getCssDesign works", {
  set.seed(23170)

  x_select <- matrix(stats::rnorm(10*6), nrow=10, ncol=6)
  x_new <- matrix(stats::rnorm(8*6), nrow=8, ncol=6)
  y_select <- stats::rnorm(10)
  y_new <- stats::rnorm(8)

  good_clusters <- list("red"=1:2, "blue"=3:4, "green"=5)

  css_res <- css(X=x_select, y=y_select, lambda=0.01, clusters=good_clusters,
                 B = 10)

  res <- getCssDesign(css_res, newX=x_new)

  testthat::expect_true(is.matrix(res))
  testthat::expect_true(is.numeric(res))
  testthat::expect_equal(nrow(res), 8)
  testthat::expect_equal(ncol(res), length(css_res$clusters))
  testthat::expect_true(all(colnames(res) %in% names(css_res$clusters)))
  testthat::expect_true(all(names(css_res$clusters) %in% colnames(res)))
  
  # Add training indices
  css_res_train <- css(X=x_select, y=y_select, lambda=0.01,
                       clusters=good_clusters, B=10, train_inds=6:10)

  # Training indices should be ignored if new x is provided

  res <- getCssDesign(css_results=css_res_train, weighting="weighted_avg",
                      min_num_clusts=2, newX=x_new)

  testthat::expect_true(is.matrix(res))
  testthat::expect_true(is.numeric(res))
  testthat::expect_equal(nrow(res), 8)
  testthat::expect_equal(ncol(res), length(css_res_train$clusters))
  testthat::expect_true(all(colnames(res) %in% names(css_res_train$clusters)))
  testthat::expect_true(all(names(css_res_train$clusters) %in% colnames(res)))

  # Things should still work if new x is not provided

  res <- getCssDesign(css_results=css_res_train, min_num_clusts=2)

  testthat::expect_true(is.matrix(res))
  testthat::expect_true(is.numeric(res))
  testthat::expect_equal(nrow(res), 5)
  testthat::expect_equal(ncol(res), length(css_res_train$clusters))
  testthat::expect_true(all(colnames(res) %in% names(css_res_train$clusters)))
  testthat::expect_true(all(names(css_res_train$clusters) %in% colnames(res)))

  # Try not providing training indices and omitting newX--should get error
  testthat::expect_error(getCssDesign(css_results=css_res, weighting="sparse",
                                       cutoff=0.5, min_num_clusts=1,
                                       max_num_clusts=5, newX=NA),
                         "css was not provided with indices to set aside for model training (train_inds), so must provide new X in order to generate a design matrix", fixed=TRUE)

  # Try naming variables

  colnames(x_select) <- LETTERS[1:6]
  css_res_named <- css(X=x_select, y=y_select, lambda=0.01,
                       clusters=good_clusters, B = 10)

  # Named variables for css matrix but not new one--should get a warning
  testthat::expect_warning(getCssDesign(css_results=css_res_named,
                                         weighting="simple_avg", cutoff=0.9,
                                         min_num_clusts=1, max_num_clusts=3,
                                         newX=x_new),
                           "New X provided had no variable names (column names) even though the X provided to css did.", fixed=TRUE)

  # Try mismatching variable names
  colnames(x_new) <- LETTERS[2:7]
  testthat::expect_error(getCssDesign(css_results=css_res_named,
                                      weighting="weighted_avg", cutoff=0.2,
                                      min_num_clusts=1, max_num_clusts=1,
                                      newX=x_new),
                         "identical(feat_names, colnames(css_X)) is not TRUE",
                         fixed=TRUE)

  colnames(x_new) <- LETTERS[1:6]

  res_named <- getCssDesign(css_results=css_res_named, weighting="sparse",
                            cutoff=0.5, min_num_clusts=2, max_num_clusts=NA,
                            newX=x_new)

  testthat::expect_true(is.matrix(res_named))
  testthat::expect_true(is.numeric(res_named))
  testthat::expect_equal(nrow(res_named), 8)
  testthat::expect_true(ncol(res_named) <= length(css_res_named$clusters))
  testthat::expect_true(all(colnames(res_named) %in% names(css_res_named$clusters)))

  # Try data.frame input to css and getCssDesign

  X_df <- datasets::mtcars

  n <- nrow(X_df)
  y <- stats::rnorm(n)

  selec_inds <- 1:round(n/2)
  fit_inds <- setdiff(1:n, selec_inds)

  css_res_df <- css(X=X_df[selec_inds, ], y=y[selec_inds], lambda=0.01, B = 10)
  res_df <- getCssDesign(css_results=css_res_df, weighting="simple_avg",
                          cutoff=0.7, min_num_clusts=3, max_num_clusts=NA,
                          newX=X_df[fit_inds, ])

  testthat::expect_true(is.matrix(res_df))
  testthat::expect_true(is.numeric(res_df))
  testthat::expect_equal(nrow(res_df), length(fit_inds))
  testthat::expect_true(ncol(res_df) <= length(css_res_df$clusters))
  testthat::expect_true(all(colnames(res_df) %in% names(css_res_df$clusters)))

  # Try again with X as a dataframe with factors (number of columns of final
  # design matrix after one-hot encoding factors won't match number of columns
  # of X_df)
  X_df$cyl <- as.factor(X_df$cyl)
  X_df$vs <- as.factor(X_df$vs)
  X_df$am <- as.factor(X_df$am)
  X_df$gear <- as.factor(X_df$gear)
  X_df$carb <- as.factor(X_df$carb)

  css_res_df <- css(X=X_df[selec_inds, ], y=y[selec_inds], lambda=0.01, B = 10)
  res_df <- getCssDesign(css_results=css_res_df, weighting="weighted_avg",
                         cutoff=0.3, min_num_clusts=1, max_num_clusts=4,
                         newX=X_df[fit_inds, ])

  testthat::expect_true(is.matrix(res_df))
  testthat::expect_true(is.numeric(res_df))
  testthat::expect_equal(nrow(res_df), length(fit_inds))
  testthat::expect_true(ncol(res_df) <= length(css_res_df$clusters))
  testthat::expect_true(all(colnames(res_df) %in% names(css_res_df$clusters)))

  ##### Try other bad inputs

  colnames(x_new) <- NULL

  testthat::expect_error(getCssDesign(css_results=css_res, cutoff=-0.3,
                                       newX=x_new), "cutoff >= 0 is not TRUE",
                         fixed=TRUE)

  testthat::expect_error(getCssDesign(css_results=css_res, cutoff="0.5",
                                       newX=x_new),
                         "is.numeric(cutoff) | is.integer(cutoff) is not TRUE",
                        fixed=TRUE)

  testthat::expect_error(getCssDesign(css_results=css_res,
                                       cutoff=as.numeric(NA), newX=x_new),
                        "!is.na(cutoff) is not TRUE", fixed=TRUE)

  testthat::expect_error(getCssDesign(css_results=css_res,
                                       weighting=c("sparse", "simple_avg"),
                                       newX=x_new),
                         "length(weighting) == 1 is not TRUE", fixed=TRUE)

  testthat::expect_error(getCssDesign(css_results=css_res, weighting=1,
                                       newX=x_new),
                         "Weighting must be a character", fixed=TRUE)

  testthat::expect_error(getCssDesign(css_results=css_res, weighting="spasre",
                                       newX=x_new),
                         "Weighting must be a character and one of sparse, simple_avg, or weighted_avg",
                         fixed=TRUE)

  testthat::expect_error(getCssDesign(css_results=css_res,
                                       min_num_clusts=c(1, 2), newX=x_new),
                         "length(min_num_clusts) == 1 is not TRUE", fixed=TRUE)

  testthat::expect_error(getCssDesign(css_results=css_res, min_num_clusts="3",
                                       newX=x_new),
                         "is.numeric(min_num_clusts) | is.integer(min_num_clusts) is not TRUE",
                         fixed=TRUE)

  testthat::expect_error(getCssDesign(css_results=css_res, min_num_clusts=0,
                                       newX=x_new),
                         "min_num_clusts >= 1 is not TRUE", fixed=TRUE)

  testthat::expect_error(getCssDesign(css_results=css_res, min_num_clusts=6,
                                       newX=x_new),
                         "min_num_clusts <= n_clusters is not TRUE", fixed=TRUE)


  testthat::expect_error(getCssDesign(css_results=css_res, max_num_clusts="4",
                                       newX=x_new),
                         "is.numeric(max_num_clusts) | is.integer(max_num_clusts) is not TRUE",
                         fixed=TRUE)

  testthat::expect_error(getCssDesign(css_results=css_res, max_num_clusts=3.5,
                                       newX=x_new),
                         "max_num_clusts == round(max_num_clusts) is not TRUE",
                         fixed=TRUE)

  testthat::expect_error(getCssDesign(css_results=css_res, min_num_clusts=2,
                                       max_num_clusts=1, newX=x_new),
                         "max_num_clusts >= min_num_clusts is not TRUE",
                         fixed=TRUE)

  testthat::expect_error(getCssDesign(css_results=css_res, max_num_clusts=8,
                                       newX=x_new),
                         "max_num_clusts <= p is not TRUE", fixed=TRUE)

  
})
## ── Warning ('<text>:63'): getCssDesign works ───────────────────────────────────
## New X provided had no variable names (column names) even though the X provided to css did.
## Backtrace:
##  1. testthat::expect_warning(...)
##  7. litr (local) getCssDesign(...)
##  8. litr (local) checkXInputResults(newX, css_results$X)
## 
## ── Warning ('<text>:63'): getCssDesign works ───────────────────────────────────
## New X provided had no variable names (column names) even though the X provided to css did.
## Backtrace:
##   1. testthat::expect_warning(...)
##   7. litr (local) getCssDesign(...)
##   8. litr (local) formCssDesign(...)
##   9. litr (local) checkFormCssDesignInputs(...)
##  10. litr (local) checkXInputResults(newx, css_results$X)
## 
## ── Warning ('<text>:80'): getCssDesign works ───────────────────────────────────
## New X provided had no variable names (column names) even though the X provided to css did.
## Backtrace:
##  1. litr (local) getCssDesign(...)
##  2. litr (local) checkXInputResults(newX, css_results$X)
## 
## ── Warning ('<text>:80'): getCssDesign works ───────────────────────────────────
## New X provided had no variable names (column names) even though the X provided to css did.
## Backtrace:
##  1. litr (local) getCssDesign(...)
##  2. litr (local) formCssDesign(...)
##  3. litr (local) checkFormCssDesignInputs(...)
##  4. litr (local) checkXInputResults(newx, css_results$X)
## 
## ── Warning ('<text>:101'): getCssDesign works ──────────────────────────────────
## New X provided had no variable names (column names) even though the X provided to css did.
## Backtrace:
##  1. litr (local) getCssDesign(...)
##  2. litr (local) checkXInputResults(newX, css_results$X)
## 
## ── Warning ('<text>:101'): getCssDesign works ──────────────────────────────────
## New X provided had no variable names (column names) even though the X provided to css did.
## Backtrace:
##  1. litr (local) getCssDesign(...)
##  2. litr (local) formCssDesign(...)
##  3. litr (local) checkFormCssDesignInputs(...)
##  4. litr (local) checkXInputResults(newx, css_results$X)
## 
## ── Warning ('<text>:121'): getCssDesign works ──────────────────────────────────
## New X provided had no variable names (column names) even though the X provided to css did.
## Backtrace:
##  1. litr (local) getCssDesign(...)
##  2. litr (local) checkXInputResults(newX, css_results$X)
## 
## ── Warning ('<text>:121'): getCssDesign works ──────────────────────────────────
## New X provided had no variable names (column names) even though the X provided to css did.
## Backtrace:
##  1. litr (local) getCssDesign(...)
##  2. litr (local) formCssDesign(...)
##  3. litr (local) checkFormCssDesignInputs(...)
##  4. litr (local) checkXInputResults(newx, css_results$X)

getCssPreds()

#' Fit model and generate predictions from new data
#'
#' Generate predictions on test data using cluster stability-selected model.
#' @param css_results An object of class "cssr" (the output of the function
#' css).
#' @param testX A numeric matrix (preferably) or a data.frame (which will
#' be coerced internally to a matrix by the function model.matrix) containing
#' the data that will be used to generate predictions. Must contain the same
#' features (in the same number of columns) as the matrix provided to css, and
#' if the columns of testX are labeled, the names must match the variable names
#' provided to css.
#' @param weighting Character; determines how to calculate the weights to
#' combine features from the selected clusters into weighted averages, called
#' cluster representatives. Must be one of "sparse", "weighted_avg", or
#' "simple_avg'. For "sparse", all the weight is put on the most frequently
#' selected individual cluster member (or divided equally among all the clusters
#' that are tied for the top selection proportion if there is a tie). For
#' "weighted_avg", the weight used for each cluster member is calculated in
#' proportion to the individual selection proportions of each feature. For
#' "simple_avg", each cluster member gets equal weight regardless of the
#' individual feature selection proportions (that is, the cluster representative
#' is just a simple average of all the cluster members). See Faletto and Bien
#' (2022) for details. Default is "weighted_avg".
#' @param cutoff Numeric; getCssPreds will make use only of those clusters with
#' selection proportions equal to at least cutoff. Must be between 0 and 1.
#' Default is 0 (in which case either all clusters are used, or max_num_clusts
#' are used, if max_num_clusts is specified).
#' @param min_num_clusts Integer or numeric; the minimum number of clusters to
#' use regardless of cutoff. (That is, if the chosen cutoff returns fewer than
#' min_num_clusts clusters, the cutoff will be increased until at least
#' min_num_clusts clusters are selected.) Default is 1.
#' @param max_num_clusts Integer or numeric; the maximum number of clusters to
#' use regardless of cutoff. (That is, if the chosen cutoff returns more than
#' max_num_clusts clusters, the cutoff will be decreased until at most
#' max_num_clusts clusters are selected.) Default is NA (in which case
#' max_num_clusts is ignored).
#' @param trainX A numeric matrix (preferably) or a data.frame (which will
#' be coerced internally to a matrix by the function model.matrix) containing
#' the data that will be used to estimate the linear model from the selected
#' clusters. trainX is only necessary to provide if no train_inds were
#' designated in the css function call to set aside observations for model
#' estimation (though even if train_inds was provided, trainX and trianY will be
#' used for model estimation if they are both provided to getCssPreds). Must 
#' contain the same features (in the same number of columns) as the matrix 
#' provided to css, and if the columns of trainX are labeled, the names must
#' match the variable names provided to css. Default is NA (in which case
#' getCssPreds uses the observations from the train_inds that were provided to
#' css to estimate a linear model).
#' @param trainY The response corresponding to trainX. Must be a real-valued
#' response (unlike in the general css setup) because predictions will be
#' generated by an ordinary least squares model. Must have the same length as
#' the number of rows of trainX. Like trainX, only needs to be provided if no
#' observations were set aside for model estimation by the parameter train_inds
#' in the css function call. Default is NA (in which case getCssPreds uses the
#' observations from the train_inds that were provided to css).
#' @return A vector of predictions corresponding to the observations from testX.
#' @author Gregory Faletto, Jacob Bien
#' @references 
<<faletto2022>>
#' @export
getCssPreds <- function(css_results, testX, weighting="weighted_avg", cutoff=0,
    min_num_clusts=1, max_num_clusts=NA, trainX=NA, trainY=NA){
    # TODO(gregfaletto) Consider adding an argument for a user-provided prediction
    # function in order to allow for more general kinds of predictions than
    # OLS.

    # Check inputs
    
    check_list <- checkGetCssPredsInputs(css_results, testX, weighting, cutoff,
        min_num_clusts, max_num_clusts, trainX, trainY)

    trainXProvided <- check_list$trainXProvided
    trainX <- check_list$trainX
    testX <- check_list$testX
    feat_names <- check_list$feat_names
    max_num_clusts <- check_list$max_num_clusts

    rm(check_list)

    n_train <- nrow(trainX)
    n <- nrow(testX)
    p <- ncol(testX)

    # Take provided training design matrix and testX and turn them into
    # matrices of cluster representatives using information from css_results
    if(trainXProvided){
        train_X_clusters <- formCssDesign(css_results, weighting, cutoff,
            min_num_clusts, max_num_clusts, newx=trainX)
        if(!is.numeric(trainY) & !is.integer(trainY)){
            stop("The provided trainY must be real-valued, because predictions will be generated by ordinary least squares regression.")
        }
        y_train <- trainY
    } else{
        train_X_clusters <- formCssDesign(css_results, weighting, cutoff,
            min_num_clusts, max_num_clusts)
        y_train <- css_results$y[css_results$train_inds]
        if(!is.numeric(y_train) & !is.integer(y_train)){
            stop("Can't generated predictions from the data that was provided to css because the provided y was not real-valued (getCssPreds generated predictions using ordinary least squares regression).")
        }
    }

    stopifnot(length(y_train) == nrow(train_X_clusters))

    testX_clusters <- formCssDesign(css_results, weighting, cutoff,
        min_num_clusts, max_num_clusts, newx=testX)

    stopifnot(ncol(testX_clusters) == ncol(train_X_clusters))

    # Get names for clusters
    clust_X_names <- paste("c_fit_", 1:ncol(testX_clusters), sep="")
    if(!is.null(colnames(train_X_clusters))){
        stopifnot(identical(colnames(train_X_clusters), colnames(testX_clusters)))
        clust_X_names <- colnames(train_X_clusters)
    }

    # Fit linear model on training data via OLS
    if(nrow(train_X_clusters) < ncol(train_X_clusters)){
        err_mess <- paste("css not provided with enough indices to fit OLS model for predictions (number of training indices: ",
            nrow(train_X_clusters), ", number of clusters: ",
            ncol(train_X_clusters),
            "). Try reducing number of clusters by increasing cutoff, or re-run css with a larger number of training indices.",
            sep="")
        stop(err_mess)
    }

    df <- data.frame(y=y_train, train_X_clusters)
    colnames(df)[2:ncol(df)] <- clust_X_names
    model <- stats::lm(y ~., data=df)

    # Use fitted model to generate predictions on testX
    df_test <- data.frame(testX_clusters)
    colnames(df_test) <- clust_X_names
    predictions <- stats::predict.lm(model, newdata=df_test)
    names(predictions) <- NULL

    # Check output
    stopifnot(is.numeric(predictions) | is.integer(predictions))
    stopifnot(length(predictions) == n)
    stopifnot(all(!is.na(predictions)))

    return(predictions)
}

checkGetCssPredsInputs():

#' Helper function to confirm that inputs to the function getCssPreds are as
#' expected, and modify inputs if needed.
#'
#' @param css_results An object of class "cssr" (the output of the function
#' css).
#' @param testX A numeric matrix (preferably) or a data.frame (which will
#' be coerced internally to a matrix by the function model.matrix) containing
#' the data that will be used to generate predictions. Must contain the same
#' features (in the same number of columns) as the matrix provided to css.
#' @param weighting Character; determines how to calculate the weights to
#' combine features from the selected clusters into weighted averages, called
#' cluster representatives. Must be one of "sparse", "weighted_avg", or
#' "simple_avg'. For "sparse", all the weight is put on the most frequently
#' selected individual cluster member (or divided equally among all the clusters
#' that are tied for the top selection proportion if there is a tie). For
#' "weighted_avg", the weight used for each cluster member is calculated in
#' proportion to the individual selection proportions of each feature. For
#' "simple_avg", each cluster member gets equal weight regardless of the
#' individual feature selection proportions (that is, the cluster representative
#' is just a simple average of all the cluster members). See Faletto and Bien
#' (2022) for details. Default is "weighted_avg".
#' @param cutoff Numeric; getCssPreds will make use only of those clusters with
#' selection proportions equal to at least cutoff. Must be between 0 and 1.
#' Default is 0 (in which case either all clusters are used, or max_num_clusts
#' are used, if max_num_clusts is specified).
#' @param min_num_clusts Integer or numeric; the minimum number of clusters to
#' use regardless of cutoff. (That is, if the chosen cutoff returns fewer than
#' min_num_clusts clusters, the cutoff will be increased until at least
#' min_num_clusts clusters are selected.) Default is 1.
#' @param max_num_clusts Integer or numeric; the maximum number of clusters to
#' use regardless of cutoff. (That is, if the chosen cutoff returns more than
#' max_num_clusts clusters, the cutoff will be decreased until at most
#' max_num_clusts clusters are selected.) Default is NA (in which case
#' max_num_clusts is ignored).
#' @param trainX A numeric matrix (preferably) or a data.frame (which will
#' be coerced internally to a matrix by the function model.matrix) containing
#' the data that will be used to estimate the linear model from the selected
#' clusters. trainX is only necessary to provide if no train_inds were
#' designated in the css function call to set aside observations for model
#' estimation (though even if train_inds was provided, trainX and trianY will be
#' used for model estimation if they are both provided to getCssPreds). Must 
#' contain the same features (in the same number of columns) as the matrix 
#' provided to css, and if the columns of trainX are labeled, the names must
#' match the variable names provided to css. Default is NA (in which case
#' getCssPreds uses the observations from the train_inds that were provided to
#' css to estimate a linear model).
#' @param trainY The response corresponding to trainX. Must be a real-valued
#' response (unlike in the general css setup) because predictions will be
#' generated by an ordinary least squares model. Must have the same length as
#' the number of rows of trainX. Like trainX, only needs to be provided if no
#' observations were set aside for model estimation by the parameter train_inds
#' in the css function call. Default is NA (in which case getCssPreds uses the
#' observations from the train_inds that were provided to css).
#' @return A named list with the following elements: \item{trainXProvided}{
#' Logical; indicates whether a valid trainX input was provided.} \item{trainX}{
#' The provided trainX matrix, coerced from a data.frame to a matrix if the
#' provided trainX was a data.frame. (If a valid trainX was not provided, this
#' output simply passes whatever was provided as trainX.)} \item{testX}{The
#' provided testX matrix, coerced from a data.frame to a matrix if the provided
#' testX was a data.frame.} \item{feat_names}{A character vector containing the
#' column names of testX (if the provided testX had column names). If the
#' provided testX did not have column names, feat_names will be NA.}
#' \item{max_num_clusts}{The provided max_num_clusts, coerced to an integer if
#' needed, and coerced to be less than or equal to the total number of clusters
#' from the output of css_results.}
#' @author Gregory Faletto, Jacob Bien
checkGetCssPredsInputs <- function(css_results, testX, weighting, cutoff,
    min_num_clusts, max_num_clusts, trainX, trainY){
    # Check inputs
    stopifnot(class(css_results) == "cssr")

    check_results <- checkNewXProvided(trainX, css_results)

    trainX <- check_results$newX
    trainXProvided <- check_results$newXProvided

    rm(check_results)

    n_train <- nrow(trainX)

    if(trainXProvided){
        if(all(!is.na(trainY)) & length(trainY) > 1){
            stopifnot(is.numeric(trainY))
            stopifnot(n_train == length(trainY))
        } else{
            if(length(css_results$train_inds) == 0){
                stop("css was not provided with indices to set aside for model training (train_inds), so must provide both trainX and trainY in order to generate predictions")
            }
            trainXProvided <- FALSE
            warning("trainX provided but no trainY provided; instead, training model using the train_inds observations provided to css to set aside for model training.")
        }
    } else{
        if(length(css_results$train_inds) == 0){
            stop("css was not provided with indices to set aside for model training (train_inds), so must provide both trainX and trainY in order to generate predictions")
        }
        if(all(!is.na(trainY)) & length(trainY) > 1){
            warning("trainY provided but no trainX provided; instead, training model using the train_inds observations provided to css to set aside for model training.")
        }
    }

    results <- checkXInputResults(testX, css_results$X)

    testX <- results$newx
    feat_names <- results$feat_names

    if(all(!is.na(feat_names))){
        stopifnot(length(feat_names) == ncol(testX))
        stopifnot(!("(Intercept)" %in% feat_names))
        colnames(testX) <- feat_names
    }

    rm(results)

    n <- nrow(testX)
    p <- ncol(testX)

    stopifnot(n >= 1)
    stopifnot(p == ncol(trainX))
    if(!is.null(colnames(trainX)) & is.null(colnames(testX))){
        warning("Column names were provided for trainX but not for testX (are you sure they both contain identical features in the same order?)")
    }
    if(is.null(colnames(trainX)) & !is.null(colnames(testX))){
        warning("Column names were provided for testX but not for trainX (are you sure they both contain identical features in the same order?)")
    }
    if(!is.null(colnames(trainX)) & !is.null(colnames(testX))){
        stopifnot(all(colnames(trainX) == colnames(testX)))
    }

    checkCutoff(cutoff)
    checkWeighting(weighting)
    checkMinNumClusts(min_num_clusts, p, length(css_results$clusters))
    max_num_clusts <- checkMaxNumClusts(max_num_clusts, min_num_clusts, p,
        length(css_results$clusters))

    return(list(trainXProvided=trainXProvided, trainX=trainX, testX=testX,
        feat_names=feat_names, max_num_clusts=max_num_clusts))

}

Tests for checkGetCssPredsInputs()

testthat::test_that("checkGetCssPredsInputs works", {
  set.seed(17081)

  x_select <- matrix(stats::rnorm(10*6), nrow=10, ncol=6)
  x_train <- matrix(stats::rnorm(8*6), nrow=8, ncol=6)
  x_pred <- matrix(stats::rnorm(7*6), nrow=7, ncol=6)
  y_select <- stats::rnorm(10)
  y_train <- stats::rnorm(8)

  good_clusters <- list("red"=1:2, "blue"=3:4, "green"=5)

  css_res <- css(X=x_select, y=y_select, lambda=0.01, clusters=good_clusters,
                 B = 10)

  res <- checkGetCssPredsInputs(css_res, testX=x_pred, weighting="simple_avg",
                                cutoff=0.05, min_num_clusts=1,
                                max_num_clusts=NA, trainX=x_train,
                                trainY=y_train)

  testthat::expect_true(is.list(res))
  testthat::expect_identical(names(res), c("trainXProvided", "trainX", "testX",
                                           "feat_names", "max_num_clusts"))
  
  testthat::expect_true(!is.na(res$trainXProvided))
  testthat::expect_equal(length(res$trainXProvided), 1)
  testthat::expect_true(is.logical(res$trainXProvided))
  testthat::expect_true(res$trainXProvided)
  
  testthat::expect_true(all(!is.na(res$trainX)))
  testthat::expect_true(is.matrix(res$trainX))
  testthat::expect_true(is.numeric(res$trainX))
  testthat::expect_equal(nrow(res$trainX), 8)
  testthat::expect_equal(ncol(res$trainX), 6)
  testthat::expect_true(all(abs(x_train - res$trainX) < 10^(-9)))
  
  testthat::expect_true(all(!is.na(res$testX)))
  testthat::expect_true(is.matrix(res$testX))
  testthat::expect_true(is.numeric(res$testX))
  testthat::expect_equal(nrow(res$testX), 7)
  testthat::expect_equal(ncol(res$testX), 6)
  testthat::expect_true(all(abs(x_pred - res$testX) < 10^(-9)))
  
  testthat::expect_true(is.character(res$feat_names))
  testthat::expect_true(is.na(res$feat_names))
  
  testthat::expect_true(is.na(res$max_num_clusts))
  testthat::expect_true(length(res$max_num_clusts) == 1)
  
  ##### Try other bad inputs
  
  testthat::expect_error(checkGetCssPredsInputs(css_res, testX=x_pred,
                                                weighting="weighted_avg",
                                                cutoff=-0.5, min_num_clusts=1,
                                                max_num_clusts=NA,
                                                trainX=x_train, trainY=y_train),
                         "cutoff >= 0 is not TRUE", fixed=TRUE)
  
  testthat::expect_error(checkGetCssPredsInputs(css_res, testX=x_pred,
                                                weighting="sparse",
                                                cutoff="0.3", min_num_clusts=1,
                                                max_num_clusts=NA,
                                                trainX=x_train, trainY=y_train),
                         "is.numeric(cutoff) | is.integer(cutoff) is not TRUE",
                        fixed=TRUE)

  testthat::expect_error(checkGetCssPredsInputs(css_res, testX=x_pred,
                                                weighting="sparse",
                                                cutoff=as.numeric(NA),
                                                min_num_clusts=1,
                                                max_num_clusts=NA,
                                                trainX=x_train, trainY=y_train),
                        "!is.na(cutoff) is not TRUE", fixed=TRUE)

  testthat::expect_error(checkGetCssPredsInputs(css_res, testX=x_pred,
                                                weighting=c("sparse",
                                                            "simple_avg"),
                                                cutoff=0.1,
                                                min_num_clusts=1,
                                                max_num_clusts=NA,
                                                trainX=x_train, trainY=y_train),
                         "length(weighting) == 1 is not TRUE", fixed=TRUE)

  testthat::expect_error(checkGetCssPredsInputs(css_res, testX=x_pred,
                                                weighting=2, cutoff=0.1,
                                                min_num_clusts=1,
                                                max_num_clusts=NA,
                                                trainX=x_train, trainY=y_train),
                         "Weighting must be a character", fixed=TRUE)

  testthat::expect_error(checkGetCssPredsInputs(css_res, testX=x_pred,
                                                weighting="spasre", cutoff=0.1,
                                                min_num_clusts=1,
                                                max_num_clusts=NA,
                                                trainX=x_train, trainY=y_train),
                         "Weighting must be a character and one of sparse, simple_avg, or weighted_avg",
                         fixed=TRUE)

  testthat::expect_error(checkGetCssPredsInputs(css_res, testX=x_pred,
                                                weighting="sparse", cutoff=0.1,
                                                min_num_clusts=c(1, 2),
                                                max_num_clusts=NA,
                                                trainX=x_train, trainY=y_train),
                         "length(min_num_clusts) == 1 is not TRUE", fixed=TRUE)

  testthat::expect_error(checkGetCssPredsInputs(css_res, testX=x_pred,
                                                weighting="weighted_avg",
                                                cutoff=0.1, min_num_clusts="2",
                                                max_num_clusts=NA,
                                                trainX=x_train, trainY=y_train),
                         "is.numeric(min_num_clusts) | is.integer(min_num_clusts) is not TRUE",
                         fixed=TRUE)

  testthat::expect_error(checkGetCssPredsInputs(css_res, testX=x_pred,
                                                weighting="simple_avg",
                                                cutoff=0.1, min_num_clusts=0,
                                                max_num_clusts=NA,
                                                trainX=x_train, trainY=y_train),
                         "min_num_clusts >= 1 is not TRUE", fixed=TRUE)

  testthat::expect_error(checkGetCssPredsInputs(css_res, testX=x_pred,
                                                weighting="weighted_avg",
                                                cutoff=0.1, min_num_clusts=10,
                                                max_num_clusts=NA,
                                                trainX=x_train, trainY=y_train),
                         "min_num_clusts <= p is not TRUE", fixed=TRUE)


  testthat::expect_error(checkGetCssPredsInputs(css_res, testX=x_pred,
                                                weighting="simple_avg",
                                                cutoff=0.1, min_num_clusts=1,
                                                max_num_clusts="5",
                                                trainX=x_train, trainY=y_train),
                         "is.numeric(max_num_clusts) | is.integer(max_num_clusts) is not TRUE",
                         fixed=TRUE)

  testthat::expect_error(checkGetCssPredsInputs(css_res, testX=x_pred,
                                                weighting="sparse",
                                                cutoff=0.1, min_num_clusts=1,
                                                max_num_clusts=4.5,
                                                trainX=x_train, trainY=y_train),
                         "max_num_clusts == round(max_num_clusts) is not TRUE",
                         fixed=TRUE)

  testthat::expect_error(checkGetCssPredsInputs(css_res, testX=x_pred,
                                                weighting="sparse",
                                                cutoff=0.1, min_num_clusts=3,
                                                max_num_clusts=2,
                                                trainX=x_train, trainY=y_train),
                         "max_num_clusts >= min_num_clusts is not TRUE",
                         fixed=TRUE)

  testthat::expect_error(checkGetCssPredsInputs(css_res, testX=x_pred,
                                                weighting="sparse",
                                                cutoff=0.1, min_num_clusts=1,
                                                max_num_clusts=10,
                                                trainX=x_train, trainY=y_train),
                         "max_num_clusts <= p is not TRUE", fixed=TRUE)

  # Add training indices
  css_res_train <- css(X=x_select, y=y_select, lambda=0.01,
                       clusters=good_clusters, B=10, train_inds=6:10)

  # Training indices should be ignored if new x is provided

  res <- checkGetCssPredsInputs(css_res_train, testX=x_pred,
                                weighting="weighted_avg",
                                cutoff=0, min_num_clusts=1,
                                max_num_clusts=NA, trainX=x_train,
                                trainY=y_train)

  testthat::expect_true(is.list(res))
  testthat::expect_identical(names(res), c("trainXProvided", "trainX", "testX",
                                           "feat_names", "max_num_clusts"))

  testthat::expect_true(!is.na(res$trainXProvided))
  testthat::expect_equal(length(res$trainXProvided), 1)
  testthat::expect_true(is.logical(res$trainXProvided))
  testthat::expect_true(res$trainXProvided)

  testthat::expect_true(all(!is.na(res$trainX)))
  testthat::expect_true(is.matrix(res$trainX))
  testthat::expect_true(is.numeric(res$trainX))
  testthat::expect_equal(nrow(res$trainX), 8)
  testthat::expect_equal(ncol(res$trainX), 6)
  testthat::expect_true(all(abs(x_train - res$trainX) < 10^(-9)))

  testthat::expect_true(all(!is.na(res$testX)))
  testthat::expect_true(is.matrix(res$testX))
  testthat::expect_true(is.numeric(res$testX))
  testthat::expect_equal(nrow(res$testX), 7)
  testthat::expect_equal(ncol(res$testX), 6)
  testthat::expect_true(all(abs(x_pred - res$testX) < 10^(-9)))

  testthat::expect_true(is.character(res$feat_names))
  testthat::expect_true(is.na(res$feat_names))

  testthat::expect_true(is.na(res$max_num_clusts))
  testthat::expect_true(length(res$max_num_clusts) == 1)

  # Things should still work if new x is not provided
  
  res <- checkGetCssPredsInputs(css_res_train, testX=x_pred,
                                weighting="weighted_avg",
                                cutoff=0, min_num_clusts=1,
                                max_num_clusts=NA, trainX=NA, trainY=NA)

  testthat::expect_true(is.list(res))
  testthat::expect_identical(names(res), c("trainXProvided", "trainX", "testX",
                                           "feat_names", "max_num_clusts"))

  testthat::expect_true(!is.na(res$trainXProvided))
  testthat::expect_equal(length(res$trainXProvided), 1)
  testthat::expect_true(is.logical(res$trainXProvided))
  testthat::expect_true(!res$trainXProvided)

  testthat::expect_true(all(!is.na(res$trainX)))
  testthat::expect_true(is.matrix(res$trainX))
  testthat::expect_true(is.numeric(res$trainX))
  testthat::expect_equal(nrow(res$trainX), 5)
  testthat::expect_equal(ncol(res$trainX), 6)
  testthat::expect_true(all(abs(x_select[6:10, ] - res$trainX) < 10^(-9)))

  testthat::expect_true(all(!is.na(res$testX)))
  testthat::expect_true(is.matrix(res$testX))
  testthat::expect_true(is.numeric(res$testX))
  testthat::expect_equal(nrow(res$testX), 7)
  testthat::expect_equal(ncol(res$testX), 6)
  testthat::expect_true(all(abs(x_pred - res$testX) < 10^(-9)))

  testthat::expect_true(is.character(res$feat_names))
  testthat::expect_true(is.na(res$feat_names))

  testthat::expect_true(is.na(res$max_num_clusts))
  testthat::expect_true(length(res$max_num_clusts) == 1)


  # Try not providing training indices and omitting newX--should get error
  testthat::expect_error(checkGetCssPredsInputs(css_res, testX=x_pred,
                                weighting="sparse",
                                cutoff=0, min_num_clusts=1,
                                max_num_clusts=NA, trainX=NA, trainY=NA),
                         "css was not provided with indices to set aside for model training (train_inds), so must provide new X in order to generate a design matrix", fixed=TRUE)

  # Try naming variables

  colnames(x_select) <- LETTERS[1:6]
  css_res_named <- css(X=x_select, y=y_select, lambda=0.01,
                       clusters=good_clusters, B = 10)

  # Named variables for css matrix but not new one--should get a warning
  testthat::expect_warning(checkGetCssPredsInputs(css_res_named, testX=x_pred,
                                weighting="simple_avg", cutoff=0,
                                min_num_clusts=1, max_num_clusts=NA,
                                trainX=x_train, trainY=y_train),
                           "New X provided had no variable names (column names) even though the X provided to css did.", fixed=TRUE)

  # Try mismatching variable names
  colnames(x_train) <- LETTERS[2:7]
  colnames(x_pred) <- LETTERS[1:6]
  testthat::expect_error(checkGetCssPredsInputs(css_res_named, testX=x_pred,
                                weighting="weighted_avg", cutoff=0,
                                min_num_clusts=1, max_num_clusts=NA,
                                trainX=x_train, trainY=y_train),
                         "identical(feat_names, colnames(css_X)) is not TRUE",
                         fixed=TRUE)

  colnames(x_train) <- LETTERS[1:6]
  colnames(x_pred) <- LETTERS[2:7]
  testthat::expect_error(checkGetCssPredsInputs(css_res_named, testX=x_pred,
                                weighting="sparse", cutoff=0,
                                min_num_clusts=1, max_num_clusts=NA,
                                trainX=x_train, trainY=y_train),
                         "identical(feat_names, colnames(css_X)) is not TRUE",
                         fixed=TRUE)

  colnames(x_pred) <- LETTERS[1:6]

  res_named <- checkGetCssPredsInputs(css_res_named, testX=x_pred,
                                      weighting="simple_avg", cutoff=0,
                                      min_num_clusts=1, max_num_clusts=NA,
                                      trainX=x_train, trainY=y_train)

  testthat::expect_true(is.list(res_named))
  testthat::expect_identical(names(res_named), c("trainXProvided", "trainX", "testX",
                                           "feat_names", "max_num_clusts"))

  testthat::expect_true(all(!is.na(res_named$trainX)))
  testthat::expect_true(is.matrix(res_named$trainX))
  testthat::expect_true(is.numeric(res_named$trainX))
  testthat::expect_equal(nrow(res_named$trainX), 8)
  testthat::expect_equal(ncol(res_named$trainX), 6)
  testthat::expect_true(all(abs(x_train - res_named$trainX) < 10^(-9)))
  
  testthat::expect_true(is.character(res_named$feat_names))
  testthat::expect_identical(res_named$feat_names, LETTERS[1:6])

  # Try data.frame input to css and checkGetCssPredsInputs

  X_df <- datasets::mtcars

  n <- nrow(X_df)
  y <- stats::rnorm(n)

  selec_inds <- 1:round(n/3)
  train_inds <- (max(selec_inds) + 1):(2*round(n/3))
  test_inds <- setdiff(1:n, c(selec_inds, train_inds))

  css_res_df <- css(X=X_df[c(selec_inds, train_inds), ],
                    y=y[c(selec_inds, train_inds)], lambda=0.01, B = 10,
                    train_inds=train_inds)
  
  res_df <- checkGetCssPredsInputs(css_res_df, testX=X_df[test_inds, ],
                                   weighting="sparse", cutoff=0,
                                   min_num_clusts=1, max_num_clusts=NA,
                                   trainX=NA, trainY=NA)
  
  testthat::expect_true(is.list(res_df))
  testthat::expect_identical(names(res_df), c("trainXProvided", "trainX",
                                              "testX","feat_names",
                                              "max_num_clusts"))
  
  testthat::expect_true(all(!is.na(res_df$trainX)))
  testthat::expect_true(is.matrix(res_df$trainX))
  testthat::expect_true(is.numeric(res_df$trainX))
  testthat::expect_equal(nrow(res_df$trainX), length(train_inds))
  
  stopifnot(nrow(css_res_df$X) >= max(train_inds))
  train_mat <- css_res_df$X[train_inds, ]

  testthat::expect_equal(ncol(res_df$trainX), ncol(train_mat))
  testthat::expect_true(all(abs(train_mat - res_df$trainX) < 10^(-9)))
  testthat::expect_identical(colnames(res_df$trainX), colnames(train_mat))

  testthat::expect_true(all(!is.na(res_df$testX)))
  testthat::expect_true(is.matrix(res_df$testX))
  testthat::expect_true(is.numeric(res_df$testX))
  testthat::expect_equal(nrow(res_df$testX), length(test_inds))
  
  test_mat <- stats::model.matrix(~ ., X_df[test_inds, ])
  test_mat <- test_mat[, colnames(test_mat) != "(Intercept)"]
  
  testthat::expect_equal(ncol(res_df$testX), ncol(test_mat))
  testthat::expect_true(all(abs(test_mat - res_df$testX) < 10^(-9)))
  testthat::expect_identical(colnames(res_df$testX), colnames(test_mat))
  testthat::expect_identical(colnames(res_df$testX), colnames(res_df$trainX))

  # Try again with X as a dataframe with factors (number of columns of final
  # design matrix after one-hot encoding factors won't match number of columns
  # of X_df)
  X_df$cyl <- as.factor(X_df$cyl)
  X_df$vs <- as.factor(X_df$vs)
  X_df$am <- as.factor(X_df$am)
  X_df$gear <- as.factor(X_df$gear)
  X_df$carb <- as.factor(X_df$carb)

  css_res_df <- css(X=X_df[selec_inds, ], y=y[selec_inds], lambda=0.01, B = 10)
  res_df <- checkGetCssPredsInputs(css_res_df, testX=X_df[test_inds, ],
                                   weighting="simple_avg", cutoff=0.3,
                                   min_num_clusts=1, max_num_clusts=4,
                                   trainX=X_df[train_inds, ],
                                   trainY=y[train_inds])
  
  testthat::expect_true(is.list(res_df))
  testthat::expect_identical(names(res_df), c("trainXProvided", "trainX",
                                              "testX","feat_names",
                                              "max_num_clusts"))
  
  testthat::expect_true(all(!is.na(res_df$trainX)))
  testthat::expect_true(is.matrix(res_df$trainX))
  testthat::expect_true(is.numeric(res_df$trainX))
  testthat::expect_equal(nrow(res_df$trainX), length(train_inds))
  
  train_mat <- stats::model.matrix(~ ., X_df[train_inds, ])
  train_mat <- train_mat[, colnames(train_mat) != "(Intercept)"]

  testthat::expect_equal(ncol(res_df$trainX), ncol(train_mat))
  testthat::expect_true(all(abs(train_mat - res_df$trainX) < 10^(-9)))

  testthat::expect_true(all(!is.na(res_df$testX)))
  testthat::expect_true(is.matrix(res_df$testX))
  testthat::expect_true(is.numeric(res_df$testX))
  testthat::expect_equal(nrow(res_df$testX), length(test_inds))
  
  test_mat <- stats::model.matrix(~ ., X_df[test_inds, ])
  test_mat <- test_mat[, colnames(test_mat) != "(Intercept)"]
  
  testthat::expect_equal(ncol(res_df$testX), ncol(test_mat))
  testthat::expect_true(all(abs(test_mat - res_df$testX) < 10^(-9)))

  
})
## ── Warning ('<text>:251'): checkGetCssPredsInputs works ────────────────────────
## New X provided had no variable names (column names) even though the X provided to css did.
## Backtrace:
##  1. testthat::expect_warning(...)
##  7. litr (local) checkGetCssPredsInputs(...)
##  8. litr (local) checkXInputResults(testX, css_results$X)
## 
## ── Warning ('<text>:278'): checkGetCssPredsInputs works ────────────────────────
## Column names were provided for testX but not for trainX (are you sure they both contain identical features in the same order?)
## Backtrace:
##  1. litr (local) checkGetCssPredsInputs(...)
## 
## ── Warning ('<text>:357'): checkGetCssPredsInputs works ────────────────────────
## Column names were provided for testX but not for trainX (are you sure they both contain identical features in the same order?)
## Backtrace:
##  1. litr (local) checkGetCssPredsInputs(...)

Finally, tests for getCssPreds()

testthat::test_that("getCssPreds works", {
  set.seed(70811)

  x_select <- matrix(stats::rnorm(10*6), nrow=10, ncol=6)
  x_train <- matrix(stats::rnorm(8*6), nrow=8, ncol=6)
  x_pred <- matrix(stats::rnorm(7*6), nrow=7, ncol=6)
  y_select <- stats::rnorm(10)
  y_train <- stats::rnorm(8)

  good_clusters <- list("red"=1:2, "blue"=3:4, "green"=5)

  css_res <- css(X=x_select, y=y_select, lambda=0.01, clusters=good_clusters,
                 B = 10)

  res <- getCssPreds(css_res, testX=x_pred, trainX=x_train, trainY=y_train)

  testthat::expect_true(all(!is.na(res)))
  testthat::expect_true(is.numeric(res))
  testthat::expect_equal(length(res), 7)
  
  ##### Try other bad inputs

  testthat::expect_error(getCssPreds(css_res, testX=x_pred, cutoff=-0.5,
                                     trainX=x_train, trainY=y_train),
                         "cutoff >= 0 is not TRUE", fixed=TRUE)

  testthat::expect_error(getCssPreds(css_res, testX=x_pred, cutoff="0.3",
                                     trainX=x_train, trainY=y_train),
                         "is.numeric(cutoff) | is.integer(cutoff) is not TRUE",
                        fixed=TRUE)

  testthat::expect_error(getCssPreds(css_res, testX=x_pred,
                                     cutoff=as.numeric(NA), trainX=x_train,
                                     trainY=y_train),
                        "!is.na(cutoff) is not TRUE", fixed=TRUE)

  testthat::expect_error(getCssPreds(css_res, testX=x_pred,
                                     weighting=c("sparse", "simple_avg"),
                                     trainX=x_train, trainY=y_train),
                         "length(weighting) == 1 is not TRUE", fixed=TRUE)

  testthat::expect_error(getCssPreds(css_res, testX=x_pred, weighting=2,
                                     trainX=x_train, trainY=y_train),
                         "Weighting must be a character", fixed=TRUE)

  testthat::expect_error(getCssPreds(css_res, testX=x_pred, weighting="spasre",
                                     trainX=x_train, trainY=y_train),
                         "Weighting must be a character and one of sparse, simple_avg, or weighted_avg",
                         fixed=TRUE)

  testthat::expect_error(getCssPreds(css_res, testX=x_pred,
                                     min_num_clusts=c(1, 2), trainX=x_train,
                                     trainY=y_train),
                         "length(min_num_clusts) == 1 is not TRUE", fixed=TRUE)

  testthat::expect_error(getCssPreds(css_res, testX=x_pred, min_num_clusts="2",
                                     trainX=x_train, trainY=y_train),
                         "is.numeric(min_num_clusts) | is.integer(min_num_clusts) is not TRUE",
                         fixed=TRUE)

  testthat::expect_error(getCssPreds(css_res, testX=x_pred, min_num_clusts=0,
                                     trainX=x_train, trainY=y_train),
                         "min_num_clusts >= 1 is not TRUE", fixed=TRUE)

  testthat::expect_error(getCssPreds(css_res, testX=x_pred, min_num_clusts=10,
                                     trainX=x_train, trainY=y_train),
                         "min_num_clusts <= p is not TRUE", fixed=TRUE)


  testthat::expect_error(getCssPreds(css_res, testX=x_pred, max_num_clusts="5",
                                     trainX=x_train, trainY=y_train),
                         "is.numeric(max_num_clusts) | is.integer(max_num_clusts) is not TRUE",
                         fixed=TRUE)

  testthat::expect_error(getCssPreds(css_res, testX=x_pred, max_num_clusts=4.5,
                                     trainX=x_train, trainY=y_train),
                         "max_num_clusts == round(max_num_clusts) is not TRUE",
                         fixed=TRUE)

  testthat::expect_error(getCssPreds(css_res, testX=x_pred, min_num_clusts=3,
                                     max_num_clusts=2, trainX=x_train,
                                     trainY=y_train),
                         "max_num_clusts >= min_num_clusts is not TRUE",
                         fixed=TRUE)

  testthat::expect_error(getCssPreds(css_res, testX=x_pred, max_num_clusts=10,
                                     trainX=x_train, trainY=y_train),
                         "max_num_clusts <= p is not TRUE", fixed=TRUE)

  # Add training indices
  css_res_train <- css(X=x_select, y=y_select, lambda=0.01,
                       clusters=good_clusters, B=10, train_inds=6:10)

  # Training indices should be ignored if new x is provided

  res <- getCssPreds(css_res_train, testX=x_pred, trainX=x_train,
                     trainY=y_train)

  testthat::expect_true(all(!is.na(res)))
  testthat::expect_true(is.numeric(res))
  testthat::expect_equal(length(res), 7)
  
  # Things should still work if new x is not provided

  res <- getCssPreds(css_res_train, testX=x_pred)
  
  testthat::expect_true(all(!is.na(res)))
  testthat::expect_true(is.numeric(res))
  testthat::expect_equal(length(res), 7)

  # Try not providing training indices and omitting newX--should get error
  testthat::expect_error(getCssPreds(css_res, testX=x_pred),
                         "css was not provided with indices to set aside for model training (train_inds), so must provide new X in order to generate a design matrix", fixed=TRUE)

  # Try naming variables

  colnames(x_select) <- LETTERS[1:6]
  css_res_named <- css(X=x_select, y=y_select, lambda=0.01,
                       clusters=good_clusters, B = 10)

  # Named variables for css matrix but not new one--should get a warning
  testthat::expect_warning(getCssPreds(css_res_named, testX=x_pred,
                                       trainX=x_train, trainY=y_train),
                           "New X provided had no variable names (column names) even though the X provided to css did.", fixed=TRUE)

  # Try mismatching variable names
  colnames(x_train) <- LETTERS[2:7]
  colnames(x_pred) <- LETTERS[1:6]
  testthat::expect_error(getCssPreds(css_res_named, testX=x_pred,
                                     trainX=x_train, trainY=y_train),
                         "identical(feat_names, colnames(css_X)) is not TRUE",
                         fixed=TRUE)

  colnames(x_train) <- LETTERS[1:6]
  colnames(x_pred) <- LETTERS[2:7]
  testthat::expect_error(getCssPreds(css_res_named, testX=x_pred,
                                     trainX=x_train, trainY=y_train),
                         "identical(feat_names, colnames(css_X)) is not TRUE",
                         fixed=TRUE)

  colnames(x_pred) <- LETTERS[1:6]

  res_named <- getCssPreds(css_res_named, testX=x_pred, trainX=x_train,
                           trainY=y_train)
  
  testthat::expect_true(all(!is.na(res)))
  testthat::expect_true(is.numeric(res))
  testthat::expect_equal(length(res), 7)

  # Try data.frame input to css and getCssPreds

  X_df <- datasets::mtcars

  n <- nrow(X_df)
  y <- stats::rnorm(n)

  selec_inds <- 1:round(n/3)
  train_inds <- (max(selec_inds) + 1):(max(selec_inds) + 17)
  test_inds <- setdiff(1:n, c(selec_inds, train_inds))

  css_res_df <- css(X=X_df[c(selec_inds, train_inds), ],
                    y=y[c(selec_inds, train_inds)], lambda=0.01, B = 10,
                    train_inds=train_inds)

  res_df <- getCssPreds(css_res_df, testX=X_df[test_inds, ])

  testthat::expect_true(all(!is.na(res_df)))
  testthat::expect_true(is.numeric(res_df))
  testthat::expect_equal(length(res_df), length(test_inds))
  
  # Try again with X as a dataframe with factors (number of columns of final
  # design matrix after one-hot encoding factors won't match number of columns
  # of X_df)
  X_df$cyl <- as.factor(X_df$cyl)
  X_df$vs <- as.factor(X_df$vs)
  X_df$am <- as.factor(X_df$am)
  X_df$gear <- as.factor(X_df$gear)
  X_df$carb <- as.factor(X_df$carb)

  css_res_df <- css(X=X_df[selec_inds, ], y=y[selec_inds], lambda=0.01, B = 10)
  res_df <- getCssPreds(css_res_df, testX=X_df[test_inds, ],
                        trainX=X_df[train_inds, ], trainY=y[train_inds])
  
  # TODO(gregfaletto): known issue--the above code produces the following
  # undesired warnings:
  # 1: In checkGetCssPredsInputs(css_results, testX, weighting, cutoff,  :
  # Column names were provided for testX but not for trainX (are you sure they both contain identical features in the same order?)
  # 2: In checkXInputResults(newx, css_results$X) :
  # New X provided had no variable names (column names) even though the X provided to css did.

  testthat::expect_true(all(!is.na(res_df)))
  testthat::expect_true(is.numeric(res_df))
  testthat::expect_equal(length(res_df), length(test_inds))

})
## ── Warning ('<text>:122'): getCssPreds works ───────────────────────────────────
## New X provided had no variable names (column names) even though the X provided to css did.
## Backtrace:
##  1. testthat::expect_warning(...)
##  7. litr (local) getCssPreds(...)
##  8. litr (local) checkGetCssPredsInputs(...)
##  9. litr (local) checkXInputResults(testX, css_results$X)
## 
## ── Warning ('<text>:122'): getCssPreds works ───────────────────────────────────
## New X provided had no variable names (column names) even though the X provided to css did.
## Backtrace:
##   1. testthat::expect_warning(...)
##   7. litr (local) getCssPreds(...)
##   8. litr (local) formCssDesign(...)
##   9. litr (local) checkFormCssDesignInputs(...)
##  10. litr (local) checkXInputResults(newx, css_results$X)
## 
## ── Warning ('<text>:122'): getCssPreds works ───────────────────────────────────
## New X provided had no variable names (column names) even though the X provided to css did.
## Backtrace:
##   1. testthat::expect_warning(...)
##   7. litr (local) getCssPreds(...)
##   8. litr (local) formCssDesign(...)
##   9. litr (local) checkFormCssDesignInputs(...)
##  10. litr (local) checkXInputResults(newx, css_results$X)
## 
## ── Warning ('<text>:143'): getCssPreds works ───────────────────────────────────
## Column names were provided for testX but not for trainX (are you sure they both contain identical features in the same order?)
## Backtrace:
##  1. litr (local) getCssPreds(...)
##  2. litr (local) checkGetCssPredsInputs(...)
## 
## ── Warning ('<text>:143'): getCssPreds works ───────────────────────────────────
## New X provided had no variable names (column names) even though the X provided to css did.
## Backtrace:
##  1. litr (local) getCssPreds(...)
##  2. litr (local) formCssDesign(...)
##  3. litr (local) checkFormCssDesignInputs(...)
##  4. litr (local) checkXInputResults(newx, css_results$X)
## 
## ── Warning ('<text>:181'): getCssPreds works ───────────────────────────────────
## Column names were provided for testX but not for trainX (are you sure they both contain identical features in the same order?)
## Backtrace:
##  1. litr (local) getCssPreds(...)
##  2. litr (local) checkGetCssPredsInputs(...)
## 
## ── Warning ('<text>:181'): getCssPreds works ───────────────────────────────────
## New X provided had no variable names (column names) even though the X provided to css did.
## Backtrace:
##  1. litr (local) getCssPreds(...)
##  2. litr (local) formCssDesign(...)
##  3. litr (local) checkFormCssDesignInputs(...)
##  4. litr (local) checkXInputResults(newx, css_results$X)
## 
## ── Warning ('<text>:181'): getCssPreds works ───────────────────────────────────
## prediction from a rank-deficient fit may be misleading
## Backtrace:
##  1. litr (local) getCssPreds(...)
##  2. stats::predict.lm(model, newdata = df_test)