-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path2-Model.R
67 lines (49 loc) · 2.09 KB
/
2-Model.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
#2-Model.R
#according to the tutorial, we will build 5 different models
#to predict species from flower measurements and we will
#set up a 10-fold cross validation to test the model and select
#the best one (most accurate)
#------- Set up cross validation
control <- trainControl(method = "cv", number=10)
metric <- "Accuracy"
#we'll be splitting the dataset into 10 parts, training on 9 and
#testing on 1 and then release, for all combinations of train test splits
#Also, the metric "Accuracy" will be used to evaluate the models. This
#is the ratio of correctly predicted instances divided by the total number
#of instances in the dataset
#------- Build Models
#We don't yet know which algorithms would be good on this problem, so
#we'll evaluate 5 different algorithms:
# - linear discriminant analysis (simple linear)
# - classification and regression trees (CART) (nonlinear)
# - k-Nearest Neighbours (KNN) (nonlinear)
# - Support Vector Machines (SVM) with a linear kernal (complex nonlinear)
# - Random Forest (RF) (complex nonlinear)
#a) linear algorithm
set.seed(7)
fit.lda <- train(Species~., data=dataset, method="lda", metric=metric, trControl=control)
# b) nonlinear algorithms
# CART
set.seed(7)
fit.cart <- train(Species~., data=dataset, method="rpart", metric=metric, trControl=control)
# kNN
set.seed(7)
fit.knn <- train(Species~., data=dataset, method="knn", metric=metric, trControl=control)
# c) advanced algorithms
# SVM
set.seed(7)
fit.svm <- train(Species~., data=dataset, method="svmRadial", metric=metric, trControl=control)
# Random Forest
set.seed(7)
fit.rf <- train(Species~., data=dataset, method="rf", metric=metric, trControl=control)
#------- Select Best Model
#summarize accuracy of models
results <- resamples(list(lda=fit.lda, cart=fit.cart, knn=fit.knn, svm=fit.svm, rf=fit.rf))
summary(results)
dotplot(results) #best model looks to be linear discrimant analysis
#summarize best model
print(fit.lda)
#------- Make Predictions
#estimate skill of LDa on the validation dataset
predictions <- predict(fit.lda, validation)
confusionMatrix(predictions, validation$Species)