Created
December 5, 2021 01:30
Revisions
-
seanmor5 created this gist
Dec 5, 2021 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,240 @@ defmodule Day4 do import Nx.Defn def part1() do File.read!("aoc/4.txt") |> parse_input() |> play_bingo() |> find_winning_board() end def part2() do File.read!("aoc/4.txt") |> parse_input() |> play_bingo_until_last() |> compute_last_score() end defp parse_input(input) do [draws | boards] = input |> String.replace("\r", "") |> String.split("\n\n") draws = draws |> String.split(",") |> Enum.map(&String.to_integer/1) |> Nx.tensor() {draws, to_matrices(boards)} end defp to_matrices(boards) do boards |> Enum.map(&to_matrix/1) |> Nx.stack() end defp to_matrix(board) do # otherwise the end of the board gets cut off, there's # probably a better way board = <<board::binary, " "::binary>> digits = for <<c::3-binary <- board>> do c |> String.trim() |> String.to_integer() end digits |> Nx.tensor() |> Nx.reshape({5, 5}, names: [:rows, :columns]) end defnp play_bingo({draws, boards}) do # the current draw current = Nx.tensor(0) # mask of filled spaces on all boards, nobody has # anything filled in mask = Nx.broadcast(Nx.tensor(0, type: {:u, 8}), boards) # iterate through draws, this will be much easier # when we merge the while loop on leading axis syntax {current, bingo_mask, _, boards} = while {current, mask, draws, boards}, Nx.logical_not(bingo?(mask)) do next_draw = Nx.squeeze(draws[current]) values_to_fill = Nx.equal(boards, next_draw) update_mask = Nx.logical_or(values_to_fill, mask) {current + 1, update_mask, draws, boards} end {Nx.squeeze(Nx.slice_axis(draws, current - 1, 1, 0)), bingo_mask, boards} end defnp bingo?(mask) do # bingo occurs when the sum along rows or columns is # 5, thank goodness there are no diagonal bingos :) any_rows? = mask |> Nx.sum(axes: [:rows]) |> Nx.equal(5) |> Nx.any?() any_cols? = mask |> Nx.sum(axes: [:columns]) |> Nx.equal(5) |> Nx.any?() Nx.logical_or(any_rows?, any_cols?) end defnp find_winning_board({last_drawn, mask, boards}) do # the winning board index is the one where the sum of # the rows or columns is 5, so we can select it with iota # then slice out the winning board rows? = mask |> Nx.sum(axes: [:rows]) |> Nx.reduce_max(axes: [:columns]) |> Nx.equal(5) |> Nx.any?() cols? = mask |> Nx.sum(axes: [:columns]) |> Nx.equal(5) |> Nx.any?() winning_board_index = cond do rows? -> mask |> Nx.sum(axes: [:rows]) |> Nx.equal(5) # we need to reduce away the columns now |> Nx.sum(axes: [:columns]) |> Nx.select(Nx.iota({100}), 0) |> Nx.sum() cols? -> mask |> Nx.sum(axes: [:columns]) |> Nx.equal(5) # we need to reduce away the rows now |> Nx.sum(axes: [:rows]) |> Nx.select(Nx.iota({100}), 0) |> Nx.sum() :otherwise -> # oh no Nx.tensor(1_000_000) end not_drawn = mask |> Nx.slice_axis(winning_board_index, 1, 0) |> Nx.logical_not() winning_board = boards |> Nx.slice_axis(winning_board_index, 1, 0) not_drawn |> Nx.select(winning_board, 0) |> Nx.sum() |> Nx.multiply(last_drawn) end defnp play_bingo_until_last({draws, boards}) do # the current draw current = Nx.tensor(0) # number of bingos num_bingos = Nx.tensor(0, type: {:u, 64}) # mask of filled spaces on all boards, nobody has # anything filled in mask = Nx.broadcast(Nx.tensor(0, type: {:u, 8}), boards) # iterate through draws, this will be much easier # when we merge the while loop on leading axis syntax {current, bingo_mask, _, draws, boards} = while {current, mask, num_bingos, draws, boards}, Nx.less(num_bingos, 99) do next_draw = Nx.squeeze(draws[current]) values_to_fill = Nx.equal(boards, next_draw) update_mask = Nx.logical_or(values_to_fill, mask) num_bingos = count_bingos(update_mask) {current + 1, update_mask, num_bingos, draws, boards} end {current, bingo_mask, draws, boards} end defnp count_bingos(mask) do # it's possible to have duplicates unfortunately, so we # need to count unique wins row_bingos = mask |> Nx.sum(axes: [:rows]) |> Nx.equal(5) |> Nx.select(Nx.iota({100, 5}, axis: 1), -1) |> Nx.reduce_max(axes: [:columns]) col_bingos = mask |> Nx.sum(axes: [:columns]) |> Nx.equal(5) |> Nx.select(Nx.iota({100, 5}, axis: 1), -1) |> Nx.reduce_max(axes: [:rows]) row_bingos |> Nx.not_equal(-1) |> Nx.logical_or(Nx.not_equal(col_bingos, -1)) |> Nx.sum() end defnp compute_last_score({current, mask, draws, boards}) do row_wins = mask |> Nx.sum(axes: [:rows]) |> Nx.equal(5) |> Nx.reduce_max(axes: [:columns]) col_wins = mask |> Nx.sum(axes: [:columns]) |> Nx.equal(5) |> Nx.reduce_max(axes: [:rows]) loser_idx = row_wins |> Nx.logical_or(col_wins) |> Nx.logical_not() |> Nx.multiply(Nx.iota({100})) |> Nx.sum() loser_mask = Nx.slice_axis(mask, loser_idx, 1, 0) loser_board = Nx.slice_axis(boards, loser_idx, 1, 0) {current, loser_mask, _, loser_board} = while {current, loser_mask, draws, loser_board}, Nx.logical_not(bingo?(loser_mask)) do next_draw = Nx.squeeze(draws[current]) values_to_fill = Nx.equal(loser_board, next_draw) update_mask = Nx.logical_or(loser_mask, values_to_fill) {current + 1, update_mask, draws, loser_board} end not_drawn = loser_mask |> Nx.logical_not() last_drawn = Nx.squeeze(Nx.slice_axis(draws, current - 1, 1, 0)) not_drawn |> Nx.select(loser_board, 0) |> Nx.sum() |> Nx.multiply(last_drawn) end end Day4.part1() |> IO.inspect() Day4.part2() |> IO.inspect()