Last active
September 2, 2022 10:22
-
-
Save shreyasms17/e6b8984c4c20cfa54f5fb55810ba068e to your computer and use it in GitHub Desktop.
AutoFlatten Complex JSON
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from pyspark.sql.functions import col, explode_outer | |
from pyspark.sql.types import * | |
from copy import deepcopy | |
from autoflatten import AutoFlatten | |
from collections import Counter | |
s3_path = 's3://mybucket/orders/' | |
df = spark.read.orc(s3_path) | |
json_df = spark.read.json(df.rdd.map(lambda row: row.json)) | |
json_schema = json_df.schema | |
af = AutoFlatten(json_schema) | |
af.compute() | |
df1 = json_df | |
visited = set([f'.{column}' for column in df1.columns]) | |
duplicate_target_counter = Counter(af.all_fields.values()) | |
cols_to_select = df1.columns | |
for rest_col in af.rest: | |
if rest_col not in visited: | |
cols_to_select += [rest_col[1:]] if (duplicate_target_counter[af.all_fields[rest_col]]==1 and af.all_fields[rest_col] not in df1.columns) else [col(rest_col[1:]).alias(f"{rest_col[1:].replace('.', '>')}")] | |
visited.add(rest_col) | |
df1 = df1.select(cols_to_select) | |
if af.order: | |
for key in af.order: | |
column = key.split('.')[-1] | |
if af.bottom_to_top[key]: | |
######### | |
#values for the column in bottom_to_top dict exists if it is an array type | |
######### | |
df1 = df1.select('*', explode_outer(col(column)).alias(f"{column}_exploded")).drop(column) | |
data_type = df1.select(f"{column}_exploded").schema.fields[0].dataType | |
if not (isinstance(data_type, StructType) or isinstance(data_type, ArrayType)): | |
df1 = df1.withColumnRenamed(f"{column}_exploded", column if duplicate_target_counter[af.all_fields[key]]<=1 else key[1:].replace('.', '>')) | |
visited.add(key) | |
else: | |
#grabbing all paths to columns after explode | |
cols_in_array_col = set(map(lambda x: f'{key}.{x}', df1.select(f'{column}_exploded.*').columns)) | |
#retrieving unvisited columns | |
cols_to_select_set = cols_in_array_col.difference(visited) | |
all_cols_to_select_set = set(af.bottom_to_top[key]) | |
#check done for duplicate column name & path | |
cols_to_select_list = list(map(lambda x: f"{column}_exploded{'.'.join(x.split(key)[1:])}" if (duplicate_target_counter[af.all_fields[x]]<=1 and x.split('.')[-1] not in df1.columns) else col(f"{column}_exploded{'.'.join(x.split(key)[1:])}").alias(f"{x[1:].replace('.', '>')}"), list(all_cols_to_select_set))) | |
#updating visited set | |
visited.update(cols_to_select_set) | |
rem = list(map(lambda x: f"{column}_exploded{'.'.join(x.split(key)[1:])}", list(cols_to_select_set.difference(all_cols_to_select_set)))) | |
df1 = df1.select(df1.columns + cols_to_select_list + rem).drop(f"{column}_exploded") | |
else: | |
######### | |
#values for the column in bottom_to_top dict do not exist if it is a struct type / array type containing a string type | |
######### | |
#grabbing all paths to columns after opening | |
cols_in_array_col = set(map(lambda x: f'{key}.{x}', df1.selectExpr(f'{column}.*').columns)) | |
#retrieving unvisited columns | |
cols_to_select_set = cols_in_array_col.difference(visited) | |
#check done for duplicate column name & path | |
cols_to_select_list = list(map(lambda x: f"{column}.{x.split('.')[-1]}" if (duplicate_target_counter[x.split('.')[-1]]<=1 and x.split('.')[-1] not in df1.columns) else col(f"{column}.{x.split('.')[-1]}").alias(f"{x[1:].replace('.', '>')}"), list(cols_to_select_set))) | |
#updating visited set | |
visited.update(cols_to_select_set) | |
df1 = df1.select(df1.columns + cols_to_select_list).drop(f"{column}") | |
final_df = df1.select([field[1:].replace('.', '>') if duplicate_target_counter[af.all_fields[field]]>1 else af.all_fields[field] for field in af.all_fields]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment