defmodule Day06 do
  def parse_data(data) do
    grid = data
    |> String.split("\n", trim: true)

    guard = grid
    |> map_elements("^")
    |> hd()

    obstacles = grid
    |> map_elements("#")

    obstacles_vert = obstacles
    |> Enum.group_by(fn {i, _} -> i end, fn {_, j} -> j end)
    |> Enum.map(fn {k, v} -> {k, v |> :gb_sets.from_list()} end)
    |> Map.new()

    obstacles_hori = obstacles
    |> Enum.group_by(fn {_, j} -> j end, fn {i, _} -> i end)
    |> Enum.map(fn {k, v} -> {k, v |> :gb_sets.from_list()} end)
    |> Map.new()

    grid = grid
    |> Enum.map(& :array.from_list(String.codepoints(&1)))
    |> :array.from_list()

    {guard, grid, obstacles_vert, obstacles_hori}
  end

  def map_elements(grid, element) do
    for {line, i} <- Enum.with_index(grid),
        {v, j} <- Enum.with_index(String.codepoints(line)) do
      if v == element do
        {i, j}
      end
    end
    |> Enum.filter(& &1 != nil)
  end

  def part1({guard, grid, _, _}) do
    path(guard, :up, grid)
    |> Enum.map(fn {i, j, _} -> {i, j} end)
    |> Enum.uniq()
    |> Enum.count()
  end

  def path(acc \\ [], {i, j}, dir, grid) do
    acc = [{i, j, dir} | acc]

    {ni, nj, ndir} = next(i, j, dir)

    case check(ni, nj, grid) do
      :invalid ->
        acc
      :obstacle ->
        path(tl(acc), {i, j}, ndir, grid)
      :valid ->
        path(acc, {ni, nj}, dir, grid)
    end
  end

  def next(i, j, dir) do
    case dir do
      :up ->
        {i-1, j, :right}
      :right ->
        {i, j+1, :down}
      :down ->
        {i+1, j, :left}
      :left ->
        {i, j-1, :up}
    end
  end

  def check(i, j, grid, extra \\ nil) do
    cond do
      i < 0
      or i == :array.size(grid)
      or j < 0
      or j == :array.get(i, grid) |> :array.size() ->
        :invalid
      :array.get(j, :array.get(i, grid)) == "#"
      or {i, j} == extra ->
        :obstacle
      true ->
        :valid
    end
  end

  def part2({guard, grid, overt, ohori}) do
    path(guard, :up, grid)
    |> Enum.map(fn {i, j, _} -> {i, j} end)
    |> Enum.uniq()
    |> Enum.filter(& &1 != guard)
    |> Enum.map(fn {i, j} ->
      is_loop?(guard, :up, grid, update(overt, i, j), update(ohori, j, i))
    end)
    |> Enum.count(& &1)
  end

  def update(m, k, v) do
    Map.get_and_update(m, k, fn set ->
      case set do
        nil ->
          {set, :gb_sets.from_list([v])}
        set ->
          {set, :gb_sets.add(v, set)}
      end
    end)
    |> elem(1)
  end

  def next(i, j, dir, overt, ohori) do
    case dir do
      :up ->
        case :gb_sets.smaller(i, Map.get(ohori, j, :gb_sets.empty())) do
          :none ->
            :out
          {_, v} ->
            {v+1, j, :right}
        end
      :right ->
        case :gb_sets.larger(j, Map.get(overt, i, :gb_sets.empty())) do
          :none ->
            :out
          {_, v} ->
            {i, v-1, :down}
        end
      :down ->
        case :gb_sets.larger(i, Map.get(ohori, j, :gb_sets.empty())) do
          :none ->
            :out
          {_, v} ->
            {v-1, j, :left}
        end
      :left ->
        case :gb_sets.smaller(j, Map.get(overt, i, :gb_sets.empty())) do
          :none ->
            :out
          {_, v} ->
            {i, v+1, :up}
        end
    end
  end

  def is_loop?(cache \\ MapSet.new(), {i, j}, dir, grid, overt, ohori) do
    key = {i, j, dir}
    if MapSet.member?(cache, key) do
      true
    else
      case next(i, j, dir, overt, ohori) do
        :out ->
          false
        {ni, nj, ndir} ->
          is_loop?(MapSet.put(cache, key), {ni, nj}, ndir, grid, overt, ohori)
      end
    end
  end
end

data = IO.read(:stdio, :eof) |> Day06.parse_data()

{time1 , ans1} = :timer.tc(fn -> Day06.part1(data) end)
IO.puts("Time  : #{time1 / 1000000}")
IO.puts("Answer: #{ans1}")

{time2 , ans2} = :timer.tc(fn -> Day06.part2(data) end)
IO.puts("Time  : #{time2 / 1000000}")
IO.puts("Answer: #{ans2}")