Mix.install([
  {:nx, "~> 0.1.0"},
])
input = """
13            26            9         44
2             14            6         23
14            20            3         28
23            25            9         60
13            24            8         42
1             12            2         5
18            23            9         51
10            18            10        44
26            24            3         42
3             14            1         9
3             12            3         14
21            27            5         43
7             17            3         22
22            21            1         34
2             12            4         16
27            26            2         46
6             15            4         26
10            21            7         33
18            18            3         29
15            26            8         43
9             20            6         37
26            25            9         62
8             21            10        47
15            22            7         38
10            20            2         22
21            21            1         29
5             12            7         34
6             14            9         38
13            19            4         30
13            20            3         28
"""
data = input
|> String.split("\n", trim: true)
|> Enum.map(fn s -> String.split(s, " ", trim: true) |> Enum.map(&String.to_integer/1) end)

tensor = Nx.tensor(data, names: [:y, :x])
x = tensor[x: 0..2]
{ b, n } = x.shape
y = tensor[x: 3] |> Nx.reshape({b, 1})
w = Nx.reshape(Nx.tensor(List.duplicate(0, n)), {n, 1})

gen = List.duplicate(1, b) |> Nx.tensor() |> Nx.reshape({b, 1}) 
new_x = Nx.concatenate([gen, x], axis: 1)

defmodule Hyperspace do
  import Nx.Defn

  defn predict(x, w) do
    Nx.dot(x, w)
  end

  defn loss(x, y, w) do
    (predict(x, w) - y)
    |> Nx.power(2)
    |> Nx.mean()
  end

  defn gradient(x, y, w, lr) do
    y_hat = predict(x, w) - y
    new_w = Nx.transpose(x)
    |> Nx.dot(y_hat)
    |> Nx.multiply(2)
    |> Nx.divide(elem(x.shape, 0))
    
    w - new_w * lr
  end
  def train(x, y, iteration, lr) do
    n = elem(x.shape, 1)
    w = Nx.reshape(Nx.tensor(List.duplicate(0, n)), {n , 1})
    for i <- 1..iteration, reduce: w do
      old_w ->
        gradient(x, y, old_w, lr)
    end
  end
end

w = Hyperspace.train(new_x, y, 100000, 0.001)
extract = fn e -> e |> Nx.to_flat_list() |> Enum.join(",") end 
IO.puts("\n Weights #{extract.(w)}")
IO.puts("\n A few prediction:")

for i <- 0..4 do
  t = Hyperspace.predict(new_x[i], w)
  IO.puts("X[#{i}] -> #{extract.(t)} (label: #{extract.(y[i])}) ")
end