import org.junit.jupiter.api.Test;

import java.util.*;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static org.junit.jupiter.api.Assertions.assertEquals;

class MergeTest {

    /**
     * A tuple that represents a character combined with count of how many
     * times this character appeared in its string
     */
    static class CharacterAndFrequency {
        char character;
        long frequency;

        public CharacterAndFrequency(char character, Map<Character, Long> frequencies) {
            this.character = character;
            this.frequency = frequencies.get(character);
        }
    }


    /**
     * Creates a map counting the frequency of each character in a string.
     */
    static Map<Character, Long> frequencyMap(String string) {
        Map<Character, Long> frequencyMap = new HashMap<>();
        characterStream(string).forEach(character -> {
            frequencyMap.compute(character, (_key, count) -> {
                if (Objects.isNull(count)) {
                    return 1L;
                } else {
                    return count + 1L;
                }
            });
        });
        return frequencyMap;
    }


    static String solution(String s1, String s2) {
        // Create maps of how frequently each character occurs in each string
        Map<Character, Long> s1frequencies = frequencyMap(s1);
        Map<Character, Long> s2frequencies = frequencyMap(s2);

        Deque<CharacterAndFrequency> s1queue = characterStream(s1)
                .map(c -> new CharacterAndFrequency(c, s1frequencies)).collect(Collectors.toCollection(ArrayDeque::new));
        Deque<CharacterAndFrequency> s2queue = characterStream(s2)
                .map(c -> new CharacterAndFrequency(c, s2frequencies)).collect(Collectors.toCollection(ArrayDeque::new));

        StringBuilder stringBuilder = new StringBuilder();

        // Compare characters first by the frequency they appear in their respective strings
        // then by the lexicographical order of the char.
        Comparator<CharacterAndFrequency> comparator = Comparator.<CharacterAndFrequency>comparingLong(c -> c.frequency)
                .thenComparing(c -> c.character);

        while (!s1queue.isEmpty() && !s2queue.isEmpty()) {
            if (comparator.compare(s1queue.peekFirst(), s2queue.peekFirst()) <= 0) {
                stringBuilder.append(s1queue.pop().character);
            } else {
                stringBuilder.append(s2queue.pop().character);
            }
        }

        // Add any leftovers
        s1queue.forEach(c -> stringBuilder.append(c.character));
        s2queue.forEach(c -> stringBuilder.append(c.character));
        return stringBuilder.toString();
    }

    /**
     * Helper function because {@link String#chars()} returns an IntStream when
     * we actually want a stream of characters
     */
    static Stream<Character> characterStream(String s) {
        return s.chars().mapToObj(c -> (char) c);
    }

    @Test
    public void testSolution() {
        assertEquals("dcecccbd", solution("dce", "cccbd"));
        assertEquals("stouperwer", solution("super", "tower"));
    }

}