import org.apache.spark.ml.feature.*; class ExampleUsage { public void example(){ List<Row> dubs = Lists.newArrayList( RowFactory.create(1000.0, 1000.0, 1.0), RowFactory.create(90.0, 90.0, 0.0) ); DataFrame df = sqlContext.createDataFrame(dubs, createStruct()); Pipeline p = new Pipeline().setStages(new PipelineStage[]{getAssembler(new String[]{"x", "y"}, "features")}); DataFrame part2 = p.fit(df).transform(df).select("features", "label"); SparkDl4jNetwork sparkDl4jNetwork = new SparkDl4jNetwork() .setFeaturesCol("features") .setLabelCol("label") .setTrainingMaster(() -> new ParameterAveragingTrainingMaster.Builder(3) .averagingFrequency(2) .workerPrefetchNumBatches(2) .batchSizePerWorker(2) .build()) .setMultiLayerConfiguration(getNNConfiguration()); SparkDl4jModel sm = sparkDl4jNetwork.fit(part2); MultiLayerNetwork mln = sm.getMultiLayerNetwork(); Assert.assertNotNull(mln); System.out.println(sm.output(Vectors.dense(0.0, 0.0))); sm.write().save("somewhere"); SparkDl4jModel spdm = SparkDl4jModel.load("somewhere"); System.out.println(spdm.predict(Vectors.dense(0.0, 0.0))); } public static VectorAssembler getAssembler(String[] input, String output){ return new VectorAssembler() .setInputCols(input) .setOutputCol(output); } private static StructType createStruct() { return new StructType(new StructField[]{ new StructField("x", DataTypes.DoubleType, true, Metadata.empty()), new StructField("y", DataTypes.DoubleType, true, Metadata.empty()), new StructField("label", DataTypes.DoubleType, true, Metadata.empty()) }); } private static MultiLayerConfiguration getNNConfiguration(){ return new NeuralNetConfiguration.Builder() .seed(12345) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .iterations(1000) .weightInit(WeightInit.UNIFORM) .learningRate(0.1) .updater(Updater.NESTEROVS) .list() .layer(0, new DenseLayer.Builder().nIn(2).nOut(100).weightInit(WeightInit.XAVIER).activation("relu").build()) .layer(1, new DenseLayer.Builder().nIn(100).nOut(120).weightInit(WeightInit.XAVIER).activation("relu").build()) .layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MSE).activation("softmax").nIn(120).nOut(2).build()) .pretrain(false).backprop(true) .build(); } }