Skip to content

Commit

Permalink
add collider
Browse files Browse the repository at this point in the history
  • Loading branch information
strengejacke committed Aug 3, 2024
1 parent 1844583 commit 435af39
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions R/check_dag.R
Original file line number Diff line number Diff line change
Expand Up @@ -177,23 +177,25 @@ plot.check_dag <- function(x, size_point = 15, colors = NULL, which = "all", ...

# tweak data
p1$data$type <- as.character(p1$data$adjusted)
p1$data$type[vapply(p1$data$name, ggdag::is_collider, logical(1), .dag = x)] <- "collider"
p1$data$type[p1$data$name == attributes(x)$outcome] <- "outcome"
p1$data$type[p1$data$name %in% attributes(x)$exposure] <- "exposure"
p1$data$type <- factor(p1$data$type, levels = c("outcome", "exposure", "adjusted", "unadjusted"))
p1$data$type <- factor(p1$data$type, levels = c("outcome", "exposure", "adjusted", "unadjusted", "collider"))

p2$data$type <- as.character(p2$data$adjusted)
p2$data$type[vapply(p2$data$name, ggdag::is_collider, logical(1), .dag = x)] <- "collider"
p2$data$type[p2$data$name == attributes(x)$outcome] <- "outcome"
p2$data$type[p2$data$name %in% attributes(x)$exposure] <- "exposure"
p2$data$type <- factor(p2$data$type, levels = c("outcome", "exposure", "adjusted", "unadjusted"))
p2$data$type <- factor(p2$data$type, levels = c("outcome", "exposure", "adjusted", "unadjusted", "collider"))

if (is.null(colors)) {
point_colors <- see::see_colors(c("yellow", "cyan", "blue grey", "red"))
} else if (length(colors) != 4) {
insight::format_error("`colors` must be a character vector with four color-values.")
point_colors <- see::see_colors(c("yellow", "cyan", "blue grey", "red", "orange"))
} else if (length(colors) != 5) {
insight::format_error("`colors` must be a character vector with five color-values.")
} else {
point_colors <- colors
}
names(point_colors) <- c("outcome", "exposure", "adjusted", "unadjusted")
names(point_colors) <- c("outcome", "exposure", "adjusted", "unadjusted", "collider")

plot1 <- ggplot2::ggplot(p1$data, ggplot2::aes(x = .data$x, y = .data$y)) +
see::geom_point_borderless(ggplot2::aes(fill = .data$type), size = size_point) +
Expand Down

0 comments on commit 435af39

Please sign in to comment.