import org.apache.spark.ml.feature.*;

class ExampleUsage {
  
    public void example(){
        List<Row> dubs = Lists.newArrayList(
                RowFactory.create(1000.0, 1000.0, 1.0),
                RowFactory.create(90.0, 90.0, 0.0)
        );
        DataFrame df = sqlContext.createDataFrame(dubs, createStruct());
        Pipeline p = new Pipeline().setStages(new PipelineStage[]{getAssembler(new String[]{"x", "y"}, "features")});
        DataFrame part2 = p.fit(df).transform(df).select("features", "label");

        SparkDl4jNetwork sparkDl4jNetwork = new SparkDl4jNetwork()
                .setFeaturesCol("features")
                .setLabelCol("label")
                .setTrainingMaster(() -> new ParameterAveragingTrainingMaster.Builder(3)
                        .averagingFrequency(2)
                        .workerPrefetchNumBatches(2)
                        .batchSizePerWorker(2)
                        .build())
                .setMultiLayerConfiguration(getNNConfiguration());

        SparkDl4jModel sm = sparkDl4jNetwork.fit(part2);
        MultiLayerNetwork mln = sm.getMultiLayerNetwork();
        Assert.assertNotNull(mln);
        System.out.println(sm.output(Vectors.dense(0.0, 0.0)));
      
        sm.write().save("somewhere");
        SparkDl4jModel spdm = SparkDl4jModel.load("somewhere");
        System.out.println(spdm.predict(Vectors.dense(0.0, 0.0)));
    }
  
    public static VectorAssembler getAssembler(String[] input, String output){
        return new VectorAssembler()
                .setInputCols(input)
                .setOutputCol(output);
    }
  
    private static StructType createStruct() {
        return  new StructType(new StructField[]{
                new StructField("x", DataTypes.DoubleType, true, Metadata.empty()),
                new StructField("y", DataTypes.DoubleType, true, Metadata.empty()),
                new StructField("label", DataTypes.DoubleType, true, Metadata.empty())
        });
    }

    private static MultiLayerConfiguration getNNConfiguration(){
        return new NeuralNetConfiguration.Builder()
                .seed(12345)
                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                .iterations(1000)
                .weightInit(WeightInit.UNIFORM)
                .learningRate(0.1)
                .updater(Updater.NESTEROVS)
                .list()
                .layer(0, new DenseLayer.Builder().nIn(2).nOut(100).weightInit(WeightInit.XAVIER).activation("relu").build())
                .layer(1,  new DenseLayer.Builder().nIn(100).nOut(120).weightInit(WeightInit.XAVIER).activation("relu").build())
                .layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MSE).activation("softmax").nIn(120).nOut(2).build())
                .pretrain(false).backprop(true)
                .build();
    }
}