Last active
June 11, 2019 19:14
-
-
Save victorkohler/8e345840a03e4675d049b7839e3b4ac1 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pandas as pd | |
import numpy as np | |
import pickle | |
def load_dataset(): | |
""" | |
Loads the lastfm dataset from a pickle file into a pandas dataframe | |
and transforms it into the format we need. | |
We then split it into a training and a test set. | |
""" | |
# Load the data from disk | |
with open('path/to/data.pkl', 'rb') as f_in: | |
df = pickle.load(f_in) | |
# Add column names | |
df = df.drop(df.columns[1], axis=1) | |
df.columns = ['user', 'item', 'plays'] | |
# Drop any rows with empty cells or rows | |
# with a play count of zero. | |
df = df.dropna() | |
df = df.loc[df.plays != 0] | |
# Remove any users with fewer than 1 interaction. | |
df_count = df.groupby(['user']).count() | |
df['count'] = df.groupby('user')['user'].transform('count') | |
df = df[df['count'] > 1] | |
# Convert artists names into numerical IDs | |
df['user_id'] = df['user'].astype("category").cat.codes | |
df['item_id'] = df['item'].astype("category").cat.codes | |
# Create a lookup frame so we can get the artist | |
# names back in readable form later. | |
item_lookup = df[['item_id', 'item']].drop_duplicates() | |
item_lookup['item_id'] = item_lookup.item_id.astype(str) | |
# Grab the columns we need in the order we need them. | |
df = df[['user_id', 'item_id', 'plays']] | |
# Create training and test sets. | |
df_train, df_test = train_test_split(df) | |
# Create lists of all unique users and artists | |
users = list(np.sort(df.user_id.unique())) | |
items = list(np.sort(df.item_id.unique())) | |
# Get the rows, columns and values for our matrix. | |
rows = df_train.user_id.astype(int) | |
cols = df_train.item_id.astype(int) | |
values = list(df_train.plays) | |
# Get all user ids and item ids. | |
uids = np.array(rows.tolist()) | |
iids = np.array(cols.tolist()) | |
# Sample 100 negative interactions for each user in our test data | |
df_neg = get_negatives(uids, iids, items, df_test) | |
return uids, iids, df_train, df_test, df_neg, users, items, item_lookup |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment