# -*- coding: utf-8 -*-

import sys
import json
from typing import Optional, Tuple
import pprint

from log_filters import filter_logs_by_source_type

class MergePolicy:

    def __init__(self,  source_type: str, msg_field: str = "message", timestamp_field: str = "timestamp"):
        self._source_type = source_type
        self._msg_field = msg_field
        self._timestamp_field = timestamp_field

    def _gen_policy_for_list_field(self, field: str, json_logs: list[dict]) -> list[dict]:
        #Generate the policy for list type field.
        group_by_treshold = 0.8
        unique_values = set()
        for record in json_logs:
            unique_values.add(str(record[field]))

        if len(unique_values) == 1:
            return [{"action": "pull_up"}, {"action": "merge", "strategy": "discard"}]
        elif len(unique_values) < group_by_treshold * len(json_logs):
            return [{"action": "pull_up"}, {"action": "group_by"}]
        else:
            return [{"action": "noop"}]

    def _encode_fields(self, fields: list[str]) -> str:
        return ".".join(fields)

    def _decode_fields(self, fields_str: str) -> list[str]:
        return fields_str.strip().split('.')

    def _gen_policy_for_dict_field(self, field: str, json_logs: list[dict]) -> dict:
        #Generate the policy for dict type field. Only do one level down.
        policy = {}
        group_by_treshold = 0.8
        sample_log = json_logs[0]
        child_fields = sample_log[field].keys()
        for child_field in child_fields:
            # Inference the type of the child_field
            serialize_flag = False
            if type(sample_log[field][child_field]) in {list, dict}:
                serialize_flag = True

            unique_values = set()
            for record in json_logs:
                val = record[field][child_field]
                if serialize_flag:
                    unique_values.add(str(val))
                else:
                    unique_values.add(val)

            merged_key = self._encode_fields([field, child_field])
            if len(unique_values) == 1:
                policy[merged_key] = [{"action": "pull_up"}, {"action": "merge", "strategy": "discard"}]
            elif len(unique_values) < group_by_treshold * len(json_logs):
                policy[merged_key] = [{"action": "pull_up"}, {"action": "group_by"}]
            else:
                policy[merged_key] = [{"action": "noop"}]

        return policy

    def _gen_policy_for_non_obj_field(self, field: str, json_logs: list[dict]) -> list[dict]:
        group_by_treshold = 0.8
        unique_values = set()
        for record in json_logs:
            unique_values.add(record[field])

        if len(unique_values) == 1:
            return [{"action": "merge", "strategy": "discard"}]
        elif len(unique_values) < group_by_treshold * len(json_logs):
            return [{"action": "group_by"}]
        else:
            return [{"action": "merge", "strategy": "array"}]


    def gen_merge_policy(self, json_logs: list[dict]) -> list[dict]:	
            #Generate merge policy (or strategy) for given json_logs with same schema.
            if not json_logs:
                return json_logs
            
            policy = {}

            #Cluster top level fields
            dict_fields = []
            list_fields = []
            non_obj_fields = []
            
            sample_log = json_logs[0]
            for field in sample_log:
                if field == self._source_type:
                    # values of source_type are suppose to be same for all logs.
                    policy[field] = [{"action": "merge", "strategy": "discard"}]
                elif sample_log[field] is None:
                    import pdb; pdb.set_trace()
                    non_obj_fields.append(field) 
                elif type(sample_log[field]) == dict:
                    dict_fields.append(field)
                elif type(sample_log[field]) == list:
                    list_fields.append(field)
                else:
                    non_obj_fields.append(field)

            print("non_obj_fields: ", non_obj_fields)
            for field in non_obj_fields:
                policy[field] = self._gen_policy_for_non_obj_field(field, json_logs)
                
            print("list_fields: ", list_fields)
            for field in list_fields:
                policy[field] = self._gen_policy_for_list_field(field, json_logs)

            print("dict_fields: ", dict_fields)
            for field in dict_fields:
                nested_policy =  self._gen_policy_for_dict_field(field, json_logs)
                policy.update(nested_policy)

            return policy


__usage = """python3 gen_merge_policy.py [source_type_field] [log_file.json]\n"""

if __name__ == "__main__":
    if len(sys.argv) != 3:
        print(__usage)
        
    source_type = sys.argv[1]
    json_filename = sys.argv[2]

    json_logs = []
    with open(json_filename, 'r') as f:
        json_logs = json.load(f)

    filtered_logs = []
    for rec in filter_logs_by_source_type(json_logs, source_type):
        if isinstance(rec, dict):
            filtered_logs.append(rec)

    merger = MergePolicy(source_type)
    policy = merger.gen_merge_policy(filtered_logs)

    pprint.pprint(policy)