Created
October 9, 2013 20:15
-
-
Save estk/6907605 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import System.Random | |
import Data.List | |
import Graphics.EasyPlot | |
-- Testing Aparatus | |
-- | |
repeatTimes = 1000 | |
testInstanceNum = 1000 | |
trainInstanceNums = [100,200..1000] | |
main :: IO Bool | |
main = do | |
gen <- newStdGen | |
let d = makeData $ randomRs (0,1000) gen-- Infinite set of data. | |
let errRates = map (average . getErrRates d) trainInstanceNums | |
let failRates = map (failureRate . getErrRates d) trainInstanceNums | |
let errRes = zip (map fromIntegral trainInstanceNums) errRates | |
let failRes = zip (map fromIntegral trainInstanceNums) failRates | |
print errRes | |
print failRes | |
plot (Latex "errorRate.tex") $ Data2D [Title "Error Rate per Training Set Size", Color Blue, Style Linespoints] [] errRes | |
plot (Latex "failureRate.tex") $ Data2D [Title "Failure Rate per Training Set Size", Color Blue, Style Linespoints] [] failRes | |
where average lst = (sum lst) / (fromIntegral $ length lst) | |
failureRate lst = (fromIntegral $ length $ failures lst) / (fromIntegral repeatTimes) | |
failures lst = filter (> 0.05) lst | |
getErrRates :: [Instance] -> Int -> [ErrorRate] | |
getErrRates d setSize = nTimes repeatTimes setSize d | |
nTimes :: Int -> Int -> [Instance] -> [ErrorRate] | |
nTimes n setSize d | n == 0 = [] | |
| otherwise = [learn (getCurrentd d) setSize] ++ (nTimes(n-1) setSize (getRestd d)) | |
where getCurrentd d = take payLoadSize d | |
getRestd d = drop payLoadSize d | |
payLoadSize = setSize + testInstanceNum | |
-- Learning Logic | |
type ErrorRate = Float | |
type Instance = (Int, Bool) | |
type Hypothesis = (Int, Int) | |
learn :: [Instance] -> Int -> ErrorRate | |
learn d n = scale $ test (findHypothesis trainData) testData | |
where (testData, trainData) = splitAt testInstanceNum $ take (testInstanceNum+n) d | |
scale err = (fromIntegral err) / (fromIntegral testInstanceNum) | |
test :: Hypothesis -> [Instance] -> Int | |
test h d = length $ incorrect $ predict h d | |
incorrect :: [(Instance, Bool)] -> [Instance] | |
incorrect lst = map fst $ filter helper lst | |
where helper ((x,t),p) = t /= p | |
predict :: Hypothesis -> [Instance] -> [(Instance, Bool)] | |
predict (min, max) d = map predicter d | |
where predicter inst@(x,t) = (inst, inInterv x) | |
inInterv x = min <= x && x <= max | |
findHypothesis :: [Instance] -> Hypothesis | |
findHypothesis lst = if (length pos) /= 0 | |
then ((minimum pos), (maximum pos)) | |
else (-1,-1) | |
where pos = map fst $ filter snd lst | |
makeData :: [Int] -> [Instance] | |
makeData lst = map classifier lst | |
where classifier x = (x, inInterv x) | |
inInterv x = 400 < x && x < 600 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment