-
Notifications
You must be signed in to change notification settings - Fork 0
/
make_predictions.R
77 lines (73 loc) · 3.68 KB
/
make_predictions.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
library(ggplot2)
make_predictions <- function(modelDir, samples, thin, nChains, panelsDir) {
#filenameout = paste("predictions_thin_", as.character(thin),
# "_samples_", as.character(samples),
# "_chains_",as.character(nChains),
# ".pdf",sep = "")
#filenameout = file.path(panelsDir, filenameout)
#if (file.exists(filenameout)) {
# print(paste("File already exists -- Skipping:", filenameout))
#} else {
filename = paste("models_thin_", as.character(thin),
"_samples_", as.character(samples),
"_chains_",as.character(nChains),
".Rdata",sep = "")
filename = file.path(modelDir, filename)
load(filename)
nm = length(models)
#pdf(file = filenameout)
for(j in 1:nm){
m = models[[j]]
covariates = all.vars(m$XFormula)
ex.sp = which.max(colMeans(m$Y,na.rm = TRUE)) #most common species as example species
if(m$distr[1,1]==2){
ex.sp = which.min(abs(colMeans(m$Y,na.rm = TRUE)-0.5)) #for probit models the species with prevalence closest to 0.5
}
for(k in 1:(length(covariates))){
covariate = covariates[[k]]
Gradient = constructGradient(m,focalVariable = covariate)
Gradient2 = constructGradient(m,focalVariable = covariate,non.focalVariables = 1)
predY = predict(m, Gradient=Gradient, expected = TRUE)
predY2 = predict(m, Gradient=Gradient2, expected = TRUE)
par(mfrow=c(2,1))
pl = plotGradient(m, Gradient, pred=predY, yshow = 0, measure="S", showData = TRUE,
main = paste0(modelnames[[j]],": summed response (total effect)"))
if(inherits(pl, "ggplot")){
print(pl + labs(title=paste0(modelnames[[j]],": summed response (total effect)")))
}
pl = plotGradient(m, Gradient2, pred=predY2, yshow = 0, measure="S", showData = TRUE,
main = paste0(modelnames[[j]],": summed response (marginal effect)"))
if(inherits(pl, "ggplot")){
print(pl + labs(title=paste0(modelnames[[j]],": summed response (marginal effect)")))
}
par(mfrow=c(2,1))
pl = plotGradient(m, Gradient, pred=predY, yshow = if(m$distr[1,1]==2){c(-0.1,1.1)}else{0}, measure="Y",index=ex.sp, showData = TRUE,
main = paste0(modelnames[[j]],": example species (total effect)"))
if(inherits(pl, "ggplot")){
print(pl + labs(title=paste0(modelnames[[j]],": example species (total effect)")))
}
pl = plotGradient(m, Gradient2, pred=predY2, yshow = if(m$distr[1,1]==2){c(-0.1,1.1)}else{0}, measure="Y",index=ex.sp, showData = TRUE,
main = paste0(modelnames[[j]],": example species (marginal effect)"))
if(inherits(pl, "ggplot")){
print(pl + labs(title=paste0(modelnames[[j]],": example species (marginal effect)")))
}
if(m$nt>1){
for(l in 2:m$nt){
par(mfrow=c(2,1))
pl = plotGradient(m, Gradient, pred=predY, measure="T",index=l, showData = TRUE,yshow = 0,
main = paste0(modelnames[[j]],": community weighted mean trait (total effect)"))
if(inherits(pl, "ggplot")){
print(pl + labs(title=paste0(modelnames[[j]],": community weighted mean trait (total effect)")))
}
pl = plotGradient(m, Gradient2, pred=predY2, measure="T",index=l, showData = TRUE, yshow = 0,
main = paste0(modelnames[[j]],": community weighted mean trait (marginal effect)"))
if(inherits(pl, "ggplot")){
print(pl + labs(title=paste0(modelnames[[j]],": community weighted mean trait (marginal effect)")))
}
}
}
}
#}
#dev.off()
}
}