## ----setup, include = FALSE---------------------------------------------------
knitr::opts_chunk$set(message = FALSE, warning = FALSE, comment = NA, 
                      fig.width = 6.25, fig.height = 5)
library(ANCOMBC)
library(tidyverse)
library(mia)
library(caret)
library(DT)
options(DT.options = list(
  initComplete = JS("function(settings, json) {",
  "$(this.api().table().header()).css({'background-color': 
  '#000', 'color': '#fff'});","}")))

## ----getPackage, eval=FALSE---------------------------------------------------
#  if (!requireNamespace("BiocManager", quietly = TRUE))
#      install.packages("BiocManager")
#  BiocManager::install("ANCOMBC")

## ----load, eval=FALSE---------------------------------------------------------
#  library(ANCOMBC)

## -----------------------------------------------------------------------------
data(QMP)
set.seed(12345)
n = 150
d = ncol(QMP)
diff_prop = 0.1
lfc_cont = -1
lfc_cat2_vs_1 = -2
lfc_cat3_vs_1 = 1

# Generate the true abundances
abn_data = sim_plnm(abn_table = QMP, taxa_are_rows = FALSE, prv_cut = 0.05, 
                    n = n, scale_mean = 1e8, scale_sd = 1e8/3)
log_abn_data = log(abn_data + 1e-5)
rownames(log_abn_data) = paste0("T", seq_len(d))
colnames(log_abn_data) = paste0("S", seq_len(n))

# Generate the sample and feature meta data
# Sampling fractions are set to differ by batches
smd = data.frame(samp_frac = log(c(runif(n/3, min = 1e-4, max = 1e-3),
                                   runif(n/3, min = 1e-3, max = 1e-2),
                                   runif(n/3, min = 1e-2, max = 1e-1))),
                 cont_cov = rnorm(n),
                 cat_cov = as.factor(rep(seq_len(3), each = n/3)))
rownames(smd) = paste0("S", seq_len(n))
                      
fmd = data.frame(seq_eff = log(runif(d, min = 0.1, max = 1)),
                 lfc_cont = sample(c(0, lfc_cont), 
                                   size = d,
                                   replace = TRUE,
                                   prob = c(1 - diff_prop, diff_prop)),
                 lfc_cat2_vs_1 = sample(c(0, lfc_cat2_vs_1), 
                                        size = d,
                                        replace = TRUE,
                                        prob = c(1 - diff_prop, diff_prop)),
                 lfc_cat3_vs_1 = sample(c(0, lfc_cat3_vs_1), 
                                        size = d,
                                        replace = TRUE,
                                        prob = c(1 - diff_prop, diff_prop))) %>%
    mutate(lfc_cat3_vs_2 = lfc_cat3_vs_1 - lfc_cat2_vs_1)
rownames(fmd) = paste0("T", seq_len(d))

# Add effect sizes of covariates to the true abundances
dmy = caret::dummyVars(" ~ cat_cov", data = smd)
smd_dmy = data.frame(predict(dmy, newdata = smd))

log_abn_data = log_abn_data + outer(fmd$lfc_cont, smd$cont_cov)
log_abn_data = log_abn_data + outer(fmd$lfc_cat2_vs_1, smd_dmy$cat_cov.2)
log_abn_data = log_abn_data + outer(fmd$lfc_cat3_vs_1, smd_dmy$cat_cov.3)

# Add sample- and taxon-specific biases
log_otu_data = t(t(log_abn_data) + smd$samp_frac)
log_otu_data = log_otu_data + fmd$seq_eff
otu_data = round(exp(log_otu_data))

# Create the tse object
assays = SimpleList(counts = otu_data)
smd = DataFrame(smd)
tse = TreeSummarizedExperiment(assays = assays, colData = smd)

## -----------------------------------------------------------------------------
set.seed(123)
output = ancombc2(data = tse, assay_name = "counts", tax_level = NULL,
                  fix_formula = "cont_cov + cat_cov", rand_formula = NULL,
                  p_adj_method = "holm", pseudo = 0, pseudo_sens = TRUE,
                  prv_cut = 0.10, lib_cut = 1000, s0_perc = 0.05,
                  group = "cat_cov", struc_zero = FALSE, neg_lb = FALSE,
                  alpha = 0.05, n_cl = 2, verbose = TRUE,
                  global = FALSE, pairwise = FALSE, 
                  dunnet = FALSE, trend = FALSE,
                  iter_control = list(tol = 1e-5, max_iter = 20, 
                                      verbose = FALSE),
                  em_control = list(tol = 1e-5, max_iter = 100),
                  lme_control = NULL, mdfdr_control = NULL, 
                  trend_control = NULL)

res_prim = output$res
tab_sens = output$pseudo_sens_tab

## ---- fig.width=10------------------------------------------------------------
sens_cat = tab_sens %>%
    transmute(sens_cat = cat_cov2) %>%
    rownames_to_column("tax_id") %>%
    left_join(res_prim %>%
                  rownames_to_column("tax_id") %>%
                  transmute(tax_id, diff_cat = diff_cat_cov2), 
              by = "tax_id") %>%
    mutate(group = "Cat2 vs. Cat1") %>%
    bind_rows(
        tab_sens %>%
            transmute(sens_cat = cat_cov3) %>%
            rownames_to_column("tax_id") %>%
            left_join(res_prim %>%
                          rownames_to_column("tax_id") %>%
                          transmute(tax_id, diff_cat = diff_cat_cov3), 
                      by = "tax_id") %>%
            mutate(group = "Cat3 vs. Cat1")
    )
sens_cat$diff_cat = recode(sens_cat$diff_cat * 1, 
                           `1` = "Significant",
                           `0` = "Nonsignificant")

fig_sens_cat = sens_cat %>%
    ggplot(aes(x = tax_id, y = sens_cat, color = diff_cat)) +
    geom_point() +
    scale_color_brewer(palette = "Dark2", name = NULL) +
    facet_grid(rows = vars(group), scales = "free") +
    labs(x = NULL, y = "Sensitivity Score") +
    theme_bw() +
    theme(axis.text.x = element_text(angle = 60, vjust = 0.5))
fig_sens_cat

sens_cut1 = 5
sens_cut2 = 5

## -----------------------------------------------------------------------------
# Not considering the effect of pseudo-count addition
lfc = res_prim %>%
    dplyr::select(starts_with("lfc"))
diff = res_prim %>%
    dplyr::select(starts_with("diff"))
res_merge = (lfc * diff) %>%
    rownames_to_column("tax_id") %>%
    left_join(fmd %>%
                  rownames_to_column("tax_id"), 
              by = "tax_id") %>%
    transmute(
        tax_id = tax_id,
        est_cov2 = case_when(
            lfc_cat_cov2 > 0 ~ 1,
            lfc_cat_cov2 < 0 ~ -1,
            TRUE ~ 0),
        est_cov3 = case_when(
            lfc_cat_cov3 > 0 ~ 1,
            lfc_cat_cov3 < 0 ~ -1,
            TRUE ~ 0),
        true_cat2 = case_when(
            lfc_cat2_vs_1 > 0 ~ 1,
            lfc_cat2_vs_1 < 0 ~ -1,
            TRUE ~ 0),
        true_cat3 = case_when(
            lfc_cat3_vs_1 > 0 ~ 1,
            lfc_cat3_vs_1 < 0 ~ -1,
            TRUE ~ 0)
        )

# Cat 2 vs Cat 1
lfc_est = res_merge$est_cov2
lfc_true = res_merge$true_cat2
tp = sum(lfc_true != 0 & lfc_est != 0)
fp = sum(lfc_true == 0 & lfc_est != 0)
fn = sum(lfc_true != 0 & lfc_est == 0)
power1_nosens = tp/(tp + fn)
fdr1_nosens = fp/(tp + fp)

# Cat 3 vs Cat 1
lfc_est = res_merge$est_cov3
lfc_true = res_merge$true_cat3
tp = sum(lfc_true != 0 & lfc_est != 0)
fp = sum(lfc_true == 0 & lfc_est != 0)
fn = sum(lfc_true != 0 & lfc_est == 0)
power2_nosens = tp/(tp + fn)
fdr2_nosens = fp/(tp + fp)

# Considering the effect of pseudo-count addition
res_merge2 = res_merge %>%
    left_join(tab_sens %>%
                  rownames_to_column("tax_id"),
              by = "tax_id")

# Cat 2 vs Cat 1
lfc_est = res_merge2$est_cov2 * (res_merge2$cat_cov2 < sens_cut1)
lfc_true = res_merge2$true_cat2
tp = sum(lfc_true != 0 & lfc_est != 0)
fp = sum(lfc_true == 0 & lfc_est != 0)
fn = sum(lfc_true != 0 & lfc_est == 0)
power1_sens = tp/(tp + fn)
fdr1_sens = fp/(tp + fp)

# Cat 3 vs Cat 1
lfc_est = res_merge2$est_cov3 * (res_merge2$cat_cov3 < sens_cut2)
lfc_true = res_merge2$true_cat3
tp = sum(lfc_true != 0 & lfc_est != 0)
fp = sum(lfc_true == 0 & lfc_est != 0)
fn = sum(lfc_true != 0 & lfc_est == 0)
power2_sens = tp/(tp + fn)
fdr2_sens = fp/(tp + fp)

tab_summ1 = data.frame(Comparison = c("Not considering the effect of pseudo-count addition", 
                                      "Considering the effect of pseudo-count addition"),
                       Power = round(c(power1_nosens, power1_sens), 2),
                       FDR = round(c(fdr1_nosens, fdr1_sens), 2))

tab_summ2 = data.frame(Comparison = c("Not considering the effect of pseudo-count addition", 
                                      "Considering the effect of pseudo-count addition"),
                       Power = round(c(power2_nosens, power2_sens), 2),
                       FDR = round(c(fdr2_nosens, fdr2_sens), 2))
tab_summ1 %>%
    datatable(caption = "Cat 2 vs Cat 1")
tab_summ2 %>%
    datatable(caption = "Cat 3 vs Cat 1")

## -----------------------------------------------------------------------------
data(atlas1006)

# Subset to baseline
tse = atlas1006[, atlas1006$time == 0]

# Re-code the bmi group
tse$bmi = recode(tse$bmi_group,
                 obese = "obese",
                 severeobese = "obese",
                 morbidobese = "obese")
# Subset to lean, overweight, and obese subjects
tse = tse[, tse$bmi %in% c("lean", "overweight", "obese")]

# Note that by default, levels of a categorical variable in R are sorted 
# alphabetically. In this case, the reference level for `bmi` will be 
# `lean`. To manually change the reference level, for instance, setting `obese`
# as the reference level, use:
tse$bmi = factor(tse$bmi, levels = c("obese", "overweight", "lean"))
# You can verify the change by checking:
# levels(sample_data(tse)$bmi)

# Create the region variable
tse$region = recode(as.character(tse$nationality),
                    Scandinavia = "NE", UKIE = "NE", SouthEurope = "SE", 
                    CentralEurope = "CE", EasternEurope = "EE",
                    .missing = "unknown")

# Discard "EE" as it contains only 1 subject
# Discard subjects with missing values of region
tse = tse[, ! tse$region %in% c("EE", "unknown")]

print(tse)

## -----------------------------------------------------------------------------
set.seed(123)
output = ancombc2(data = tse, assay_name = "counts", tax_level = "Family",
                  fix_formula = "age + region + bmi", rand_formula = NULL,
                  p_adj_method = "holm", pseudo = 0, pseudo_sens = TRUE,
                  prv_cut = 0.10, lib_cut = 1000, s0_perc = 0.05,
                  group = "bmi", struc_zero = TRUE, neg_lb = TRUE,
                  alpha = 0.05, n_cl = 2, verbose = TRUE,
                  global = TRUE, pairwise = TRUE, dunnet = TRUE, trend = TRUE,
                  iter_control = list(tol = 1e-2, max_iter = 20, 
                                      verbose = TRUE),
                  em_control = list(tol = 1e-5, max_iter = 100),
                  lme_control = lme4::lmerControl(),
                  mdfdr_control = list(fwer_ctrl_method = "holm", B = 100),
                  trend_control = list(contrast = list(matrix(c(1, 0, -1, 1),
                                                              nrow = 2, 
                                                              byrow = TRUE),
                                                       matrix(c(-1, 0, 1, -1),
                                                              nrow = 2, 
                                                              byrow = TRUE)),
                                       node = list(2, 2),
                                       solver = "ECOS",
                                       B = 100))

## -----------------------------------------------------------------------------
tab_zero = output$zero_ind
tab_zero %>%
    datatable(caption = "The detection of structural zeros")

## -----------------------------------------------------------------------------
tab_sens = output$pseudo_sens_tab
tab_sens %>%
    datatable(caption = "Sensitivity Scores") %>%
    formatRound(colnames(tab_sens), digits = 2)

## -----------------------------------------------------------------------------
res_prim = output$res

## -----------------------------------------------------------------------------
df_age = res_prim %>%
    rownames_to_column("tax_id") %>%
    dplyr::select(tax_id, ends_with("age")) 
df_fig_age = df_age %>%
    filter(diff_age == 1) %>% 
    arrange(desc(lfc_age)) %>%
    mutate(direct = ifelse(lfc_age > 0, "Positive LFC", "Negative LFC"))
df_fig_age$tax_id = factor(df_fig_age$tax_id, levels = df_fig_age$tax_id)
df_fig_age$direct = factor(df_fig_age$direct, 
                           levels = c("Positive LFC", "Negative LFC"))
  
fig_age = df_fig_age %>%
    ggplot(aes(x = tax_id, y = lfc_age, fill = direct)) + 
    geom_bar(stat = "identity", width = 0.7, color = "black", 
             position = position_dodge(width = 0.4)) +
    geom_errorbar(aes(ymin = lfc_age - se_age, ymax = lfc_age + se_age), 
                  width = 0.2, position = position_dodge(0.05), color = "black") + 
    labs(x = NULL, y = "Log fold change", 
         title = "Log fold changes as one unit increase of age") + 
    scale_fill_discrete(name = NULL) +
    scale_color_discrete(name = NULL) +
    theme_bw() + 
    theme(plot.title = element_text(hjust = 0.5),
          panel.grid.minor.y = element_blank(),
          axis.text.x = element_text(angle = 60, hjust = 1))
fig_age

## -----------------------------------------------------------------------------
sens_age = tab_sens %>%
    transmute(sens_age = age) %>%
    rownames_to_column("tax_id") %>%
    left_join(df_age, by = "tax_id")
sens_age$diff_age = recode(sens_age$diff_age * 1, 
                           `1` = "Significant",
                           `0` = "Nonsignificant")

fig_sens_age = sens_age %>%
    ggplot(aes(x = tax_id, y = sens_age, color = diff_age)) +
    geom_point() +
    scale_color_brewer(palette = "Dark2", name = NULL) +
    labs(x = NULL, y = "Sensitivity Score") +
    theme_bw() +
    theme(axis.text.x = element_text(angle = 60, vjust = 0.5))
fig_sens_age

## -----------------------------------------------------------------------------
df_bmi = res_prim %>%
    rownames_to_column("tax_id") %>%
    dplyr::select(tax_id, contains("bmi")) 
df_fig_bmi = df_bmi %>%
    filter(diff_bmilean == 1 | diff_bmioverweight == 1) %>%
    mutate(lfc_overweight = ifelse(diff_bmioverweight == 1, 
                                   lfc_bmioverweight, 0),
           lfc_lean = ifelse(diff_bmilean == 1, 
                             lfc_bmilean, 0)) %>%
    transmute(tax_id, 
              `Overweight vs. Obese` = round(lfc_overweight, 2),
              `Lean vs. Obese` = round(lfc_lean, 2)) %>%
    pivot_longer(cols = `Overweight vs. Obese`:`Lean vs. Obese`, 
                 names_to = "group", values_to = "value") %>%
    arrange(tax_id)
  
lo = floor(min(df_fig_bmi$value))
up = ceiling(max(df_fig_bmi$value))
mid = (lo + up)/2
fig_bmi = df_fig_bmi %>%
  ggplot(aes(x = group, y = tax_id, fill = value)) + 
  geom_tile(color = "black") +
  scale_fill_gradient2(low = "blue", high = "red", mid = "white", 
                       na.value = "white", midpoint = mid, limit = c(lo, up),
                       name = NULL) +
  geom_text(aes(group, tax_id, label = value), color = "black", size = 4) +
  labs(x = NULL, y = NULL, title = "Log fold changes as compared to obese subjects") +
  theme_minimal() +
  theme(plot.title = element_text(hjust = 0.5))
fig_bmi

## -----------------------------------------------------------------------------
sens_bmi = tab_sens %>%
    transmute(sens_bmi = bmilean) %>%
    rownames_to_column("tax_id") %>%
    left_join(df_bmi %>%
                  transmute(tax_id, diff_bmi = diff_bmilean), 
              by = "tax_id") %>%
    mutate(group = "Lean vs. Obese") %>%
    bind_rows(
        tab_sens %>%
            transmute(sens_bmi = bmioverweight) %>%
            rownames_to_column("tax_id") %>%
            left_join(df_bmi %>%
                          transmute(tax_id, diff_bmi = diff_bmioverweight), 
                      by = "tax_id") %>%
            mutate(group = "Overweight vs. Obese")
    )
sens_bmi$diff_bmi = recode(sens_bmi$diff_bmi * 1, 
                           `1` = "Significant",
                           `0` = "Nonsignificant")

fig_sens_bmi = sens_bmi %>%
    ggplot(aes(x = tax_id, y = sens_bmi, color = diff_bmi)) +
    geom_point() +
    scale_color_brewer(palette = "Dark2", name = NULL) +
    facet_grid(rows = vars(group), scales = "free") +
    labs(x = NULL, y = "Sensitivity Score") +
    theme_bw() +
    theme(axis.text.x = element_text(angle = 60, vjust = 0.5))
fig_sens_bmi

## -----------------------------------------------------------------------------
res_global = output$res_global

## -----------------------------------------------------------------------------
df_bmi = res_prim %>%
    rownames_to_column("tax_id") %>%
    dplyr::select(tax_id, contains("bmi")) 
df_fig_global = df_bmi %>%
    left_join(res_global %>%
                  rownames_to_column("tax_id") %>%
                  transmute(tax_id, diff_bmi = diff_abn)) %>%
    dplyr::filter(diff_bmi == 1) %>%
    mutate(lfc_lean = ifelse(diff_bmilean == 1, 
                             lfc_bmilean, 0),
           lfc_overweight = ifelse(diff_bmioverweight == 1, 
                                   lfc_bmioverweight, 0)) %>%
    transmute(tax_id, 
              `Lean vs. Obese` = round(lfc_lean, 2), 
              `Overweight vs. Obese` = round(lfc_overweight, 2)) %>%
    pivot_longer(cols = `Lean vs. Obese`:`Overweight vs. Obese`, 
                 names_to = "group", values_to = "value") %>%
    arrange(tax_id)
  
lo = floor(min(df_fig_global$value))
up = ceiling(max(df_fig_global$value))
mid = (lo + up)/2
fig_global = df_fig_global %>%
  ggplot(aes(x = group, y = tax_id, fill = value)) + 
  geom_tile(color = "black") +
  scale_fill_gradient2(low = "blue", high = "red", mid = "white", 
                       na.value = "white", midpoint = mid, limit = c(lo, up),
                       name = NULL) +
  geom_text(aes(group, tax_id, label = value), color = "black", size = 4) +
  labs(x = NULL, y = NULL, title = "Log fold changes for globally significant taxa") +
  theme_minimal() +
  theme(plot.title = element_text(hjust = 0.5))
fig_global

## -----------------------------------------------------------------------------
sens_global = tab_sens %>%
    transmute(sens_global = global) %>%
    rownames_to_column("tax_id") %>%
    left_join(res_global %>%
                  rownames_to_column("tax_id") %>%
                  transmute(tax_id, diff_global = diff_abn * 1), 
              by = "tax_id") 
sens_global$diff_global = recode(sens_global$diff_global, 
                                 `1` = "Significant",
                                 `0` = "Nonsignificant")

fig_sens_global = sens_global %>%
    ggplot(aes(x = tax_id, y = sens_global, color = diff_global)) +
    geom_point() +
    scale_color_brewer(palette = "Dark2", name = NULL) +
    labs(x = NULL, y = "Sensitivity Score") +
    theme_bw() +
    theme(axis.text.x = element_text(angle = 60, vjust = 0.5))
fig_sens_global

## -----------------------------------------------------------------------------
res_pair = output$res_pair

## -----------------------------------------------------------------------------
df_pair = res_pair %>%
    rownames_to_column("tax_id") 
df_fig_pair = df_pair %>%
    filter(diff_bmilean == 1 | diff_bmioverweight == 1 |
               diff_bmilean_bmioverweight == 1) %>%
    mutate(lfc_lean = ifelse(diff_bmilean == 1, 
                             lfc_bmilean, 0),
           lfc_overweight = ifelse(diff_bmioverweight == 1, 
                                   lfc_bmioverweight, 0),
           lfc_lean_overweight = ifelse(diff_bmilean_bmioverweight == 1, 
                                        lfc_bmilean_bmioverweight, 0)) %>%
    transmute(tax_id, 
              `Lean vs. Obese` = round(lfc_lean, 2), 
              `Overweight vs. Obese` = round(lfc_overweight, 2),
              `Lean vs. Overweight` = round(lfc_lean_overweight, 2)
              ) %>%
    pivot_longer(cols = `Lean vs. Obese`:`Lean vs. Overweight`, 
                 names_to = "group", values_to = "value") %>%
    arrange(tax_id)
df_fig_pair$group = factor(df_fig_pair$group, 
                           levels = c("Lean vs. Obese",
                                      "Overweight vs. Obese",
                                      "Lean vs. Overweight"))
  
lo = floor(min(df_fig_pair$value))
up = ceiling(max(df_fig_pair$value))
mid = (lo + up)/2
fig_pair = df_fig_pair %>%
  ggplot(aes(x = group, y = tax_id, fill = value)) + 
  geom_tile(color = "black") +
  scale_fill_gradient2(low = "blue", high = "red", mid = "white", 
                       na.value = "white", midpoint = mid, limit = c(lo, up),
                       name = NULL) +
  geom_text(aes(group, tax_id, label = value), color = "black", size = 4) +
  labs(x = NULL, y = NULL, title = "Log fold change of pairwise comparisons") +
  theme_minimal() +
  theme(plot.title = element_text(hjust = 0.5))
fig_pair

## ---- fig.height=8------------------------------------------------------------
sens_pair = tab_sens %>%
    transmute(sens_pair = `obese - lean`) %>%
    rownames_to_column("tax_id") %>%
    left_join(df_pair %>%
                  transmute(tax_id, 
                            diff_pair = diff_bmilean), 
              by = "tax_id") %>%
    mutate(group = "Lean vs. Obese") %>%
    bind_rows(
        tab_sens %>%
            transmute(sens_pair = `obese - overweight`) %>%
            rownames_to_column("tax_id") %>%
            left_join(df_pair %>%
                          transmute(tax_id, 
                                    diff_pair = diff_bmioverweight), 
                      by = "tax_id") %>%
            mutate(group = "Overweight vs. Obese")
    ) %>%
    bind_rows(
        tab_sens %>%
            transmute(sens_pair = `overweight - lean`) %>%
            rownames_to_column("tax_id") %>%
            left_join(df_pair %>%
                          transmute(tax_id, 
                                    diff_pair = diff_bmilean_bmioverweight), 
                      by = "tax_id") %>%
            mutate(group = "Lean vs. Overweight")
    )
sens_pair$diff_pair = recode(sens_pair$diff_pair * 1, 
                             `1` = "Significant",
                             `0` = "Nonsignificant")
sens_pair$group = factor(sens_pair$group, 
                         levels = c("Lean vs. Obese",
                                    "Overweight vs. Obese",
                                    "Lean vs. Overweight"))

fig_sens_pair = sens_pair %>%
    ggplot(aes(x = tax_id, y = sens_pair, color = diff_pair)) +
    geom_point() +
    scale_color_brewer(palette = "Dark2", name = NULL) +
    facet_grid(rows = vars(group), scales = "free") +
    labs(x = NULL, y = "Sensitivity Score") +
    theme_bw() +
    theme(axis.text.x = element_text(angle = 60, vjust = 0.5))
fig_sens_pair

## -----------------------------------------------------------------------------
res_dunn = output$res_dunn

## -----------------------------------------------------------------------------
df_dunn = res_dunn %>%
    rownames_to_column("tax_id") 
df_fig_dunn = df_dunn %>%
    dplyr::filter(diff_bmilean == 1 | diff_bmioverweight == 1) %>%
    mutate(lfc_lean = ifelse(diff_bmilean == 1, 
                             lfc_bmilean, 0),
           lfc_overweight = ifelse(diff_bmioverweight == 1, 
                                   lfc_bmioverweight, 0)) %>%
    transmute(tax_id, 
              `Lean vs. Obese` = round(lfc_lean, 2), 
              `Overweight vs. Obese` = round(lfc_overweight, 2)) %>%
    pivot_longer(cols = `Lean vs. Obese`:`Overweight vs. Obese`, 
                 names_to = "group", values_to = "value") %>%
    arrange(tax_id)
  
lo = floor(min(df_fig_dunn$value))
up = ceiling(max(df_fig_dunn$value))
mid = (lo + up)/2
fig_dunn = df_fig_dunn %>%
  ggplot(aes(x = group, y = tax_id, fill = value)) + 
  geom_tile(color = "black") +
  scale_fill_gradient2(low = "blue", high = "red", mid = "white", 
                       na.value = "white", midpoint = mid, limit = c(lo, up),
                       name = NULL) +
  geom_text(aes(group, tax_id, label = value), color = "black", size = 4) +
  labs(x = NULL, y = NULL, title = "Log fold changes as compared to obese subjects") +
  theme_minimal() +
  theme(plot.title = element_text(hjust = 0.5))
fig_dunn

## -----------------------------------------------------------------------------
sens_dunn = tab_sens %>%
    transmute(sens_bmi = bmilean) %>%
    rownames_to_column("tax_id") %>%
    left_join(df_dunn %>%
                  transmute(tax_id, diff_bmi = diff_bmilean), 
              by = "tax_id") %>%
    mutate(group = "Lean vs. Obese") %>%
    bind_rows(
        tab_sens %>%
            transmute(sens_bmi = bmioverweight) %>%
            rownames_to_column("tax_id") %>%
            left_join(df_dunn %>%
                          transmute(tax_id, diff_bmi = diff_bmioverweight), 
                      by = "tax_id") %>%
            mutate(group = "Overweight vs. Obese")
    )
sens_dunn$diff_bmi = recode(sens_dunn$diff_bmi * 1, 
                            `1` = "Significant",
                            `0` = "Nonsignificant")

fig_sens_dunn = sens_dunn %>%
    ggplot(aes(x = tax_id, y = sens_bmi, color = diff_bmi)) +
    geom_point() +
    scale_color_brewer(palette = "Dark2", name = NULL) +
    facet_grid(rows = vars(group), scales = "free") +
    labs(x = NULL, y = "Sensitivity Score") +
    theme_bw() +
    theme(axis.text.x = element_text(angle = 60, vjust = 0.5))
fig_sens_dunn

## -----------------------------------------------------------------------------
res_trend = output$res_trend

## ---- fig.width=10------------------------------------------------------------
df_trend = res_trend %>%
    rownames_to_column("tax_id") 
df_fig_trend = df_trend %>%
    dplyr::filter(diff_abn) %>%
    transmute(tax_id,
              lfc = lfc_bmioverweight,
              se = se_bmioverweight,
              q_val,
              group = "Overweight - Obese") %>%
    bind_rows(
        df_trend %>%
            dplyr::filter(diff_abn) %>%
            transmute(tax_id,
                      lfc = lfc_bmilean,
                      se = se_bmilean,
                      q_val,
                      group = "Lean - Obese")
    )
    
df_fig_trend$group = factor(df_fig_trend$group, 
                           levels = c("Overweight - Obese", "Lean - Obese"))
  
fig_trend = df_fig_trend %>%
  ggplot(aes(x = group, y = lfc, fill = group)) + 
  geom_bar(stat = "identity", position = position_dodge(), color = "black") +
  geom_errorbar(aes(ymin = lfc - se, ymax = lfc + se), width = .2,
                position = position_dodge(.9)) +
  facet_wrap(vars(tax_id), nrow = 2, scales = "free") +
  labs(x = NULL, y = NULL, title = "Log fold change as compared to obese subjects") +
  scale_fill_brewer(palette = "Set2", name = NULL) +
  theme_bw() +
  theme(plot.title = element_text(hjust = 0.5),
        axis.text.x = element_blank(),
        axis.ticks.x = element_blank(),
        legend.position = c(0.9, 0.2))
fig_trend

## -----------------------------------------------------------------------------
sens_trend = tab_sens %>%
    transmute(sens_bmi = bmilean) %>%
    rownames_to_column("tax_id") %>%
    left_join(df_trend %>%
                  transmute(tax_id, diff_bmi = diff_abn), 
              by = "tax_id") %>%
    mutate(group = "Lean vs. Obese") %>%
    bind_rows(
        tab_sens %>%
            transmute(sens_bmi = bmioverweight) %>%
            rownames_to_column("tax_id") %>%
            left_join(df_trend %>%
                          transmute(tax_id, diff_bmi = diff_abn), 
                      by = "tax_id") %>%
            mutate(group = "Overweight vs. Obese")
    )
sens_trend$diff_bmi = recode(sens_trend$diff_bmi * 1, 
                            `1` = "Significant",
                            `0` = "Nonsignificant")

fig_sens_trend = sens_trend %>%
    ggplot(aes(x = tax_id, y = sens_bmi, color = diff_bmi)) +
    geom_point() +
    scale_color_brewer(palette = "Dark2", name = NULL) +
    facet_grid(rows = vars(group), scales = "free") +
    labs(x = NULL, y = "Sensitivity Score") +
    theme_bw() +
    theme(axis.text.x = element_text(angle = 60, vjust = 0.5))
fig_sens_trend

## -----------------------------------------------------------------------------
data(dietswap)
tse = dietswap
print(tse)

## -----------------------------------------------------------------------------
set.seed(123)
output = ancombc2(data = tse, assay_name = "counts", tax_level = "Family",
                  fix_formula = "nationality + timepoint + group",
                  rand_formula = "(timepoint | subject)",
                  p_adj_method = "holm", pseudo = 0, pseudo_sens = TRUE,
                  prv_cut = 0.10, lib_cut = 1000, s0_perc = 0.05,
                  group = "group", struc_zero = TRUE, neg_lb = TRUE,
                  alpha = 0.05, n_cl = 2, verbose = TRUE,
                  global = TRUE, pairwise = TRUE, dunnet = TRUE, trend = TRUE,
                  iter_control = list(tol = 1e-2, max_iter = 20, 
                                      verbose = TRUE),
                  em_control = list(tol = 1e-5, max_iter = 100),
                  lme_control = lme4::lmerControl(),
                  mdfdr_control = list(fwer_ctrl_method = "holm", B = 100),
                  trend_control = list(contrast = list(matrix(c(1, 0, -1, 1),
                                                              nrow = 2, 
                                                              byrow = TRUE)),
                                       node = list(2),
                                       solver = "ECOS",
                                       B = 100))

## ----sessionInfo, message = FALSE, warning = FALSE, comment = NA--------------
sessionInfo()

