Skip to content

Commit

Permalink
modify vignettes to use new plotting functionality, fix sorting, twea…
Browse files Browse the repository at this point in the history
…k inf in logsumexp
  • Loading branch information
helske committed Sep 25, 2024
1 parent 7ebed10 commit 488b313
Show file tree
Hide file tree
Showing 12 changed files with 288 additions and 388 deletions.
219 changes: 110 additions & 109 deletions R/HMMplot.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ HMMplot <- function(x, layout = "horizontal", pie = TRUE,
cex.legend = 1, ncol.legend = "auto", cpal = "auto",
legend.pos = "center", main = "auto", ...) {
dots <- list(...)

labelprint <- function(z, labs) {
if (labs == TRUE && (z > 0.001 || z == 0)) {
labs <- FALSE
Expand All @@ -26,7 +26,7 @@ HMMplot <- function(x, layout = "horizontal", pie = TRUE,
z <- prettyNum(signif(z, digits = label.signif), scientific = labs)
}
}

stopifnot_(
is.matrix(layout) || is.function(layout) ||
layout %in% c("horizontal", "vertical"),
Expand Down Expand Up @@ -78,7 +78,7 @@ HMMplot <- function(x, layout = "horizontal", pie = TRUE,
}
vertex.label.pos <- vpos
}

# Vertex labels
if (length(vertex.label) == 1 && !is.na(vertex.label) && vertex.label != FALSE) {
if (vertex.label == "initial.probs") {
Expand All @@ -91,7 +91,7 @@ HMMplot <- function(x, layout = "horizontal", pie = TRUE,
hidden states.")
vertex.label <- rep(vertex.label, length.out = length(x$state_names))
}

# Vertex label distances
if (is.character(vertex.label.dist)) {
ind <- pmatch(vertex.label.dist, "auto")
Expand All @@ -106,27 +106,27 @@ HMMplot <- function(x, layout = "horizontal", pie = TRUE,
of edges.")
vertex.label.dist <- rep(vertex.label.dist, length.out = length(x$n_states))
}


# Trimming (remove small transition probablities from plot)
transM <- x$transition_probs
transM[transM < trim] <- 0

# Adjacency matrix (which edges to plot)
edges <- transM
edges[edges > 0] <- 1
# Remove transitions back to the same state
if (!loops) {
diag(edges) <- 0
}

# Vector of non-zero transition probabilities
transitions <- transM
if (loops == FALSE && length(transitions) > 1) {
diag(transitions) <- 0
}
transitions <- t(transitions)[t(transitions) > 0]

# Edge labels
if (!is.na(edge.label) && edge.label != FALSE) {
if (length(edge.label) == 1 && (edge.label == "auto" || edge.label == TRUE)) {
Expand All @@ -137,8 +137,8 @@ HMMplot <- function(x, layout = "horizontal", pie = TRUE,
edge.label <- rep(edge.label, length.out = length(transitions))
}
}


# Edge widths
if (is.character(edge.width)) {
ind <- pmatch(edge.width, "auto")
Expand All @@ -153,10 +153,10 @@ HMMplot <- function(x, layout = "horizontal", pie = TRUE,
edges.")
edge.width <- rep(edge.width, length.out = length(transitions))
}

# Defining the graph structure
g1 <- graph.adjacency(edges, mode = "directed")

# Layout of the graph
if (is.function(layout)) {
glayout <- layout(g1)
Expand All @@ -169,8 +169,8 @@ HMMplot <- function(x, layout = "horizontal", pie = TRUE,
glayout <- layout_on_grid(g1, width = 1)
}
}


# Colors for the (combinations of) observed states
if (identical(cpal, "auto")) {
pie.colors <- TraMineR::cpal(x$observations)
Expand All @@ -187,7 +187,7 @@ HMMplot <- function(x, layout = "horizontal", pie = TRUE,
if (with.legend != FALSE) {
pie.colors.l <- pie.colors
}

# Legend position and number of columns
if (with.legend != FALSE && pie == TRUE) {
if (!is.null(ltext)) {
Expand All @@ -200,7 +200,7 @@ HMMplot <- function(x, layout = "horizontal", pie = TRUE,
ltext <- x$symbol_names
}
}

# Defining rescale, xlim, ylim if not given
if (!is.matrix(layout) && !is.function(layout)) {
if (layout == "horizontal") {
Expand Down Expand Up @@ -259,14 +259,14 @@ HMMplot <- function(x, layout = "horizontal", pie = TRUE,
dots[["rescale"]] <- NULL
}
}


# Plotting graph
if (pie == TRUE) {
pie.values <- lapply(seq_len(nrow(transM)), function(i) x$emission_probs[i, ])
# If slices are combined
if (combine.slices > 0 &&
!all(unlist(pie.values)[unlist(pie.values) > 0] > combine.slices)) {
!all(unlist(pie.values)[unlist(pie.values) > 0] > combine.slices)) {
if (with.legend != FALSE) {
pie.colors.l <- NULL
lt <- NULL
Expand Down Expand Up @@ -307,140 +307,141 @@ HMMplot <- function(x, layout = "horizontal", pie = TRUE,
}
}
}

if (!is.matrix(layout) && !is.function(layout) &&
(layout == "horizontal" || layout == "vertical")) {
(layout == "horizontal" || layout == "vertical")) {
if (length(dots) > 0) {
plotcall <- as.call(c(list(plot.igraph, g1,
layout = glayout,
vertex.shape = "pie", vertex.pie = pie.values,
vertex.pie.color = list(pie.colors),
vertex.size = vertex.size,
vertex.label = vertex.label, vertex.label.dist = vertex.label.dist,
vertex.label.degree = vertex.label.pos,
vertex.label.family = vertex.label.family,
edge.curved = edge.curved, edge.width = edge.width,
edge.label = edge.label,
edge.label.family = edge.label.family,
edge.arrow.size = edge.arrow.size,
xlim = xlim, ylim = ylim, rescale = rescale, main = main
layout = glayout,
vertex.shape = "pie", vertex.pie = pie.values,
vertex.pie.color = list(pie.colors),
vertex.size = vertex.size,
vertex.label = vertex.label, vertex.label.dist = vertex.label.dist,
vertex.label.degree = vertex.label.pos,
vertex.label.family = vertex.label.family,
edge.curved = edge.curved, edge.width = edge.width,
edge.label = edge.label,
edge.label.family = edge.label.family,
edge.arrow.size = edge.arrow.size,
xlim = xlim, ylim = ylim, rescale = rescale, main = main
), dots))
} else {
plotcall <- call("plot.igraph", g1,
layout = glayout,
vertex.shape = "pie", vertex.pie = pie.values,
vertex.pie.color = list(pie.colors),
vertex.size = vertex.size,
vertex.label = vertex.label, vertex.label.dist = vertex.label.dist,
vertex.label.degree = vertex.label.pos,
vertex.label.family = vertex.label.family,
edge.curved = edge.curved, edge.width = edge.width,
edge.label = edge.label,
edge.label.family = edge.label.family,
edge.arrow.size = edge.arrow.size,
xlim = xlim, ylim = ylim, rescale = rescale, main = main
layout = glayout,
vertex.shape = "pie", vertex.pie = pie.values,
vertex.pie.color = list(pie.colors),
vertex.size = vertex.size,
vertex.label = vertex.label, vertex.label.dist = vertex.label.dist,
vertex.label.degree = vertex.label.pos,
vertex.label.family = vertex.label.family,
edge.curved = edge.curved, edge.width = edge.width,
edge.label = edge.label,
edge.label.family = edge.label.family,
edge.arrow.size = edge.arrow.size,
xlim = xlim, ylim = ylim, rescale = rescale, main = main
)
}
} else {
if (length(dots) > 0) {
plotcall <- as.call(c(list(plot.igraph, g1,
layout = glayout,
vertex.shape = "pie", vertex.pie = pie.values,
vertex.pie.color = list(pie.colors),
vertex.size = vertex.size,
vertex.label = vertex.label, vertex.label.dist = vertex.label.dist,
vertex.label.degree = vertex.label.pos,
vertex.label.family = vertex.label.family,
edge.curved = edge.curved, edge.width = edge.width,
edge.label = edge.label,
edge.label.family = edge.label.family,
edge.arrow.size = edge.arrow.size, main = main
layout = glayout,
vertex.shape = "pie", vertex.pie = pie.values,
vertex.pie.color = list(pie.colors),
vertex.size = vertex.size,
vertex.label = vertex.label, vertex.label.dist = vertex.label.dist,
vertex.label.degree = vertex.label.pos,
vertex.label.family = vertex.label.family,
edge.curved = edge.curved, edge.width = edge.width,
edge.label = edge.label,
edge.label.family = edge.label.family,
edge.arrow.size = edge.arrow.size, main = main
), dots))
} else {
plotcall <- call("plot.igraph", g1,
layout = glayout,
vertex.shape = "pie", vertex.pie = pie.values,
vertex.pie.color = list(pie.colors),
vertex.size = vertex.size,
vertex.label = vertex.label, vertex.label.dist = vertex.label.dist,
vertex.label.degree = vertex.label.pos,
vertex.label.family = vertex.label.family,
edge.curved = edge.curved, edge.width = edge.width,
edge.label = edge.label,
edge.label.family = edge.label.family,
edge.arrow.size = edge.arrow.size, main = main
layout = glayout,
vertex.shape = "pie", vertex.pie = pie.values,
vertex.pie.color = list(pie.colors),
vertex.size = vertex.size,
vertex.label = vertex.label, vertex.label.dist = vertex.label.dist,
vertex.label.degree = vertex.label.pos,
vertex.label.family = vertex.label.family,
edge.curved = edge.curved, edge.width = edge.width,
edge.label = edge.label,
edge.label.family = edge.label.family,
edge.arrow.size = edge.arrow.size, main = main
)
}
}
} else {
if (!is.matrix(layout) && !is.function(layout) &&
(layout == "horizontal" || layout == "vertical")) {
(layout == "horizontal" || layout == "vertical")) {
if (length(dots) > 0) {
plotcall <- as.call(c(list(plot.igraph, g1,
layout = glayout,
vertex.size = vertex.size,
vertex.label = vertex.label, vertex.label.dist = vertex.label.dist,
vertex.label.degree = vertex.label.pos,
vertex.label.family = vertex.label.family,
edge.curved = edge.curved, edge.width = edge.width,
edge.label = edge.label,
edge.label.family = edge.label.family,
xlim = xlim, ylim = ylim, rescale = rescale, main = main
layout = glayout,
vertex.size = vertex.size,
vertex.label = vertex.label, vertex.label.dist = vertex.label.dist,
vertex.label.degree = vertex.label.pos,
vertex.label.family = vertex.label.family,
edge.curved = edge.curved, edge.width = edge.width,
edge.label = edge.label,
edge.label.family = edge.label.family,
xlim = xlim, ylim = ylim, rescale = rescale, main = main
), dots))
} else {
plotcall <- call("plot.igraph", g1,
layout = glayout,
vertex.size = vertex.size,
vertex.label = vertex.label, vertex.label.dist = vertex.label.dist,
vertex.label.degree = vertex.label.pos,
vertex.label.family = vertex.label.family,
edge.curved = edge.curved, edge.width = edge.width,
edge.label = edge.label,
edge.label.family = edge.label.family,
xlim = xlim, ylim = ylim, rescale = rescale, main = main
layout = glayout,
vertex.size = vertex.size,
vertex.label = vertex.label, vertex.label.dist = vertex.label.dist,
vertex.label.degree = vertex.label.pos,
vertex.label.family = vertex.label.family,
edge.curved = edge.curved, edge.width = edge.width,
edge.label = edge.label,
edge.label.family = edge.label.family,
xlim = xlim, ylim = ylim, rescale = rescale, main = main
)
}
} else {
if (length(dots) > 0) {
plotcall <- as.call(c(list(plot.igraph, g1,
layout = glayout,
vertex.size = vertex.size,
vertex.label = vertex.label, vertex.label.dist = vertex.label.dist,
vertex.label.degree = vertex.label.pos,
vertex.label.family = vertex.label.family,
edge.curved = edge.curved, edge.width = edge.width,
edge.label = edge.label,
edge.label.family = edge.label.family, main = main
layout = glayout,
vertex.size = vertex.size,
vertex.label = vertex.label, vertex.label.dist = vertex.label.dist,
vertex.label.degree = vertex.label.pos,
vertex.label.family = vertex.label.family,
edge.curved = edge.curved, edge.width = edge.width,
edge.label = edge.label,
edge.label.family = edge.label.family, main = main
), dots))
} else {
plotcall <- call("plot.igraph", g1,
layout = glayout,
vertex.size = vertex.size,
vertex.label = vertex.label, vertex.label.dist = vertex.label.dist,
vertex.label.degree = vertex.label.pos,
vertex.label.family = vertex.label.family,
edge.curved = edge.curved, edge.width = edge.width,
edge.label = edge.label,
edge.label.family = edge.label.family, main = main
layout = glayout,
vertex.size = vertex.size,
vertex.label = vertex.label, vertex.label.dist = vertex.label.dist,
vertex.label.degree = vertex.label.pos,
vertex.label.family = vertex.label.family,
edge.curved = edge.curved, edge.width = edge.width,
edge.label = edge.label,
edge.label.family = edge.label.family, main = main
)
}
}
}


# Plotting legend
if (with.legend != FALSE && pie == TRUE) {
legendcall <- call("TraMineR::seqlegend",
legendcall <- call(
"seqlegend",
seqdata = x$observations, cpal = pie.colors.l, ltext = ltext,
position = legend.pos, cex = cex.legend, ncol = ncol.legend,
with.missing = FALSE
)
} else {
legendcall <- NULL
}


return(list(plotcall = plotcall, legendcall = legendcall))

# graphics::layout(1)
}
Loading

0 comments on commit 488b313

Please sign in to comment.