Skip to content

Instantly share code, notes, and snippets.

@esshka
Created April 8, 2025 02:18
Show Gist options
  • Save esshka/f91d2e133aad6ad4728b7f0af5a5fb4a to your computer and use it in GitHub Desktop.
Save esshka/f91d2e133aad6ad4728b7f0af5a5fb4a to your computer and use it in GitHub Desktop.
clojure dual nums impl
(ns dual-math
(:refer-clojure :exclude [+ - * / abs max min sin cos tan asin acos atan exp log pow sqrt])
(:require [clojure.core :as core]))
;; == Part 1: Dual Number Definition and Helpers ==
(defrecord DualNumber [value deriv])
;; Helper to check if something is a dual number
(defn dual? [x] (instance? DualNumber x))
;; Create a dual number representing a constant (derivative = 0)
(defn constant [x]
(if (dual? x)
x ; Already dual, pass through
(->DualNumber (double x) 0.0)))
;; Create a dual number representing a variable (derivative = 1)
(defn variable [x]
(->DualNumber (double x) 1.0))
;; Coerce inputs to DualNumber for operations.
;; Ensures we handle mixing Clojure numbers and DualNumbers.
(defn ->dual [x]
(if (dual? x) x (constant x)))
;; Extractors
(defn value [d] (:value d))
(defn deriv [d] (:deriv d))
;; == Part 2: Extending Clojure Protocols ==
;; We need to teach Clojure's arithmetic how to handle DualNumbers.
;; We use extend-protocol for clojure.lang.Numbers which covers +, -, *, etc.
(extend-protocol clojure.lang.Numbers
DualNumber
(+ [a b]
(let [a* (->dual a)
b* (->dual b)]
(->DualNumber (core/+ (:value a*) (:value b*))
(core/+ (:deriv a*) (:deriv b*)))))
(- [a] ;; Unary minus
(let [a* (->dual a)]
(->DualNumber (core/- (:value a*))
(core/- (:deriv a*)))))
(- [a b]
(let [a* (->dual a)
b* (->dual b)]
(->DualNumber (core/- (:value a*) (:value b*))
(core/- (:deriv a*) (:deriv b*)))))
(* [a b]
(let [a* (->dual a)
b* (->dual b)
va (:value a*)
vb (:value b*)
da (:deriv a*)
db (:deriv b*)]
;; Product Rule: d(uv) = u'v + uv'
(->DualNumber (core/* va vb)
(core/+ (core/* da vb) (core/* va db)))))
(/ [a b]
(let [a* (->dual a)
b* (->dual b)
va (:value a*)
vb (:value b*)
da (:deriv a*)
db (:deriv b*)]
(if (core/= vb 0.0)
(throw (ArithmeticException. "Division by zero in dual number"))
;; Quotient Rule: d(u/v) = (u'v - uv') / v^2
(->DualNumber (core// va vb)
(core// (core/- (core/* da vb) (core/* va db))
(core/* vb vb))))))
;; Other clojure.lang.Numbers methods if needed (abs, etc.)
;; Note: Comparisons (<, >, <=, >=) don't make obvious sense for dual numbers
;; in general, as they aren't totally ordered. We *could* compare only values.
(< [a b] (core/< (:value (->dual a)) (:value (->dual b))))
(<= [a b] (core/<= (:value (->dual a)) (:value (->dual b))))
(> [a b] (core/> (:value (->dual a)) (:value (->dual b))))
(>= [a b] (core/>= (:value (->dual a)) (:value (->dual b))))
(zero? [a] (core/zero? (:value (->dual a))))
(abs [a]
(let [a* (->dual a)
v (:value a*)
d (:deriv a*)]
;; d(abs(u))/dx = u' * sign(u)
;; Handle sign carefully at 0, derivative is undefined technically,
;; but often taken as 0 or handled by context. Let's use signum.
(->DualNumber (core/abs v) (core/* d (Math/signum v)))))
(max [a b]
(let [a* (->dual a)
b* (->dual b)]
(if (core/>= (:value a*) (:value b*)) a* b*))) ;; Derivative follows the chosen branch
(min [a b]
(let [a* (->dual a)
b* (->dual b)]
(if (core/<= (:value a*) (:value b*)) a* b*))) ;; Derivative follows the chosen branch
)
;; == Part 3: Dual-Aware Math Functions ==
;; We can't easily extend java.lang.Math, so we define new functions.
(defn sin [a]
(let [a* (->dual a)
v (:value a*)
d (:deriv a*)]
;; d(sin(u))/dx = u' * cos(u)
(->DualNumber (Math/sin v) (core/* d (Math/cos v)))))
(defn cos [a]
(let [a* (->dual a)
v (:value a*)
d (:deriv a*)]
;; d(cos(u))/dx = u' * -sin(u)
(->DualNumber (Math/cos v) (core/* d (core/- (Math/sin v))))))
(defn tan [a]
(let [a* (->dual a)
v (:value a*)
d (:deriv a*)]
;; d(tan(u))/dx = u' * sec^2(u) = u' / cos^2(u)
(let [cosv (Math/cos v)]
(->DualNumber (Math/tan v) (core// d (core/* cosv cosv))))))
(defn asin [a]
(let [a* (->dual a)
v (:value a*)
d (:deriv a*)]
;; d(asin(u))/dx = u' / sqrt(1 - u^2)
(->DualNumber (Math/asin v) (core// d (Math/sqrt (core/- 1.0 (core/* v v)))))))
(defn acos [a]
(let [a* (->dual a)
v (:value a*)
d (:deriv a*)]
;; d(acos(u))/dx = -u' / sqrt(1 - u^2)
(->DualNumber (Math/acos v) (core// (core/- d) (Math/sqrt (core/- 1.0 (core/* v v)))))))
(defn atan [a]
(let [a* (->dual a)
v (:value a*)
d (:deriv a*)]
;; d(atan(u))/dx = u' / (1 + u^2)
(->DualNumber (Math/atan v) (core// d (core/+ 1.0 (core/* v v))))))
(defn exp [a]
(let [a* (->dual a)
v (:value a*)
d (:deriv a*)]
;; d(exp(u))/dx = u' * exp(u)
(->DualNumber (Math/exp v) (core/* d (Math/exp v)))))
(defn log [a] ;; Natural log
(let [a* (->dual a)
v (:value a*)
d (:deriv a*)]
(if (core/<= v 0.0)
(throw (ArithmeticException. "Log of non-positive number in dual number"))
;; d(log(u))/dx = u' / u
(->DualNumber (Math/log v) (core// d v)))))
(defn pow [a b] ;; Computes a^b
(let [a* (->dual a)
b* (->dual b)
va (:value a*)
vb (:value b*)
da (:deriv a*)
db (:deriv b*)]
;; d(u^v)/dx = d(exp(v*log(u)))/dx
;; Chain rule: exp(v*log(u)) * d(v*log(u))/dx
;; d(v*log(u))/dx = v'*log(u) + v*(log(u))'
;; = v'*log(u) + v*(u'/u)
;; Result: u^v * (v'*log(u) + v*u'/u)
(if (core/<= va 0.0)
(throw (ArithmeticException. "Base must be positive for general dual power"))
(let [pow-val (Math/pow va vb)]
(->DualNumber pow-val
(core/* pow-val
(core/+ (core/* db (Math/log va))
(core/* vb (core// da va)))))))))
(defn sqrt [a]
(let [a* (->dual a)
v (:value a*)
d (:deriv a*)]
;; d(sqrt(u))/dx = u' / (2*sqrt(u))
(if (core/< v 0.0)
(throw (ArithmeticException. "Sqrt of negative number in dual number"))
(let [sqrt-v (Math/sqrt v)]
(->DualNumber sqrt-v (core// d (core/* 2.0 sqrt-v)))))))
;; == Part 4: Vectorization Functions (from Part 1/2) ==
;; Include these for completeness or assume they are in another namespace
(defn vectorize-binary-op [f]
(fn vbo [a b]
;; Handle potential scalars first for clarity with dual numbers
(let [a-is-vec (vector? a)
b-is-vec (vector? b)]
(cond
(and a-is-vec b-is-vec)
(if (= (count a) (count b))
(map vbo a b)
(throw (IllegalArgumentException. "Vector sizes must match for binary operation")))
a-is-vec
(map #(vbo % b) a) ; Broadcast b to elements of a
b-is-vec
(map #(vbo a %) b) ; Broadcast a to elements of b
:else ; Both are scalars (or DualNumbers treated as scalars)
(f a b)))))
(defn vectorize-unary-op [f]
(fn vuo [x]
(if (vector? x)
(map vuo x)
(f x))))
;; == Part 5: Define Vectorized Operations using Dual-Aware Functions ==
(def vadd (vectorize-binary-op +)) ;; Uses extended +
(def vsub (vectorize-binary-op -)) ;; Uses extended -
(def vmul (vectorize-binary-op *)) ;; Uses extended *
(def vdiv (vectorize-binary-op /)) ;; Uses extended /
(def vsin (vectorize-unary-op sin)) ;; Uses dual sin
(def vcos (vectorize-unary-op cos)) ;; Uses dual cos
(def vtan (vectorize-unary-op tan))
(def vasin (vectorize-unary-op asin))
(def vacos (vectorize-unary-op acos))
(def vatan (vectorize-unary-op atan))
(def vexp (vectorize-unary-op exp)) ;; Uses dual exp
(def vlog (vectorize-unary-op log)) ;; Uses dual log
(def vpow (vectorize-binary-op pow)) ;; Uses dual pow
(def vsqrt (vectorize-unary-op sqrt));; Uses dual sqrt
(def vabs (vectorize-unary-op abs))
(def vmax (vectorize-binary-op max))
(def vmin (vectorize-binary-op min))
;; Helper to get vectors of values or derivatives
(defn get-values [dual-coll] (map :value dual-coll))
(defn get-derivs [dual-coll] (map :deriv dual-coll))
;; == Part 6: Example Usage (from Article) ==
(comment
;; Basic scalar example
(let [x (variable 3.0)
y (* x x)]
(println "Scalar Example:")
(println "Value:" (:value y)) ;; => 9.0
(println "Deriv:" (:deriv y))) ;; => 6.0
;; Vector Add Example
(let [v1 [(variable 1.0) (variable 2.0)]
v2 [(constant 10.0) (constant 5.0)]
result (vadd v1 v2)]
(println "\nVector Add Example:")
(println "Result:" (vec result)))
;; => [#my_dual_math.DualNumber{:value 11.0, :deriv 1.0} #my_dual_math.DualNumber{:value 7.0, :deriv 1.0}]
;; Scalar Broadcast Example
(let [v [(constant 1.0) (constant 2.0) (constant 3.0)]
s (variable 10.0)
result (vmul s v)]
(println "\nScalar Broadcast Example:")
(println "Result:" (vec result)))
;; => [#my_dual_math.DualNumber{:value 10.0, :deriv 1.0} #my_dual_math.DualNumber{:value 20.0, :deriv 2.0} #my_dual_math.DualNumber{:value 30.0, :deriv 3.0}]
;; Vector Sin Example
(let [angles [(variable 0.0) (variable (/ Math/PI 2))]
sines (vsin angles)]
(println "\nVector Sin Example:")
(println "Result:" (vec sines))
(println "Values:" (vec (get-values sines))) ;; => [0.0 1.0]
(println "Derivs:" (vec (get-derivs sines)))) ;; => [1.0 0.0]
;; Gradient Example from Article
(defn square [x] (* x x))
(def vsquare (vectorize-unary-op square))
(defn f [v]
(vadd (vsquare v) (vsin v)))
(def input-vec [(variable 1.0) (variable 2.0) (variable 3.0)])
(def result-vec (f input-vec))
(def values (vec (get-values result-vec)))
(def gradient (vec (get-derivs result-vec)))
(println "\nGradient Example:")
(println "Input Vec:" (vec input-vec))
(println "Result Values f(v):" values)
(println "Result Gradient f'(v):" gradient)
;; f'(x) = 2x + cos(x)
;; f'(1) approx 2.5403
;; f'(2) approx 3.5839
;; f'(3) approx 5.0101
) ;; End comment block
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment