Created
March 10, 2017 21:06
-
-
Save arthur-e/c972670644861acdc8aaa711ca4adee7 to your computer and use it in GitHub Desktop.
An example Python script of using scikit-learn to learn water from non-water pixels
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
''' | |
A module for machine learning on Landsat data; implemented and tested, | |
specifically, for learning water areas on an image. Performance so far: | |
Gaussian naive Bayes (where validation data chosen by the hydro mask): | |
Mean precision: Not water=0.9997, Water=0.3090 | |
Mean recall: Not water=0.9787, Water=0.9763 | |
Gaussian naive Bayes (where validation data inspected in Google Earth): | |
Mean precision: Not water=0.9675, Water=1.0000 | |
Mean recall: Not water=1.0000, Water=0.9664 | |
''' | |
from epsg import EPSG | |
from utils import * | |
from lsma import ravel_and_filter | |
import numpy as np | |
from osgeo import gdal | |
from scipy.stats import pearsonr | |
from sklearn.cross_validation import StratifiedKFold | |
from sklearn.metrics import classification_report, confusion_matrix, precision_recall_fscore_support, roc_curve, auc | |
from sklearn.naive_bayes import GaussianNB | |
from pylab import * | |
def array_correlation(arr1, arr2, mask=None, nodata=-9999, n=1000, method='pearson'): | |
''' | |
Calculates Pearson's r (correlation) between two input raster arrays. | |
''' | |
if arr1.ndim != 3: | |
arr1 = arr1.reshape(1, arr1.shape[0], arr1.shape[1]) | |
if arr2.ndim != 3: | |
arr2 = arr2.reshape(1, arr2.shape[0], arr2.shape[1]) | |
assert arr1.shape == arr2.shape, 'Arrays must have the same shape' | |
if mask is not None: | |
arr1 = binary_mask(arr1, mask) | |
arr2 = binary_mask(arr2, mask) | |
# Convert the data to 1D arrays | |
arr1 = ravel_and_filter(arr1, filter=False).flatten() | |
arr2 = ravel_and_filter(arr2, filter=False).flatten() | |
# Pick some random sampling points | |
sample_points = np.random.randint(0, len(arr1), n) | |
# Sample the data after filtering out the NoData | |
arr1_sample = ravel_and_filter(arr1[sample_points], nodata=nodata) | |
arr2_sample = ravel_and_filter(arr2[sample_points], nodata=nodata) | |
if method == 'pearson': | |
return (pearsonr(arr1_sample.tolist(), arr2_sample.tolist()), | |
'N=%d' % len(arr1_sample)) | |
else: | |
raise NotImplemented | |
def train_and_test(xd, yd, cl=GaussianNB, kwargs={}, k=10, verbose=False): | |
''' | |
Performs a k-fold cross-validation using the provided classifier and | |
reports performance in terms of precision and recall. | |
''' | |
kfold = StratifiedKFold(yd, k) | |
precision = [] | |
recall = [] | |
for i, (train, test) in enumerate(kfold): | |
classifier = cl(**kwargs) | |
preds = classifier.fit(xd[train], yd[train]).predict(xd[test]) | |
metrics = precision_recall_fscore_support(yd[test], preds) | |
precision.append(metrics[0]) | |
recall.append(metrics[1]) | |
if verbose: | |
print(classification_report(yd[test], preds, [0, 1], [ | |
'Not water', | |
'Water' | |
])) | |
precision = np.ndarray((k, 2), buffer=np.array(precision)) | |
recall = np.ndarray((k, 2), buffer=np.array(recall)) | |
print('Mean precision: Not water=%.4f, Water=%.4f' % tuple(precision.mean(0).tolist())) | |
print('Mean recall: Not water=%.4f, Water=%.4f' % tuple(recall.mean(0).tolist())) | |
def train_and_test_roc(xd, yd, cl=GaussianNB, kwargs={}, k=10, verbose=False): | |
''' | |
Performs a k-fold cross-validation using the provided classifier and | |
reports performance in terms of an ROC curve. | |
''' | |
# Ex. from: http://scikit-learn.org/stable/auto_examples/model_selection/plot_roc_crossval.html | |
kfold = StratifiedKFold(yd, k) | |
mean_tpr = 0.0 | |
mean_fpr = np.linspace(0, 1, 100) | |
all_tpr = [] | |
for i, (train, test) in enumerate(kfold): | |
classifier = cl(**kwargs) | |
preds = classifier.fit(xd[train], yd[train]).predict(xd[test]) | |
# Compute ROC curve and area the curve | |
fpr, tpr, thresholds = roc_curve(yd[test], preds) | |
mean_tpr += interp(mean_fpr, fpr, tpr) | |
mean_tpr[0] = 0.0 | |
roc_auc = auc(fpr, tpr) | |
plt.plot(fpr, tpr, lw=1, label='ROC fold %d (area = %0.2f)' % (i, roc_auc)) | |
plt.plot([0, 1], [0, 1], '--', color=(0.6, 0.6, 0.6), label='Chance') | |
mean_tpr /= len(kfold) | |
mean_tpr[-1] = 1.0 | |
mean_auc = auc(mean_fpr, mean_tpr) | |
plt.plot(mean_fpr, mean_tpr, 'k--', | |
label='Mean ROC (area = %0.2f)' % mean_auc, lw=2) | |
plt.xlim([-0.05, 1.05]) | |
plt.ylim([-0.05, 1.05]) | |
plt.xlabel('False Positive Rate') | |
plt.ylabel('True Positive Rate') | |
plt.title('Receiver Operating Characteristic (ROC) Curve') | |
plt.legend(loc='lower right') | |
plt.show() | |
if __name__ == '__main__': | |
lds = gdal.Open('./data/20150610/land_areas_Oakland_extract.tiff') | |
wds = gdal.Open('./data/20150610/water_areas_Oakland_extract.tiff') | |
land = ravel_and_filter(lds.ReadAsArray()) | |
watr = ravel_and_filter(wds.ReadAsArray()) | |
lds = wds = None | |
# Randomly shuffle the validation data | |
# np.random.shuffle(land) | |
# np.random.shuffle(watr) | |
# Downsample the water data so it matches the size of the land data | |
# watr = watr[0:land.shape[0],:] | |
# Stack the two classes together and generate their labels | |
xdata = np.vstack((land, watr)) | |
ydata = np.hstack((np.zeros(land.shape[0]), np.ones(watr.shape[0]))) | |
# Evaluate the potential performance of this classifier | |
train_and_test(xdata, ydata, cl=GaussianNB, k=10) | |
############# | |
# Prediction | |
# Now, make a prediction on new data | |
ds = gdal.Open('./data/LE7020030+031_merge_Oakland.tiff') | |
gt = ds.GetGeoTransform() | |
arr = ds.ReadAsArray() | |
shp = arr.shape | |
ds = None | |
gnb = GaussianNB() | |
gnb.fit(xdata, ydata) | |
preds = gnb.predict(ravel_and_filter(arr, filter=False)) | |
rast = array_to_raster(preds.reshape((shp[1], shp[2])), | |
gt=gt, wkt=EPSG[32617]) | |
dump_raster(rast, 'data/20150610/water_prediction.tiff', nodata=-9999) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment