import pylab as pl from itertools import cycle def plot_2D(data, target, target_names): colors = cycle('rgbcmykw') target_ids = range(len(target_names)) pl.figure() for i, c, label in zip(target_ids, colors, target_names): pl.scatter(data[target == i, 0], data[target == i, 1], c=c, label=label) pl.legend()