als算法参数_Spark2.0协同过滤与ALS算法介绍
ALS矩陣分解
一個(gè) 的打分矩陣 A 可以用兩個(gè)小矩陣和的乘積來(lái)近似,描述一個(gè)人的喜好經(jīng)常是在一個(gè)抽象的低維空間上進(jìn)行的,并不需要把其喜歡的事物一一列出。再抽象一些,把人們的喜好和電影的特征都投到這個(gè)低維空間,一個(gè)人的喜好映射到了一個(gè)低維向量,一個(gè)電影的特征變成了緯度相同的向量,那么這個(gè)人和這個(gè)電影的相似度就可以表述成這兩個(gè)向量之間的內(nèi)積。
我們把打分理解成相似度,那么“打分矩陣A(m*n)”就可以由“用戶(hù)喜好特征矩陣U(m*k)”和“產(chǎn)品特征矩陣V(n*k)”的乘積。
矩陣分解過(guò)程中所用的優(yōu)化方法分為兩種:交叉最小二乘法(alternative least squares)和隨機(jī)梯度下降法(stochastic gradient descent)。
損失函數(shù)包括正則化項(xiàng)(setRegParam)。
參數(shù)選取
分塊數(shù):分塊是為了并行計(jì)算,默認(rèn)為10。 正則化參數(shù):默認(rèn)為1。 秩:模型中隱藏因子的個(gè)數(shù)顯示偏好信息-false,隱式偏好信息-true,默認(rèn)false(顯示) alpha:只用于隱式的偏好數(shù)據(jù),偏好值可信度底線。 非負(fù)限定 numBlocks is the number of blocks the users and items will be
partitioned into in order to parallelize computation (defaults to
10). rank is the number of latent factors in the model (defaults to 10). maxIter is the maximum number of iterations to run (defaults to 10). regParam specifies the regularization parameter in ALS (defaults to 1.0). implicitPrefs specifies whether to use the explicit feedback ALS variant or one adapted for implicit feedback data (defaults to false
which means using explicit feedback). alpha is a parameter applicable to the implicit feedback variant of ALS that governs the baseline confidence in preference
observations (defaults to 1.0). nonnegative specifies whether or not to use nonnegative constraints for least squares (defaults to false).
ALS als = newALS()
.setMaxIter(10)//最大迭代次數(shù),設(shè)置太大發(fā)生java.lang.StackOverflowError
.setRegParam(0.16)//正則化參數(shù)
.setAlpha(1.0)
.setImplicitPrefs(false)
.setNonnegative(false)
.setNumBlocks(10)
.setRank(10)
.setUserCol("userId")
.setItemCol("movieId")
.setRatingCol("rating");
需要注意的問(wèn)題:
對(duì)于用戶(hù)和物品項(xiàng)ID ,基于DataFrame API 只支持integers,因此最大值限定在integers范圍內(nèi)。
The DataFrame-based API for ALS currently only supports integers for
user and item ids. Other numeric types are supported for the user and
item id columns, but the ids must be within the integer value range.
//循環(huán)正則化參數(shù),每次由Evaluator給出RMSError
List RMSE=new ArrayList();//構(gòu)建一個(gè)List保存所有的RMSE
for(int i=0;i<20;i++){//進(jìn)行20次循環(huán)
double lambda=(i*5+1)*0.01;//RegParam按照0.05增加
ALS als = newALS()
.setMaxIter(5)//最大迭代次數(shù)
.setRegParam(lambda)//正則化參數(shù)
.setUserCol("userId")
.setItemCol("movieId")
.setRatingCol("rating");
ALSModel model=als.fit(training);//Evaluate the model by computing the RMSE on the test data
Dataset predictions =model.transform(test);//RegressionEvaluator.setMetricName可以定義四種評(píng)估器//"rmse" (default): root mean squared error//"mse": mean squared error//"r2": R^2^ metric//"mae": mean absolute error
RegressionEvaluator evaluator = newRegressionEvaluator()
.setMetricName("rmse")//RMS Error
.setLabelCol("rating")
.setPredictionCol("prediction");
Double rmse=evaluator.evaluate(predictions);
RMSE.add(rmse);
System.out.println("RegParam "+0.01*i+" RMSE " + rmse+"\n");
}//輸出所有結(jié)果
for (int j = 0; j < RMSE.size(); j++) {
Double lambda=(j*5+1)*0.01;
System.out.println("RegParam= "+lambda+" RMSE= " + RMSE.get(j)+"\n");
}
通過(guò)設(shè)計(jì)一個(gè)循環(huán),可以研究最合適的參數(shù),部分結(jié)果如下:
RegParam= 0.01 RMSE= 1.956
RegParam= 0.06 RMSE= 1.166
RegParam= 0.11 RMSE= 0.977
RegParam= 0.16 RMSE= 0.962//具備最小的RMSE,參數(shù)最合適
RegParam= 0.21 RMSE= 0.985
RegParam= 0.26 RMSE= 1.021
RegParam= 0.31 RMSE= 1.061
RegParam= 0.36 RMSE= 1.102
RegParam= 0.41 RMSE= 1.144
RegParam= 0.51 RMSE= 1.228
RegParam= 0.56 RMSE= 1.267
RegParam= 0.61 RMSE= 1.300
//將RegParam固定在0.16,繼續(xù)研究迭代次數(shù)的影響
輸出如下的結(jié)果,在單機(jī)環(huán)境中,迭代次數(shù)設(shè)置過(guò)大,會(huì)出現(xiàn)一個(gè)java.lang.StackOverflowError異常。是由于當(dāng)前線程的棧滿(mǎn)了引起的。
numMaxIteration= 1 RMSE= 1.7325
numMaxIteration= 4 RMSE= 1.0695
numMaxIteration= 7 RMSE= 1.0563
numMaxIteration= 10 RMSE= 1.055
numMaxIteration= 13 RMSE= 1.053
numMaxIteration= 16 RMSE= 1.053
//測(cè)試Rank隱含語(yǔ)義個(gè)數(shù)
Rank =1 RMSErr = 1.1584
Rank =3 RMSErr = 1.1067
Rank =5 RMSErr = 0.9366
Rank =7 RMSErr = 0.9745
Rank =9 RMSErr = 0.9440
Rank =11 RMSErr = 0.9458
Rank =13 RMSErr = 0.9466
Rank =15 RMSErr = 0.9443
Rank =17 RMSErr = 0.9543
//可以用SPARK-SQL自己定義評(píng)估算法(如下面定義了一個(gè)平均絕對(duì)值誤差計(jì)算過(guò)程)//Register the DataFrame as a SQL temporary view
predictions.createOrReplaceTempView("tmp_predictions");
Dataset absDiff=spark.sql("select abs(prediction-rating) as diff from tmp_predictions");
absDiff.createOrReplaceTempView("tmp_absDiff");
spark.sql("select mean(diff) as absMeanDiff from tmp_absDiff").show();
完整代碼
public class Rating implements Serializable{...}
可以在 http://spark.apache.org/docs/latest/ml-collaborative-filtering.html找到:
packagemy.spark.ml.practice.classification;importorg.apache.spark.api.java.function.Function;importorg.apache.spark.ml.evaluation.RegressionEvaluator;importorg.apache.spark.ml.recommendation.ALS;importorg.apache.spark.ml.recommendation.ALSModel;importorg.apache.log4j.Level;importorg.apache.log4j.Logger;importorg.apache.spark.api.java.JavaRDD;importorg.apache.spark.sql.Dataset;importorg.apache.spark.sql.Row;importorg.apache.spark.sql.SparkSession;public classmyCollabFilter2 {public static voidmain(String[] args) {
SparkSession spark=SparkSession
.builder()
.appName("CoFilter")
.master("local[4]")
.config("spark.sql.warehouse.dir","file///:G:/Projects/Java/Spark/spark-warehouse")
.getOrCreate();
String path="G:/Projects/CgyWin64/home/pengjy3/softwate/spark-2.0.0-bin-hadoop2.6/"
+ "data/mllib/als/sample_movielens_ratings.txt";//屏蔽日志
Logger.getLogger("org.apache.spark").setLevel(Level.WARN);
Logger.getLogger("org.eclipse.jetty.server").setLevel(Level.OFF);//-------------------------------1.0 準(zhǔn)備DataFrame----------------------------//..javaRDD()函數(shù)將DataFrame轉(zhuǎn)換為RDD//然后對(duì)RDD進(jìn)行Map 每一行String->Rating
JavaRDD ratingRDD=spark.read().textFile(path).javaRDD()
.map(newFunction() {
@Overridepublic Rating call(String str) throwsException {returnRating.parseRating(str);
}
});//System.out.println(ratingRDD.take(10).get(0).getMovieId());//由JavaRDD(每一行都是一個(gè)實(shí)例化的Rating對(duì)象)和Rating Class創(chuàng)建DataFrame
Dataset ratings=spark.createDataFrame(ratingRDD, Rating.class);//ratings.show(30);//將數(shù)據(jù)隨機(jī)分為訓(xùn)練集和測(cè)試集
double[] weights=new double[] {0.8,0.2};long seed=1234;
Dataset [] split=ratings.randomSplit(weights, seed);
Dataset training=split[0];
Dataset test=split[1];//------------------------------2.0 ALS算法和訓(xùn)練數(shù)據(jù)集,產(chǎn)生推薦模型-------------
for(int rank=1;rank<20;rank++)
{//定義算法
ALS als=newALS()
.setMaxIter(5)最大迭代次數(shù),設(shè)置太大發(fā)生java.lang.StackOverflowError
.setRegParam(0.16)
.setUserCol("userId")
.setRank(rank)
.setItemCol("movieId")
.setRatingCol("rating");//訓(xùn)練模型
ALSModel model=als.fit(training);//---------------------------3.0 模型評(píng)估:計(jì)算RMSE,均方根誤差---------------------
Dataset predictions=model.transform(test);//predictions.show();
RegressionEvaluator evaluator=newRegressionEvaluator()
.setMetricName("rmse")
.setLabelCol("rating")
.setPredictionCol("prediction");
Double rmse=evaluator.evaluate(predictions);
System.out.println("Rank =" + rank+" RMSErr = " +rmse);
}
}
}
《新程序員》:云原生和全面數(shù)字化實(shí)踐50位技術(shù)專(zhuān)家共同創(chuàng)作,文字、視頻、音頻交互閱讀總結(jié)
以上是生活随笔為你收集整理的als算法参数_Spark2.0协同过滤与ALS算法介绍的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: 宝塔面板怎么运行python_在宝塔面板
- 下一篇: hamming weight_popco