Created
          December 30, 2020 07:23 
        
      - 
      
 - 
        
Save iwiwi/10fb477eceaff0d36cdacf9a268db780 to your computer and use it in GitHub Desktop.  
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
  | use crate::*; | |
| mod normal_distribution { | |
| const S2PI: f64 = 2.50662827463100050242E0; | |
| // https://github.com/scipy/scipy/blob/v1.5.4/scipy/special/cephes/ndtri.c | |
| const P0: [f64; 5] = [ | |
| -5.99633501014107895267E1, | |
| 9.80010754185999661536E1, | |
| -5.66762857469070293439E1, | |
| 1.39312609387279679503E1, | |
| -1.23916583867381258016E0, | |
| ]; | |
| const Q0: [f64; 8] = [ | |
| /* 1.00000000000000000000E0, */ | |
| 1.95448858338141759834E0, | |
| 4.67627912898881538453E0, | |
| 8.63602421390890590575E1, | |
| -2.25462687854119370527E2, | |
| 2.00260212380060660359E2, | |
| -8.20372256168333339912E1, | |
| 1.59056225126211695515E1, | |
| -1.18331621121330003142E0, | |
| ]; | |
| const P1: [f64; 9] = [ | |
| 4.05544892305962419923E0, | |
| 3.15251094599893866154E1, | |
| 5.71628192246421288162E1, | |
| 4.40805073893200834700E1, | |
| 1.46849561928858024014E1, | |
| 2.18663306850790267539E0, | |
| -1.40256079171354495875E-1, | |
| -3.50424626827848203418E-2, | |
| -8.57456785154685413611E-4, | |
| ]; | |
| const Q1: [f64; 8] = [ | |
| /* 1.00000000000000000000E0, */ | |
| 1.57799883256466749731E1, | |
| 4.53907635128879210584E1, | |
| 4.13172038254672030440E1, | |
| 1.50425385692907503408E1, | |
| 2.50464946208309415979E0, | |
| -1.42182922854787788574E-1, | |
| -3.80806407691578277194E-2, | |
| -9.33259480895457427372E-4, | |
| ]; | |
| const P2: [f64; 9] = [ | |
| 3.23774891776946035970E0, | |
| 6.91522889068984211695E0, | |
| 3.93881025292474443415E0, | |
| 1.33303460815807542389E0, | |
| 2.01485389549179081538E-1, | |
| 1.23716634817820021358E-2, | |
| 3.01581553508235416007E-4, | |
| 2.65806974686737550832E-6, | |
| 6.23974539184983293730E-9, | |
| ]; | |
| const Q2: [f64; 8] = [ | |
| /* 1.00000000000000000000E0, */ | |
| 6.02427039364742014255E0, | |
| 3.67983563856160859403E0, | |
| 1.37702099489081330271E0, | |
| 2.16236993594496635890E-1, | |
| 1.34204006088543189037E-2, | |
| 3.28014464682127739104E-4, | |
| 2.89247864745380683936E-6, | |
| 6.79019408009981274425E-9, | |
| ]; | |
| // https://github.com/scipy/scipy/blob/v1.5.4/scipy/special/cephes/polevl.h#L67 | |
| fn polevl(x: f64, coef: &[f64]) -> f64 { | |
| let mut ans = 0.0; | |
| for c in coef { | |
| ans = ans * x + *c; | |
| } | |
| ans | |
| } | |
| // https://github.com/scipy/scipy/blob/v1.5.4/scipy/special/cephes/polevl.h#L90 | |
| fn p1evl(x: f64, coef: &[f64]) -> f64 { | |
| let mut ans = 1.0; | |
| for c in coef { | |
| ans = ans * x + *c; | |
| } | |
| ans | |
| } | |
| // https://github.com/scipy/scipy/blob/v1.5.4/scipy/special/cephes/ndtri.c#L134 | |
| pub fn ppf(y0: f64) -> f64 { | |
| dbg!(y0); | |
| assert!(0.0 <= y0 && y0 <= 1.0); | |
| let y; | |
| let code; | |
| if y0 > (1.0 - 0.13533528323661269189) { | |
| y = 1.0 - y0; | |
| code = 0; | |
| } else { | |
| y = y0; | |
| code = 1; | |
| } | |
| if y > 0.13533528323661269189 { | |
| let y = y - 0.5; | |
| let y2 = y * y; | |
| let x = y + y * (y2 * polevl(y2, &P0) / p1evl(y2, &Q0)); | |
| let x = x * S2PI; | |
| return x; | |
| } | |
| let x = (-2.0 * y.ln()).sqrt(); | |
| let x0 = x - x.ln() / x; | |
| let z = 1.0 / x; | |
| let x1; | |
| if x < 8.0 { | |
| x1 = z * polevl(z, &P1) / p1evl(z, &Q1); | |
| } else { | |
| x1 = z * polevl(z, &P2) / p1evl(z, &Q2); | |
| } | |
| let mut x = x0 - x1; | |
| if code != 0 { | |
| x = -x; | |
| } | |
| x | |
| } | |
| } | |
| const BOUNDS_THRESHOLD: f64 = 1e-7; | |
| #[derive(Debug, Clone)] | |
| pub struct QuantileTransformer { | |
| references: Vec<f64>, | |
| quantiles: Vec<Vec<f64>>, | |
| } | |
| fn transform_col(x: f64, quantiles: &Vec<f64>, references: &Vec<f64>) -> f64 { | |
| let y; | |
| let xlb = quantiles[0]; | |
| let xub = *quantiles.last().unwrap(); | |
| if x <= xlb { | |
| y = 0.0; | |
| } else if x >= xub { | |
| y = 1.0; | |
| } else { | |
| // xの左右を二分探索で探す | |
| let mut ilb = 0; | |
| let mut iub = quantiles.len() - 1; | |
| while iub - ilb > 1 { | |
| let imd = (ilb + iub) / 2; | |
| let qmd = quantiles[imd]; | |
| if qmd < x { | |
| ilb = imd; | |
| } else { | |
| iub = imd; | |
| } | |
| } | |
| assert!(quantiles[ilb] <= x); | |
| assert!(quantiles[iub] >= x); | |
| // 線形補間する | |
| let xlb = quantiles[ilb]; | |
| let xub = quantiles[iub]; | |
| let dlb = x - xlb; | |
| let dub = xub - x; | |
| let wlb = dub / (dlb + dub); | |
| let wub = dlb / (dlb + dub); | |
| dbg!(wlb, wub); | |
| y = references[ilb] * wlb + references[iub] * wub; | |
| } | |
| let y = y.clamp( | |
| BOUNDS_THRESHOLD - f64::EPSILON, | |
| 1.0 - (BOUNDS_THRESHOLD - f64::EPSILON), | |
| ); | |
| dbg!(y); | |
| normal_distribution::ppf(y) | |
| } | |
| #[derive(serde::Deserialize)] | |
| struct Dump { | |
| output_distribution: String, | |
| references_: Vec<f64>, | |
| quantiles_: Vec<Vec<f64>>, | |
| } | |
| impl QuantileTransformer { | |
| pub fn from_dump(dump: serde_json::Value) -> R<QuantileTransformer> { | |
| let dump: Dump = serde_json::from_value(dump)?; | |
| // normalしかサポートしない | |
| assert_eq!(dump.output_distribution, "normal"); | |
| // 転置しといたほうが便利、ってか元のsklearnの実装も転置しといたほうが便利に見えて仕方ないのに何で転置してないんだろ | |
| let n_features = dump.quantiles_[0].len(); | |
| let n_references = dump.references_.len(); | |
| let mut quantiles = vec![vec![0.0; n_references]; n_features]; | |
| for i in 0..n_features { | |
| for j in 0..n_references { | |
| quantiles[i][j] = dump.quantiles_[j][i]; | |
| } | |
| } | |
| // quantilesがユニークじゃない場合は結構変な処理しないといけないが、冷静に俺はそういうの使う予定ないから落とす | |
| for qs in quantiles.iter_mut() { | |
| qs.dedup(); | |
| assert_eq!(qs.len(), n_references); | |
| } | |
| Ok(QuantileTransformer { | |
| references: dump.references_, | |
| quantiles, | |
| }) | |
| } | |
| pub fn transform(&self, x: &[f64]) -> Vec<f64> { | |
| assert_eq!(x.len(), self.quantiles.len()); | |
| x.iter() | |
| .zip(self.quantiles.iter()) | |
| .map(|(x, quantiles)| transform_col(*x, quantiles, &self.references)) | |
| .collect() | |
| } | |
| } | |
| #[cfg(test)] | |
| mod tests { | |
| use super::*; | |
| fn create() -> QuantileTransformer { | |
| let j = serde_json::json!({ | |
| "n_quantiles": 10, | |
| "output_distribution": "normal", | |
| "ignore_implicit_zeros": false, | |
| "subsample": 100000, | |
| "random_state": null, | |
| "copy": true, | |
| "n_features_in_": 3, | |
| "n_quantiles_": 10, | |
| "references_": [ | |
| 0.0, | |
| 0.1111111111111111, | |
| 0.2222222222222222, | |
| 0.3333333333333333, | |
| 0.4444444444444444, | |
| 0.5555555555555556, | |
| 0.6666666666666666, | |
| 0.7777777777777777, | |
| 0.8888888888888888, | |
| 1.0, | |
| ], | |
| "quantiles_": [ | |
| [-2.613328303719778, -9.321205290168217, 0.13674275875021852], | |
| [-1.1414101426254262, -3.294752914303114, 0.31128319201392723], | |
| [-0.7306000262128471, -1.919245430597138, 0.5628969761610223], | |
| [-0.427117826339735, -0.7238972474428322, 0.7805069173920551], | |
| [-0.05639286493061981, 0.885189368484695, 0.9814794091618718], | |
| [0.1722755897532091, 1.991170360919146, 1.2760902770308562], | |
| [0.39748654705760395, 3.8665254431411706, 1.6340155327513783], | |
| [0.67348557453346, 5.3665644352138, 1.959333216365857], | |
| [0.9437791586194102, 7.205771950040178, 2.8721836987270755], | |
| [2.4620269142769113, 12.26632790171583, 17.47969633690788], | |
| ], | |
| }); | |
| QuantileTransformer::from_dump(j).unwrap() | |
| } | |
| fn check(qt: &QuantileTransformer, x: &[f64], y: &[f64]) { | |
| let z = qt.transform(x); | |
| assert_eq!(y.len(), z.len()); | |
| for (a, b) in y.iter().zip(z.iter()) { | |
| assert_approx_eq!(a, b); | |
| } | |
| } | |
| #[test] | |
| fn test_references() { | |
| let qt = create(); | |
| let cases = &[ | |
| ( | |
| [-2.613328303719778, -9.321205290168217, 0.13674275875021852], | |
| [-5.199337582605575, -5.199337582605575, -5.199337582605575], | |
| ), | |
| ( | |
| [-1.1414101426254262, -3.294752914303114, 0.31128319201392723], | |
| [-1.22064034884735, -1.22064034884735, -1.22064034884735], | |
| ), | |
| ( | |
| [-0.7306000262128471, -1.919245430597138, 0.5628969761610223], | |
| [-0.764709673786387, -0.764709673786387, -0.764709673786387], | |
| ), | |
| ( | |
| [-0.427117826339735, -0.7238972474428322, 0.7805069173920551], | |
| [-0.430727299295457, -0.430727299295457, -0.430727299295457], | |
| ), | |
| ( | |
| [-0.05639286493061981, 0.885189368484695, 0.9814794091618718], | |
| [-0.139710298881862, -0.139710298881862, -0.139710298881862], | |
| ), | |
| ( | |
| [0.1722755897532091, 1.991170360919146, 1.2760902770308562], | |
| [0.1397102988818621, 0.1397102988818621, 0.1397102988818621], | |
| ), | |
| ( | |
| [0.39748654705760395, 3.8665254431411706, 1.6340155327513783], | |
| [0.4307272992954574, 0.4307272992954574, 0.4307272992954574], | |
| ), | |
| ( | |
| [0.67348557453346, 5.3665644352138, 1.959333216365857], | |
| [0.7647096737863867, 0.7647096737863867, 0.7647096737863867], | |
| ), | |
| ( | |
| [0.9437791586194102, 7.205771950040178, 2.8721836987270755], | |
| [1.2206403488473496, 1.2206403488473496, 1.2206403488473496], | |
| ), | |
| ( | |
| [2.4620269142769113, 12.26632790171583, 17.47969633690788], | |
| [5.19933758270342, 5.19933758270342, 5.19933758270342], | |
| ), | |
| ]; | |
| for case in cases { | |
| check(&qt, &case.0, &case.1); | |
| } | |
| } | |
| #[test] | |
| fn test_random() { | |
| let qt = create(); | |
| let cases = &[ | |
| ( | |
| [9.543169032696227, 13.004265458495848, 1.2817188786196336], | |
| [5.19933758270342, 5.19933758270342, 0.14413444750289997], | |
| ), | |
| ( | |
| [6.148524766915234, 12.277375912423363, -2.0032741621514276], | |
| [5.19933758270342, 5.19933758270342, -5.199337582605575], | |
| ), | |
| ( | |
| [3.2479687586644044, 6.333366933638192, -0.18854883910825748], | |
| [5.19933758270342, 0.9788976648571445, -5.199337582605575], | |
| ), | |
| ( | |
| [12.026790240804463, -8.17807192753417, 7.139275629961631], | |
| [5.19933758270342, -2.0320120702704356, 1.414185316101896], | |
| ), | |
| ( | |
| [11.297395996011659, 12.389898901401988, 4.3528110754416325], | |
| [5.19933758270342, 5.19933758270342, 1.2824135105950227], | |
| ), | |
| ( | |
| [-5.548440019781655, 3.9141903012544503, 8.462175829374196], | |
| [-5.199337582605575, 0.44045804894271895, 1.4863658158773896], | |
| ), | |
| ( | |
| [0.48259293083427224, 1.8454399739866432, -8.433798779937739], | |
| [0.5270731855886761, 0.10273894814593508, -5.199337582605575], | |
| ), | |
| ( | |
| [2.9464268070547117, 12.73414827971353, 5.139387737583421], | |
| [5.19933758270342, 5.19933758270342, 1.3173195432982814], | |
| ), | |
| ( | |
| [-3.4636253585345047, 13.086601949163505, 2.851517222764288], | |
| [-5.199337582605575, 5.19933758270342, 1.207464726166755], | |
| ), | |
| ( | |
| [-6.081831945801352, -6.336175884405989, -7.420982125992383], | |
| [-5.199337582605575, -1.5978724353638714, -5.199337582605575], | |
| ), | |
| ( | |
| [-6.275065630029899, 7.586341240789082, 1.7097115253806585], | |
| [-5.199337582605575, 1.266007404810665, 0.5030071199147996], | |
| ), | |
| ( | |
| [-0.13038122929708784, -4.161259056492385, 5.808837900045351], | |
| [-0.1960917782971291, -1.3097800036623548, 1.3483455214298523], | |
| ), | |
| ( | |
| [7.618117371577743, 5.909956441138961, 10.508511337022796], | |
| [5.19933758270342, 0.8801295864809635, 1.6161969123723285], | |
| ), | |
| ( | |
| [8.693866471169883, -0.7330487877831544, 9.163553125519751], | |
| [5.19933758270342, -0.4330680381712861, 1.5280004611048847], | |
| ), | |
| ( | |
| [0.4282314588607754, -1.1967112585064736, -6.889187183469036], | |
| [0.46502685937014465, -0.5551855155241495, -5.199337582605575], | |
| ), | |
| ( | |
| [7.685794675525781, 4.500672464040191, 14.250996390962012], | |
| [5.19933758270342, 0.5640480936891384, 1.967567697630419], | |
| ), | |
| ( | |
| [1.5798977254115556, -9.611474340719736, -6.600078504292563], | |
| [1.5176006671813367, -5.199337582605575, -5.199337582605575], | |
| ), | |
| ( | |
| [0.17534066429800532, -3.5008980831264935, 12.394201107918992], | |
| [0.1435390280777139, -1.2409594504784927, 1.7661839743843688], | |
| ), | |
| ( | |
| [3.5825769280006945, -1.551988325695696, -6.6376732992024206], | |
| [5.19933758270342, -0.6546087494691977, -5.199337582605575], | |
| ), | |
| ( | |
| [-4.10776763251768, -9.18756794697149, 9.369452587399543], | |
| [-5.199337582605575, -2.811715563246142, 1.5407399020190642], | |
| ), | |
| ]; | |
| for case in cases { | |
| check(&qt, &case.0, &case.1); | |
| } | |
| } | |
| } | 
  
    Sign up for free
    to join this conversation on GitHub.
    Already have an account?
    Sign in to comment