Created
April 28, 2017 21:27
-
-
Save ramv/0093095fa87bef02483488cde44146ce to your computer and use it in GitHub Desktop.
Similarity Analysis Example
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import org.apache.spark.api.java.JavaPairRDD; | |
import org.apache.spark.api.java.JavaRDD; | |
import org.apache.spark.api.java.JavaSparkContext; | |
import org.apache.spark.api.java.function.Function; | |
import org.apache.spark.api.java.function.PairFunction; | |
import org.apache.spark.mllib.recommendation.Rating; | |
import org.apache.spark.sql.DataFrame; | |
import org.apache.spark.sql.Row; | |
import org.apache.spark.sql.RowFactory; | |
import org.apache.spark.sql.SQLContext; | |
import org.slf4j.Logger; | |
import org.slf4j.LoggerFactory; | |
import scala.Tuple2; | |
import java.util.Iterator; | |
/** | |
* TODO fix some bugs and test with MovieLens data | |
*/ | |
public class SimilarityAnalysis { | |
static final Logger LOGGER = LoggerFactory.getLogger(CoOccurence.class); | |
private static void computeSimilarVideos(JavaRDD<Rating> ratings, JavaSparkContext jsc, String outDir){ | |
/** | |
* Parameters to regularize correlation. | |
*/ | |
double PRIOR_COUNT = 10; | |
double PRIOR_CORRELATION = 0; | |
SQLContext sqlContext = new SQLContext(jsc); | |
DataFrame ratingsDf = sqlContext.createDataFrame(ratings, DmRating.class); | |
// get num raters per movie, keyed on movie id | |
DataFrame numRatingsPerVideo = ratingsDf.groupBy("modVideoId").count(); | |
numRatingsPerVideo.show(); | |
// join ratings with num raters on movie id | |
// ratingsWithSize now contains the following fields: (user, movie, rating, numRaters). | |
DataFrame ratingsWithSize = ratingsDf.join(numRatingsPerVideo); | |
ratingsWithSize.show(); | |
JavaPairRDD<String, Row> userIdKey = ratingsWithSize | |
.toJavaRDD() | |
.keyBy(new Function<Row, String>() { | |
@Override | |
public String call(Row v1) throws Exception { | |
return v1.getString(5); | |
} | |
}); | |
JavaPairRDD<String, Row> userIdKey2 = ratingsWithSize | |
.toJavaRDD() | |
.keyBy(new Function<Row, String>() { | |
@Override | |
public String call(Row v1) throws Exception { | |
return v1.getString(5); | |
} | |
}); | |
LOGGER.info("number of userIdKey {} userIdKey2 {}",userIdKey.count(), userIdKey2.count()); | |
JavaPairRDD<String, Tuple2<Row, Row>> ratingPairs = userIdKey2 | |
.join(userIdKey) | |
.filter(new Function<Tuple2<String, Tuple2<Row, Row>>, Boolean>() { | |
@Override | |
public Boolean call(Tuple2<String, Tuple2<Row, Row>> v1) throws Exception { | |
return v1._2()._1().getInt(2) < v1._2()._2().getInt(2); | |
} | |
}); | |
LOGGER.info("number of rating pairs {}",ratingPairs.count()); | |
// compute raw inputs to similarity metrics for each movie pair | |
JavaPairRDD<Tuple2<Integer, Integer>, Row> pairStats = ratingPairs.mapToPair((item)-> { | |
// this tuple contains videoIds of the pairs | |
Tuple2<Integer, Integer> videoIdPairs = new Tuple2<>(item._2()._1().getInt(1), item._2()._2().getInt(1)); | |
Row row = RowFactory.create( | |
item._2()._1().getInt(2) * item._2()._2().getInt(2), // rating 1 * rating 2 | |
item._2()._1().getInt(2), // rating movie 1 | |
item._2()._2().getInt(2), // rating movie 2 | |
Math.pow(item._2()._1().getInt(2), 2), // square of rating movie 1 | |
Math.pow(item._2()._2().getInt(2), 2), // square of rating movie 2 | |
item._2()._1().getInt(3), // number of raters movie 1 | |
item._2()._2().getInt(3)); // number of raters movie 2 | |
return new Tuple2<>(videoIdPairs, row); | |
}); | |
LOGGER.info("number of pair stats {}",pairStats.count()); | |
JavaPairRDD<Tuple2<Integer, Integer>, Row> vectorCals = pairStats.groupByKey().mapToPair(new PairFunction<Tuple2<Tuple2<Integer, Integer>, Iterable<Row>>, Tuple2<Integer, Integer>, Row>() { | |
@Override | |
public Tuple2<Tuple2<Integer, Integer>, Row> call(Tuple2<Tuple2<Integer, Integer>, Iterable<Row>> data) throws Exception { | |
Tuple2<Integer, Integer> key = data._1(); | |
Iterator<Row> vals = data._2().iterator(); | |
int size=0, dotProduct=0, ratingSum=0, rating2Sum=0, ratingSq=0, rating2Sq=0, numRaters=0, numRaters2=0; | |
while(vals.hasNext()){ | |
Row row=vals.next(); | |
size++; | |
dotProduct += row.getInt(0); | |
ratingSum += row.getInt(1); | |
rating2Sum += row.getInt(2); | |
ratingSq += row.getInt(3); | |
rating2Sq += row.getInt(4); | |
numRaters = Math.max(numRaters, row.getInt(5)); | |
numRaters2 = Math.max(numRaters2, row.getInt(6)); | |
} | |
return new Tuple2<Tuple2<Integer, Integer>, Row>(key, RowFactory.create(size, dotProduct, ratingSum, rating2Sum, ratingSq, rating2Sq, numRaters, numRaters2)); | |
} | |
}); | |
LOGGER.info("number of vector calcs {}",vectorCals.count()); | |
// compute similarity metrics for each movie pair | |
JavaPairRDD<Integer, Row> similarities = vectorCals.mapToPair((data)->{ | |
Tuple2<Integer, Integer> key = data._1(); | |
Row row = data._2(); | |
Double size = row.getDouble(0), | |
dotProduct = row.getDouble(1), | |
ratingSum = row.getDouble(2), | |
rating2Sum = row.getDouble(3), | |
ratingNormSq = row.getDouble(4), | |
rating2NormSq = row.getDouble(5), | |
numRaters = row.getDouble(6), | |
numRaters2 = row.getDouble(7); | |
double corr = correlation(size, dotProduct, ratingSum, rating2Sum, ratingNormSq, rating2NormSq); | |
double regCorr = regularizedCorrelation(size, dotProduct, ratingSum, rating2Sum, ratingNormSq, rating2NormSq, PRIOR_COUNT, PRIOR_CORRELATION); | |
double cosSim = cosineSimilarity(dotProduct, Math.sqrt(ratingNormSq), Math.sqrt(rating2NormSq)); | |
double jaccard = jaccardSimilarity(size, numRaters, numRaters2); | |
return new Tuple2<Integer, Row>(key._1(), RowFactory.create(key._1(), key._2(), corr, regCorr, cosSim, jaccard)); | |
}); | |
similarities.saveAsTextFile(outDir+"/similarities"); | |
} | |
// ************************* | |
// * SIMILARITY MEASURES | |
// ************************* | |
/** | |
* The correlation between two vectors A, B is | |
* cov(A, B) / (stdDev(A) * stdDev(B)) | |
* | |
* This is equivalent to | |
* [n * dotProduct(A, B) - sum(A) * sum(B)] / | |
* sqrt{ [n * norm(A)^2 - sum(A)^2] [n * norm(B)^2 - sum(B)^2] } | |
*/ | |
private static Double correlation(Double size, Double dotProduct, Double ratingSum, | |
Double rating2Sum, Double ratingNormSq, Double rating2NormSq){ | |
double numerator = size * dotProduct - ratingSum * rating2Sum; | |
double denominator = Math.sqrt(size * ratingNormSq - ratingSum * ratingSum) * | |
Math.sqrt(size * rating2NormSq - rating2Sum * rating2Sum); | |
return numerator / denominator; | |
} | |
/** | |
* Regularize correlation by adding virtual pseudocounts over a prior: | |
* RegularizedCorrelation = w * ActualCorrelation + (1 - w) * PriorCorrelation | |
* where w = # actualPairs / (# actualPairs + # virtualPairs). | |
*/ | |
private static Double regularizedCorrelation(Double size , Double dotProduct, Double ratingSum, | |
Double rating2Sum, Double ratingNormSq, Double rating2NormSq, | |
Double virtualCount, Double priorCorrelation) { | |
double unregularizedCorrelation = correlation(size, dotProduct, ratingSum, rating2Sum, ratingNormSq, rating2NormSq); | |
double w = size / (size + virtualCount); | |
return w * unregularizedCorrelation + (1 - w) * priorCorrelation; | |
} | |
/** | |
* The cosine similarity between two vectors A, B is | |
* dotProduct(A, B) / (norm(A) * norm(B)) | |
*/ | |
private static Double cosineSimilarity(Double dotProduct, Double ratingNorm, Double rating2Norm ) { | |
return dotProduct / (ratingNorm * rating2Norm); | |
} | |
/** | |
* The Jaccard Similarity between two sets A, B is | |
* |Intersection(A, B)| / |Union(A, B)| | |
*/ | |
private static Double jaccardSimilarity(Double usersInCommon , Double totalUsers1, Double totalUsers2){ | |
Double union = totalUsers1 + totalUsers2 - usersInCommon; | |
return usersInCommon / union; | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment