summaryrefslogtreecommitdiffstats
path: root/lib/day08.ml
blob: 0570b6674d775db8da4ced52034ca3ae9f66d451 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
(*
 * SPDX-FileCopyrightText: Copyright 2025 Alexandre Jesus <https://adbjesus.com>
 *
 * SPDX-License-Identifier: GPL-3.0-or-later
 *)

let parse ch =
  ch
  |> In_channel.input_lines
  |> List.filter (( <> ) "")
  |> List.map (fun s -> Scanf.sscanf s "%d,%d,%d" (fun x y z -> (x, y, z)))

let sqdist (x1, y1, z1) (x2, y2, z2) =
  let xd = abs (x1 - x2) in
  let yd = abs (y1 - y2) in
  let zd = abs (z1 - z2) in
  (xd * xd) + (yd * yd) + (zd * zd)

let solve part ch =
  let nodes = parse ch in
  let rec make_edges acc = function
    | [] | _ :: [] -> acc
    | h :: t ->
        make_edges
          (List.fold_left (fun acc x -> (sqdist h x, (h, x)) :: acc) acc t)
          t
  in
  let edges = make_edges [] nodes in
  let edges = List.fast_sort (fun (d1, _) (d2, _) -> compare d1 d2) edges in
  let make_tbl fn = Hashtbl.of_seq (Seq.map fn (List.to_seq nodes)) in
  let par = make_tbl (fun n -> (n, n)) in
  let sz = make_tbl (fun n -> (n, 1)) in
  let rec find n =
    let p = Hashtbl.find par n in
    if p = n then p
    else
      let p = find p in
      Hashtbl.replace par n p;
      p
  in
  let union u v =
    let a = find u in
    let b = find v in
    if a <> b then (
      let sa = Hashtbl.find sz a in
      let sb = Hashtbl.find sz b in
      if sa < sb then (
        Hashtbl.replace par a b;
        Hashtbl.replace sz b (sa + sb))
      else (
        Hashtbl.replace par b a;
        Hashtbl.replace sz a (sa + sb));
      `Union (u, v))
    else `Nil
  in
  if part = 1 then
    let _ = List.take 1000 edges |> List.map (fun (_, (a, b)) -> union a b) in
    Hashtbl.to_seq sz
    |> Seq.filter (fun (n, _) -> n = find n)
    |> Seq.map snd
    |> List.of_seq
    |> List.fast_sort (fun a b -> -compare a b)
    |> List.take 3
    |> List.fold_left ( * ) 1
    |> Printf.printf "%d\n"
  else if part = 2 then
    edges
    |> List.map (fun (_, (a, b)) -> union a b)
    |> List.filter_map (fun v ->
        match v with `Union (a, b) -> Some (a, b) | `Nil -> None)
    |> List.rev
    |> List.hd
    |> fun ((xa, _, _), (xb, _, _)) -> Printf.printf "%d\n" (xa * xb)
  else failwith "Invalid part, must be 1 or 2"

let part1 ch = solve 1 ch
let part2 ch = solve 2 ch