【学习笔记】The Analytics Edge 第四周:分类回归树与随机森林


  • cid:41:privileges:topics:read

    分类回归树与随机森林

    本周的课程以预测美国最高法院的判决为例展示了如何在R中使用分类回归树以及随机森林算法。数据集为1994年至2001年间的大法官Justice Steven对于下级法院的审判结果是否驳回(1=Reverse 0=affirm),以及案件相关的属性。属性包括:
    Circuit court of origin(1-11th, DC,FED)
    Issue area(civil right, taxation)
    Type of petitioner, type of respondant (US, an employer)
    Ideological direction of lower court decision(conservative or liberal)
    Whether petitioner argued that a law/practice was unconstitutional

    数据集下载地址:https://d37djvu3ytnwxt.cloudfront.net/assets/courseware/v1/80dfa0ae5eb3fa2013a34507cd58fabb/asset-v1:MITx+15.071x_3+1T2016+type@asset+block/stevens.csv

    这个问题可以使用逻辑回归来预测,但是这样会难以解释模型中的哪个因子对结果的影响更大,并且不便于对新的案例进行预测。因此使用分类回归树(CART)这个更容易解释的模型。

    一个简单的CART例子如下:
    0_1474856179329_cart1.jpg
    0_1474856278295_cart2.jpg

    在树上的每一个节点处做出判断选择子节点,直到叶子节点为止:
    如果x<60则判定为红色(split 1)
    否则,如果y>=20则判定为灰色(split 2)
    否则,如果x<85则判定为红色(split 3)
    否则,判定为灰色

    ps: R中使用minbucket参数来控制分枝数量。

    Classification And Regression Trees
    首先导入数据

    stevens = read.csv(“stevens.csv”)
    str(stevens)

    接着讲数据划分为训练集和测试集

    library(caTools)
    set.seed(3000)
    spl = sample.split(stevens$Reverse, SplitRatio = 0.7)
    Train = subset(stevens, spl==TRUE)
    Test = subset(stevens, spl==FALSE)

    导入包

    install.packages(“rpart”)
    library(rpart)
    install.packages(“rpart.plot”)
    library(rpart.plot)

    建立模型, method=”class”申明建立的模型是分类树,minbucket限制了树的大小以防止过拟合

    StevensTree = rpart(Reverse ~ Circuit + Issue + Petitioner + Respondent + LowerCourt + Unconst, data = Train, method=“class”, minbucket=25)

    打印出分类树

    prp(StevensTree)
    0_1474856303672_cart3.jpg

    使用建立好的模型做出预测、并计算预测准确率

    PredictCART = predict(StevensTree, newdata = Test, type = “class”)
    table(Test$Reverse, PredictCART)
    PredictCART
    0 1
    0 41 36
    1 22 71

    (41+71)/(41+36+22+71)
    [1] 0.6588235

    最后使用ROC来评估一下模型

    library(ROCR)
    PredictROC = predict(StevensTree, newdata = Test)
    pred = prediction(PredictROC[,2], Test$Reverse)
    perf = performance(pred, “tpr”, “fpr”)
    plot(perf)
    0_1474856325541_cart4.jpg

    Random Forests
    随机森林(Random Forests)可以被用来提高CART预测的准确性,副作用是让模型不再容易被人理解。随机森林会从数据集中随机选取数据来建立多棵树。由于数据可以被重复选取,所以虽然取自同一个数据集,但是建立的多棵树是不同的。
    随机森林有两个变量需要人为选择:
    nodesize:一个数据子集所需的数据个数的下限(minbucket)
    Ntree:需要多少棵树

    导入包

    install.packages(“randomForest”)
    library(randomForest)

    将reverse变量转换为random forest可以接受的factor类型变量:

    Train$Reverse = as.factor(Train$Reverse)
    Test$Reverse = as.factor(Test$Reverse)

    训练模型并作出预测

    StevensForest = randomForest(Reverse ~ Circuit + Issue + Petitioner + Respondent + LowerCourt + Unconst, data = Train, ntree=200, nodesize=25 )
    PredictForest = predict(StevensForest, newdata = Test)

    计算一下预测的准确度

    table(Test$Reverse, PredictForest)
    PredictForest
    0 1
    0 40 37
    1 19 74
    (40+74)/(40+37+19+74)
    [1] 0.6705882

    使用随机森林的时候需要人为决定两个变量的值。为了选取这两个值,我们需要交叉验证(k-fold cross-validation)。具体来说,将一个训练数据集分成k份(比如5份),交替选取其中的k-1份(4份)训练模型,剩下的一份测试模型。记录每一种数据集组合下变量不同取值最后得到的预测结果,取均值后得到一条预测结果随变量变化的曲线。最后根据这条曲线来决定变量的取值。

    0_1474856347650_cart5.jpg

    0_1474856362447_cart6.jpg

    0_1474856376349_cart7.jpg

    除此以外,随机森林中还有一个变量cp (Complexity Parameter)来衡量模型的好坏(overfitting vs. underfitting)。

    安装交叉验证所需的包

    install.packages(“caret”)
    library(caret)
    install.packages(“e1071”)
    library(e1071)

    决定交叉验证的k值

    numFolds = trainControl( method = “cv”, number = 10 )

    决定测试cp的范围

    cpGrid = expand.grid( .cp = seq(0.01,0.5,0.01))

    交叉验证

    train(Reverse ~ Circuit + Issue + Petitioner + Respondent + LowerCourt + Unconst, data = Train, method = “rpart”, trControl = numFolds, tuneGrid = cpGrid )
    0_1474856431781_cart8.jpg
    0_1474856443944_cart9.jpg

    最终结果显示为 cp = 0.18, 因此我们使用0.18来建立新的CART

    StevensTreeCV = rpart(Reverse ~ Circuit + Issue + Petitioner + Respondent + LowerCourt + Unconst, data = Train, method=“class”, cp = 0.18)

    计算下准确率,发现上升到70%了

    PredictCV = predict(StevensTreeCV, newdata = Test, type = “class”)
    table(Test$Reverse, PredictCV)
    PredictCV
    0 1
    0 59 18
    1 29 64
    (59+64)/(59+18+29+64)
    [1] 0.7235294



  • 确实是很实战的课程~


登录后回复
 

与 BitTiger Community 的连接断开,我们正在尝试重连,请耐心等待