"""
The schemas that Spark produces for DataFrames are typically
nested, and these nested schemas are quite difficult to work with
interactively. In many cases, it's possible to flatten a schema
into a single level of column names.
"""

import typing as T

import cytoolz.curried as tz
import pyspark


def schema_to_columns(schema: pyspark.sql.types.StructType) -> T.List[T.List[str]]:
    """
    Produce a flat list of column specs from a possibly nested DataFrame schema
    """

    columns = list()
    
    def helper(schm: pyspark.sql.types.StructType, prefix: list = None):
        
        if prefix is None:
            prefix = list()
            
        for item in schm.fields:
            if isinstance(item.dataType, pyspark.sql.types.StructType):
                helper(item.dataType, prefix + [item.name])
            else:
                columns.append(prefix + [item.name])
    
    helper(schema)
    
    return columns

def flatten_frame(frame: pyspark.sql.DataFrame) -> pyspark.sql.DataFrame:

    aliased_columns = list()
    
    for col_spec in schema_to_columns(frame.schema):
        c = tz.get_in(col_spec, frame)
        if len(col_spec) == 1:
            aliased_columns.append(c)
        else:
            aliased_columns.append(c.alias(':'.join(col_spec)))
    
    return frame.select(aliased_columns)