Skip to content

Instantly share code, notes, and snippets.

@akatasonov
Created October 1, 2019 10:42

Revisions

  1. akatasonov created this gist Oct 1, 2019.
    251 changes: 251 additions & 0 deletions prepare_data.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,251 @@
    import os
    import glob
    import pyproj
    import shapely
    import shapely.geometry
    import shapely.ops
    import fiona
    import rasterio
    import rasterio.mask
    import rasterio.merge
    import numpy
    import pickle

    def project_wsg_shape_to_csr(shape, from_crs, to_crs):
    project = lambda x, y: pyproj.transform(
    from_crs,
    to_crs,
    x,
    y
    )
    return shapely.ops.transform(project, shape)

    train_shapefile = fiona.open("train/train.shp", "r")
    train_shape_crs = pyproj.Proj(train_shapefile.crs)

    test_shapefile = fiona.open("test/test.shp", "r")
    test_shape_crs = pyproj.Proj(test_shapefile.crs)
    #print(shapefile.crs)

    # Start by enumerating SAFE products
    # TODO: check cloud contamination using s2cloudless
    product_groups = {}
    train_field_data = {}
    train_field_data_r = {}
    train_field_data_g = {}
    train_field_data_b = {}
    test_field_data = {}
    test_field_data_r = {}
    test_field_data_g = {}
    test_field_data_b = {}
    for product_fn in glob.glob('*.SAFE'):
    #print(product_fn)
    """
    The compact naming convention is arranged as follows:
    MMM_MSIL1C_YYYYMMDDHHMMSS_Nxxyy_ROOO_Txxxxx_<Product Discriminator>.SAFE
    The products contain two dates.
    The first date (YYYYMMDDHHMMSS) is the datatake sensing time.
    The second date is the "<Product Discriminator>" field, which is 15 characters in length, and is used to distinguish between different end user products from the same datatake. Depending on the instance, the time in this field can be earlier or slightly later than the datatake sensing time.
    The other components of the filename are:
    MMM: is the mission ID(S2A/S2B)
    MSIL1C: denotes the Level-1C product level
    YYYYMMDDHHMMSS: the datatake sensing start time
    Nxxyy: the Processing Baseline number (e.g. N0204)
    ROOO: Relative Orbit number (R001 - R143)
    Txxxxx: Tile Number field
    SAFE: Product Format (Standard Archive Format for Europe)
    """
    # Split the product name into parts
    product_attrs = product_fn.split('_')
    datatake_time = product_attrs[2]
    tile_number = product_attrs[5]
    # Since the shape files provided cover two tiles, group tiles by datatake_time
    if datatake_time in product_groups:
    product_groups[datatake_time].append(product_fn)
    else:
    product_groups[datatake_time] = [product_fn]

    # sort the dict in the chronological order
    product_groups = dict(sorted(product_groups.items()))

    # Enumerate groups of tiles
    for product_group in product_groups:
    print('*** Processing {}..'.format(product_group))
    b2 = [] # all B4 bands for a group, blue
    b3 = [] # all B4 bands for a group, green
    b4 = [] # all B4 bands for a group, red
    b8 = [] # all B8 bands for a group
    for product_fn in product_groups[product_group]:
    print(' {}'.format(product_fn))
    b2fn = ''
    b3fn = ''
    b4fn = ''
    b8fn = ''
    for bandfn in glob.glob('{}/GRANULE/*/IMG_DATA/*.jp2'.format(product_fn)):
    # Split the band file name
    base = os.path.basename(bandfn)
    band_attrs = os.path.splitext(base)[0].split('_')
    band_type = band_attrs[2] # B01, B02, etc
    if band_type == 'B02':
    b2fn = bandfn
    if band_type == 'B03':
    b3fn = bandfn
    if band_type == 'B04':
    b4fn = bandfn
    if band_type == 'B08':
    b8fn = bandfn

    assert b4fn and b8fn # should have both values
    b2.append(rasterio.open(b2fn))
    b3.append(rasterio.open(b3fn))
    b4.append(rasterio.open(b4fn))
    b8.append(rasterio.open(b8fn))

    print(' Merging bands..')
    # For a group of tiles/products, merge bands from different tiles together
    blue, _ = rasterio.merge.merge(b2)
    green, _ = rasterio.merge.merge(b3)
    red, out_trans = rasterio.merge.merge(b4)
    nir, _ = rasterio.merge.merge(b8)

    # Calculate the NDVI, given B4 and B8 band filenames
    print(' Calculating the NDVI..')
    ndvi = (nir.astype(float) - red.astype(float)) / (nir + red)
    # Save the NDVI image for manual analysis later
    print(' Saving the NDVI raster to ndvi/{}.tif..'.format(product_group))
    meta = b4[0].meta.copy()
    meta.update(dtype=rasterio.float64,
    compress='lzw',
    driver='GTiff',
    transform=out_trans,
    height=red.shape[1],
    width=red.shape[2]
    )
    with rasterio.open('ndvi/{}.tif'.format(product_group), 'w', **meta) as dst:
    dst.write(ndvi)
    dst.close()

    # convert 0..255 range in r,g,b to 0..1
    red = red.astype(float) / 65535
    green = green.astype(float) / 65535
    blue = blue.astype(float) / 65535

    # Save red, green and blue images as well
    print(' Saving the RGB raster to rgb/{}-r/g/b.tif..'.format(product_group))
    with rasterio.open('rgb/{}-r.tif'.format(product_group), 'w', **meta) as dst:
    dst.write(red)
    dst.close()
    with rasterio.open('rgb/{}-g.tif'.format(product_group), 'w', **meta) as dst:
    dst.write(green)
    dst.close()
    with rasterio.open('rgb/{}-b.tif'.format(product_group), 'w', **meta) as dst:
    dst.write(blue)
    dst.close()

    ndvi_img = rasterio.open('ndvi/{}.tif'.format(product_group))
    #print(' NDVI CRS is', ndvi_img.crs.data)
    ndvi_crs = pyproj.Proj(ndvi_img.crs)

    red_img = rasterio.open('rgb/{}-r.tif'.format(product_group))
    red_crs = pyproj.Proj(red_img.crs)
    green_img = rasterio.open('rgb/{}-g.tif'.format(product_group))
    green_crs = pyproj.Proj(green_img.crs)
    blue_img = rasterio.open('rgb/{}-b.tif'.format(product_group))
    blue_crs = pyproj.Proj(blue_img.crs)

    # Alright, NDVI is ready for the whole region in question
    # Use the shape file to mask out everything, except fields
    for field in train_shapefile:
    #print(field['properties']['Field_Id'], field['properties']['Crop_Id_Ne'])
    field_id = field['properties']['Field_Id']
    #print(' Cropping NDVI data for train field #{}'.format(field_id))
    try:
    projected_shape = project_wsg_shape_to_csr(shapely.geometry.shape(field['geometry']),
    train_shape_crs,
    ndvi_crs)
    except Exception as e:
    print(' ', e, ' exception for field #', field_id)
    continue

    #print(projected_shape)
    field_img, field_img_transform = rasterio.mask.mask(ndvi_img, [projected_shape], crop=True)
    field_img_red, _ = rasterio.mask.mask(red_img, [projected_shape], crop=True)
    field_img_green, _ = rasterio.mask.mask(green_img, [projected_shape], crop=True)
    field_img_blue, _ = rasterio.mask.mask(blue_img, [projected_shape], crop=True)
    # remove the first dimension
    field_img = numpy.squeeze(field_img, axis=0)
    field_img_red = numpy.squeeze(field_img_red, axis=0)
    field_img_green = numpy.squeeze(field_img_green, axis=0)
    field_img_blue = numpy.squeeze(field_img_blue, axis=0)
    # add the 3rd dimension
    field_img = numpy.expand_dims(field_img, 2)
    field_img_red = numpy.expand_dims(field_img_red, 2)
    field_img_green = numpy.expand_dims(field_img_green, 2)
    field_img_blue = numpy.expand_dims(field_img_blue, 2)

    if field_id in train_field_data:
    train_field_data[field_id] = numpy.concatenate((train_field_data[field_id], field_img), axis=2)
    train_field_data_r[field_id] = numpy.concatenate((train_field_data_r[field_id], field_img_red), axis=2)
    train_field_data_g[field_id] = numpy.concatenate((train_field_data_g[field_id], field_img_green), axis=2)
    train_field_data_b[field_id] = numpy.concatenate((train_field_data_b[field_id], field_img_blue), axis=2)
    else:
    train_field_data[field_id] = field_img
    train_field_data_r[field_id] = field_img_red
    train_field_data_g[field_id] = field_img_green
    train_field_data_b[field_id] = field_img_blue

    for field in test_shapefile:
    #print(field['properties']['Field_Id'], field['properties']['Crop_Id_Ne'])
    field_id = field['properties']['Field_Id']
    #print(' Cropping NDVI data for test field #{}'.format(field_id))
    try:
    projected_shape = project_wsg_shape_to_csr(shapely.geometry.shape(field['geometry']),
    test_shape_crs,
    ndvi_crs)
    except Exception as e:
    print(' ', e, ' exception for field #', field_id)
    continue

    #print(projected_shape)
    field_img, field_img_transform = rasterio.mask.mask(ndvi_img, [projected_shape], crop=True)
    field_img_red, _ = rasterio.mask.mask(red_img, [projected_shape], crop=True)
    field_img_green, _ = rasterio.mask.mask(green_img, [projected_shape], crop=True)
    field_img_blue, _ = rasterio.mask.mask(blue_img, [projected_shape], crop=True)
    # remove the first dimension
    field_img = numpy.squeeze(field_img, axis=0)
    field_img_red = numpy.squeeze(field_img_red, axis=0)
    field_img_green = numpy.squeeze(field_img_green, axis=0)
    field_img_blue = numpy.squeeze(field_img_blue, axis=0)
    # add the 3rd dimension
    field_img = numpy.expand_dims(field_img, 2)
    field_img_red = numpy.expand_dims(field_img_red, 2)
    field_img_green = numpy.expand_dims(field_img_green, 2)
    field_img_blue = numpy.expand_dims(field_img_blue, 2)


    if field_id in test_field_data:
    test_field_data[field_id] = numpy.concatenate((test_field_data[field_id], field_img), axis=2)
    test_field_data_r[field_id] = numpy.concatenate((test_field_data_r[field_id], field_img_red), axis=2)
    test_field_data_g[field_id] = numpy.concatenate((test_field_data_g[field_id], field_img_green), axis=2)
    test_field_data_b[field_id] = numpy.concatenate((test_field_data_b[field_id], field_img_blue), axis=2)
    else:
    test_field_data[field_id] = field_img
    test_field_data_r[field_id] = field_img_red
    test_field_data_g[field_id] = field_img_green
    test_field_data_b[field_id] = field_img_blue


    # save the fields data to file
    pickle.dump(train_field_data, open('train/train.pkl', 'wb'), protocol=pickle.HIGHEST_PROTOCOL)
    pickle.dump(train_field_data_r, open('train/train-r.pkl', 'wb'), protocol=pickle.HIGHEST_PROTOCOL)
    pickle.dump(train_field_data_g, open('train/train-g.pkl', 'wb'), protocol=pickle.HIGHEST_PROTOCOL)
    pickle.dump(train_field_data_b, open('train/train-b.pkl', 'wb'), protocol=pickle.HIGHEST_PROTOCOL)
    pickle.dump(test_field_data, open('test/test.pkl', 'wb'), protocol=pickle.HIGHEST_PROTOCOL)
    pickle.dump(test_field_data_r, open('test/test-r.pkl', 'wb'), protocol=pickle.HIGHEST_PROTOCOL)
    pickle.dump(test_field_data_g, open('test/test-g.pkl', 'wb'), protocol=pickle.HIGHEST_PROTOCOL)
    pickle.dump(test_field_data_b, open('test/test-b.pkl', 'wb'), protocol=pickle.HIGHEST_PROTOCOL)