模型sql文件?:https://pan.baidu.com/s/1hugrI9e 使用數據鏈接?https://pan.baidu.com/s/1kWz8fNh NaiveBayes Spark Mllib訓練?
package com.xxx.xxx.xxximport java.io.ObjectInputStream
import java.sql.{Connection, DriverManager, PreparedStatement}
import java.util.{Arrays, Date, Scanner}import org.ansj.splitWord.analysis.ToAnalysis
import org.apache.spark.mllib.classification.{NaiveBayes, NaiveBayesModel}
import org.apache.spark.mllib.feature.{HashingTF, IDF}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.sql.DataFrame
import org.apache.spark.{SparkConf, SparkContext}import scala.collection.mutable.ArrayBuffer/*** Created by Zsh on 1/31 0031.*/
object WudiBayesModel {var conn: Connection = nullvar stmt: PreparedStatement = nullval outputPath = "/zshModel"val driverUrl = "jdbc:mysql://192.168.2.107:3306/data_mining?user=xxx&password=xxx&zeroDateTimeBehavior=convertToNull&characterEncoding=utf-8&autoDeserialize=true"var df :DataFrame = nullval classify:String = "健康"var model:NaiveBayesModel = nullval lables = "社會|體育|汽車|女性|新聞|科技|財經|軍事|廣告|娛樂|健康|教育|旅游|生活|文化"def main(args: Array[String]): Unit ={// training(lables)// val text = "今年“五一”假期雖然縮短為三天,但來個“周邊游”卻正逢其時。昨日記者從保定市旅游部門了解到,5月1日—3日,該市滿城漢墓景區將舉辦“2008全國滑翔傘俱樂部聯賽”第一站比賽。屆時將有來自全國各地的滑翔傘高手云集陵山,精彩上演自由飛翔。\uE40C滑翔傘俱樂部舉辦聯賽為我國第一次,所有參賽運動員均在國際或國內大型比賽中取得過名次,并且所有運動員必須持2008年貼花的中國航空運動協會會員證書和B級以上滑翔傘運動證書,所使用的比賽動作均為我國滑翔傘最新動作。本屆比賽的項目,除保留傳統的“留空時間賽”和“精確著陸賽”以外,還增加了“盤升高度賽”等內容。屆時,參賽運動員將沖擊由保定市運動員韓均創造的1450米的盤升高度記錄。截至目前,已有11個省市的50多名運動員報名參賽,其中包括多名外籍運動員和7名女運動員?!?#xff08;來源:燕趙晚報)\uE40C(責任編輯:李妍)"//測試mysql讀取的模型
// BayesUtils.testMysql(text,lables)// inputTestModel()}//手動輸入數據測試模型def inputTestModel(): Unit ={val scan = new Scanner(System.in)val startTime = new Date().getTimeloadModelval time2 =new Date().getTimeprintln("加載模型時間:"+(time2-startTime))println("模型加載完畢-----")while(true) {val str = scan.nextLine()testData(model,str,lables)println("---------------------------------")}}//批量測試某個類準確率def batchTesting(): Unit ={var time2 =new Date().getTimeval result = df.map(x=>testData(model,x.getAs("content").toString,lables))var time3 =new Date().getTimeprintln("預測需要時間:"+ (time3-time2))println("準確率:" + result.filter(_.equals(classify)).count().toDouble/result.count())}//加載模型def loadModel(){val conf = new SparkConf().setAppName("NaiveBayesExample1").setMaster("local").set("spark.serializer", "org.apache.spark.serializer.KryoSerializer").set("spark.kryoserializer.buffer.max", "1024mb")val sc =new SparkContext(conf)val sqlContext = new org.apache.spark.sql.SQLContext(sc)val model = NaiveBayesModel.load(sc,outputPath)val jdbcDF = sqlContext.read.options(Map("url" -> driverUrl, "dbtable" -> "industry_classify_tmp")).format("jdbc").load()jdbcDF.registerTempTable("testData")val sql = "select content from testData where classify in ('"+classify+"')"df = sqlContext.sql(sql)}def testModel()={val conf = new SparkConf().setAppName("NaiveBayesExample1").setMaster("local").set("spark.serializer", "org.apache.spark.serializer.KryoSerializer").set("spark.kryoserializer.buffer.max", "1024mb")val sc =new SparkContext(conf)val sqlContext = new org.apache.spark.sql.SQLContext(sc)val model = NaiveBayesModel.load(sc,outputPath)model}//測試數據類型def testData(model :NaiveBayesModel,text:String,labels_name:String)={// val text= "新浪微博采集"val dim = math.pow(2, 20).toIntval hashingTF= new HashingTF(dim)val tfVector = hashingTF.transform(tokenizer(text))val d = model.predict(tfVector)// val labels_name = "社會|體育|汽車|女性|新聞|科技|財經|軍事|娛樂|健康|教育|旅游|文化"val list2 = labels_name.split("\\|").toList//println(list2(d.toInt))println("result:"+list2(d.toInt) + " " + d + " " + text)list2(d.toInt)}//訓練模型def training(labels_name:String): Unit ={//全部類型標簽// val labels_name = "社會|體育|汽車|女性|新聞|科技|財經|軍事|娛樂|健康|教育|旅游|文化"val list2 = labels_name.split("\\|").toList//標簽轉化list對應(0 - list.length)的listvar num=0.0 to labels_name.split("\\|").length.toDouble by 1 toListval tuples = list2.zip(num).toMapval temp= labels_name.split("\\|").toList //.toList.zip(0 to labelsname.split("\\|").length)var str:String = ""for(i<-0 to temp.length-1){if(i<temp.length-1)str=str+"""""""+temp(i)+"""","""elsestr=str+"""""""+temp(i)+"""""""}val conf = new SparkConf().setAppName("NaiveBayesExample1").setMaster("local[4]")// .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")// .set("spark.kryoserializer.buffer.max", "1024mb")val sc =new SparkContext(conf)val sqlContext = new org.apache.spark.sql.SQLContext(sc)val jdbcDF = sqlContext.read.options(Map("url" -> driverUrl, "dbtable" -> "industry_classify_tmp")).format("jdbc").load()jdbcDF.registerTempTable("testData")val sql = "select content,classify from testData where classify in ("+str+")"// println("str"+str)//val jdbcDF = sqlContext.read.options(Map("url" -> driverUrl, "dbtable" -> "industry_classify_tmp")).format("jdbc").load()// jdbcDF.registerTempTable("testData")// val sql = "select content,classify from testData where classify in ('"+str+"')"// df = sqlContext.sql(sql)//從Mysql讀取訓練模型所需數據val trainData = sqlContext.sql(sql)// val trainData = dfprintln("trcount:"+trainData.count())//獲取正文與標簽字段,eg.(text,教育)val trainData1 = trainData.map(x=>(x.getAs("content").toString,x.getAs("classify").toString))//將正文分詞,標簽用數字替換val trainData2 = trainData1.map(x=>(tokenizer(x._1),tuples(x._2)))//tfidf訓練所需的部分val cData = trainData2.map(_._1)//標簽字段(1,2,3,4,5,6,7,8...)val clData = trainData2.map(_._2)//設置向量維度,該值越大模型占用空間越大,河里設置該值val dim = math.pow(2, 16).toInt//計算TFval hashTF= new HashingTF(dim)val tf = hashTF.transform(cData).cache()//計算idfval hashIDF = new IDF().fit(tf)val idf = hashIDF.transform(tf)//將計算后的向量與標簽字段關聯val zip = clData.zip(idf)//轉化為可訓練的類型LabeledPointval tData = zip.map{case (label,vector) =>LabeledPoint(label,vector)}//切分數據60%訓練數據,40%驗證數據val splits = tData.randomSplit(Array(0.7, 0.3), seed = 11L)val trData = splits(0).cache()val teData = splits(1).cache()val model = NaiveBayes.train(trData,lambda = 1.0, modelType = "multinomial")model.save(sc,outputPath)println("save model success !")//將model轉換BayesModelData2,保存到mysqlval data = BayesModelData2(model.labels.toArray,model.pi.toArray,model.theta.map(_.toArray).toArray,"multinomial")//保存到mysqlserializeToMysql(data)//===============模型驗證val testAndLabel = teData.map(x=>{(model.predict(x.features),x.label)})// println("****************************")// testAndLabel.foreach(println)// println("****************************")val total = testAndLabel.count()//已知分類val totalPostiveNum = testAndLabel.filter(x => x._2 == 11.0).count()//預測結果val totalTrueNum = testAndLabel.filter(x => x._1 == 11.0).count()//某一類別預測正確數val testRealTrue = testAndLabel.filter(x => x._1 == x._2 && x._2 == 11.0).count()//全部預測正確數val testReal = testAndLabel.filter(x => x._1 == x._2).count()val testAccuracy = 1.0 * testReal / totalval testPrecision = 1.0 * testRealTrue / totalTrueNumval testRecall = 1.0 * testRealTrue / totalPostiveNumprintln("統計分類準確率:============================")println("準確率:", testAccuracy) //預測正確數/預測總數 Accuracy=(TP+TN)/(TP+FP+TN+FN) Error= (FP+FN)/(TP+FP+TN+FN)println("精確度:", testPrecision) //預測為P實際T/實際為P 查準率 Precision=TP/(TP+FP)println("召回率:", testRecall) //預測為P實際T/實際為T 查全率 Recall=TP/(TP+FN)// val accuracy = 1.0 * testAndLabel.filter(x => x._1 == x._2).count() / teData.count()println("模型準確度============================")}def tokenizer(line: String): Seq[String] = {val reg1 = "@\\w{2,20}:".rval reg2 = "http://[0-9a-zA-Z/\\?&#%$@\\=\\\\]+".rAnsjSegment(line).split(",").filter(_!=null).filter(token => !reg1.pattern.matcher(token).matches).filter(token => !reg2.pattern.matcher(token).matches)// .filter(token => !stopwordSet.contains(token)).toSeq}def AnsjSegment(line: String): String={val StopNatures="""w","",null,"s", "f", "b", "z", "r", "q", "d", "p", "c", "uj", "ul","en", "y", "o", "h", "k", "x"""val KeepNatures=List("n","v","a","m","t")val StopWords=Arrays.asList("的", "是","了") //Arrays.asList(stopwordlist.toString())//val filter = new FilterRecognition()//加入停用詞//filter.insertStopWords(StopWords)//加入停用詞性//filter.insertStopNatures(StopNatures)//filter.insertStopRegex("小.*?")//此步驟將會只取分詞,不附帶詞性//for (i <- Range(0, filter1.size())) {//word += words.get(i).getName//}val words = ToAnalysis.parse(line)val word = ArrayBuffer[String]()for (i <- Range(0,words.size())) { //KeepNatures.contains(words.get(i).getNatureStr.substring(0,1))&&if(KeepNatures.contains(words.get(i).getNatureStr.substring(0,1))&&words.get(i).getName.length()>=2)word += words.get(i).getName}// println(word)word.mkString(",")}//保存到mysqldef serializeToMysql[T](o: BayesModelData2) { //文件序列化val model_Id = "test"new MysqlConn()val query="replace into "+"ams_recommender_model"+"(model_ID,model_Data) values (?,?)"stmt=conn.prepareStatement(query)stmt.setString(1, model_Id)stmt.setObject(2,o)stmt.executeQuery()conn.close()}class MysqlConn() {val trainning_url="jdbc:mysql://192.168.2.107:3306/data_mining?user=xxx&password=xxx&zeroDateTimeBehavior=convertToNull&characterEncoding=utf-8"try {//當前使用訓練和輸出同一個url,以后可以分為兩個conn = DriverManager.getConnection(trainning_url, "xxx", "xxx")} catch {case e: Exception => println("mysql連接異常")}}//從mysql取出,并將類型轉換def deserializeFromMysql[T](): BayesModelData2 = { //文件反序列化 bytes: Array[Byte]new MysqlConn()val model_Id = "test"val query="select model_Data from "+"ams_recommender_model"+" where model_ID='"+ model_Id +"' "stmt=conn.prepareStatement(query)val resultSet = stmt.executeQuery()resultSet.next()val bis= resultSet.getBlob("model_Data").getBinaryStream()val ois = new ObjectInputStream(bis)conn.close()ois.readObject.asInstanceOf[BayesModelData2]}
}
?
調用BayesUtils類,目錄必須是org.apache.spark ,因為NaiveBayesModel是private[spark]私有的 參考:How to use BLAS library in Spark (Symbol BLAS is inaccessible from this space) - spark:http://note.youdao.com/noteshare?id=7f1eec90cc6e56303d06ff92422c29b6&sub=wcp151747625212826
調用
package org.apache.sparkimport com.izhonghong.mission.learn.BayesModelData2
import com.izhonghong.mission.learn.WudiBayesModel.{deserializeFromMysql, tokenizer}
import org.apache.spark.mllib.classification.NaiveBayesModel
import org.apache.spark.mllib.feature.HashingTF
import org.apache.spark.mllib.linalg.{DenseMatrix, DenseVector, Vector}/*** Created by Zsh on 2/1 0001.*/
object BayesUtils {//測試mysql讀取出的模型def testMysql(text:String,labels_name:String){val hashingTF= new HashingTF()val tfVector = hashingTF.transform(tokenizer(text))val BayesModelData2 = deserializeFromMysql()val model = new NaiveBayesModel(BayesModelData2.labels,BayesModelData2.pi,BayesModelData2.theta,BayesModelData2.modelType)val d = model.predict(tfVector)
// val d = predict(BayesModelData2,tfVector)val list2 = labels_name.split("\\|").toListlist2(d.toInt)println("result:"+list2(d.toInt) + " " + d + " " + text)}//預測返回類別,NaiveBayesModel源碼中提取,最初因為NaiveBayesModel無法引用,后來講源碼提取出來發現,在spark目錄下就可以new NaiveBayesModeldef predict(bayesModel :BayesModelData2,tfVector:Vector)={val thetaMatrix = new DenseMatrix(bayesModel.labels.length, bayesModel.theta(0).length, bayesModel.theta.flatten, true)val piVector = new DenseVector(bayesModel.pi)val prob = thetaMatrix.multiply(tfVector)org.apache.spark.mllib.linalg.BLAS.axpy( 1.0, piVector, prob)val d = bayesModel.labels(prob.argmax)d}}
<dependency><groupId>org.apache.spark</groupId><artifactId>spark-mllib_2.10</artifactId><version>1.6.0</version></dependency><dependency><groupId>org.ansj</groupId><artifactId>ansj_seg</artifactId><version>5.0.4</version></dependency>
全部配置文件
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"><modelVersion>4.0.0</modelVersion><groupId>com.xxx</groupId><artifactId>xxx-xxx-xxx</artifactId><version>1.0-SNAPSHOT</version><properties><maven.compiler.source>1.6</maven.compiler.source><maven.compiler.target>1.6</maven.compiler.target><encoding>UTF-8</encoding><scala.tools.version>2.10</scala.tools.version><scala.version>2.10.6</scala.version><hbase.version>1.2.2</hbase.version></properties><dependencies><!-- <dependency><groupId>org.apache.spark</groupId><artifactId>spark-mllib_2.11</artifactId><version>2.1.0</version></dependency>--><!--<dependency><groupId>org.apache.spark</groupId><artifactId>spark-mllib_2.11</artifactId><version>1.6.0</version></dependency>--><!-- <dependency><groupId>com.hankcs</groupId><artifactId>hanlp</artifactId><version>portable-1.5.0</version></dependency>--><dependency><groupId>org.apache.spark</groupId><artifactId>spark-mllib_2.10</artifactId><version>1.6.0</version></dependency><dependency><groupId>org.ansj</groupId><artifactId>ansj_seg</artifactId><version>5.0.4</version></dependency><dependency><groupId>org.scala-lang</groupId><artifactId>scala-library</artifactId><version>2.10.6</version></dependency><dependency><groupId>org.apache.kafka</groupId><artifactId>kafka-clients</artifactId><version>0.10.0.0</version></dependency><dependency><groupId>net.sf.json-lib</groupId><classifier>jdk15</classifier><artifactId>json-lib</artifactId><version>2.4</version></dependency><dependency><groupId>org.apache.spark</groupId><artifactId>spark-streaming-kafka_2.10</artifactId><version>1.6.2</version></dependency><!-- <dependency> <groupId>org.apache.spark</groupId> <artifactId>spark-streaming-kafka-0-10_2.10</artifactId><version>2.1.1</version> </dependency> --><dependency><groupId>org.apache.spark</groupId><artifactId>spark-streaming_2.10</artifactId><version>1.6.2</version><exclusions><exclusion><artifactId>scala-library</artifactId><groupId>org.scala-lang</groupId></exclusion></exclusions></dependency><!-- <dependency> <groupId>org.apache.spark</groupId> <artifactId>spark-streaming_2.10</artifactId><version>2.1.1</version> <scope>provided</scope> </dependency> --><dependency><groupId>com.huaban</groupId><artifactId>jieba-analysis</artifactId><version>1.0.2</version></dependency><dependency><groupId>com.alibaba</groupId><artifactId>fastjson</artifactId><version>1.2.14</version></dependency><dependency><groupId>redis.clients</groupId><artifactId>jedis</artifactId><version>2.9.0</version></dependency><dependency><groupId>org.scala-lang</groupId><artifactId>scala-library</artifactId><version>${scala.version}</version></dependency><dependency><groupId>org.apache.hbase</groupId><artifactId>hbase-server</artifactId><version>1.2.2</version><exclusions><exclusion><artifactId>servlet-api-2.5</artifactId><groupId>org.mortbay.jetty</groupId></exclusion></exclusions></dependency><!-- <dependency><groupId>com.alibaba</groupId><artifactId>fastjson</artifactId><version>1.2.18</version></dependency>--><dependency><groupId>org.apache.spark</groupId><artifactId>spark-core_2.10</artifactId><version>1.6.2</version><!-- <version>2.1.1</version> --></dependency><dependency><groupId>org.apache.hadoop</groupId><artifactId>hadoop-client</artifactId><version>2.7.0</version></dependency><dependency><groupId>org.apache.hadoop</groupId><artifactId>hadoop-common</artifactId><version>2.7.0</version></dependency><dependency><groupId>org.apache.hadoop</groupId><artifactId>hadoop-hdfs</artifactId><version>2.7.0</version><exclusions><exclusion><groupId>javax.servlet.jsp</groupId><artifactId>*</artifactId></exclusion><exclusion><artifactId>servlet-api</artifactId><groupId>javax.servlet</groupId></exclusion></exclusions></dependency><dependency><groupId>org.apache.spark</groupId><artifactId>spark-sql_2.10</artifactId><version>1.6.2</version></dependency><dependency><groupId>org.apache.spark</groupId><artifactId>spark-hive_2.10</artifactId><version>1.6.2</version></dependency><dependency><groupId>mysql</groupId><artifactId>mysql-connector-java</artifactId><version>5.1.39</version></dependency><!--<dependency><groupId>org.apache.hbase</groupId><artifactId>hbase-server</artifactId><version>1.2.2</version></dependency>--><!-- Test --><dependency><groupId>junit</groupId><artifactId>junit</artifactId><version>4.11</version><scope>test</scope></dependency><dependency><groupId>org.specs2</groupId><artifactId>specs2_${scala.tools.version}</artifactId><version>1.13</version><scope>test</scope></dependency><dependency><groupId>org.scalatest</groupId><artifactId>scalatest_${scala.tools.version}</artifactId><version>2.0.M6-SNAP8</version><scope>test</scope></dependency></dependencies><build><plugins><plugin><groupId>net.alchim31.maven</groupId><artifactId>scala-maven-plugin</artifactId><version>3.2.0</version><executions><execution><goals><goal>compile</goal><goal>testCompile</goal></goals></execution></executions></plugin><plugin><groupId>org.apache.maven.plugins</groupId><artifactId>maven-jar-plugin</artifactId><configuration><archive><manifest><addClasspath>true</addClasspath><classpathPrefix>lib/</classpathPrefix><mainClass></mainClass></manifest></archive></configuration></plugin><plugin><groupId>org.apache.maven.plugins</groupId><artifactId>maven-compiler-plugin</artifactId><configuration><source>1.8</source><target>1.8</target></configuration></plugin><plugin><groupId>org.apache.maven.plugins</groupId><artifactId>maven-dependency-plugin</artifactId><executions><execution><id>copy</id><phase>package</phase><goals><goal>copy-dependencies</goal></goals><configuration><outputDirectory>${project.build.directory}/lib</outputDirectory></configuration></execution></executions></plugin></plugins></build><!-- <build> <plugins> <plugin> <artifactId>maven-assembly-plugin</artifactId><configuration> <archive> <manifest> 這里要替換成jar包main方法所在類 <mainClass>com.sf.pps.client.IntfClientCall</mainClass></manifest> <manifestEntries> <Class-Path>.</Class-Path> </manifestEntries></archive> <descriptorRefs> <descriptorRef>jar-with-dependencies</descriptorRef></descriptorRefs> </configuration> <executions> <execution> <id>make-assembly</id>this is used for inheritance merges <phase>package</phase> 指定在打包節點執行jar包合并操作<goals> <goal>single</goal> </goals> </execution> </executions> </plugin></plugins> </build> --></project>
?
?
總結
以上是生活随笔 為你收集整理的Spark NaiveBayes Demo 朴素贝叶斯分类算法 的全部內容,希望文章能夠幫你解決所遇到的問題。
如果覺得生活随笔 網站內容還不錯,歡迎將生活随笔 推薦給好友。