Deeplearning4j数据集的分割:测试集与训练集

3

Deeplearning4j有支持将数据集拆分为测试和训练的功能,以及用于洗牌数据集的机制,但据我所知,它们要么不起作用,要么我做错了什么。

示例:

    DataSetIterator iter = new IrisDataSetIterator(150, 150);
    DataSet next = iter.next();
    // next.shuffle();
    SplitTestAndTrain testAndTrain = next.splitTestAndTrain(120, new Random(seed));
    DataSet train = testAndTrain.getTrain();
    DataSet test = testAndTrain.getTest();

    for (int i = 0; i < 30; i++) {
        String features = test.getFeatures().getRow(i).toString();
        String actual = test.getLabels().getRow(i).toString().trim();
        log.info("features " + features + " -> " + actual );
    }

在输入数据集的最后30行中返回的结果显示,splitTestAndTrain的随机种子参数Random(seed)似乎完全被忽略了。
如果不是将随机种子传递给splitTestAndTrain函数而是取消注释下面的shuffle()行,则第三个和第四个特征会混排,同时保持第一个和第二个特征以及测试标签的现有顺序,这比根本不排序还要糟糕。
所以问题是,我使用方法错误还是Deeplearning4j本身就存在问题?
附加问题:如果Deeplearning4j甚至不能处理生成测试和样本数据集这样简单的任务,那么它能够被信任处理其他任务吗?或者我最好使用其他库来代替Deeplearning4j?

请加入我们在 Gitter 上的讨论。我们会在那里帮助您解决问题:https://gitter.im/deeplearning4j/deeplearning4j - racknuf
3个回答

3
Deeplearning4j假设数据集是小批量的,例如:它们并非全部存储在内存中。这与Python世界相矛盾,后者可能更注重小型数据集和易用性。这只适用于玩具问题,并不适合处理真实问题的大规模数据。
我们为本地情况优化了数据集迭代器接口(请注意,对于像Spark这样的分布式系统,这将有所不同)。这意味着我们依赖于数据集先使用datavec进行拆分以解析数据集(提示:不要编写自己的迭代器:使用我们的迭代器并使用datavec进行自定义解析)或允许使用数据集迭代器拆分器:https://deeplearning4j.org/doc/org/deeplearning4j/datasets/iterator/DataSetIteratorSplitter.html进行训练测试拆分。
数据集拆分的train test类仅在数据集已全部存储在内存中时才能工作,但对于大多数半现实问题可能没有意义(例如:超越XOR或MNIST)。我建议您仅运行ETL步骤一次而不是每次都运行。将数据集预切片成预定大小的批次是一种方法。可以使用以下组合完成此操作:https://github.com/deeplearning4j/deeplearning4j/blob/master/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/BalanceMinibatchesTest.java#L40和: https://nd4j.org/doc/org/nd4j/linalg/dataset/ExistingMiniBatchDataSetIterator.html
这样做的另一个原因是为了可重现性。如果您想做类似每个时期对迭代器进行洗牌之类的事情,可以尝试基于上述组合编写一些代码。无论如何,我建议在训练之前处理ETL并预先创建向量,否则,您将花费大量时间加载大型数据集。

我从你的帖子中理解到,对于生产编码,最不推荐到最推荐的顺序是:1. dataSet.splitTestAndTrain(...) 2. DataSetIteratorSplitter 3. ExistingMiniBatchDataSetIterator。 - Umesh Rajbhandari

1
由于这个问题已经过时,对于可能会找到这个问题的人来说,你可以在GitHub上看到一些例子,拆分可以简单地完成。
DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader,batchSize,labelIndex,numClasses);
DataSet allData = iterator.next();
allData.shuffle();
SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.65);  //Use 65% of data for training

DataSet trainingData = testAndTrain.getTrain();
DataSet testData = testAndTrain.getTest();

首先创建迭代器,遍历所有数据,对其进行随机排序,然后将其分为测试和训练集。

这是来自this示例的内容。


0
据我所知,deeplearning4j 简直是坏掉了。最终,我创建了自己的 splitTestandTrain 实现。
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import java.util.Random;
import org.nd4j.linalg.factory.Nd4j;

public class TestTrain {  
    protected DataSet test;
    protected DataSet train;

    public TestTrain(DataSet input, int splitSize, Random rng) {
        int inTest = 0;
        int inTrain = 0;
        int testSize = input.numExamples() - splitSize;

        INDArray train_features = Nd4j.create(splitSize, input.getFeatures().columns());
        INDArray train_outcomes = Nd4j.create(splitSize, input.numOutcomes());
        INDArray test_features  = Nd4j.create(testSize, input.getFeatures().columns());
        INDArray test_outcomes  = Nd4j.create(testSize, input.numOutcomes());

        for (int i = 0; i < input.numExamples(); i++) {
            DataSet D = input.get(i);
            if (rng.nextDouble() < (splitSize-inTrain)/(double)(input.numExamples()-i)) {
                train_features.putRow(inTrain, D.getFeatures());
                train_outcomes.putRow(inTrain, D.getLabels());
                inTrain += 1;
            } else {
                test_features.putRow(inTest, D.getFeatures());
                test_outcomes.putRow(inTest, D.getLabels());
                inTest += 1;
            }
        }

        train = new DataSet(train_features, train_outcomes);
        test  = new DataSet(test_features, test_outcomes);
    }

    public DataSet getTrain() {
        return train;
    }

    public DataSet getTest() {
        return test;
    }
}

这个方法可以用,但是它并没有让我对这个库有信心。如果有人能提供更好的答案,我会很高兴,但现在只能先这样了。


2
请看我的回复。你的“信心”来自于不理解库的假设。我在下面进行了更正。你的实现最多只能处理内存中的小问题。这在GPU或大型数据集上将不会高效。 - Adam Gibson

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接