C++反向自动微分及其图形化

4
我正在尝试用C++实现反向模式自动微分
我想到的思路是,每个由一个或两个变量进行操作得出的变量,都将在一个向量中保存梯度。
以下是代码:
class Var {
    private:
        double value;
        char character;
        std::vector<std::pair<double, const Var*> > children;

    public:
        Var(const double& _value=0, const char& _character='_') : value(_value), character(_character) {};
        void set_character(const char& character){ this->character = character; }

        // computes the derivative of the current object with respect to 'var'
        double gradient(Var* var) const{
            if(this==var){
                return 1.0;
            }

            double sum=0.0;
            for(auto& pair : children){
                // std::cout << "(" << this->character << " -> " <<  pair.second->character << ", " << this << " -> " << pair.second << ", weight=" << pair.first << ")" << std::endl;
                sum += pair.first*pair.second->gradient(var);
            }
            return sum;
        }

        friend Var operator+(const Var& l, const Var& r){
            Var result(l.value+r.value);
            result.children.push_back(std::make_pair(1.0, &l));
            result.children.push_back(std::make_pair(1.0, &r));
            return result;
        }

        friend Var operator*(const Var& l, const Var& r){
            Var result(l.value*r.value);
            result.children.push_back(std::make_pair(r.value, &l));
            result.children.push_back(std::make_pair(l.value, &r));
            return result;
        }

        friend std::ostream& operator<<(std::ostream& os, const Var& var){
            os << var.value;
            return os;
        }
};

我尝试以如下方式运行代码:

int main(int argc, char const *argv[]) {
    Var x(5,'x'), y(6,'y'), z(7,'z');

    Var k = z + x*y;
    k.set_character('k');

    std::cout << "k = " << k << std::endl;
    std::cout << "∂k/∂x = " << k.gradient(&x) << std::endl;
    std::cout << "∂k/∂y = " << k.gradient(&y) << std::endl;
    std::cout << "∂k/∂z = " << k.gradient(&z) << std::endl;

    return 0;
}

应该构建的计算图如下所示:
       x(5)   y(6)              z(7)
         \     /                 /
 ∂w/∂x=y  \   /  ∂w/∂y=x        /
           \ /                 /
          w=x*y               /
             \               /  ∂k/∂z=1
              \             /
      ∂k/∂w=1  \           /
                \_________/
                     |
                   k=w+z

接下来,如果我想计算∂k/∂x,那么我就必须乘以沿着边缘的梯度,��对每个边缘的结果求和。这通过double gradient(Var* var) const递归地完成。因此,我有∂k/∂x = ∂k/∂w * ∂w/∂x + ∂k/∂z * ∂z/∂x

问题

如果我有像x*y这样的中间计算,会出现问题。当取消注释std::cout时,输出如下:

k = 37
(k -> z, 0x7ffeb3345740 -> 0x7ffeb3345710, weight=1)
(k -> _, 0x7ffeb3345740 -> 0x7ffeb3345770, weight=1)
(_ -> x, 0x7ffeb3345770 -> 0x7ffeb33456b0, weight=0)
(_ -> y, 0x7ffeb3345770 -> 0x7ffeb33456e0, weight=5)
∂k/∂x = 0
(k -> z, 0x7ffeb3345740 -> 0x7ffeb3345710, weight=1)
(k -> _, 0x7ffeb3345740 -> 0x7ffeb3345770, weight=1)
(_ -> x, 0x7ffeb3345770 -> 0x7ffeb33456b0, weight=0)
(_ -> y, 0x7ffeb3345770 -> 0x7ffeb33456e0, weight=5)
∂k/∂y = 5
(k -> z, 0x7ffeb3345740 -> 0x7ffeb3345710, weight=1)
(k -> _, 0x7ffeb3345740 -> 0x7ffeb3345770, weight=1)
(_ -> x, 0x7ffeb3345770 -> 0x7ffeb33456b0, weight=0)
(_ -> y, 0x7ffeb3345770 -> 0x7ffeb33456e0, weight=5)
∂k/∂z = 1

它打印出哪个变量连接到哪一个,然后是它们的地址以及连接的权重(应该是梯度)。

问题在于 x 和保存 x*y 结果的中间变量之间的 weight=0(我在图中将其表示为 w)。 我不知道为什么这个权重为零而不是连接到 y 的另一个权重。

我注意到的另一件事是,如果你交换 operator* 中的行:

result.children.push_back(std::make_pair(1.0, &r));
result.children.push_back(std::make_pair(1.0, &l));

然后是“y”连接被取消。非常感谢您的帮助。
1个回答

4

这行代码:

Var k = z + x*y;

调用operator*,返回一个临时的Var,然后用于r参数传递给operator+,其中pair存储临时变量的地址。在该行完成后,k个子项包含指向临时变量所在位置的指针,但它已经不存在了。


虽然它不能防止上述错误,但可以通过避免使用未命名的临时变量来创建预期的行为...

Var xy = x * y;
xy.set_character('*');
Var k = z + xy;
k.set_character('k');

...你的程序会产生:

k = 37
∂k/∂x = 6
∂k/∂y = 5
∂k/∂z = 1

更好的解决方法可能是通过值来捕获子元素。

As a general tip for catching such mistakes... when your program seems to be doing something inexplicable (and/or crashing), try running it under a memory error detector such as valgrind. For your code, the report starts off with:

==22137== Invalid read of size 8
==22137==    at 0x1090EA: Var::gradient(Var*) const (in /home/median/so/deriv)
==22137==    by 0x109109: Var::gradient(Var*) const (in /home/median/so/deriv)
==22137==    by 0x108E12: main (in /home/median/so/deriv)
==22137==  Address 0x5b82cd0 is 0 bytes inside a block of size 32 free'd
==22137==    at 0x4C3123B: operator delete(void*) (in /usr/lib/valgrind/vgpreload_memcheck-amd64-linux.so)
==22137==    by 0x109FC1: __gnu_cxx::new_allocator<std::pair<double, Var const*> >::deallocate(std::pair<double, Var const*>*, unsigned long) (in /home/median/so/deriv)
==22137==    by 0x109CDD: std::allocator_traits<std::allocator<std::pair<double, Var const*> > >::deallocate(std::allocator<std::pair<double, Var const*> >&, std::pair<double, Var const*>*, unsigned long) (in /home/median/so/deriv)
==22137==    by 0x109963: std::_Vector_base<std::pair<double, Var const*>, std::allocator<std::pair<double, Var const*> > >::_M_deallocate(std::pair<double, Var const*>*, unsigned long) (in /home/median/so/deriv)
==22137==    by 0x1097BC: std::_Vector_base<std::pair<double, Var const*>, std::allocator<std::pair<double, Var const*> > >::~_Vector_base() (in /home/median/so/deriv)
==22137==    by 0x1095EA: std::vector<std::pair<double, Var const*>, std::allocator<std::pair<double, Var const*> > >::~vector() (in /home/median/so/deriv)
==22137==    by 0x109161: Var::~Var() (in /home/median/so/deriv)
==22137==    by 0x108D95: main (in /home/median/so/deriv)
==22137==  Block was alloc'd at
==22137==    at 0x4C3017F: operator new(unsigned long) (in /usr/lib/valgrind/vgpreload_memcheck-amd64-linux.so)
==22137==    by 0x10A153: __gnu_cxx::new_allocator<std::pair<double, Var const*> >::allocate(unsigned long, void const*) (in /home/median/so/deriv)
==22137==    by 0x10A060: std::allocator_traits<std::allocator<std::pair<double, Var const*> > >::allocate(std::allocator<std::pair<double, Var const*> >&, unsigned long) (in /home/median/so/deriv)
==22137==    by 0x109F03: std::_Vector_base<std::pair<double, Var const*>, std::allocator<std::pair<double, Var const*> > >::_M_allocate(unsigned long) (in /home/median/so/deriv)
==22137==    by 0x109A8D: void std::vector<std::pair<double, Var const*>, std::allocator<std::pair<double, Var const*> > >::_M_realloc_insert<std::pair<double, Var const*> >(__gnu_cxx::__normal_iterator<std::pair<double, Var const*>*, std::vector<std::pair<double, Var const*>, std::allocator<std::pair<double, Var const*> > > >, std::pair<double, Var const*>&&) (in /home/median/so/deriv)
==22137==    by 0x1098CF: void std::vector<std::pair<double, Var const*>, std::allocator<std::pair<double, Var const*> > >::emplace_back<std::pair<double, Var const*> >(std::pair<double, Var const*>&&) (in /home/median/so/deriv)
==22137==    by 0x10973F: std::vector<std::pair<double, Var const*>, std::allocator<std::pair<double, Var const*> > >::push_back(std::pair<double, Var const*>&&) (in /home/median/so/deriv)
==22137==    by 0x109520: operator*(Var const&, Var const&) (in /home/median/so/deriv)
==22137==    by 0x108D6F: main (in /home/median/so/deriv)

Another way to catch it can be to add logging in a destructor so you know when the object addresses mentioned in your logging are no longer valid.


我也曾经认为临时变量 x*y 已经被销毁了,但是后来我想知道为什么它的地址仍然指向某个东西?因为我实际上可以获取它的 children 并循环遍历它... 如果对象在计算后真的被销毁了,那么在尝试访问该对象后我应该会遇到 segfault 错误,对吗? - Omar Aflak
1
C++并不会尝试在对象的析构函数运行时跟踪每个指向该对象的指针并将其置空 - 这会导致非常低效的程序。当指针指向的内存不可访问时,往往会发生段错误 - 这通常是因为指针为空或从未是有效指针;如果您有一个指向曾经存在对象的内存的指针,则问题通常是该位置的内存已被释放(这会创建未定义的行为/-通常表现为访问错误的值),而不是变得不可访问并导致段错误。 - Tony Delroy
虽然篇幅较长,但从点赞数来看,很多人发现这个SO答案有助于理解指针生命周期。 - Tony Delroy
我明白了,现在更有意义了。 - Omar Aflak

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接