R上的CART Package — rpart [參數篇]

在 rpart model 中大概有幾個比較重要的參數:

  • weights: 用來給與data的weight,如果想加重某些data的權重時可使用。 (例如:Adaboost.M1 的演算法)
  • method:分成 “anova”、”poisson”、”class”和”exp”。
  • parms:splitting function的參數,會根據上面不同的方法給不同的參數。(例如:”anova”方法是不需要參數的)
  • control: rpart.control object

以上是rpart的參數部分,大部份都是集中在選擇model和data的權重上。一旦決定方法後,在model中的重要參數,大部分都是在control中用一個rpart.control的物件進行設定的。

接下來我們來介紹 rpart.control中的幾個重要參數:

  • minsplit:每一個node最少要幾個data
  • minbucket:在末端的node上最少要幾個data
  • cp:complexity parameter. (決定精度的參數)
  • maxdepth:Tree的深度

rpart.control中的參數很多,半筆者比較常用的大概是上面幾個。

(附帶一提,rpart.control的參數,也可以直接掛在rpart的最後面”…”的地方。)

~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

接下來,我們將使用cats data調整不同的cp和minsplit:

將前文中的這行

cats_rpart_model <- rpart(Sex~., data = cats)

改成這行

cats_rpart_model <- rpart(Sex~., data = cats, minsplit=1, cp=1e-3)

我們可以比較一下兩次的結果:

很明顯看得出來,右邊雖然誤差可能比較小,但應該是已經overfit了。

~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

接下來我們測試iris data,只採用iris中的Petal.Length和Petal.Length兩個維度的Data。

rm(list=ls(all=TRUE))
data(iris)
Iris_2D <- iris[,3:5]
plot(Iris_2D[1:2], pch = 21, bg = c("red", "green3", "blue")[unclass(Iris_2D$Species)])

library(rpart)
Iris_2D_rpart_model <- rpart(Species~., data = Iris_2D)
Iris_2D_rpart_pred <- predict(Iris_2D_rpart_model, Iris_2D)

Iris_2D_rpart_pred_ClassN <- apply( Iris_2D_rpart_pred,1,function(one_row) return(which(one_row == max(one_row))))
Iris_2D_rpart_pred_Class <- apply( Iris_2D_rpart_pred,1,function(one_row) return(colnames(Iris_2D_rpart_pred)[which(one_row == max(one_row))]))

Iris_2D_Class_temp <- unclass(Iris_2D$Species)
Iris_2D_Class <- attr(Iris_2D_Class_temp ,"levels")[Iris_2D_Class_temp]

table(Iris_2D_rpart_pred_Class,Iris_2D_Class)

plot(Iris_2D_rpart_model)
text(Iris_2D_rpart_model)

x1 <- seq(min(Iris_2D$Petal.Length), max(Iris_2D$Petal.Length), length = 50)
x2 <- seq(min(Iris_2D$Petal.Width), max(Iris_2D$Petal.Width), length = 50)
Feature_x1_to_x2 <- expand.grid(Petal.Length = x1, Petal.Width = x2)
Feature_x1_to_x2_Class <- apply(predict(Iris_2D_rpart_model,Feature_x1_to_x2),1,
	function(one_row) return(which(one_row == max(one_row))))

plot(Iris_2D[1:2], pch = 21, bg = c("red", "green3", "blue")[unclass(Iris_2D$Species)])
contour(x1,x2,matrix(Feature_x1_to_x2_Class,length(x1)),add = T, levels = c(1.5,2.5),labex = 0)

同樣的,我們可以把第7行的model換成下面這行在run一次

Iris_2D_rpart_model <- rpart(Species~., data = Iris_2D, minsplit=1, cp=1e-3)

我們會得到:

~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

最後,在rpart中也有類似像svm中那個tune的設計。可以直接在圖上畫出各種cp所train出來的model的精度。如下:(延續使用上方的iris作為例子)

plotcp(Iris_2D_rpart_model)
printcp(Iris_2D_rpart_model)

會得到cp和error的一張圖,如下

並且會在R的shell中看到以下的訊息:

> printcp(Iris_2D_rpart_model)

Classification tree:
rpart(formula = Species ~ ., data = Iris_2D, minsplit = 1, cp = 0.001)

Variables actually used in tree construction:
[1] Petal.Length Petal.Width 

Root node error: 100/150 = 0.66667

n= 150 

     CP nsplit rel error xerror     xstd
1 0.500      0      1.00   1.16 0.051277
2 0.440      1      0.50   0.65 0.060690
3 0.020      2      0.06   0.07 0.025833
4 0.010      3      0.04   0.08 0.027520
5 0.001      6      0.01   0.07 0.025833

從圖上和訊息中,我們可以看出error隨著cp的縮小,而遞減的速度。

~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

上一篇: R上的CART Package — rpart [入門篇]

下一篇:

~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

This entry was posted in Classification and Regression, Data Mining / Machine Learning, R Programming. Bookmark the permalink.

Leave a comment