Created
April 8, 2025 02:18
-
-
Save esshka/f91d2e133aad6ad4728b7f0af5a5fb4a to your computer and use it in GitHub Desktop.
clojure dual nums impl
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
(ns dual-math | |
(:refer-clojure :exclude [+ - * / abs max min sin cos tan asin acos atan exp log pow sqrt]) | |
(:require [clojure.core :as core])) | |
;; == Part 1: Dual Number Definition and Helpers == | |
(defrecord DualNumber [value deriv]) | |
;; Helper to check if something is a dual number | |
(defn dual? [x] (instance? DualNumber x)) | |
;; Create a dual number representing a constant (derivative = 0) | |
(defn constant [x] | |
(if (dual? x) | |
x ; Already dual, pass through | |
(->DualNumber (double x) 0.0))) | |
;; Create a dual number representing a variable (derivative = 1) | |
(defn variable [x] | |
(->DualNumber (double x) 1.0)) | |
;; Coerce inputs to DualNumber for operations. | |
;; Ensures we handle mixing Clojure numbers and DualNumbers. | |
(defn ->dual [x] | |
(if (dual? x) x (constant x))) | |
;; Extractors | |
(defn value [d] (:value d)) | |
(defn deriv [d] (:deriv d)) | |
;; == Part 2: Extending Clojure Protocols == | |
;; We need to teach Clojure's arithmetic how to handle DualNumbers. | |
;; We use extend-protocol for clojure.lang.Numbers which covers +, -, *, etc. | |
(extend-protocol clojure.lang.Numbers | |
DualNumber | |
(+ [a b] | |
(let [a* (->dual a) | |
b* (->dual b)] | |
(->DualNumber (core/+ (:value a*) (:value b*)) | |
(core/+ (:deriv a*) (:deriv b*))))) | |
(- [a] ;; Unary minus | |
(let [a* (->dual a)] | |
(->DualNumber (core/- (:value a*)) | |
(core/- (:deriv a*))))) | |
(- [a b] | |
(let [a* (->dual a) | |
b* (->dual b)] | |
(->DualNumber (core/- (:value a*) (:value b*)) | |
(core/- (:deriv a*) (:deriv b*))))) | |
(* [a b] | |
(let [a* (->dual a) | |
b* (->dual b) | |
va (:value a*) | |
vb (:value b*) | |
da (:deriv a*) | |
db (:deriv b*)] | |
;; Product Rule: d(uv) = u'v + uv' | |
(->DualNumber (core/* va vb) | |
(core/+ (core/* da vb) (core/* va db))))) | |
(/ [a b] | |
(let [a* (->dual a) | |
b* (->dual b) | |
va (:value a*) | |
vb (:value b*) | |
da (:deriv a*) | |
db (:deriv b*)] | |
(if (core/= vb 0.0) | |
(throw (ArithmeticException. "Division by zero in dual number")) | |
;; Quotient Rule: d(u/v) = (u'v - uv') / v^2 | |
(->DualNumber (core// va vb) | |
(core// (core/- (core/* da vb) (core/* va db)) | |
(core/* vb vb)))))) | |
;; Other clojure.lang.Numbers methods if needed (abs, etc.) | |
;; Note: Comparisons (<, >, <=, >=) don't make obvious sense for dual numbers | |
;; in general, as they aren't totally ordered. We *could* compare only values. | |
(< [a b] (core/< (:value (->dual a)) (:value (->dual b)))) | |
(<= [a b] (core/<= (:value (->dual a)) (:value (->dual b)))) | |
(> [a b] (core/> (:value (->dual a)) (:value (->dual b)))) | |
(>= [a b] (core/>= (:value (->dual a)) (:value (->dual b)))) | |
(zero? [a] (core/zero? (:value (->dual a)))) | |
(abs [a] | |
(let [a* (->dual a) | |
v (:value a*) | |
d (:deriv a*)] | |
;; d(abs(u))/dx = u' * sign(u) | |
;; Handle sign carefully at 0, derivative is undefined technically, | |
;; but often taken as 0 or handled by context. Let's use signum. | |
(->DualNumber (core/abs v) (core/* d (Math/signum v))))) | |
(max [a b] | |
(let [a* (->dual a) | |
b* (->dual b)] | |
(if (core/>= (:value a*) (:value b*)) a* b*))) ;; Derivative follows the chosen branch | |
(min [a b] | |
(let [a* (->dual a) | |
b* (->dual b)] | |
(if (core/<= (:value a*) (:value b*)) a* b*))) ;; Derivative follows the chosen branch | |
) | |
;; == Part 3: Dual-Aware Math Functions == | |
;; We can't easily extend java.lang.Math, so we define new functions. | |
(defn sin [a] | |
(let [a* (->dual a) | |
v (:value a*) | |
d (:deriv a*)] | |
;; d(sin(u))/dx = u' * cos(u) | |
(->DualNumber (Math/sin v) (core/* d (Math/cos v))))) | |
(defn cos [a] | |
(let [a* (->dual a) | |
v (:value a*) | |
d (:deriv a*)] | |
;; d(cos(u))/dx = u' * -sin(u) | |
(->DualNumber (Math/cos v) (core/* d (core/- (Math/sin v)))))) | |
(defn tan [a] | |
(let [a* (->dual a) | |
v (:value a*) | |
d (:deriv a*)] | |
;; d(tan(u))/dx = u' * sec^2(u) = u' / cos^2(u) | |
(let [cosv (Math/cos v)] | |
(->DualNumber (Math/tan v) (core// d (core/* cosv cosv)))))) | |
(defn asin [a] | |
(let [a* (->dual a) | |
v (:value a*) | |
d (:deriv a*)] | |
;; d(asin(u))/dx = u' / sqrt(1 - u^2) | |
(->DualNumber (Math/asin v) (core// d (Math/sqrt (core/- 1.0 (core/* v v))))))) | |
(defn acos [a] | |
(let [a* (->dual a) | |
v (:value a*) | |
d (:deriv a*)] | |
;; d(acos(u))/dx = -u' / sqrt(1 - u^2) | |
(->DualNumber (Math/acos v) (core// (core/- d) (Math/sqrt (core/- 1.0 (core/* v v))))))) | |
(defn atan [a] | |
(let [a* (->dual a) | |
v (:value a*) | |
d (:deriv a*)] | |
;; d(atan(u))/dx = u' / (1 + u^2) | |
(->DualNumber (Math/atan v) (core// d (core/+ 1.0 (core/* v v)))))) | |
(defn exp [a] | |
(let [a* (->dual a) | |
v (:value a*) | |
d (:deriv a*)] | |
;; d(exp(u))/dx = u' * exp(u) | |
(->DualNumber (Math/exp v) (core/* d (Math/exp v))))) | |
(defn log [a] ;; Natural log | |
(let [a* (->dual a) | |
v (:value a*) | |
d (:deriv a*)] | |
(if (core/<= v 0.0) | |
(throw (ArithmeticException. "Log of non-positive number in dual number")) | |
;; d(log(u))/dx = u' / u | |
(->DualNumber (Math/log v) (core// d v))))) | |
(defn pow [a b] ;; Computes a^b | |
(let [a* (->dual a) | |
b* (->dual b) | |
va (:value a*) | |
vb (:value b*) | |
da (:deriv a*) | |
db (:deriv b*)] | |
;; d(u^v)/dx = d(exp(v*log(u)))/dx | |
;; Chain rule: exp(v*log(u)) * d(v*log(u))/dx | |
;; d(v*log(u))/dx = v'*log(u) + v*(log(u))' | |
;; = v'*log(u) + v*(u'/u) | |
;; Result: u^v * (v'*log(u) + v*u'/u) | |
(if (core/<= va 0.0) | |
(throw (ArithmeticException. "Base must be positive for general dual power")) | |
(let [pow-val (Math/pow va vb)] | |
(->DualNumber pow-val | |
(core/* pow-val | |
(core/+ (core/* db (Math/log va)) | |
(core/* vb (core// da va))))))))) | |
(defn sqrt [a] | |
(let [a* (->dual a) | |
v (:value a*) | |
d (:deriv a*)] | |
;; d(sqrt(u))/dx = u' / (2*sqrt(u)) | |
(if (core/< v 0.0) | |
(throw (ArithmeticException. "Sqrt of negative number in dual number")) | |
(let [sqrt-v (Math/sqrt v)] | |
(->DualNumber sqrt-v (core// d (core/* 2.0 sqrt-v))))))) | |
;; == Part 4: Vectorization Functions (from Part 1/2) == | |
;; Include these for completeness or assume they are in another namespace | |
(defn vectorize-binary-op [f] | |
(fn vbo [a b] | |
;; Handle potential scalars first for clarity with dual numbers | |
(let [a-is-vec (vector? a) | |
b-is-vec (vector? b)] | |
(cond | |
(and a-is-vec b-is-vec) | |
(if (= (count a) (count b)) | |
(map vbo a b) | |
(throw (IllegalArgumentException. "Vector sizes must match for binary operation"))) | |
a-is-vec | |
(map #(vbo % b) a) ; Broadcast b to elements of a | |
b-is-vec | |
(map #(vbo a %) b) ; Broadcast a to elements of b | |
:else ; Both are scalars (or DualNumbers treated as scalars) | |
(f a b))))) | |
(defn vectorize-unary-op [f] | |
(fn vuo [x] | |
(if (vector? x) | |
(map vuo x) | |
(f x)))) | |
;; == Part 5: Define Vectorized Operations using Dual-Aware Functions == | |
(def vadd (vectorize-binary-op +)) ;; Uses extended + | |
(def vsub (vectorize-binary-op -)) ;; Uses extended - | |
(def vmul (vectorize-binary-op *)) ;; Uses extended * | |
(def vdiv (vectorize-binary-op /)) ;; Uses extended / | |
(def vsin (vectorize-unary-op sin)) ;; Uses dual sin | |
(def vcos (vectorize-unary-op cos)) ;; Uses dual cos | |
(def vtan (vectorize-unary-op tan)) | |
(def vasin (vectorize-unary-op asin)) | |
(def vacos (vectorize-unary-op acos)) | |
(def vatan (vectorize-unary-op atan)) | |
(def vexp (vectorize-unary-op exp)) ;; Uses dual exp | |
(def vlog (vectorize-unary-op log)) ;; Uses dual log | |
(def vpow (vectorize-binary-op pow)) ;; Uses dual pow | |
(def vsqrt (vectorize-unary-op sqrt));; Uses dual sqrt | |
(def vabs (vectorize-unary-op abs)) | |
(def vmax (vectorize-binary-op max)) | |
(def vmin (vectorize-binary-op min)) | |
;; Helper to get vectors of values or derivatives | |
(defn get-values [dual-coll] (map :value dual-coll)) | |
(defn get-derivs [dual-coll] (map :deriv dual-coll)) | |
;; == Part 6: Example Usage (from Article) == | |
(comment | |
;; Basic scalar example | |
(let [x (variable 3.0) | |
y (* x x)] | |
(println "Scalar Example:") | |
(println "Value:" (:value y)) ;; => 9.0 | |
(println "Deriv:" (:deriv y))) ;; => 6.0 | |
;; Vector Add Example | |
(let [v1 [(variable 1.0) (variable 2.0)] | |
v2 [(constant 10.0) (constant 5.0)] | |
result (vadd v1 v2)] | |
(println "\nVector Add Example:") | |
(println "Result:" (vec result))) | |
;; => [#my_dual_math.DualNumber{:value 11.0, :deriv 1.0} #my_dual_math.DualNumber{:value 7.0, :deriv 1.0}] | |
;; Scalar Broadcast Example | |
(let [v [(constant 1.0) (constant 2.0) (constant 3.0)] | |
s (variable 10.0) | |
result (vmul s v)] | |
(println "\nScalar Broadcast Example:") | |
(println "Result:" (vec result))) | |
;; => [#my_dual_math.DualNumber{:value 10.0, :deriv 1.0} #my_dual_math.DualNumber{:value 20.0, :deriv 2.0} #my_dual_math.DualNumber{:value 30.0, :deriv 3.0}] | |
;; Vector Sin Example | |
(let [angles [(variable 0.0) (variable (/ Math/PI 2))] | |
sines (vsin angles)] | |
(println "\nVector Sin Example:") | |
(println "Result:" (vec sines)) | |
(println "Values:" (vec (get-values sines))) ;; => [0.0 1.0] | |
(println "Derivs:" (vec (get-derivs sines)))) ;; => [1.0 0.0] | |
;; Gradient Example from Article | |
(defn square [x] (* x x)) | |
(def vsquare (vectorize-unary-op square)) | |
(defn f [v] | |
(vadd (vsquare v) (vsin v))) | |
(def input-vec [(variable 1.0) (variable 2.0) (variable 3.0)]) | |
(def result-vec (f input-vec)) | |
(def values (vec (get-values result-vec))) | |
(def gradient (vec (get-derivs result-vec))) | |
(println "\nGradient Example:") | |
(println "Input Vec:" (vec input-vec)) | |
(println "Result Values f(v):" values) | |
(println "Result Gradient f'(v):" gradient) | |
;; f'(x) = 2x + cos(x) | |
;; f'(1) approx 2.5403 | |
;; f'(2) approx 3.5839 | |
;; f'(3) approx 5.0101 | |
) ;; End comment block |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment