Created
October 5, 2021 11:03
-
-
Save tecno14/7fe0c8f7579e62988fe97b9e0b0cdf4a to your computer and use it in GitHub Desktop.
StandardScaler in C# take list of class objects where have constructor
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using System; | |
using System.Data; | |
using System.Linq; | |
using System.Collections.Generic; | |
using System.Reflection; | |
using System.ComponentModel; | |
namespace PricePrediction.MachineLearning | |
{ | |
/// <summary> | |
/// Standardize features by removing the mean and scaling to unit variance. | |
/// more : https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.StandardScaler.html | |
/// </summary> | |
public class StandardScaler<T> where T : new() | |
{ | |
private List<double> _mean; | |
private List<double> _standardDeviation; | |
/// <summary> | |
/// fit then transform | |
/// </summary> | |
/// <param name="dt"></param> | |
/// <returns></returns> | |
public List<T> FitTransform(List<T> listOfObjects) | |
{ | |
return Fit(listOfObjects).Transform(listOfObjects); | |
} | |
/// <summary> | |
/// Reset then | |
/// </summary> | |
/// <param name="dt"></param> | |
/// <returns></returns> | |
public StandardScaler<T> Fit(List<T> listOfObjects) | |
{ | |
_mean = new(); | |
_standardDeviation = new(); | |
if (listOfObjects.Count < 1) | |
throw new Exception("no data"); | |
var dt = listOfObjects.ToArraysOfColumns<T>(); | |
for (int i = 0; i < dt.Length; i++) | |
{ | |
_mean.Add(dt[i].Average()); | |
_standardDeviation.Add(Calculations.StandardDeviation(dt[i])); | |
} | |
return this; | |
} | |
/// <summary> | |
/// Get | |
/// </summary> | |
/// <param name="dt"></param> | |
/// <returns></returns> | |
public List<T> Transform(List<T> listOfObjects) | |
{ | |
if (_mean == null) | |
throw new Exception("This StandardScaler instance is not fitted yet. Call 'Fit' with appropriate arguments before using this estimator."); | |
//if (dt.Columns.Count != _mean.Count) | |
// throw new Exception("number of fitted columns not same as current one"); | |
var dt = listOfObjects.ToArraysOfColumns<T>(); | |
for (int c = 0; c < dt.Length; c++) | |
for (int r = 0; r < dt[c].Length; r++) | |
dt[c][r] = (dt[c][r] - _mean[c]) / _standardDeviation[c]; | |
return ToListOfObject(dt); | |
} | |
private static List<T> ToListOfObject(double[][] arr) | |
{ | |
var res = new List<T>(); | |
PropertyDescriptorCollection properties = TypeDescriptor.GetProperties(typeof(T)); | |
var ObjectsCount = arr[0].Length; | |
for (int i = 0; i < ObjectsCount; i++) | |
{ | |
T o = new(); | |
for (int j = 0; j < properties.Count; j++) | |
properties[j].SetValue(o, arr[j][i]); | |
res.Add(o); | |
} | |
return res; | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment