Created
December 23, 2023 06:10
-
-
Save gelisam/e1e7e312d3b5e5d16496c4135b0c885f to your computer and use it in GitHub Desktop.
Tracking whether the combination of two functions is still strictly-monotonic
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
-- In response to https://twitter.com/kmett/status/1738168271357026634 | |
-- | |
-- The challenge is to implement a version of | |
-- | |
-- > mapKeys :: Ord k2 => (k1 -> k2) -> Map k1 a -> Map k2 a | |
-- | |
-- which costs O(1) if the (k1 -> k2) function is coerce and the coercion | |
-- preserves the ordering, O(n) if the function is injective, and O(n log n) | |
-- otherwise. Obviously, the implementation can't inspect a pure function in | |
-- order to determine which case it is, so our version will need to use a | |
-- different type. | |
-- | |
-- The challenge asks to do this "without making your user cry or accidentally | |
-- violate invariants". | |
{-# LANGUAGE GADTs #-} | |
{-# LANGUAGE LambdaCase #-} | |
module Main where | |
import Prelude hiding (id, (.)) | |
import Control.Category (Category(..)) | |
import Data.Coerce (Coercible, coerce) | |
import Data.Functor.Identity (Identity(Identity)) | |
import Data.Map (Map) | |
import qualified Data.Map as Map | |
import Data.Ord (Down(Down)) | |
import Unsafe.Coerce (unsafeCoerce) | |
-- Haskell's type system is great, but you'll get into trouble if you try to | |
-- track too much at the type level. Suppose we tracked whether a function is | |
-- strictly-monotone at the type level, and then used a typeclass to pick which | |
-- 'mapKeys' implement to use. Simple examples would probably work fine, but | |
-- higher-order functions would require more complex types might fail the "don't | |
-- make the user cry" criteria. | |
-- | |
-- The secret to this challenge is thus: let's track this at the value level! | |
-- | |
-- The simplest solution would be to define a sum type for the three cases and | |
-- to let the user specify at the value level which case they want: | |
-- | |
-- > data Fun a b where | |
-- > Coerce | |
-- > :: Coercible a b | |
-- > => Fun a b | |
-- > StrictlyMonotone | |
-- > :: (a -> b) | |
-- > -> Fun a b | |
-- > NotMonotone | |
-- > :: (a -> b) | |
-- > -> Fun a b | |
-- > | |
-- > fancyMapKeys (Coerce :: String -> Identity String) exampleMap1 | |
-- > fancyMapKeys (StrictlyMonotone (\x -> [x])) exampleMap1 | |
-- > fancyMapKeys (NotMonotone Down) exampleMap1 | |
-- | |
-- But this might fail the "don't accidentally violate invariants" criteria, as | |
-- the user has to manually label their function at every single call site and | |
-- is bound to make a mistake sooner or later. | |
-- | |
-- The second secret to this challenge is thus: define a combinator library! | |
-- The library defines a bunch of primitives which are already | |
-- correctly-labeled, plus a bunch of combinators which correctly update the | |
-- label. When prototyping, it is easy for the user to use 'arr' to get a | |
-- correct program which might not be as efficient as possible. When the user | |
-- needs the extra performance, then they can use the primitives to define a | |
-- more efficient version. And if the primitives are insufficient, they can use | |
-- 'unsafeStrictlyMonotone', and it it only then that the user needs to think | |
-- about whether that label is correct. | |
-- | |
-- In the implementation below, I split 'Fun' into 'CFun' and 'MFun', because | |
-- the decisions regarding whether coerce can be used are orthogonal to the | |
-- decisions regarding whether a function is strictly-monotone. This is turn | |
-- leads to some typeclasses, just to reuse the same combinator name for (->), | |
-- 'CFun', and 'MFun'. Those typeclasses are entirely unnecessary, as my | |
-- solution is at the value level, not the type level. | |
data CFun a b where | |
Coerce :: Coercible a b => CFun a b | |
NotCoerce :: (a -> b) -> CFun a b | |
runCFun :: CFun a b -> a -> b | |
runCFun Coerce = coerce | |
runCFun (NotCoerce f) = f | |
-- | In a real library, these data constructors would be exported from a | |
-- ".Internal" module, so that dedicated users can define their own combinators. | |
data MFun k a b | |
= -- | if x < y then f x < f y | |
StrictlyMonotone (k a b) | |
| NotMonotone (k a b) | |
runMFun :: MFun k a b -> k a b | |
runMFun (StrictlyMonotone f) = f | |
runMFun (NotMonotone f) = f | |
-- | The caller promises that the function really is strictly-monotone. | |
unsafeStrictlyMonotone | |
:: k a b | |
-> MFun k a b | |
unsafeStrictlyMonotone = StrictlyMonotone | |
type MCFun = MFun CFun | |
runMCFun :: MCFun a b -> a -> b | |
runMCFun = runCFun . runMFun | |
arr :: (a -> b) -> MCFun a b | |
arr = NotMonotone . NotCoerce | |
fancyMapKeys | |
:: Ord k2 | |
=> MCFun k1 k2 | |
-> Map k1 a | |
-> Map k2 a | |
fancyMapKeys (StrictlyMonotone Coerce) | |
= -- O(1) case. | |
-- Coerce brings a Coercible instance in scope, so we could call coerce, | |
-- but that intentionally doesn't type-check thanks to k1's nominal role, | |
-- which prevents e.g. | |
-- > coerce :: Map k1 a -> Map (Down k1) a | |
-- from breaking the invariant. In this case, we know that the invariant | |
-- will not be broken because coerce is strictly-monotone, but the type | |
-- system doesn't know that, so we have to use unsafeCoerce. | |
unsafeCoerce | |
fancyMapKeys (StrictlyMonotone (NotCoerce f)) | |
= -- O(n) case. | |
Map.mapKeysMonotonic f | |
fancyMapKeys (NotMonotone f) | |
= -- O(n log n) case. | |
Map.mapKeys (runCFun f) | |
-- That's it, that's the core of the library, and it was trivial! The rest | |
-- defines a bunch of combinators which carefully update the labels. | |
-- A lot more could be defined, I am sure. | |
instance Category CFun where | |
id = Coerce | |
Coerce . Coerce | |
= Coerce | |
f . g | |
= NotCoerce (runCFun f . runCFun g) | |
instance Category k => Category (MFun k) where | |
id = StrictlyMonotone id | |
StrictlyMonotone f . StrictlyMonotone g | |
= -- if x < y | |
-- then g x < g y | |
-- then f (g x) < f (g y) | |
StrictlyMonotone (f . g) | |
f . g | |
= NotMonotone (runMFun f . runMFun g) | |
class Product k where | |
(***) :: k a1 b1 -> k a2 b2 -> k (a1,a2) (b1,b2) | |
instance Product (->) where | |
(f1 *** f2) (x1,x2) = (f1 x1, f2 x2) | |
instance Product CFun where | |
Coerce *** Coerce | |
= Coerce | |
f1 *** f2 | |
= NotCoerce $ \(x1,x2) | |
-> (runCFun f1 x1, runCFun f2 x2) | |
instance Product k => Product (MFun k) where | |
StrictlyMonotone f1 *** StrictlyMonotone f2 | |
= -- if (x1,x2) < (y1,y2) | |
-- case 1: x1 < y1 | |
-- then f1 x1 < f1 y1 | |
-- then (f1 x1, _) < (f1 y1, _) | |
-- then (f1 x1, f2 x2) < (f1 y1, f2 y2) | |
-- case 2: x1 == y1 and x2 < y2 | |
-- then f1 x1 == f1 y1 && f2 x2 < f2 y2 | |
-- then (f1 x1, f2 x2) < (f1 y1, f2 y2) | |
StrictlyMonotone (f1 *** f2) | |
f1 *** f2 | |
= NotMonotone (runMFun f1 *** runMFun f2) | |
dup :: MCFun a (a,a) | |
dup | |
= -- if x < y | |
-- then (x,_) < (y,_) | |
-- then (x,x) < (y,y) | |
StrictlyMonotone | |
$ NotCoerce $ \x | |
-> (x,x) | |
(&&&) :: MCFun a b1 -> MCFun a b2 -> MCFun a (b1,b2) | |
f1 &&& f2 = (f1 *** f2) . dup | |
class Sum k where | |
(+++) :: k a1 b1 -> k a2 b2 -> k (Either a1 a2) (Either b1 b2) | |
instance Sum (->) where | |
f1 +++ f2 = \case | |
Left x1 | |
-> Left (f1 x1) | |
Right x2 | |
-> Right (f2 x2) | |
instance Sum CFun where | |
Coerce +++ Coerce | |
= Coerce | |
f1 +++ f2 | |
= NotCoerce $ \case | |
Left x1 -> Left (runCFun f1 x1) | |
Right x2 -> Right (runCFun f2 x2) | |
instance Sum k => Sum (MFun k) where | |
StrictlyMonotone f1 +++ StrictlyMonotone f2 | |
= -- if x < y | |
-- case 1: x = Left x1 && y = Left y1 && x1 < y1 | |
-- then f1 x1 < f1 y1 | |
-- then Left (f1 x1) < Left (f1 y1) | |
-- then (f1 +++ f2) x < (f1 +++ f2) y | |
-- case 2: x = Left x1 && y = Right y2 | |
-- then Left _ < Right _ | |
-- then Left (f1 x1) < Right (f2 y2) | |
-- then (f1 +++ f2) x < (f1 +++ f2) y | |
-- case 3: x = Right x2 && y = Right y1 && x2 < y2 | |
-- then f2 x2 < f2 y2 | |
-- then Right (f2 x2) < Right (f2 y2) | |
-- then (f1 +++ f2) x < (f1 +++ f2) y | |
StrictlyMonotone (f1 +++ f2) | |
f1 +++ f2 | |
= NotMonotone (runMFun f1 +++ runMFun f2) | |
class MapList k where | |
mapList :: k a b -> k [a] [b] | |
instance MapList (->) where | |
mapList = map | |
instance MapList CFun where | |
mapList Coerce | |
= Coerce | |
mapList (NotCoerce f) | |
= NotCoerce (map f) | |
instance MapList k => MapList (MFun k) where | |
mapList (StrictlyMonotone f) | |
= -- if xs < ys | |
-- case 1: xs = [] && ys = y:ys' | |
-- then [] < _:_ | |
-- then map f [] < _:_ | |
-- then map f xs < f y : map f ys' | |
-- then map f xs < map f ys | |
-- case 2: xs = x:xs' && ys = y:ys' && x < y | |
-- then f x < f y | |
-- then f x : _ < f y : _ | |
-- then f x : map f xs' < f y : map f ys' | |
-- then map f xs < map f ys | |
-- case 3: xs = x:xs' && ys = y:ys' && x == y && xs' < ys' | |
-- then f x == f y && by induction, map f xs' < map f ys' | |
-- then f x : map f xs' < f y : map f ys' | |
-- then map f xs < map f ys | |
StrictlyMonotone (mapList f) | |
mapList (NotMonotone f) | |
= NotMonotone (mapList f) | |
-- Finally, let's write some tests. From now on we refrain from using the data | |
-- constructors from the ".Internal" module, and we imagine that it is the user | |
-- who is writing the code below using the library above. Thus, the user spends | |
-- their cognitive budget on making sure that 'wrapIdentity', 'singleton', and | |
-- 'addPrefix' really are strictly-monotone, and then they reap the benefits | |
-- below, in 'main', where they can compose those functions in a bunch of | |
-- different ways without having to think about monotonicity anymore. | |
exampleMap1 :: Map String Int | |
exampleMap1 = Map.fromList [("a",1),("b",2)] | |
wrapIdentity :: MCFun a (Identity a) | |
wrapIdentity | |
= unsafeStrictlyMonotone | |
$ Coerce | |
singleton :: MCFun a [a] | |
singleton | |
= unsafeStrictlyMonotone | |
$ NotCoerce (:[]) | |
addPrefix :: String -> MCFun String String | |
addPrefix prefix | |
= unsafeStrictlyMonotone | |
$ NotCoerce (prefix ++) | |
down :: MCFun a (Down a) | |
down = arr Down | |
printComplexity | |
:: MCFun a b | |
-> IO () | |
printComplexity (StrictlyMonotone Coerce) = do | |
putStrLn "O(1)" | |
printComplexity (StrictlyMonotone (NotCoerce _)) = do | |
putStrLn "O(n)" | |
printComplexity (NotMonotone _) = do | |
putStrLn "O(n log n)" | |
test | |
:: (Eq a, Ord b, Show b) | |
=> MCFun a b | |
-> Map a Int | |
-> IO () | |
test f input = do | |
let expected = Map.mapKeys (runMCFun f) input | |
let actual = fancyMapKeys f input | |
if expected == actual | |
then do | |
printComplexity f | |
else do | |
putStrLn "expected:" | |
print expected | |
putStrLn "actual:" | |
print actual | |
main :: IO () | |
main = do | |
test id exampleMap1 -- O(1) | |
test wrapIdentity exampleMap1 -- O(1) | |
test singleton exampleMap1 -- O(n) | |
test (addPrefix "./") exampleMap1 -- O(n) | |
test down exampleMap1 -- O(n log n) | |
test (wrapIdentity &&& wrapIdentity) exampleMap1 -- O(n) | |
test (id . wrapIdentity . id) exampleMap1 -- O(1) | |
test (addPrefix "../" . addPrefix "../") exampleMap1 -- O(n) | |
test (mapList id) exampleMap1 -- O(1) | |
test (mapList wrapIdentity) exampleMap1 -- O(1) | |
test (arr (map (\c -> "./" ++ [c]))) exampleMap1 -- when prototyping: (n log n) | |
test (mapList (addPrefix "./" . singleton)) exampleMap1 -- O(n) | |
test (mapList down) exampleMap1 -- O(n log n) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment