找回密码
 立即注册

QQ登录

只需一步,快速开始

搜索
查看: 1742|回复: 0
打印 上一主题 下一主题
收起左侧

Python随机森林例子 源码分享

[复制链接]
跳转到指定楼层
楼主
  1.     "#测试gini\n",
  2.     "gini=calGini((l,r),classLabels)\n",
  3.     "print(gini)\n"
  4.    ]
  5.   },
  6.   {
  7.    "cell_type": "code",
  8.    "execution_count": 19,
  9.    "metadata": {},
  10.    "outputs": [],
  11.    "source": [
  12.     "def getBestSplit(dataSet,featureNumbers):\n",
  13.     "    '''\n",
  14.     "    对于一个数据集,选择featureNumber个特征进行简单划分,得到最好的特征和划分结果\n",
  15.     "    args:\n",
  16.     "      dataSet:数据集,类型:list\n",
  17.     "      featureNumbers:选择的特征值数,类型:int\n",
  18.     "      classLabels:所有分类,类型:list\n",
  19.     "    ''' \n",
  20.     "    \n",
  21.     "    #样本数\n",
  22.     "    m=len(dataSet)\n",
  23.     "    if m==0:\n",
  24.     "        return None\n",
  25.     "    #样本特征值数+1(因为最后有一个标签)\n",
  26.     "    totalColumnNumber=len(dataSet[0])\n",
  27.     "    #随机选择的特征索引\n",
  28.     "    randomSelectedFeatures=[]\n",
  29.     "    \n",
  30.     "    \n",
  31.     "    \n",
  32.     "    #选择数目必须在特征数目范围内\n",
  33.     "    if totalColumnNumber-1>=featureNumbers:        \n",
  34.     "        #借助这个变量防止选择重复的特征进入\n",
  35.     "        indexList=list(range(totalColumnNumber-1))            \n",
  36.     "        for j in range(featureNumbers):\n",
  37.     "            #索引序列长度\n",
  38.     "            leftSize=len(indexList)\n",
  39.     "            #随机数\n",
  40.     "            randIndex=random.randrange(leftSize)\n",
  41.     "            #索引学列随机数处数据弹出,放入选择特征列表\n",
  42.     "            origIndex=indexList.pop(randIndex)\n",
  43.     "            #存入的是原始数据特征索引\n",
  44.     "            randomSelectedFeatures.append(origIndex)\n",
  45.     "    else:\n",
  46.     "        randomSelectedFeatures=range(totalColumnNumber-1)#特征全部被选择\n",
  47.     "    \n",
  48.     "    \n",
  49.     "   # print(\"current select features\")\n",
  50.     "   # print(randomSelectedFeatures)\n",
  51.     "\n",
  52.     "    #当前数据集的标签序列\n",
  53.     "    class_values=list(set(item[-1] for item in dataSet))\n",
  54.     "    \n",
  55.     "    #对于每个特征以及每个特征值进行简单划分\n",
  56.     "    #保留最小的基尼系数\n",
  57.     "    minGini=9999\n",
  58.     "    #存入最好的信息\n",
  59.     "    bestInfor={}\n",
  60.     "    #外层循环,对于每个特征\n",
  61.     "    for index in randomSelectedFeatures:\n",
  62.     "        #内层循环对于每个特征值\n",
  63.     "        tempFeatureValueList=list(set(item[index] for item in dataSet))\n",
  64.     "        #print(len(tempFeatureValueList))\n",
  65.     "        for tempValue in tempFeatureValueList:\n",
  66.     "            #简单分类\n",
  67.     "            groups=simpleSplit(dataSet,index,tempValue)            \n",
  68.     "            #print(\"currentIndex:%d,CurrentTempValue:%f\"%(index,tempValue))\n",
  69.     "            #计算基尼系数\n",
  70.     "            gini=calGini(groups,class_values)\n",
  71.     "            #print(\"computed gini:\",gini)            \n",
  72.     "            if gini<minGini:\n",
  73.     "                minGini=gini\n",
  74.     "                #保存目前最后的信息\n",
  75.     "                bestInfor[\"index\"]=index#存入原来索引                \n",
  76.     "                bestInfor[\"indexValue\"]=tempValue\n",
  77.     "                bestInfor[\"groups\"]=groups\n",
  78.     "                bestInfor[\"gini\"]=gini\n",
  79.     "                \n",
  80.     "    return bestInfor"
  81.    ]
  82.   },
  83.   {
  84.    "cell_type": "code",
  85.    "execution_count": 20,
  86.    "metadata": {},
  87.    "outputs": [
  88.     {
  89.      "name": "stdout",
  90.      "output_type": "stream",
  91.      "text": [
  92.       "52 0.017\n"
  93.      ]
  94.     }
  95.    ],
  96.    "source": [
  97.     "#测试最好分类函数\n",
  98.     "bestInfor=getBestSplit(dataSet,3)\n",
  99.     "print(bestInfor[\"index\"],bestInfor[\"indexValue\"])"
  100.    ]
  101.   },
  102.   {
  103.    "cell_type": "code",
  104.    "execution_count": 21,
  105.    "metadata": {},
  106.    "outputs": [],
  107.    "source": [
  108.     "def terminalLabel(subSet):\n",
  109.     "    '''\n",
  110.     "    树叶点对应的标签\n",
  111.     "    args:\n",
  112.     "      subSet:当前数据集,最后列是标签列,类型:list\n",
  113.     "    returns:\n",
  114.     "      当前列中最多的标签,类型:原标签类型\n",
  115.     "    '''\n",
  116.     "    #得到最后一列\n",
  117.     "    labelList=[item[-1] for item in subSet]\n",
  118.     "    #max函数,key后是函数,代表对前面的进行那种运算,这里是技术\n",
  119.     "    #max返回值是第一个参数,这里set是把labelList转换成集合,即去掉重复项\n",
  120.     "    #key:相当于循环调用labelList.count(set(labelList))中的每个元素,然后max取得最大值\n",
  121.     "    #返回set(labelList)中对应最大的那个标签\n",
  122.     "    return max(set(labelList), key=labelList.count)   # 输出 subSet 中出现次数较多的标签 \n",
  123.     "\n",
  124.     "    #下面的写法也是成立的,利用lambda表达式,表达式中x从全面取,这种写法可能更好理解些\n",
  125.     "    #return max(set(labelList), key=lambda x:labelList.count(x)) "
  126.    ]
  127.   },
  128.   {
  129.    "cell_type": "code",
  130.    "execution_count": 22,
  131.    "metadata": {},
  132.    "outputs": [
  133.     {
  134.      "name": "stdout",
  135.      "output_type": "stream",
  136.      "text": [
  137.       "R\n"
  138.      ]
  139.     }
  140.    ],
  141.    "source": [
  142.     "#测试\n",
  143.     "label=terminalLabel(l)\n",
  144.     "print(label)"
  145.    ]
  146.   },
  147.   {
  148.    "cell_type": "code",
  149.    "execution_count": 23,
  150.    "metadata": {},
  151.    "outputs": [],
  152.    "source": [
  153.     "#对得到的最好分类信息进行分割\n",
  154.     "def split(node, max_depth, min_size, n_features, depth):  # 创建子分割器 递归分类 直到分类结束\n",
  155.     "    '''\n",
  156.     "    :param node:        节点,类型:字典\n",
  157.     "                    bestInfor[\"index\"]=index#存入原来索引                \n",
  158.     "                    bestInfor[\"indexValue\"]=tempValue\n",
  159.     "                    bestInfor[\"groups\"]=groups\n",
  160.     "                    bestInfor[\"gini\"]=gini\n",
  161.     "    :param max_depth:   最大深度,int\n",
  162.     "    :param min_size:    最小,int\n",
  163.     "    :param n_features:  特征选取个数,int\n",
  164.     "    :param depth:       深度,int\n",
  165.     "    :return:\n",
  166.     "    '''\n",
  167.     "    left, right = node['groups']\n",
  168.     "    del (node['groups'])\n",
  169.     "\n",
  170.     "    if not left or not right:  # 如果只有一个子集\n",
  171.     "        node['left'] = node['right'] = terminalLabel(left + right)  # 投票出类型\n",
  172.     "        return\n",
  173.     "\n",
  174.     "    if depth >= max_depth:  # 如果即将超过\n",
  175.     "        node['left'], node['right'] = terminalLabel(left), terminalLabel(right)  # 投票出类型\n",
  176.     "        return\n",
  177.     "\n",
  178.     "    if len(left) <= min_size:  # 处理左子集\n",
  179.     "        node['left'] = terminalLabel(left)\n",
  180.     "    else:\n",
  181.     "        node['left'] = getBestSplit(left, n_features)  # node['left']是一个字典,形式为{'index':b_index, 'value':b_value, 'groups':b_groups},所以node是一个多层字典\n",
  182.     "        split(node['left'], max_depth, min_size, n_features, depth + 1)  # 递归,depth+1计算递归层数\n",
  183.     "\n",
  184.     "    if len(right) <= min_size:  # 处理右子集\n",
  185.     "        node['right'] = terminalLabel(right)\n",
  186.     "    else:\n",
  187.     "        node['right'] = getBestSplit(right, n_features)\n",
  188.     "        split(node['right'], max_depth, min_size, n_features, depth + 1)\n",
  189.     "        "
  190.    ]
  191.   },
  192.   {
  193.    "cell_type": "code",
  194.    "execution_count": 24,
  195.    "metadata": {},
  196.    "outputs": [],
  197.    "source": [
  198.     "#构建一个决策树\n",
  199.     "def buildTree(train, max_depth, min_size, n_features):\n",
  200.     "    '''\n",
  201.     "    创建一个决策树\n",
  202.     "    :param train:       训练数据集\n",
  203.     "    :param max_depth:   决策树深度不能太深 不然容易导致过拟合\n",
  204.     "    :param min_size:    叶子节点的大小\n",
  205.     "    :param n_features:  选择的特征的个数\n",
  206.     "    :return\n",
  207.     "        root    返回决策树\n",
  208.     "    '''\n",
  209.     "    root = getBestSplit(train, n_features)  # 获取样本数据集\n",
  210.     "    split(root, max_depth, min_size, n_features, 1)  # 进行样本分割,构架决策树\n",
  211.     "    return root  # 返回决策树\n"
  212.    ]
  213.   },
  214.   {
  215.    "cell_type": "code",
  216.    "execution_count": 25,
  217.    "metadata": {},
  218.    "outputs": [
  219.     {
  220.      "name": "stdout",
  221.      "output_type": "stream",
  222.      "text": [
  223.       "{'index': 55, 'indexValue': 0.0114, 'gini': 0.0, 'left': {'index': 35, 'indexValue': 0.2288, 'gini': 0.0, 'left': 'R', 'right': {'index': 33, 'indexValue': 0.2907, 'gini': 0.0, 'left': 'R', 'right': {'index': 58, 'indexValue': 0.0057, 'gini': 0.0, 'left': {'index': 12, 'indexValue': 0.0493, 'gini': 0.0, 'left': 'R', 'right': 'R'}, 'right': 'R'}}}, 'right': {'index': 54, 'indexValue': 0.0063, 'gini': 0.0, 'left': {'index': 21, 'indexValue': 0.8384, 'gini': 0.0, 'left': 'M', 'right': 'M'}, 'right': {'index': 32, 'indexValue': 0.558, 'gini': 0.0, 'left': 'M', 'right': {'index': 58, 'indexValue': 0.0332, 'gini': 0.0, 'left': 'M', 'right': 'M'}}}}\n"
  224.      ]
  225.     }
  226.    ],
  227.    "source": [
  228.     "#测试决策树\n",
  229.     "#选择一个子集\n",
  230.     "s=putBackSample(dataSet,10)\n",
  231.     "tempTree=buildTree(s,10,1,3)\n",
  232.     "print(tempTree)"
  233.    ]
  234.   },
  235.   {
  236.    "cell_type": "code",
  237.    "execution_count": 26,
  238.    "metadata": {},
  239.    "outputs": [],
  240.    "source": [
  241.     "#根据决策树进行预测\n",
  242.     "def predict(node, row):   # 预测模型分类结果\n",
  243.     "    '''\n",
  244.     "    在当前节点进行预测,row是待预测样本\n",
  245.     "    args:\n",
  246.     "       node:树节点\n",
  247.     "       row:待分类样本\n",
  248.     "    return:\n",
  249.     "       分类标签\n",
  250.     "    '''\n",
  251.     "    if row[node['index']] < node['indexValue']:\n",
  252.     "        if isinstance(node['left'], dict):       # isinstance 是 Python 中的一个内建函数。是用来判断一个对象是否是一个已知的类型。\n",
  253.     "            return predict(node['left'], row)\n",
  254.     "        else:\n",
  255.     "            return node['left']\n",
  256.     "    else:\n",
  257.     "        if isinstance(node['right'], dict):\n",
  258.     "            return predict(node['right'], row)\n",
  259.     "        else:\n",
  260.     "            return node['right']"
  261.    ]
  262.   },
  263.   {
  264.    "cell_type": "code",
  265.    "execution_count": 27,
  266.    "metadata": {},
  267.    "outputs": [
  268.     {
  269.      "name": "stdout",
  270.      "output_type": "stream",
  271.      "text": [
  272.       "R R\n"
  273.      ]
  274.     }
  275.    ],
  276.    "source": [
  277.     "#测试下\n",
  278.     "label=predict(tempTree,s[0])\n",
  279.     "print(label,s[0][-1])"
  280.    ]
  281.   },
  282.   {
  283.    "cell_type": "code",
  284.    "execution_count": 28,
  285.    "metadata": {},
  286.    "outputs": [],
  287.    "source": [
  288.     "#多个树的决策,多数服从少数\n",
  289.     "def baggingPredict(trees, row):\n",
  290.     "    \"\"\"\n",
  291.     "    多个树的决策,多数服从少数\n",
  292.     "    Args:\n",
  293.     "        trees           决策树的集合\n",
  294.     "        row             测试数据集的每一行数据\n",
  295.     "    Returns:\n",
  296.     "        返回随机森林中,决策树结果出现次数做大的\n",
  297.     "    \"\"\"\n",
  298.     "\n",
  299.     "    # 使用多个决策树trees对测试集test的第row行进行预测,再使用简单投票法判断出该行所属分类\n",
  300.     "    predictions = [predict(tree, row) for tree in trees]\n",
  301.     "    return max(set(predictions), key=predictions.count)\n"
  302.    ]
  303.   },
  304.   {
  305.    "cell_type": "code",
  306.    "execution_count": 29,
  307.    "metadata": {},
  308.    "outputs": [],
  309.    "source": [
  310.     "def subSample(dataSet, ratio):  \n",
  311.     "    '''\n",
  312.     "    按比例随机抽取数据,有重复抽样\n",
  313.     "    args:\n",
  314.     "      dataSet:数据集,类型:list\n",
  315.     "      ratio:0-1之间的数\n",
  316.     "    '''\n",
  317.     "    if ratio<0.0:\n",
  318.     "        return None\n",
  319.     "    if ratio>=1:\n",
  320.     "        return dataSet\n",
  321.     "    sampleNumber=int(len(dataSet)*ratio)\n",
  322.     "    subSet=putBackSample(dataSet,sampleNumber)\n",
  323.     "    return subSet"
  324.    ]
  325.   },
  326.   {
  327.    "cell_type": "code",
  328.    "execution_count": 30,
  329.    "metadata": {},
  330.    "outputs": [
  331.     {
  332.      "name": "stdout",
  333.      "output_type": "stream",
  334.      "text": [
  335.       "41\n"
  336.      ]
  337.     }
  338.    ],
  339.    "source": [
  340.     "#测试\n",
  341.     "subSet=subSample(dataSet,0.2)\n",
  342.     "print(len(subSet))"
  343.    ]
  344.   },
  345.   {
  346.    "cell_type": "code",
  347.    "execution_count": 31,
  348.    "metadata": {},
  349.    "outputs": [],
  350.    "source": [
  351.     "#随机森林主函数\n",
  352.     "def buildRandomForest(train, max_depth=10, min_size=1, sample_size=0.2, n_trees=10, n_features=3):\n",
  353.     "    \"\"\"\n",
  354.     "    random_forest(评估算法性能,返回模型得分)\n",
  355.     "    Args:\n",
  356.     "        train           训练数据集,类型:list        \n",
  357.     "        max_depth       决策树深度不能太深,不然容易导致过拟合\n",
  358.     "        min_size        叶子节点的大小\n",
  359.     "        sample_size     训练数据集的样本比例,0,1之间的数\n",
  360.     "        n_trees         决策树的个数\n",
  361.     "        n_features      选取的特征的个数\n",
  362.     "    Returns:\n",
  363.     "        trees:树序列\n",
  364.     "    \"\"\"\n",
  365.     "\n",
  366.     "    trees = list()\n",
  367.     "    # n_trees 表示决策树的数量\n",
  368.     "    for i in range(n_trees):\n",
  369.     "        # 随机抽样的训练样本, 随机采样保证了每棵决策树训练集的差异性\n",
  370.     "        sample = subSample(train, sample_size)\n",
  371.     "        # 创建一个决策树\n",
  372.     "        tree = buildTree(sample, max_depth, min_size, n_features)\n",
  373.     "        trees.append(tree)\n",
  374.     "    return trees\n",
  375.     "  \n"
  376.    ]
  377.   },
  378.   {
  379.    "cell_type": "code",
  380.    "execution_count": 32,
  381.    "metadata": {},
  382.    "outputs": [],
  383.    "source": [
  384.     "def predictByForest(trees,test):\n",
  385.     "    '''\n",
  386.     "    predictions     每一行的预测结果,bagging 预测最后的分类结果\n",
  387.     "    '''\n",
  388.     "    # 每一行的预测结果,bagging 预测最后的分类结果\n",
  389.     "    predictions = [baggingPredict(trees, row) for row in test]\n",
  390.     "    return predictions"
  391.    ]
  392.   },
  393.   {
  394.    "cell_type": "code",
  395.    "execution_count": 33,
  396.    "metadata": {},
  397.    "outputs": [],
  398.    "source": [
  399.     "def calQuota(predictions,labelClass,OrigClassLabels):\n",
  400.     "    '''\n",
  401.     "    计算分类指标\n",
  402.     "    args:\n",
  403.     "      predictions:预测值,类型:list\n",
  404.     "      labelClass:真实标签,类型:list\n",
  405.     "      OrigClassLabels:数据可能的标签库,一个正例一个负例标签\n",
  406.     "    '''\n",
  407.     "    \n",
  408.     "    Pos=OrigClassLabels[0]\n",
  409.     "    Nev=OrigClassLabels[1]    \n",
  410.     "    #真正例   \n",
  411.     "    #TP=len([item for item in labelClass if item==Pos and predictions[labelClass.index(item)]==Pos])\n",
  412.     "    TP=0\n",
  413.     "    TN=0\n",
  414.     "    FP=0\n",
  415.     "    FN=0\n",
  416.     "    for j in range(len(predictions)):        \n",
  417.     "        if predictions[j]==Pos and  labelClass[j]==Pos:\n",
  418.     "            TP+=1\n",
  419.     "        if predictions[j]==Nev and  labelClass[j]==Nev:\n",
  420.     "            TN+=1\n",
  421.     "        if predictions[j]==Pos and  labelClass[j]==Nev:\n",
  422.     "            FP+=1\n",
  423.     "        if predictions[j]==Nev and  labelClass[j]==Pos:\n",
  424.     "            FN+=1\n",
  425.     "#     #真负例,下面的做法不行,原因是index可能得到不同的索引\n",
  426.     "#     TN=len([item for item in labelClass if item==Nev and predictions[labelClass.index(item)]==Nev])\n",
  427.     "#     #伪正例\n",
  428.     "#     FP=len([item for item in labelClass if item==Nev and predictions[labelClass.index(item)]==Pos])\n",
  429.     "#     #伪负例\n",
  430.     "#     FN=len([item for item in labelClass if item==Pos and predictions[labelClass.index(item)]==Nev])\n",
  431.     "\n",
  432.     "    #Recall,TruePosProp=TP/(TP+FN)#识别的正例占整个正例的比率\n",
  433.     "    #FalsPosProp=FP/(FP+TN)#识别的正例占整个负例的比率\n",
  434.     "    #Precition=TP/(TP+FP)#识别的正确正例占识别出所有正例的比率\n",
  435.     "    \n",
  436.     "    return TP,TN,FP,FN"
  437.    ]
  438.   },
  439.   {
  440.    "cell_type": "code",
  441.    "execution_count": 34,
  442.    "metadata": {},
  443.    "outputs": [],
  444.    "source": [
  445.     "#测试下:\n",
  446.     "trees=buildRandomForest(dataSet)\n",
  447.     "testSet=nonPutBackSample(dataSet,100)\n",
  448.     "prediction=predictByForest(trees,testSet)\n"
  449.    ]
  450.   },
  451.   {
  452.    "cell_type": "code",
  453.    "execution_count": 35,
  454.    "metadata": {},
  455.    "outputs": [
  456.     {
  457.      "name": "stdout",
  458.      "output_type": "stream",
  459.      "text": [
  460.       "(44, 39, 12, 5)\n"
  461.      ]
  462.     }
  463.    ],
  464.    "source": [
  465.     "labelClass=[item[-1] for item in testSet]\n",
  466.     "\n",
  467.     "tp=calQuota(prediction,labelClass,list(classLabels))\n",
  468.     "print(tp)"
  469.    ]
  470.   },
  471.   {
  472.    "cell_type": "code",
  473.    "execution_count": 36,
  474.    "metadata": {},
  475.    "outputs": [],
  476.    "source": [
  477.     "def accuracy( predicted,actual):  \n",
  478.     "    correct = 0\n",
  479.     "    for i in range(len(actual)):\n",
  480.     "        if actual[i] == predicted[i]:\n",
  481.     "            correct += 1\n",
  482.     "    return correct / float(len(actual)) * 100.0\n"
  483.    ]
  484.   },
  485.   {
  486.    "cell_type": "code",
  487.    "execution_count": 37,
  488.    "metadata": {},
  489.    "outputs": [
  490.     {
  491.      "name": "stdout",
  492.      "output_type": "stream",
  493.      "text": [
  494.       "83.0\n"
  495.      ]
  496.     }
  497.    ],
  498.    "source": [
  499.     "a=accuracy(prediction,labelClass)\n",
  500.     "print(a)"
  501.    ]
  502.   },
  503.   {
  504.    "cell_type": "code",
  505.    "execution_count": 38,
  506.    "metadata": {},
  507.    "outputs": [],
  508.    "source": [
  509.     "def createCrossValideSets(trainSet,n_folds,bPutBack=True):\n",
  510.     "    '''\n",
  511.     "    产生交叉验证数据集\n",
  512.     "    Args:\n",
  513.     "        dataset     原始数据集       \n",
  514.     "        n_folds     数据的份数,数据集交叉验证的份数,采用无放回抽取\n",
  515.     "        bPutBack    是否放回\n",
  516.     "    '''\n",
  517.     "    subSetsList=[]\n",
  518.     "    subLen=int(len(trainSet)/n_folds)\n",
  519.     "    if bPutBack:\n",
  520.     "        for j in range(n_folds):\n",
  521.     "            subSet=putBackSample(trainSet,subLen)\n",
  522.     "            subSetsList.append(subSet)\n",
  523.     "    else:\n",
  524.     "        for j in range(n_folds):\n",
  525.     "            subSet=nonPutBackSample(trainSet,subLen)\n",
  526.     "            subSetsList.append(subSet)\n",
  527.     "    return subSetsList"
  528.    ]
  529.   },
  530.   {
  531.    "cell_type": "code",
  532.    "execution_count": 39,
  533.    "metadata": {},
  534.    "outputs": [],
  535.    "source": [
  536.     "def randomForest(trainSet,testSet,max_depth=10, min_size=1, sample_size=0.2, n_trees=10, n_features=3):\n",
  537.     "    '''\n",
  538.     "    构造随机森林并测试\n",
  539.     "     Args:\n",
  540.     "        train           训练数据集,类型:list        \n",
  541.     "        testSet         测试集,类型:list\n",
  542.     "        max_depth       决策树深度不能太深,不然容易导致过拟合\n",
  543.     "        min_size        叶子节点的大小\n",
  544.     "        sample_size     训练数据集的样本比例,0,1之间的数\n",
  545.     "        n_trees         决策树的个数\n",
  546.     "        n_features      选取的特征的个数\n",
  547.     "    Returns:\n",
  548.     "        predition       测试集预测值,类型:list\n",
  549.     "    '''\n",
  550.     "    trees=buildRandomForest(trainSet,max_depth, min_size, sample_size, n_trees, n_features)\n",
  551.     "    predition=predictByForest(trees,testSet)\n",
  552.     "    return predition"
  553.    ]
  554.   },
  555.   {
  556.    "cell_type": "code",
  557.    "execution_count": 40,
  558.    "metadata": {},
  559.    "outputs": [],
  560.    "source": [
  561.     "def evaluteAlgorithm(trainSet,algorithm,n_folds,*args):\n",
  562.     "    '''\n",
  563.     "    评价算法函数\n",
  564.     "     Args:\n",
  565.     "        dataset     原始数据集\n",
  566.     "        algorithm   使用的算法\n",
  567.     "        n_folds     数据的份数,数据集交叉验证的份数,采用无放回抽取\n",
  568.     "        *args       其他的参数\n",
  569.     "    Returns:\n",
  570.     "        scores      模型得分\n",
  571.     "    '''\n",
  572.     "    folds = createCrossValideSets(trainSet, n_folds)\n",
  573.     "    scores = list()\n",
  574.     "    # 每次循环从 folds 从取出一个 fold 作为测试集,其余作为训练集,遍历整个 folds ,实现交叉验证\n",
  575.     "    for fold in folds:\n",
  576.     "        train_set = list(folds)\n",
  577.     "        train_set.remove(fold)\n",
  578.     "        # 将多个 fold 列表组合成一个 train_set 列表, 类似 union all\n",
  579.     "        \"\"\"\n",
  580.     "        In [20]: l1=[[1, 2, 'a'], [11, 22, 'b']]\n",
  581.     "        In [21]: l2=[[3, 4, 'c'], [33, 44, 'd']]\n",
  582.     "        In [22]: l=[]\n",
  583.     "        In [23]: l.append(l1)\n",
  584.     "        In [24]: l.append(l2)\n",
  585.     "        In [25]: l\n",
  586.     "        Out[25]: [[[1, 2, 'a'], [11, 22, 'b']], [[3, 4, 'c'], [33, 44, 'd']]]\n",
  587.     "        In [26]: sum(l, [])\n",
  588.     "        Out[26]: [[1, 2, 'a'], [11, 22, 'b'], [3, 4, 'c'], [33, 44, 'd']]\n",
  589.     "        \"\"\"\n",
  590.     "        train_set = sum(train_set, [])\n",
  591.     "        test_set = list()\n",
  592.     "        # fold 表示从原始数据集 dataset 提取出来的测试集\n",
  593.     "#         for row in fold:\n",
  594.     "#             row_copy = list(row)\n",
  595.     "#             row_copy[-1] = None\n",
  596.     "#             test_set.append(row_copy)\n",
  597.     "        predicted = algorithm(train_set, fold, *args)\n",
  598.     "    \n",
  599.     "        actual = [row[-1] for row in fold]\n",
  600.     "\n",
  601.     "        # 计算随机森林的预测结果的正确率\n",
  602.     "        accuracyValue = accuracy(predicted,actual)\n",
  603.     "        scores.append(accuracyValue)\n",
  604.     "    return scores"
  605.    ]
  606.   },
  607.   {
  608.    "cell_type": "code",
  609.    "execution_count": 41,
  610.    "metadata": {},
  611.    "outputs": [
  612.     {
  613.      "name": "stdout",
  614.      "output_type": "stream",
  615.      "text": [
  616.       "随机因子= 0.13436424411240122\n",
  617.       "决策树个数: 1\n",
  618.       "模型得分: [87.8048780487805, 90.2439024390244, 92.6829268292683, 85.36585365853658, 95.1219512195122]\n",
  619.       "平均准确度: 90.244%\n",
  620.       "随机因子= 0.13436424411240122\n",
  621.       "决策树个数: 10\n",
  622.       "模型得分: [92.6829268292683, 92.6829268292683, 87.8048780487805, 78.04878048780488, 100.0]\n",
  623.       "平均准确度: 90.244%\n"
  624.      ]
  625.     }
  626.    ],
  627.    "source": [
  628.     "    \n",
  629.     "    #综合测试函数\n",
  630.     "    n_folds = 5        # 分成5份数据,进行交叉验证\n",
  631.     "    max_depth = 20     # 调参(自己修改) #决策树深度不能太深,不然容易导致过拟合\n",
  632.     "    min_size = 1       # 决策树的叶子节点最少的元素数量\n",
  633.     "    sample_size = 1.0  # 做决策树时候的样本的比例\n",
  634.     "    # n_features = int((len(dataset[0])-1))\n",
  635.     "    n_features = 15     # 调参(自己修改) #准确性与多样性之间的权衡\n",
  636.     "    for n_trees in [1, 10]:  # 理论上树是越多越好\n",
  637.     "        scores = evaluteAlgorithm(dataSet, randomForest, n_folds, max_depth, min_size, sample_size, n_trees, n_features)\n",
  638.     "        # 每一次执行本文件时都能产生同一个随机数\n",
  639.     "        random.seed(1)\n",
  640.     "        print('随机因子=', random.random())  # 每一次执行本文件时都能产生同一个随机数\n",
  641.     "        print('决策树个数: %d' % n_trees)  # 输出决策树个数\n",
  642.     "        print('模型得分: %s' % scores)  # 输出五份随机样本的模型得分\n",
  643.     "        print('平均准确度: %.3f%%' % (sum(scores)/float(len(scores))))  # 输出五份随机样本的平均准确度\n"
  644.    ]
  645.   },
  646.   {
  647.    "cell_type": "code",
  648.    "execution_count": 42,
  649.    "metadata": {},
  650.    "outputs": [
  651.     {
  652.      "name": "stdout",
  653.      "output_type": "stream",
  654.      "text": [
  655.       "随机因子= 0.13436424411240122\n",
  656.       "决策树个数: 1\n",
  657.       "模型得分: [80.48780487804879, 75.60975609756098, 73.17073170731707, 75.60975609756098, 78.04878048780488]\n",
  658.       "平均准确度: 76.585%\n",
  659.       "随机因子= 0.13436424411240122\n",
  660.       "决策树个数: 10\n",
  661.       "模型得分: [87.8048780487805, 85.36585365853658, 90.2439024390244, 78.04878048780488, 92.6829268292683]\n",
  662.       "平均准确度: 86.829%\n"
  663.      ]
  664.     }
  665.    ],
  666.    "source": [
  667.     "    sample_size =0.5  # 做决策树时候的样本的比例\n",
  668.     "    \n",
  669.     "    for n_trees in [1, 10]:  # 理论上树是越多越好\n",
  670.     "        scores = evaluteAlgorithm(dataSet, randomForest, n_folds, max_depth, min_size, sample_size, n_trees, n_features)\n",
  671.     "        # 每一次执行本文件时都能产生同一个随机数\n",
  672.     "        random.seed(1)\n",
  673.     "        print('随机因子=', random.random())  # 每一次执行本文件时都能产生同一个随机数\n",
  674.     "        print('决策树个数: %d' % n_trees)  # 输出决策树个数\n",
  675.     "        print('模型得分: %s' % scores)  # 输出五份随机样本的模型得分\n",
  676.     "        print('平均准确度: %.3f%%' % (sum(scores)/float(len(scores))))  # 输出五份随机样本的平均准确度"
  677.    ]
  678.   }
  679. ],
  680. 余下见附件
复制代码

全部资料51hei下载地址:
随机森林例子.zip (99.15 KB, 下载次数: 10)

评分

参与人数 1黑币 +50 收起 理由
admin + 50 共享资料的黑币奖励!

查看全部评分

分享到:  QQ好友和群QQ好友和群 QQ空间QQ空间 腾讯微博腾讯微博 腾讯朋友腾讯朋友
收藏收藏 分享淘帖 顶 踩
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

手机版|小黑屋|51黑电子论坛 |51黑电子论坛6群 QQ 管理员QQ:125739409;技术交流QQ群281945664

Powered by 单片机教程网

快速回复 返回顶部 返回列表