xgboost之spark上运行-scala接口
概述
xgboost可以在spark上運行,我用的xgboost的版本是0.7的版本,目前只支持spark2.0以上版本上運行,
編譯好jar包,加載到maven倉庫里面去:
?
??
?
添加依賴:
?
?<dependency>
<groupId>ml.dmlc</groupId>
<artifactId>xgboost4j-spark</artifactId>
<version>0.7</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-core_2.10</artifactId>
<version>2.0.0</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-mllib_2.10</artifactId>
<version>2.0.0</version>
</dependency>
</dependencies>
?
?
RDD接口:
?
?package com.meituan.spark_xgboost
import org.apache.log4j.{ Level, Logger }
import org.apache.spark.{ SparkConf, SparkContext }
import ml.dmlc.xgboost4j.scala.spark.XGBoost
import org.apache.spark.sql.{ SparkSession, Row }
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.Vectors
object XgboostR {
def main(args: Array[String]): Unit = {
Logger.getLogger("org.apache.spark").setLevel(Level.ERROR)
Logger.getLogger("org.eclipse.jetty.server").setLevel(Level.OFF)
val spark = SparkSession.builder.master("local").appName("example").
config("spark.sql.warehouse.dir", s"file:///Users/shuubiasahi/Documents/spark-warehouse").
config("spark.sql.shuffle.partitions", "20").getOrCreate()
spark.conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
val path = "/Users/shuubiasahi/Documents/workspace/xgboost/demo/data/"
val trainString = "agaricus.txt.train"
val testString = "agaricus.txt.test"
val train = MLUtils.loadLibSVMFile(spark.sparkContext, path + trainString)
val test = MLUtils.loadLibSVMFile(spark.sparkContext, path + testString)
val traindata = train.map { x =>
val f = x.features.toArray
val v = x.label
LabeledPoint(v, Vectors.dense(f))
}
val testdata = test.map { x =>
val f = x.features.toArray
val v = x.label
Vectors.dense(f)
}
val numRound = 15
//"objective" -> "reg:linear", //定義學習任務及相應的學習目標
//"eval_metric" -> "rmse", //校驗數據所需要的評價指標 用于做回歸
val paramMap = List(
"eta" -> 1f,
"max_depth" ->5, //數的最大深度。缺省值為6 ,取值范圍為:[1,∞]
"silent" -> 1, //取0時表示打印出運行時信息,取1時表示以緘默方式運行,不打印運行時信息。缺省值為0
"objective" -> "binary:logistic", //定義學習任務及相應的學習目標
"lambda"->2.5,
"nthread" -> 1 //XGBoost運行時的線程數。缺省值是當前系統可以獲得的最大線程數
).toMap
println(paramMap)
val model = XGBoost.trainWithRDD(traindata, paramMap, numRound, 55, null, null, useExternalMemory = false, Float.NaN)
print("sucess")
val result=model.predict(testdata)
result.take(10).foreach(println)
spark.stop();
}
}
?
DataFrame接口:
?
?package com.meituan.spark_xgboost
import org.apache.log4j.{ Level, Logger }
import org.apache.spark.{ SparkConf, SparkContext }
import ml.dmlc.xgboost4j.scala.spark.XGBoost
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
import org.apache.spark.sql.{ SparkSession, Row }
object XgboostD {
def main(args: Array[String]): Unit = {
Logger.getLogger("org.apache.spark").setLevel(Level.ERROR)
Logger.getLogger("org.eclipse.jetty.server").setLevel(Level.OFF)
val spark = SparkSession.builder.master("local").appName("example").
config("spark.sql.warehouse.dir", s"file:///Users/shuubiasahi/Documents/spark-warehouse").
config("spark.sql.shuffle.partitions", "20").getOrCreate()
spark.conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
val path = "/Users/shuubiasahi/Documents/workspace/xgboost/demo/data/"
val trainString = "agaricus.txt.train"
val testString = "agaricus.txt.test"
val train = spark.read.format("libsvm").load(path + trainString).toDF("label", "feature")
val test = spark.read.format("libsvm").load(path + testString).toDF("label", "feature")
val numRound = 15
//"objective" -> "reg:linear", //定義學習任務及相應的學習目標
//"eval_metric" -> "rmse", //校驗數據所需要的評價指標 用于做回歸
val paramMap = List(
"eta" -> 1f,
"max_depth" -> 5, //數的最大深度。缺省值為6 ,取值范圍為:[1,∞]
"silent" -> 1, //取0時表示打印出運行時信息,取1時表示以緘默方式運行,不打印運行時信息。缺省值為0
"objective" -> "binary:logistic", //定義學習任務及相應的學習目標
"lambda" -> 2.5,
"nthread" -> 1 //XGBoost運行時的線程數。缺省值是當前系統可以獲得的最大線程數
).toMap
val model = XGBoost.trainWithDataFrame(train, paramMap, numRound, 45, obj = null, eval = null, useExternalMemory = false, Float.NaN, "feature", "label")
val predict = model.transform(test)
val scoreAndLabels = predict.select(model.getPredictionCol, model.getLabelCol)
.rdd
.map { case Row(score: Double, label: Double) => (score, label) }
//get the auc
val metric = new BinaryClassificationMetrics(scoreAndLabels)
val auc = metric.areaUnderROC()
println("auc:" + auc)
}
總結
以上是生活随笔為你收集整理的xgboost之spark上运行-scala接口的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 基本套接字TCP和UDP
- 下一篇: 聊聊JVM(一)相对全面的GC总结