Java简单神经网络设置

23

我决定在Java中尝试一些关于神经网络的简单概念,并且在改编了一些我在论坛上找到的无用代码后,我成功地创建了一个非常简单的模型来模拟典型初学者的XOR运算:


public class MainApp {
    public static void main (String [] args) {
        Neuron xor = new Neuron(0.5f);
        Neuron left = new Neuron(1.5f);
        Neuron right = new Neuron(0.5f);
        left.setWeight(-1.0f);
        right.setWeight(1.0f);
        xor.connect(left, right);

        for (String val : args) {
            Neuron op = new Neuron(0.0f);
            op.setWeight(Boolean.parseBoolean(val));
            left.connect(op);
            right.connect(op);
        }

        xor.fire();

        System.out.println("Result: " + xor.isFired());

    }
}


public class Neuron {
    private ArrayList inputs;
    private float weight;
    private float threshhold;
    private boolean fired;

    public Neuron (float t) {
        threshhold = t;
        fired = false;
        inputs = new ArrayList();
    }

    public void connect (Neuron ... ns) {
        for (Neuron n : ns) inputs.add(n);
    }

    public void setWeight (float newWeight) {
        weight = newWeight;
    }

    public void setWeight (boolean newWeight) {
        weight = newWeight ? 1.0f : 0.0f;
    }

    public float getWeight () {
        return weight;
    }

    public float fire () {
        if (inputs.size() > 0) {
            float totalWeight = 0.0f;
            for (Neuron n : inputs) {
                n.fire();
                totalWeight += (n.isFired()) ? n.getWeight() : 0.0f;
            }
            fired = totalWeight > threshhold;
            return totalWeight;
        }
        else if (weight != 0.0f) {
            fired = weight > threshhold;
            return weight;
        }
        else {
            return 0.0f;
        }
    }

    public boolean isFired () {
        return fired;
    }
}

在我的主类中,我已经按照Jeff Heaton的图表创建了简单的模拟: XOR diagram 然而,我想确保神经元类的实现是正确的。我已经测试了所有可能的输入([true true],[true false],[false true],[false false]),并且它们都通过了我的手动验证。此外,由于这个程序接受输入作为参数,对于像[true false false],[true true false]等输入,它似乎也通过了手动验证。
但从概念上讲,这种实现是否正确?或者在我开始进一步开发和研究这个主题之前,我应该如何改进它?
谢谢!

为了在具有必要CPU的大型应用程序中节省内存,最好添加一个衰减率,其值可以在第二个构造函数中定义。 - user2425429
2个回答

9

看起来这是一个很好的起点。我有一些建议:

  1. 为了扩展性,应该重构fire()方法,使得已经使用当前输入集合触发过的神经元不需要每次都重新计算。如果你有另一个隐藏层或多个输出节点,这种情况就会出现。

  2. 考虑将阈值计算拆分成自己的方法。然后你可以继承Neuron类并使用不同类型的激活函数(双极Sigmoid、RBF、线性等)。

  3. 为了学习更复杂的函数,给每个神经元添加偏置输入。它基本上像另一个输入,有着它自己的权重值,但输入始终固定为1(或-1)。

  4. 不要忘记允许训练方法。反向传播将需要类似于fire()的倒数,以取一个目标输出,并通过每一层连锁反应地改变权重。


2
谢谢您的建议!不过,我需要在这个主题上进行更多的研究,因为我感觉对于您所建议的大部分内容都不是很了解哈哈。 - jerluc

0

从我所做的(有限的)神经网络工作来看,该实现和模型看起来是正确的 - 输出符合我的预期,源代码看起来很可靠。


1
感谢您的回复,根据您的经验,您是否认为这个神经元类在可扩展性方面存在任何问题?从我的角度来看,我已经尝试使其足够可扩展,以适应多个内部神经元层,但两个人的观点胜过一个。 - jerluc

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接