Skip to content

Instantly share code, notes, and snippets.

@seanmor5
Created December 5, 2021 01:30

Revisions

  1. seanmor5 created this gist Dec 5, 2021.
    240 changes: 240 additions & 0 deletions aoc_day4_nx.exs
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,240 @@
    defmodule Day4 do
    import Nx.Defn

    def part1() do
    File.read!("aoc/4.txt")
    |> parse_input()
    |> play_bingo()
    |> find_winning_board()
    end

    def part2() do
    File.read!("aoc/4.txt")
    |> parse_input()
    |> play_bingo_until_last()
    |> compute_last_score()
    end

    defp parse_input(input) do
    [draws | boards] =
    input
    |> String.replace("\r", "")
    |> String.split("\n\n")

    draws =
    draws
    |> String.split(",")
    |> Enum.map(&String.to_integer/1)
    |> Nx.tensor()

    {draws, to_matrices(boards)}
    end

    defp to_matrices(boards) do
    boards
    |> Enum.map(&to_matrix/1)
    |> Nx.stack()
    end

    defp to_matrix(board) do
    # otherwise the end of the board gets cut off, there's
    # probably a better way
    board = <<board::binary, " "::binary>>

    digits =
    for <<c::3-binary <- board>> do
    c
    |> String.trim()
    |> String.to_integer()
    end

    digits
    |> Nx.tensor()
    |> Nx.reshape({5, 5}, names: [:rows, :columns])
    end

    defnp play_bingo({draws, boards}) do
    # the current draw
    current = Nx.tensor(0)
    # mask of filled spaces on all boards, nobody has
    # anything filled in
    mask = Nx.broadcast(Nx.tensor(0, type: {:u, 8}), boards)
    # iterate through draws, this will be much easier
    # when we merge the while loop on leading axis syntax
    {current, bingo_mask, _, boards} =
    while {current, mask, draws, boards}, Nx.logical_not(bingo?(mask)) do
    next_draw = Nx.squeeze(draws[current])
    values_to_fill = Nx.equal(boards, next_draw)
    update_mask = Nx.logical_or(values_to_fill, mask)
    {current + 1, update_mask, draws, boards}
    end

    {Nx.squeeze(Nx.slice_axis(draws, current - 1, 1, 0)), bingo_mask, boards}
    end

    defnp bingo?(mask) do
    # bingo occurs when the sum along rows or columns is
    # 5, thank goodness there are no diagonal bingos :)
    any_rows? =
    mask
    |> Nx.sum(axes: [:rows])
    |> Nx.equal(5)
    |> Nx.any?()

    any_cols? =
    mask
    |> Nx.sum(axes: [:columns])
    |> Nx.equal(5)
    |> Nx.any?()

    Nx.logical_or(any_rows?, any_cols?)
    end

    defnp find_winning_board({last_drawn, mask, boards}) do
    # the winning board index is the one where the sum of
    # the rows or columns is 5, so we can select it with iota
    # then slice out the winning board
    rows? =
    mask
    |> Nx.sum(axes: [:rows])
    |> Nx.reduce_max(axes: [:columns])
    |> Nx.equal(5)
    |> Nx.any?()

    cols? =
    mask
    |> Nx.sum(axes: [:columns])
    |> Nx.equal(5)
    |> Nx.any?()

    winning_board_index =
    cond do
    rows? ->
    mask
    |> Nx.sum(axes: [:rows])
    |> Nx.equal(5)
    # we need to reduce away the columns now
    |> Nx.sum(axes: [:columns])
    |> Nx.select(Nx.iota({100}), 0)
    |> Nx.sum()

    cols? ->
    mask
    |> Nx.sum(axes: [:columns])
    |> Nx.equal(5)
    # we need to reduce away the rows now
    |> Nx.sum(axes: [:rows])
    |> Nx.select(Nx.iota({100}), 0)
    |> Nx.sum()

    :otherwise ->
    # oh no
    Nx.tensor(1_000_000)
    end

    not_drawn =
    mask
    |> Nx.slice_axis(winning_board_index, 1, 0)
    |> Nx.logical_not()

    winning_board =
    boards
    |> Nx.slice_axis(winning_board_index, 1, 0)

    not_drawn
    |> Nx.select(winning_board, 0)
    |> Nx.sum()
    |> Nx.multiply(last_drawn)
    end

    defnp play_bingo_until_last({draws, boards}) do
    # the current draw
    current = Nx.tensor(0)
    # number of bingos
    num_bingos = Nx.tensor(0, type: {:u, 64})
    # mask of filled spaces on all boards, nobody has
    # anything filled in
    mask = Nx.broadcast(Nx.tensor(0, type: {:u, 8}), boards)
    # iterate through draws, this will be much easier
    # when we merge the while loop on leading axis syntax
    {current, bingo_mask, _, draws, boards} =
    while {current, mask, num_bingos, draws, boards}, Nx.less(num_bingos, 99) do
    next_draw = Nx.squeeze(draws[current])
    values_to_fill = Nx.equal(boards, next_draw)
    update_mask = Nx.logical_or(values_to_fill, mask)
    num_bingos = count_bingos(update_mask)
    {current + 1, update_mask, num_bingos, draws, boards}
    end

    {current, bingo_mask, draws, boards}
    end

    defnp count_bingos(mask) do
    # it's possible to have duplicates unfortunately, so we
    # need to count unique wins
    row_bingos =
    mask
    |> Nx.sum(axes: [:rows])
    |> Nx.equal(5)
    |> Nx.select(Nx.iota({100, 5}, axis: 1), -1)
    |> Nx.reduce_max(axes: [:columns])

    col_bingos =
    mask
    |> Nx.sum(axes: [:columns])
    |> Nx.equal(5)
    |> Nx.select(Nx.iota({100, 5}, axis: 1), -1)
    |> Nx.reduce_max(axes: [:rows])

    row_bingos
    |> Nx.not_equal(-1)
    |> Nx.logical_or(Nx.not_equal(col_bingos, -1))
    |> Nx.sum()
    end

    defnp compute_last_score({current, mask, draws, boards}) do
    row_wins =
    mask
    |> Nx.sum(axes: [:rows])
    |> Nx.equal(5)
    |> Nx.reduce_max(axes: [:columns])

    col_wins =
    mask
    |> Nx.sum(axes: [:columns])
    |> Nx.equal(5)
    |> Nx.reduce_max(axes: [:rows])

    loser_idx =
    row_wins
    |> Nx.logical_or(col_wins)
    |> Nx.logical_not()
    |> Nx.multiply(Nx.iota({100}))
    |> Nx.sum()

    loser_mask = Nx.slice_axis(mask, loser_idx, 1, 0)
    loser_board = Nx.slice_axis(boards, loser_idx, 1, 0)

    {current, loser_mask, _, loser_board} =
    while {current, loser_mask, draws, loser_board}, Nx.logical_not(bingo?(loser_mask)) do
    next_draw = Nx.squeeze(draws[current])
    values_to_fill = Nx.equal(loser_board, next_draw)
    update_mask = Nx.logical_or(loser_mask, values_to_fill)
    {current + 1, update_mask, draws, loser_board}
    end

    not_drawn =
    loser_mask
    |> Nx.logical_not()

    last_drawn = Nx.squeeze(Nx.slice_axis(draws, current - 1, 1, 0))

    not_drawn
    |> Nx.select(loser_board, 0)
    |> Nx.sum()
    |> Nx.multiply(last_drawn)
    end
    end

    Day4.part1() |> IO.inspect()
    Day4.part2() |> IO.inspect()