神经网络输出 NaN 的解决方法

9
我正在尝试编写我的第一个神经网络来玩连接四游戏,使用Java和deeplearning4j。我尝试实现遗传算法,但当我训练网络一段时间后,网络的输出会跳到NaN,我无法确定我犯了什么错误导致这种情况发生。我将下面的三个类发布出来,其中Game是游戏逻辑和规则,VGFrame是UI,Main是所有神经网络相关内容。
我有35个神经网络池,每次迭代时我让最好的5个存活并繁殖,并稍微随机化新创建的神经网络。为了评估网络,我让它们互相对抗,并给胜者和输家各自加分。由于我惩罚将棋子放入已满列中,所以我希望神经网络至少能够按照规则玩游戏,但他们做不到。
我谷歌搜索了 NaN 问题,似乎这是梯度爆炸问题,但从我的理解来看,这不应该在遗传算法中出现?你们有什么建议我可以寻找错误或者我的实现有什么问题吗?
Main
import java.io.File;
import java.io.IOException;
import java.util.Arrays;
import java.util.Random;

import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Nesterovs;

public class Main {
    final int numRows = 7;
    final int numColums = 6;
    final int randSeed = 123;
    MultiLayerNetwork[] models;

    static Random random = new Random();
    private static final Logger log = LoggerFactory.getLogger(Main.class);
    final float learningRate = .8f;
    int batchSize = 64; // Test batch size
    int nEpochs = 1; // Number of training epochs
    // --
    public static Main current;
    Game mainGame = new Game();

    public static void main(String[] args) {
        current = new Main();
        current.frame = new VGFrame();
        current.loadWeights();
    }

    private VGFrame frame;
    private final double mutationChance = .05;

    public Main() {
        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER)
                .activation(Activation.RELU).seed(randSeed)
                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(new Nesterovs(0.1, 0.9))
                .list()
                .layer(new DenseLayer.Builder().nIn(42).nOut(30).activation(Activation.RELU)
                        .weightInit(WeightInit.XAVIER).build())
                .layer(new DenseLayer.Builder().nIn(30).nOut(15).activation(Activation.RELU)
                        .weightInit(WeightInit.XAVIER).build())
                .layer(new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD).nIn(15).nOut(7)
                        .activation(Activation.SOFTMAX).weightInit(WeightInit.XAVIER).build())
                .build();
        models = new MultiLayerNetwork[35];
        for (int i = 0; i < models.length; i++) {
            models[i] = new MultiLayerNetwork(conf);
            models[i].init();
        }

    }

    public void addChip(int i, boolean b) {
        if (mainGame.gameState == 0)
            mainGame.addChip(i, b);
        if (mainGame.gameState == 0) {
            float[] f = Main.rowsToInput(mainGame.rows);
            INDArray input = Nd4j.create(f);
            INDArray output = models[0].output(input);
            for (int i1 = 0; i1 < 7; i1++) {
                System.out.println(i1 + ": " + output.getDouble(i1));
            }
            System.out.println("----------------");
            mainGame.addChip(Main.getHighestOutput(output), false);
        }
        getFrame().paint(getFrame().getGraphics());
    }

    public void newGame() {
        mainGame = new Game();
        getFrame().paint(getFrame().getGraphics());
    }

    public void startTraining(int iterations) {

        // --------------------------
        for (int gameNumber = 0; gameNumber < iterations; gameNumber++) {
            System.out.println("Iteration " + gameNumber + " of " + iterations);
            float[] evaluation = new float[models.length];
            for (int i = 0; i < models.length; i++) {
                for (int j = 0; j < models.length; j++) {
                    if (i != j) {
                        Game g = new Game();
                        g.playFullGame(models[i], models[j]);
                        if (g.gameState == 1) {
                            evaluation[i] += 45;
                            evaluation[j] += g.turnNumber;
                        }
                        if (g.gameState == 2) {
                            evaluation[j] += 45;
                            evaluation[i] += g.turnNumber;
                        }
                    }
                }
            }

            float[] evaluationSorted = evaluation.clone();
            Arrays.sort(evaluationSorted);
            // keep the best 4
            int n1 = 0, n2 = 0, n3 = 0, n4 = 0, n5 = 0;
            for (int i = 0; i < evaluation.length; i++) {
                if (evaluation[i] == evaluationSorted[evaluationSorted.length - 1])
                    n1 = i;
                if (evaluation[i] == evaluationSorted[evaluationSorted.length - 2])
                    n2 = i;
                if (evaluation[i] == evaluationSorted[evaluationSorted.length - 3])
                    n3 = i;
                if (evaluation[i] == evaluationSorted[evaluationSorted.length - 4])
                    n4 = i;
                if (evaluation[i] == evaluationSorted[evaluationSorted.length - 5])
                    n5 = i;
            }
            models[0] = models[n1];
            models[1] = models[n2];
            models[2] = models[n3];
            models[3] = models[n4];
            models[4] = models[n5];

            for (int i = 3; i < evaluationSorted.length; i++) {
                // random parent/keep w8ts
                double r = Math.random();
                if (r > .3) {
                    models[i] = models[random.nextInt(3)].clone();

                } else if (r > .1) {
                    models[i].setParams(breed(models[random.nextInt(3)], models[random.nextInt(3)]));
                }
                // Mutate
                INDArray params = models[i].params();
                models[i].setParams(mutate(params));
            }
        }
    }

    private INDArray mutate(INDArray params) {
        double[] d = params.toDoubleVector();
        for (int i = 0; i < d.length; i++) {
            if (Math.random() < mutationChance)
                d[i] += (Math.random() - .5) * learningRate;

        }
        return Nd4j.create(d);
    }

    private INDArray breed(MultiLayerNetwork m1, MultiLayerNetwork m2) {
        double[] d = m1.params().toDoubleVector();
        double[] d2 = m2.params().toDoubleVector();
        for (int i = 0; i < d.length; i++) {
            if (Math.random() < .5)
                d[i] += d2[i];
        }
        return Nd4j.create(d);
    }

    static int getHighestOutput(INDArray output) {
        int x = 0;
        for (int i = 0; i < 7; i++) {
            if (output.getDouble(i) > output.getDouble(x))
                x = i;
        }
        return x;
    }

    static float[] rowsToInput(byte[][] rows) {
        float[] f = new float[7 * 6];
        for (int i = 0; i < 6; i++) {
            for (int j = 0; j < 7; j++) {
                // f[j + i * 7] = rows[j][i] / 2f;
                f[j + i * 7] = (rows[j][i] == 0 ? .5f : rows[j][i] == 1 ? 0f : 1f);
            }
        }
        return f;
    }

    public void saveWeights() {
        log.info("Saving model");
        for (int i = 0; i < models.length; i++) {
            File resourcesDirectory = new File("src/resources/model" + i);
            try {
                models[i].save(resourcesDirectory, true);
            } catch (IOException e) {
                e.printStackTrace();
            }
        }
    }

    public void loadWeights() {
        if (new File("src/resources/model0").exists()) {
            for (int i = 0; i < models.length; i++) {
                File resourcesDirectory = new File("src/resources/model" + i);
                try {

                    models[i] = MultiLayerNetwork.load(resourcesDirectory, true);
                } catch (IOException e) {
                    // TODO Auto-generated catch block
                    e.printStackTrace();
                }
            }
        }
        System.out.println("col: " + models[0].params().shapeInfoToString());
    }

    public VGFrame getFrame() {
        return frame;
    }

}

VGFrame

import java.awt.Color;
import java.awt.Graphics;
import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;
import javax.swing.BorderFactory;
import javax.swing.JButton;
import javax.swing.JFrame;
import javax.swing.JPanel;
import javax.swing.JTextField;

public class VGFrame extends JFrame {
    JTextField iterations;
    /**
     * 
     */
    private static final long serialVersionUID = 1L;

    public VGFrame() {
        super("Vier Gewinnt");
        this.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
        this.setSize(1300, 800);
        this.setVisible(true);
        JPanel panelGame = new JPanel();
        panelGame.setBorder(BorderFactory.createLineBorder(Color.black, 2));
        this.add(panelGame);

        var handler = new Handler();
        var menuHandler = new MenuHandler();

        JButton b1 = new JButton("1");
        JButton b2 = new JButton("2");
        JButton b3 = new JButton("3");
        JButton b4 = new JButton("4");
        JButton b5 = new JButton("5");
        JButton b6 = new JButton("6");
        JButton b7 = new JButton("7");
        b1.addActionListener(handler);
        b2.addActionListener(handler);
        b3.addActionListener(handler);
        b4.addActionListener(handler);
        b5.addActionListener(handler);
        b6.addActionListener(handler);
        b7.addActionListener(handler);
        panelGame.add(b1);
        panelGame.add(b2);
        panelGame.add(b3);
        panelGame.add(b4);
        panelGame.add(b5);
        panelGame.add(b6);
        panelGame.add(b7);

        JButton buttonTrain = new JButton("Train");
        JButton buttonNewGame = new JButton("New Game");
        JButton buttonSave = new JButton("Save Weights");
        JButton buttonLoad = new JButton("Load Weights");

        iterations = new JTextField("1000");

        buttonTrain.addActionListener(menuHandler);
        buttonNewGame.addActionListener(menuHandler);
        buttonSave.addActionListener(menuHandler);
        buttonLoad.addActionListener(menuHandler);
        iterations.addActionListener(menuHandler);

        panelGame.add(iterations);
        panelGame.add(buttonTrain);
        panelGame.add(buttonNewGame);
        panelGame.add(buttonSave);
        panelGame.add(buttonLoad);

        this.validate();
    }

    @Override
    public void paint(Graphics g) {
        super.paint(g);
        if (Main.current.mainGame.rows == null)
            return;
        var rows = Main.current.mainGame.rows;
        for (int i = 0; i < rows.length; i++) {
            for (int j = 0; j < rows[0].length; j++) {
                if (rows[i][j] == 0)
                    break;

                g.setColor((rows[i][j] == 1 ? Color.yellow : Color.red));
                g.fillOval(80 + 110 * i, 650 - 110 * j, 100, 100);
            }
        }
    }

    public void update() {
    }
}

class Handler implements ActionListener {

    @Override
    public void actionPerformed(ActionEvent event) {
        if (Main.current.mainGame.playersTurn)
            Main.current.addChip(Integer.parseInt(event.getActionCommand()) - 1, true);
    }
}

class MenuHandler implements ActionListener {

    @Override
    public void actionPerformed(ActionEvent event) {
        switch (event.getActionCommand()) {
        case "New Game":
            Main.current.newGame();
            break;
        case "Train":
            Main.current.startTraining(Integer.parseInt(Main.current.getFrame().iterations.getText()));
            break;
        case "Save Weights":
            Main.current.saveWeights();
            break;
        case "Load Weights":
            Main.current.loadWeights();
            break;
        }

    }
}

游戏

import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

public class Game {

    int turnNumber = 0;
    byte[][] rows = new byte[7][6];
    boolean playersTurn = true;

    int gameState = 0; // 0:running, 1:Player1, 2:Player2, 3:Draw

    public boolean isRunning() {
        return this.gameState == 0;
    }

    public void addChip(int x, boolean player1) {
        turnNumber++;
        byte b = nextRow(x);
        if (b == 6) {
            gameState = player1 ? 2 : 1;
            return;
        }
        rows[x][b] = (byte) (player1 ? 1 : 2);
        gameState = checkWinner(x, b);
    }

    private byte nextRow(int x) {
        for (byte i = 0; i < rows[x].length; i++) {
            if (rows[x][i] == 0)
                return i;
        }
        return 6;
    }

    // 0 continue, 1 Player won, 2 ai won, 3 Draw
    private int checkWinner(int x, int y) {
        int color = rows[x][y];
        // Vertikal
        if (getCount(x, y, 1, 0) + getCount(x, y, -1, 0) >= 3)
            return rows[x][y];

        // Horizontal
        if (getCount(x, y, 0, 1) + getCount(x, y, 0, -1) >= 3)
            return rows[x][y];

        // Diagonal1
        if (getCount(x, y, 1, 1) + getCount(x, y, -1, -1) >= 3)
            return rows[x][y];
        // Diagonal2
        if (getCount(x, y, -1, 1) + getCount(x, y, 1, -1) >= 3)
            return rows[x][y];
        
        for (byte[] bs : rows) {
            for (byte s : bs) {
                if (s == 0)
                    return 0;
            }
        }
        return 3; // Draw
    }

    private int getCount(int x, int y, int dirX, int dirY) {
        int color = rows[x][y];
        int count = 0;
        while (true) {
            x += dirX;
            y += dirY;
            if (x < 0 | x > 6 | y < 0 | y > 5)
                break;
            if (color != rows[x][y])
                break;
            count++;
        }
        return count;
    }

    public void playFullGame(MultiLayerNetwork m1, MultiLayerNetwork m2) {
        boolean player1 = true;
        while (this.gameState == 0) {
            float[] f = Main.rowsToInput(this.rows);
            INDArray input = Nd4j.create(f);
            this.addChip(Main.getHighestOutput(player1 ? m1.output(input) : m2.output(input)), player1);
            player1 = !player1;
        }
    }
}
2个回答

7

通过快速查看和分析您的乘数变体,似乎NaN是由算术下溢引起的,这是由于您的梯度太小接近绝对零)。

这是代码中最可疑的部分:

 f[j + i * 7] = (rows[j][i] == 0 ? .5f : rows[j][i] == 1 ? 0f : 1f);

如果rows[j][i] == 1,则存储0f。我不知道神经网络(甚至是Java)如何管理此内容,但从数学上讲,有限大小的浮点数不能包括零。即使您的代码会使用一些额外的方法来改变0f,这些数组值的结果仍有可能变得接近于零。由于表示实数时的精度受到限制,因此无法表示非常接近于零的值,因此出现了NaN

这些值有一个非常友好的名称:subnormal numbers

任何具有小于最小正规数大小的幅度的非零数字都是subnormal

enter image description here

IEEE_754

与IEEE 754-1985一样,该标准建议信号NaN的值为0,安静NaN的值为1,因此只需将该位更改为1即可将信号NaN变为安静NaN,反之则可能导致编码为无穷大。

上面的文本在这里很重要:根据标准,实际上您正在使用存储任何0f值的NaN


即使名称有误导性,Float.MIN_VALUE 是一个正数大于0

enter image description here

实际上,最小的float值是:-Float.MAX_VALUE

浮点数运算是否是次标准化的?


规范化梯度

如果您发现问题仅是由于0f值引起的,您可以将它们更改为表示类似内容的其他值;例如:Float.MIN_VALUEFloat.MIN_NORMAL等。在可能出现这种情况的代码的其他部分也可以采用类似的方法。以下仅为示例,请根据实际情况调整这些范围:

rows[j][i] == 1 ? Float.MIN_VALUE : 1f;

rows[j][i] == 1 ?  Float.MIN_NORMAL : Float.MAX_VALUE/2;

rows[j][i] == 1 ? -Float.MAX_VALUE/2 : Float.MAX_VALUE/2;

即便如此,根据这些值的修改方式,也可能导致出现 NaN。如果是这样,你应该对这些值进行归一化处理。你可以尝试使用 GradientNormalizer 进行处理。在网络初始化中,应该为每个层(或有问题的层)定义类似下面的内容:
new NeuralNetConfiguration
  .Builder()
  .weightInit(WeightInit.XAVIER)
  (...)
  .layer(new DenseLayer.Builder().nIn(42).nOut(30).activation(Activation.RELU)
        .weightInit(WeightInit.XAVIER)
        .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer) //this   
        .build())
  
  (...)

有不同的规范化器,因此选择最适合您模式的规范化器以及应包括哪些层。选项如下:

GradientNormalization

  • RenormalizeL2PerLayer

    通过将层中所有梯度的L2范数除以梯度进行重新缩放。

  • RenormalizeL2PerParamType

    通过将每种参数类型的梯度的L2范数分别除以梯度进行重新缩放。这与RenormalizeL2PerLayer不同,因为这里每个参数类型(权重、偏置等)都是单独归一化的。例如,在MLP/FeedForward网络中(其中G是梯度向量),输出如下:

    GOut_weight = G_weight / l2(G_weight) GOut_bias = G_bias / l2(G_bias)

  • ClipElementWiseAbsoluteValue

    按元素剪辑梯度。对于每个梯度g,设置g <- sign(g) max(maxAllowedValue,|g|)。即,如果参数梯度的绝对值大于阈值,则将其截断。例如,如果阈值为5,则在范围-5<g<5内的值不变;小于-5的值设为-5;大于5的值设为5。

  • ClipL2PerLayer

    有条件的重新缩放。与RenormalizeL2PerLayer有些相似,此策略仅在梯度的L2范数(整个层的)超过指定阈值时才对其进行缩放。具体来说,如果G是该层的梯度向量,则:

    GOut = G if l2Norm(G) < threshold (即,不变化)GOut = threshold * G / l2Norm(G)

  • ClipL2PerParamType

    有条件的重新缩放。与ClipL2PerLayer非常相似,但是不是按层剪辑,而是分别在每种参数类型上进行剪辑。例如,在循环神经网络中,输入权重梯度、循环权重梯度和偏置梯度都被单独剪辑。


这里提供了一个完整的应用这些GradientNormalizers的示例。


感谢您的详细解释。我认为梯度只在反向传播中才有影响,这不正确吗?我将我的输入更改为Float.Min_Value,并将梯度归一化应用于所有层,但问题仍然存在。 - David
@David 你好!我以前确实做过神经网络,但那是很多年前的事情了(呼!!)。如果某些数组字段存储绝对零值,也可能会发生这种情况。如果我错了,请纠正我,你的梯度不会太大,对吧?它们在0和其他一些小值之间振荡。我告诉你吧,我不会让它失败,我很想让它工作。所以请不要犹豫,问我需要什么信息。这里有任何除法吗?你对该层的值有什么预期结果? - aran
另外,您是否可以访问一些堆栈跟踪,以了解事情变得复杂的原因?无论如何,我们都会让它工作起来!您有我的话,一个巴斯克人的话。如果需要,我将杀死一半人类来实现这一点。 - aran
0f和1f的值代表相反的值吗?我的意思是,它们可以被改变吗?例如,通过-FLOAT_MAX.VALUE/2FLOAT_MAX.VALUE/2或者更简单地通过-10f10f。即便如此,我认为首先要检查的是没有任何0浮点数附近。 - aran
一个解决方法也可以是将浮点数更改为双精度,因为它们缺乏这种精度并且通常不会在此类值上溢出。您会失去精度,但作为一个测试,它可能有助于确定主要问题。浮点计算有问题。 - aran
@David,作为更新,请查看IEEE规范:与IEEE 754-1985一样,该标准建议将信号NaN设置为0,将静默NaN设置为1 -- 这可能确实会有问题。 - aran

1
我想我终于搞明白了。我试图使用deeplearning4j-ui可视化网络,但是遇到了一些不兼容的版本错误。在改变版本之后,我得到了一个新的错误,它表明网络输入期望一个二维数组,并且我在互联网上发现,这在所有版本中都是期望的。
所以我作出了改变。
float[] f = new float[7 * 6];
Nd4j.create(f);

float[][] f = new float[1][7 * 6];
Nd4j.createFromArray(f);

最终NaN值消失了。@aran,我想假设输入不正确确实是正确的方向。非常感谢您的帮助 :)

有一些相关的信息..这几天会进行更新。 - aran

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接