Skip to content

Instantly share code, notes, and snippets.

@braz
Created February 26, 2019 15:31
Show Gist options
  • Save braz/27fb580143b06ce7bc0ded16ba24a675 to your computer and use it in GitHub Desktop.
Save braz/27fb580143b06ce7bc0ded16ba24a675 to your computer and use it in GitHub Desktop.
Linear regression plotting against Iris data
import urllib
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
# URL for the Iris dataset (UCI Machine Learning Repository)
url = "http://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data"
raw_data = urllib.urlopen(url)
# load the CSV file as a numpy matrix
data = pd.read_csv(raw_data, header=None)
# Reshape dataframe values for sklearn
fit_data = data[[2, 3]].values
x_data = fit_data[:,0].reshape(-1,1)
y_data = fit_data[:,1].reshape(-1,1)
# Create linear regression object
regr = linear_model.LinearRegression()
# once the data is reshaped, running the fit is simple
regr.fit(x_data, y_data)
# Load the Iris data from sci-kit learn
from sklearn import datasets
data = datasets.load_iris()
# create a Pandas DataFrame
df = pd.DataFrame(data.data, columns=data.feature_names)
# Add a target (0,1,2) to the data frame to link to the iris species
df['target'] = pd.DataFrame(data.target)
plt.clf()
plt.figure(figsize = (10, 6))
# Create an array to hold the three species names
names = data.target_names
colors = ['b','r','g']
# Create an array mapping the target to species name in 'names'
label = (data.target).astype(np.int)
plt.title('Petal Width vs Petal Length')
plt.xlabel(data.feature_names[2])
plt.ylabel(data.feature_names[3])
# Plot each species in turn
for i in range(len(names)):
# Extract the details for the current species
bucket = df[df['target'] == i]
# Limit the details to just petal width and petal length
bucket = bucket.iloc[:,[2,3]].values
# Plot Petal Width vs Petal Length for the current species
plt.scatter(bucket[:, 0], bucket[:, 1], label=names[i])
# Plot the data and the fit for the linear regresssion
plt.plot(x_data, regr.predict(x_data), color='black', linewidth=3)
# Add the legend to the diagram
plt.legend()
# Show the rendered image
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment