Created
September 1, 2018 17:25
-
-
Save Aankhen/cb6c5545823d36cacc548af9ddaad527 to your computer and use it in GitHub Desktop.
‘Recommending books (with Rust)’, enhanced
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
cargo-features = ["edition"] | |
[package] | |
name = "goodbooks-recommender" | |
version = "0.1.0" | |
authors = ["A"] | |
edition = "2018" | |
[dependencies] | |
reqwest = "0.8.8" | |
failure = "0.1.2" | |
serde_derive = "1.0.74" | |
serde = "1.0.74" | |
serde_json = "1.0.26" | |
csv = "1.0.1" | |
sbr = "0.4.0" | |
rand = "0.5.5" | |
elapsed = "0.1.2" | |
clap = "^2.32.0" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#![feature(uniform_paths)] | |
use std::collections::HashMap; | |
use std::fs::File; | |
use std::io::BufWriter; | |
use std::path::Path; | |
use serde_derive::{Deserialize, Serialize}; | |
/// Download file from `url` and save it to `destination`. | |
fn download(url: impl AsRef<str>, destination: impl AsRef<Path>) -> Result<(), failure::Error> { | |
let destination = destination.as_ref(); | |
if destination.exists() { | |
return Ok(()); | |
} | |
let file = File::create(destination)?; | |
let mut writer = BufWriter::new(file); | |
let mut response = reqwest::get(url.as_ref())?; | |
response.copy_to(&mut writer)?; | |
Ok(()) | |
} | |
/// Download ratings and metadata. | |
fn download_data( | |
ratings_path: impl AsRef<Path>, | |
books_path: impl AsRef<Path>, | |
ratings_url: impl AsRef<str>, | |
books_url: impl AsRef<str>, | |
) { | |
download(ratings_url.as_ref(), ratings_path.as_ref()).expect("Could not download ratings"); | |
download(books_url.as_ref(), books_path.as_ref()).expect("Could not download metadata"); | |
} | |
#[derive(Debug, Serialize, Deserialize)] | |
struct WishlistEntry { | |
user_id: usize, | |
book_id: usize, | |
} | |
fn deserialize_ratings(path: impl AsRef<Path>) -> Result<Vec<WishlistEntry>, failure::Error> { | |
let mut reader = csv::Reader::from_path(path)?; | |
let entries = reader.deserialize().collect::<Result<Vec<_>, _>>()?; | |
Ok(entries) | |
} | |
#[derive(Debug, Serialize, Deserialize)] | |
struct Book { | |
book_id: usize, | |
title: String, | |
} | |
/// Deserialize from file at `path` into book mappings. | |
fn deserialize_books( | |
path: impl AsRef<Path>, | |
) -> Result<(HashMap<usize, String>, HashMap<String, usize>), failure::Error> { | |
let mut reader = csv::Reader::from_path(path.as_ref())?; | |
let entries: Vec<Book> = reader | |
.deserialize::<Book>() | |
.collect::<Result<Vec<_>, _>>()?; | |
let id_to_title: HashMap<usize, String> = entries | |
.iter() | |
.map(|book| (book.book_id, book.title.clone())) | |
.collect(); | |
let title_to_id: HashMap<String, usize> = entries | |
.iter() | |
.map(|book| (book.title.clone(), book.book_id)) | |
.collect(); | |
Ok((id_to_title, title_to_id)) | |
} | |
use sbr::models::ewma::{Hyperparameters, ImplicitEWMAModel}; | |
use sbr::models::{Loss, Optimizer}; | |
fn build_model(num_items: usize) -> ImplicitEWMAModel { | |
let hp = Hyperparameters::new(num_items, 128) | |
.embedding_dim(32) | |
.learning_rate(0.16) | |
.l2_penalty(0.0004) | |
.loss(Loss::WARP) | |
.optimizer(Optimizer::Adagrad) | |
.num_epochs(10) | |
.num_threads(1); | |
hp.build() | |
} | |
use sbr::data::{Interaction, Interactions}; | |
fn build_interactions(data: &[WishlistEntry]) -> Interactions { | |
let num_users = data.iter().map(|x| x.user_id).max().unwrap() + 1; | |
let num_items = data.iter().map(|x| x.book_id).max().unwrap() + 1; | |
let mut interactions = Interactions::new(num_users, num_items); | |
for (idx, datum) in data.iter().enumerate() { | |
interactions.push(Interaction::new(datum.user_id, datum.book_id, idx)); | |
} | |
interactions | |
} | |
use rand::SeedableRng; | |
use sbr::data::user_based_split; | |
use sbr::OnlineRankingModel; | |
use sbr::evaluation::mrr_score; | |
/// Fit the model. | |
/// | |
/// If successful, return the MRR on the test set. Otherwise, return | |
/// an error. | |
fn fit(model: &mut ImplicitEWMAModel, data: &Interactions) -> Result<f32, failure::Error> { | |
let mut rng = rand::XorShiftRng::from_seed([42; 16]); | |
let (train, test) = user_based_split(data, &mut rng, 0.2); | |
model.fit(&train.to_compressed())?; | |
let mrr = mrr_score(model, &test.to_compressed())?; | |
Ok(mrr) | |
} | |
fn serialize_model( | |
model: &ImplicitEWMAModel, | |
path: impl AsRef<Path>, | |
) -> Result<(), failure::Error> { | |
let file = File::create(path.as_ref())?; | |
let mut writer = BufWriter::new(file); | |
Ok(serde_json::to_writer(&mut writer, model)?) | |
} | |
use elapsed::measure_time; | |
/// Download training data and build a model. | |
/// | |
/// We’ll use this function to power the `fit` subcommand of our | |
/// command line tool. | |
fn main_build( | |
model_path: impl AsRef<Path>, | |
ratings_path: impl AsRef<Path>, | |
books_path: impl AsRef<Path>, | |
ratings_url: impl AsRef<str>, | |
books_url: impl AsRef<str>, | |
) { | |
let model_path = model_path.as_ref(); | |
if model_path.exists() { | |
println!("Model already fitted."); | |
return; | |
} | |
let ratings_path = ratings_path.as_ref(); | |
let books_path = books_path.as_ref(); | |
println!("Downloading data..."); | |
download_data(ratings_path, books_path, ratings_url, books_url); | |
let ratings = deserialize_ratings(ratings_path).unwrap(); | |
let (id_to_title, _) = deserialize_books(books_path).unwrap(); | |
println!( | |
"Deserialized {} ratings and {} books.", | |
ratings.len(), | |
id_to_title.len() | |
); | |
let interactions = build_interactions(&ratings); | |
let mut model = build_model(interactions.num_items()); | |
println!("Fitting..."); | |
let (elapsed, mrr) = | |
measure_time(|| fit(&mut model, &interactions).expect("Unable to fit model")); | |
println!("Fitted model with MRR of {:.2} in {}.", mrr, elapsed); | |
serialize_model(&model, model_path).expect("Unable to serialize model."); | |
} | |
use std::io::BufReader; | |
fn deserialize_model(model_path: impl AsRef<Path>) -> Result<ImplicitEWMAModel, failure::Error> { | |
let file = File::open(model_path.as_ref())?; | |
let reader = BufReader::new(file); | |
let model = serde_json::from_reader(reader)?; | |
Ok(model) | |
} | |
use std::iter::Iterator; | |
fn predict( | |
books_path: impl AsRef<Path>, | |
input_titles: &[String], | |
model: &ImplicitEWMAModel, | |
) -> Result<Vec<String>, failure::Error> { | |
let (id_to_title, title_to_id) = deserialize_books(books_path.as_ref()).unwrap(); | |
for title in input_titles { | |
if !title_to_id.contains_key(title) { | |
println!("No such title, ignoring: {}", title); | |
} | |
} | |
let input_indices: Vec<_> = input_titles | |
.iter() | |
.filter_map(|title| title_to_id.get(title)) | |
.cloned() | |
.collect(); | |
let indices_to_score: Vec<usize> = (0..id_to_title.len()).collect(); | |
let user = model.user_representation(&input_indices)?; | |
let predictions = model.predict(&user, &indices_to_score)?; | |
let mut predictions: Vec<_> = indices_to_score | |
.iter() | |
.zip(predictions) | |
.map(|(idx, score)| (idx, score)) | |
.collect(); | |
predictions.sort_by(|(_, score_a), (_, score_b)| score_b.partial_cmp(score_a).unwrap()); | |
Ok((&predictions[..10]) | |
.iter() | |
.map(|(idx, _)| id_to_title.get(idx).unwrap()) | |
.cloned() | |
.collect()) | |
} | |
use std::ffi::{OsStr, OsString}; | |
fn is_existing_file(val: &OsStr) -> Result<(), OsString> { | |
let path = Path::new(&val); | |
if path.exists() { | |
Ok(()) | |
} else { | |
Err(OsString::from("Not an existing file")) | |
} | |
} | |
fn main() { | |
use clap::{App, AppSettings, Arg, SubCommand}; | |
let matches = App::new("Goodbooks Recommender") | |
.version("0.1.0") | |
.about("Recommends books using the goodbooks-10k dataset") | |
.setting(AppSettings::SubcommandRequired) | |
.subcommand( | |
SubCommand::with_name("fit") | |
.about("Fits") | |
.arg( | |
Arg::with_name("ratings_url") | |
.help("URL of ratings data") | |
.long("ratings-url") | |
.default_value( | |
"https://github.com/zygmuntz/goodbooks-10k/raw/master/ratings.csv", | |
), | |
).arg( | |
Arg::with_name("books_url") | |
.help("URL of books data") | |
.long("books-url") | |
.default_value( | |
"https://github.com/zygmuntz/goodbooks-10k/raw/master/books.csv", | |
), | |
).arg( | |
Arg::with_name("ratings_filename") | |
.help("Specifies ratings filename") | |
.long("ratings-filename") | |
.default_value_os(OsStr::new("ratings.json")), | |
).arg( | |
Arg::with_name("model_filename") | |
.help("Specifies model file path") | |
.long("model-filename") | |
.default_value_os(OsStr::new("model.json")), | |
).arg( | |
Arg::with_name("books_filename") | |
.long("books-filename") | |
.help("Specifies books file path") | |
.default_value_os(OsStr::new("books.json")), | |
), | |
).subcommand( | |
SubCommand::with_name("predict") | |
.about("Makes predictions") | |
.arg( | |
Arg::with_name("titles") | |
.help("Titles to base predictions on") | |
.index(1) | |
.multiple(true) | |
.required(true), | |
).arg( | |
Arg::with_name("model_filename") | |
.help("Specifies model file path") | |
.long("model-filename") | |
.default_value_os(OsStr::new("model.json")) | |
.validator_os(is_existing_file), | |
).arg( | |
Arg::with_name("books_filename") | |
.help("Specifies books file path") | |
.long("books-filename") | |
.default_value_os(OsStr::new("books.json")) | |
.validator_os(is_existing_file), | |
), | |
).get_matches(); | |
match matches.subcommand() { | |
("fit", Some(matches)) => { | |
let ratings_path = Path::new(matches.value_of("ratings_filename").unwrap()); | |
let model_path = Path::new(matches.value_of("model_filename").unwrap()); | |
let books_path = Path::new(matches.value_of("books_filename").unwrap()); | |
let ratings_url = matches.value_of("ratings_url").unwrap(); | |
let books_url = matches.value_of("books_url").unwrap(); | |
main_build( | |
&model_path, | |
&ratings_path, | |
&books_path, | |
&ratings_url, | |
&books_url, | |
) | |
} | |
("predict", Some(matches)) => { | |
let model_path = Path::new(matches.value_of("model_filename").unwrap()); | |
let books_path = Path::new(matches.value_of("books_filename").unwrap()); | |
let model = deserialize_model(&model_path) | |
.expect(&format!("Unable to deserialize {}.", model_path.display())); | |
let predictions = predict( | |
&books_path, | |
&matches | |
.values_of("titles") | |
.unwrap() | |
.map(|s| s.to_owned()) | |
.collect::<Vec<_>>(), | |
&model, | |
).expect("Unable to get predictions"); | |
if predictions.len() == 0 { | |
println!("No predictions found.") | |
} else { | |
println!("Predictions:"); | |
for prediction in predictions { | |
println!(" {}", prediction); | |
} | |
} | |
} | |
_ => unreachable!(), | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
This is based on maciejkula’s blog post, with a few tweaks to try out Rust 2018 features and a few useful crates.