Created
February 26, 2019 15:31
-
-
Save braz/27fb580143b06ce7bc0ded16ba24a675 to your computer and use it in GitHub Desktop.
Linear regression plotting against Iris data
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import urllib | |
import pandas as pd | |
import numpy as np | |
import matplotlib.pyplot as plt | |
# URL for the Iris dataset (UCI Machine Learning Repository) | |
url = "http://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data" | |
raw_data = urllib.urlopen(url) | |
# load the CSV file as a numpy matrix | |
data = pd.read_csv(raw_data, header=None) | |
# Reshape dataframe values for sklearn | |
fit_data = data[[2, 3]].values | |
x_data = fit_data[:,0].reshape(-1,1) | |
y_data = fit_data[:,1].reshape(-1,1) | |
# Create linear regression object | |
regr = linear_model.LinearRegression() | |
# once the data is reshaped, running the fit is simple | |
regr.fit(x_data, y_data) | |
# Load the Iris data from sci-kit learn | |
from sklearn import datasets | |
data = datasets.load_iris() | |
# create a Pandas DataFrame | |
df = pd.DataFrame(data.data, columns=data.feature_names) | |
# Add a target (0,1,2) to the data frame to link to the iris species | |
df['target'] = pd.DataFrame(data.target) | |
plt.clf() | |
plt.figure(figsize = (10, 6)) | |
# Create an array to hold the three species names | |
names = data.target_names | |
colors = ['b','r','g'] | |
# Create an array mapping the target to species name in 'names' | |
label = (data.target).astype(np.int) | |
plt.title('Petal Width vs Petal Length') | |
plt.xlabel(data.feature_names[2]) | |
plt.ylabel(data.feature_names[3]) | |
# Plot each species in turn | |
for i in range(len(names)): | |
# Extract the details for the current species | |
bucket = df[df['target'] == i] | |
# Limit the details to just petal width and petal length | |
bucket = bucket.iloc[:,[2,3]].values | |
# Plot Petal Width vs Petal Length for the current species | |
plt.scatter(bucket[:, 0], bucket[:, 1], label=names[i]) | |
# Plot the data and the fit for the linear regresssion | |
plt.plot(x_data, regr.predict(x_data), color='black', linewidth=3) | |
# Add the legend to the diagram | |
plt.legend() | |
# Show the rendered image | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment