This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#------------- | |
# HYPERPARAMS | |
#------------- | |
num_neg = 6 | |
latent_features = 8 | |
epochs = 20 | |
batch_size = 256 | |
learning_rate = 0.001 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#------------- | |
# HYPERPARAMS | |
#------------- | |
num_neg = 4 | |
latent_features = 8 | |
epochs = 20 | |
batch_size = 256 | |
learning_rate = 0.001 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
for epoch in range(epochs): | |
# Get our training input. | |
user_input, item_input, labels = get_train_instances() | |
# Generate a list of minibatches. | |
minibatches = random_mini_batches(user_input, item_input, labels) | |
# This has noting to do with tensorflow but gives |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
import numpy as np | |
import pandas as pd | |
import math | |
import heapq | |
from tqdm import tqdm | |
# Load and prepare our data. | |
uids, iids, df_train, df_test, df_neg, users, items, item_lookup = load_dataset() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def get_train_instances(): | |
"""Samples a number of negative user-item interactions for each | |
user-item pair in our testing data. | |
Returns: | |
user_input (list): A list of all users for each item | |
item_input (list): A list of all items for every user, | |
both positive and negative interactions. | |
labels (list): A list of all labels. 0 or 1. | |
""" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def get_negatives(uids, iids, items, df_test): | |
"""Returns a pandas dataframe of 100 negative interactions | |
based for each user in df_test. | |
Args: | |
uids (np.array): Numpy array of all user ids. | |
iids (np.array): Numpy array of all item ids. | |
items (list): List of all unique items. | |
df_test (dataframe): Our test set. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pandas as pd | |
import numpy as np | |
import pickle | |
def load_dataset(): | |
""" | |
Loads the lastfm dataset from a pickle file into a pandas dataframe | |
and transforms it into the format we need. | |
We then split it into a training and a test set. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#--------------------- | |
# MAKE RECOMMENDATION | |
#--------------------- | |
def make_recommendation(user_id=None, num_items=10): | |
"""Recommend items for a given user given a trained model | |
Args: | |
user_id (int): The id of the user we want to create recommendations for. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#----------------------- | |
# FIND SIMILAR ARTISTS | |
#----------------------- | |
def find_similar_artists(artist=None, num_items=10): | |
"""Find artists similar to an artist. | |
Args: | |
artist (str): The name of the artist we want to find similar artists for |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#------------------ | |
# GRAPH EXECUTION | |
#------------------ | |
# Run the session. | |
session = tf.Session(config=None, graph=graph) | |
session.run(init) | |
# This has noting to do with tensorflow but gives |
NewerOlder