Source code for ml_utils

import logging
from typing import List

import numpy as np
import tensorflow as tf
import xarray as xr

logger = logging.getLogger(__name__)


[docs]def chunk_result(img, input_shape, pad_mode="reflect"): """ A function to create splits out of a particular size from a given image. images are split up row wise, i.e - row1 split up, row2 split up and so on NOTE - padding is added in case the image can't be split into equal parts padding is added on the right and the bottom of the image, padding type is reflected by default The function assumes that the input will always be of 4 dimensions which pertain to [channels, time-step, height, width]. Expand any missing dimensions as 1 before passing the data Args: img : image to be split up [C x T x H x W] input_shape : size of the split [size_h, size_w] Returns: splits : a list containing the split up images """ img = img.dataset.to_array() h_idx, w_idx = 2, 3 border_len_y = (0 - img.shape[h_idx]) % input_shape[0] border_len_x = (0 - img.shape[w_idx]) % input_shape[1] img_attrs = img.attrs # Padding removes metadata for some reason, # also note that in the docs it is stated that this is an experimental function # https://docs.xarray.dev/en/stable/generated/xarray.DataArray.pad.html img = img.pad(pad_width={"y": (0, border_len_y), "x": (0, border_len_x)}, mode=pad_mode) img.attrs = img_attrs splits = [] for i in range(0, img.shape[h_idx], input_shape[0]): for j in range(0, img.shape[w_idx], input_shape[1]): splits.append(img[:, :, i : i + input_shape[0], j : j + input_shape[1]]) return split_da_to_split_ds(splits)
def split_da_to_split_ds(splits): splits_ds = [] for split in splits: data_array_dict = {} for var_idx, var_name in enumerate(split.coords["variable"].values): var_arr = split[var_idx, :, :, :] var_xr = xr.DataArray( var_arr, dims=["time", "y", "x"], coords={"time": split.time, "y": split.y, "x": split.x} ) data_array_dict[var_name] = var_xr split_dataset = list(data_array_dict.values())[0].to_dataset(name=list(data_array_dict.keys())[0]) for var_name in list(data_array_dict.keys())[1:]: split_dataset[var_name] = data_array_dict[var_name] split_dataset.attrs = split.attrs splits_ds.append(split_dataset) return splits_ds def tf_bytes_feature(value): """Returns a bytes_list from a string / byte.""" if isinstance(value, type(tf.constant(0))): value = value.numpy() # BytesList won't unpack a string from an EagerTensor. return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) def tf_float_feature(value): """Returns a float_list from a float / double.""" return tf.train.Feature(float_list=tf.train.FloatList(value=[value])) def tf_int64_feature(value): """Returns an int64_list from a bool / enum / int / uint.""" return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
[docs]def create_tf_record(chunks: List[xr.Dataset], save_path: str, save_coords: bool = False, save_attrs: bool = False): """ A function to save "chunked" results into the "tfrecord" format The function assumes that the input will always be of 4 dimensions which pertain to [channels, time-step, height, width]. Expand any missing dimensions as 1 before passing the data Args: chunks (list of xarray datasets) : List of xarray datasets with equal dimensions save_path (str) : Path to the directory to save the tfrecord file save_coords (bool) : Boolean to save coordinates from the datasets save_attrs (bool) : Boolean to save metadata and attributes from the datasets Returns: features_dict (dict) : Dictionary mapping feature name to the tf.io.FixedLenFeature as they are stored out_types_dict (dict) : Dictionary mapping feature name to the data type to which it needs to be decoded shapes_dict (dict) : Dictionary mapping feature name to the the shape it needs to be decoded to, empty tuple for scalar data """ features_dict = {} out_types_dict = {} shapes_dict = {} writer = tf.io.TFRecordWriter(save_path) # img_id_counter = 0 for chunk in chunks: chunk_tf_dict = {} # serialize the data tensors for data_var_name in list(chunk.data_vars.keys()): data_arr = chunk[data_var_name].values.astype("float64") chunk_tf_dict[data_var_name] = tf_bytes_feature(tf.io.serialize_tensor(data_arr)) features_dict[data_var_name] = tf.io.FixedLenFeature([], tf.string) out_types_dict[data_var_name] = tf.float64 shapes_dict[data_var_name] = data_arr.shape if save_coords: # serialize the coordinate tensors for coord_name in list(chunk.coords.keys()): coord_arr = chunk[coord_name].values if np.issubdtype(coord_arr.dtype, np.datetime64): coord_arr = np.array(list(map(str, coord_arr))) chunk_tf_dict["coord_" + coord_name] = tf_bytes_feature(tf.io.serialize_tensor(coord_arr)) features_dict["coord_" + coord_name] = tf.io.FixedLenFeature([], tf.string) out_types_dict["coord_" + coord_name] = tf.string shapes_dict["coord_" + coord_name] = coord_arr.shape else: coord_arr = coord_arr.astype("float64") chunk_tf_dict["coord_" + coord_name] = tf_bytes_feature(tf.io.serialize_tensor(coord_arr)) features_dict["coord_" + coord_name] = tf.io.FixedLenFeature([], tf.string) out_types_dict["coord_" + coord_name] = tf.float64 shapes_dict["coord_" + coord_name] = coord_arr.shape if save_attrs: # serialize the attribute tensors for attr_name in list(chunk.attrs.keys()): attr = chunk.attrs[attr_name] if isinstance(attr, str): chunk_tf_dict["attr_" + attr_name] = tf_bytes_feature(bytes(attr, "utf-8")) features_dict["attr_" + attr_name] = tf.io.FixedLenFeature([], tf.string) out_types_dict["attr_" + attr_name] = tf.string shapes_dict["attr_" + attr_name] = () elif isinstance(attr, np.ndarray): attr_arr = attr.astype("float64") chunk_tf_dict["attr_" + attr_name] = tf_bytes_feature(tf.io.serialize_tensor(attr_arr)) features_dict["attr_" + attr_name] = tf.io.FixedLenFeature([], tf.string) out_types_dict["attr_" + attr_name] = tf.float64 shapes_dict["attr_" + attr_name] = attr_arr.shape else: logger.info( f"{type(attr_arr)} is not supported for saving currently, skipping saving attribute {attr_name}" ) message_feature = tf.train.Example(features=tf.train.Features(feature=chunk_tf_dict)) writer.write(message_feature.SerializeToString()) writer.close() return features_dict, out_types_dict, shapes_dict
[docs]def load_img(example_proto, features_dict, out_types_dict, shapes_dict): """ Function to map data from a saved tfrecord to the accompanying saved dictionaries. This function is meant to be used in conjunction with the tf.data API when loading in the dataset as a tfrecord. Args: example_proto (str) : Single example (data sample) from the tfrecord features_dict (dict) : Dictionary mapping feature name to the tf.io.FixedLenFeature as they are stored out_types_dict (dict) : Dictionary mapping feature name to the data type to which it needs to be decoded shapes_dict (dict) : Dictionary mapping feature name to the the shape it needs to be decoded to, empty tuple for scalar data Example ---------- >>> dataset = tf.data.TFRecordDataset(save_file_path) >>> dataset = dataset.map( >>> lambda example_proto: ml_utils.load_img( >>> example_proto, features_dict=features_dict, out_types_dict=out_types_dict, shapes_dict=shapes_dict >>> ) >>> ) """ single_example = tf.io.parse_single_example(example_proto, features_dict) shapeless_feats = [feat_name for feat_name, feat_shape in shapes_dict.items() if len(feat_shape) == 0] example_data = {feat_name: single_example[feat_name] for feat_name in shapeless_feats} for feature in features_dict.keys(): if feature in shapeless_feats: continue feature_data = tf.io.parse_tensor(single_example[feature], out_type=out_types_dict[feature]) feature_data = tf.reshape(feature_data, shapes_dict[feature]) example_data[feature] = feature_data return example_data
[docs]def combine_bands(example_data, input_bands, output_bands): """ Function to stack the input and output bands from the tfrecord dataset. Args: example_data (str): Single example (data sample) from the tfrecord input_bands (list of strings): List of bands used as model inputs output_bands (list of strings): List of bands used as model outputs Example ---------- >>> input_bands = ["S2_RED", "S2_GREEN", "S2_BLUE"] >>> output_bands = ["S2_SCL"] >>> dataset = dataset.map(lambda example_data: ml_utils.combine_bands(example_data, input_bands, output_bands)) """ input_data = tf.stack([example_data[ip_band] for ip_band in input_bands], axis=-1) if not output_bands: return input_data output_data = tf.stack([example_data[op_band] for op_band in output_bands], axis=-1) return input_data, output_data