""" The schemas that Spark produces for DataFrames are typically nested, and these nested schemas are quite difficult to work with interactively. In many cases, it's possible to flatten a schema into a single level of column names. """ import typing as T import cytoolz.curried as tz import pyspark def schema_to_columns(schema: pyspark.sql.types.StructType) -> T.List[T.List[str]]: """ Produce a flat list of column specs from a possibly nested DataFrame schema """ columns = list() def helper(schm: pyspark.sql.types.StructType, prefix: list = None): if prefix is None: prefix = list() for item in schm.fields: if isinstance(item.dataType, pyspark.sql.types.StructType): helper(item.dataType, prefix + [item.name]) else: columns.append(prefix + [item.name]) helper(schema) return columns def flatten_frame(frame: pyspark.sql.DataFrame) -> pyspark.sql.DataFrame: aliased_columns = list() for col_spec in schema_to_columns(frame.schema): c = tz.get_in(col_spec, frame) if len(col_spec) == 1: aliased_columns.append(c) else: aliased_columns.append(c.alias(':'.join(col_spec))) return frame.select(aliased_columns)