결정 트 리 알고리즘, ID3, C 4.5, CART 원리, SparkMllib 의 붓꽃 실전
H X = − ∑ i = 1 n P ( x i ) l o g P ( x i ) HX=-\sum_{i=1}^{n}P_{(x_i)}log^{P_{(x_i)}} HX=−i=1∑nP(xi)logP(xi)
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.DecisionTree
import org.apache.spark.mllib.tree.model.DecisionTreeModel
import org.apache.spark.rdd.RDD
import org.apache.spark.{SparkConf, SparkContext}
object SparkMllibIris1 {
def main(args: Array[String]): Unit = {
// 1.
val conf: SparkConf = new SparkConf().setMaster("local[*]").setAppName("SparkMllibIris1Rdd")
val sc = new SparkContext(conf)
// 2.
val path = "iris.csv"
val rdd: RDD[String] = sc.textFile(path)
// rdd.foreach(println)
// 6.2,3.4,5.4,2.3,Iris-virginica
// 5.9,3.0,5.1,1.8,Iris-virginica
// 3,
// 3-1 LabelPoint rdd API , , ,
var rddLp: RDD[LabeledPoint] = rdd.map(
x => {
val strings: Array[String] = x.split(",")
regression.LabeledPoint(
strings(4) match {
case "Iris-setosa" => 0.0
case "Iris-versicolor" => 1.0
case "Iris-virginica" => 2.0
}
,
Vectors.dense(
strings(0).toDouble,
strings(1).toDouble,
strings(2).toDouble,
strings(3).toDouble))
}
)
// rddLp.foreach(println)
// (1.0,[6.0,2.9,4.5,1.5])
// (0.0,[5.1,3.5,1.4,0.2])
// 4.
val Array(trainData,testData): Array[RDD[LabeledPoint]] = rddLp.randomSplit(Array(0.8,0.2))
// 5.
val decisonModel: DecisionTreeModel = DecisionTree.trainClassifier(trainData,3, Map[Int, Int](),"gini",8,16)
// 6. , ,
// DataFrame , ,RDD ,
val result: RDD[(Double, Double)] = testData.map(
x=> {
val pre: Double = decisonModel.predict(x.features)
(x.label,pre)
}
)
val acc: Double = result.filter(x=>x._1==x._2).count().toDouble /result.count()
println(acc)
println("error", (1-acc))
// 0.9642857142857143
// (error,0.0357142857142857)
}
}
import org.apache.spark.ml.classification.{DecisionTreeClassificationModel, DecisionTreeClassifier}
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.feature.{IndexToString, StringIndexer, StringIndexerModel, VectorAssembler}
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
object SparkMlIris2 {
def main(args: Array[String]): Unit = {
// * 1-
val sparkSession: SparkSession = SparkSession.builder().master("local[*]").appName("SparkMllibIris2").getOrCreate()
// * 2-
// 2-1 CSV , http://spark.apache.org/docs/latest/sql-data-sources-load-save-functions.html
var path = "irisHeader.csv"
// .option("inferSchema", "true"), df schema String
val df: DataFrame = sparkSession.read.format("csv").option("inferSchema", "true").option("header","true").option("sep",",").load(path)
// df.printSchema()
// root
// |-- sepal_length: double (nullable = true)
// |-- sepal_width: double (nullable = true)
// |-- petal_length: double (nullable = true)
// |-- petal_width: double (nullable = true)
// |-- class: string (nullable = true)
//df.show(false)
// +------------+-----------+------------+-----------+-----------+
// |sepal_length|sepal_width|petal_length|petal_width|class |
// +------------+-----------+------------+-----------+-----------+
// |5.1 |3.5 |1.4 |0.2 |Iris-setosa|
//* 4-
//4-1 4
val assembler: VectorAssembler = new VectorAssembler().setInputCols(Array("sepal_length","sepal_width","petal_length","petal_width")).setOutputCol("features")
val assmblerDf: DataFrame = assembler.transform(df)
assmblerDf.show(false)
//4-2 class
val stringIndex: StringIndexer = new StringIndexer().setInputCol("class").setOutputCol("label")
val stingIndexModel: StringIndexerModel = stringIndex.fit(assmblerDf)
val indexDf: DataFrame = stingIndexModel.transform(assmblerDf)
// indexDf.show(false)
// +------------+-----------+------------+-----------+-----------+-----------------+-----+
// |sepal_length|sepal_width|petal_length|petal_width|class |features |label|
// +------------+-----------+------------+-----------+-----------+-----------------+-----+
// |5.1 |3.5 |1.4 |0.2 |Iris-setosa|[5.1,3.5,1.4,0.2]|0.0 |
// |4.9 |3.0 |1.4 |0.2 |Iris-setosa|[4.9,3.0,1.4,0.2]|0.0 |
//4-3 ,
val Array(trainData,testData): Array[Dataset[Row]] = indexDf.randomSplit(Array(0.8,0.2))
// * 5- ,
val classifier: DecisionTreeClassifier = new DecisionTreeClassifier().setFeaturesCol("features").setMaxBins(16).setImpurity("gini").setSeed(10)
val dtcModel: DecisionTreeClassificationModel = classifier.fit(trainData)
// * 6-
val trainPre: DataFrame = dtcModel.transform(trainData)
// * 7-
val testPre: DataFrame = dtcModel.transform(testData)
// * 8-
//val savePath = "E:\\ml\\workspace\\SparkMllibBase\\sparkmllib_part2\\DescitionTree\\model"
//dtcModel.save(savePath)
// trainPre.show(false)
// +------------+-----------+------------+-----------+---------------+-----------------+-----+--------------+-------------+----------+
// |sepal_length|sepal_width|petal_length|petal_width|class |features |label|rawPrediction |probability |prediction|
// +------------+-----------+------------+-----------+---------------+-----------------+-----+--------------+-------------+----------+
// |4.3 |3.0 |1.1 |0.1 |Iris-setosa |[4.3,3.0,1.1,0.1]|0.0 |[47.0,0.0,0.0]|[1.0,0.0,0.0]|0.0 |
// |4.4 |2.9 |1.4 |0.2 |Iris-setosa |[4.4,2.9,1.4,0.2]|0.0 |[47.0,0.0,0.0]|[1.0,0.0,0.0]|0.0 |
// testPre.show(false)
// +------------+-----------+------------+-----------+---------------+-----------------+-----+--------------+-------------+----------+
// |sepal_length|sepal_width|petal_length|petal_width|class |features |label|rawPrediction |probability |prediction|
// +------------+-----------+------------+-----------+---------------+-----------------+-----+--------------+-------------+----------+
// |4.6 |3.2 |1.4 |0.2 |Iris-setosa |[4.6,3.2,1.4,0.2]|0.0 |[47.0,0.0,0.0]|[1.0,0.0,0.0]|0.0 |
// |4.8 |3.4 |1.9 |0.2 |Iris-setosa |[4.8,3.4,1.9,0.2]|0.0 |[47.0,0.0,0.0]|[1.0,0.0,0.0]|0.0 |
// |5.0 |2.0 |3.5 |1.0 |Iris-versicolor|[5.0,2.0,3.5,1.0]|1.0 |[0.0,33.0,0.0]|[0.0,1.0,0.0]|1.0 |
val acc: Double = new MulticlassClassificationEvaluator().setMetricName("accuracy").evaluate(testPre)
println("acc is ", acc)
println("err is", (1-acc))
// 9-
val indexToString: IndexToString = new IndexToString().setInputCol("prediction").setOutputCol("preStringLabel").setLabels(stingIndexModel.labels)
val result: DataFrame = indexToString.transform(testPre)
// result.show(false)
// +------------+-----------+------------+-----------+---------------+-----------------+-----+--------------+-------------------------------------------+----------+---------------+
// |sepal_length|sepal_width|petal_length|petal_width|class |features |label|rawPrediction |probability |prediction|preStringLabel |
// +------------+-----------+------------+-----------+---------------+-----------------+-----+--------------+-------------------------------------------+----------+---------------+
// |4.6 |3.6 |1.0 |0.2 |Iris-setosa |[4.6,3.6,1.0,0.2]|0.0 |[38.0,0.0,0.0]|[1.0,0.0,0.0] |0.0 |Iris-setosa |
// |4.8 |3.4 |1.6 |0.2 |Iris-setosa |[4.8,3.4,1.6,0.2]|0.0 |[38.0,0.0,0.0]|[1.0,0.0,0.0] |0.0 |Iris-setosa |
}
}
이 내용에 흥미가 있습니까?
현재 기사가 여러분의 문제를 해결하지 못하는 경우 AI 엔진은 머신러닝 분석(스마트 모델이 방금 만들어져 부정확한 경우가 있을 수 있음)을 통해 가장 유사한 기사를 추천합니다:
N - Gram 모델 로 데이터 요약 (Python 설명)자연 언어 에는 n - gram 이라는 모델 이 있 는데 문자 나 언어 중의 n 개의 연속 적 인 단어 구성 서열 을 나타 낸다.자연 언어 분석 을 할 때 n - gram 을 사용 하거나 상용 어 구 를 찾 으 면 ...
텍스트를 자유롭게 공유하거나 복사할 수 있습니다.하지만 이 문서의 URL은 참조 URL로 남겨 두십시오.
CC BY-SA 2.5, CC BY-SA 3.0 및 CC BY-SA 4.0에 따라 라이센스가 부여됩니다.