为什么我在这里收到自动广播警告?

3
function [theta, J_history] = gradientDescent(X, y, theta, alpha, num_iters)
  m = length(y);
  J_history = zeros(num_iters, 1);

  for iter = 1:num_iters
    ## warning: product: automatic broadcasting operation applied
    theta = theta - sum(X .* (X * theta - y))' .* (alpha / (m .* 2));
    J_history(iter) = computeCost(X, y, theta);
  end
end

这是我的作业,但我并不要求你替我完成(实际上我认为我已经完成了或接近完成)。我已经阅读了手册中有关广播的部分,但我仍然不明白为什么我在这里收到了警告?

注:broadcasting指的是在numpy数组上执行算术运算时的一种特殊机制。


2
请删除这个问题。或者修改您的帖子,以免在此处透露答案。我知道你的出发点。 - Tyagi Akhilesh
同意。我七年后上了同样的课程,然后找到了答案。我建议删除这个问题。 - Daniel Wabyick
2个回答

7
问题在于size(theta')1 2,而size(X)m 2。当你将它们相乘时,Octave会先将X(1,1)theta'(1,1)相乘,再将X(1,2)theta'(1,2)相乘。然后,它转到X的第二行并尝试将X(2,1)theta'(2,1)相乘。但是theta'没有第二行,因此该操作无意义。
为了避免出现错误,Octave猜测您希望扩展theta',使其具有与X相同的行数,然后再开始乘法运算。然而,由于它只是猜测,所以感觉应该警告您正在做什么。
您可以通过在开始乘法之前使用repmat函数显式地扩展theta的长度来避免警告。
repmat(theta',m,1) .* X

2
由于警告显示广播来自产品操作,因此它将来自有问题的行中的任何.*。如果不知道您给函数的输入值,我无法说出哪一个,但是假设:
  1. X是一个向量;
  2. alpha是一个标量;
  3. theta是一个标量。
我的猜测是警告来自X .* (X * theta - y))',特别是因为您正在转置第二部分。尝试删除转置运算符(如果存在其他错误,则可能会导致错误--我假设您不想执行广播)。

哦,我知道了,抱歉,我应该更具体一些,它是第一个.*,是一个97x2矩阵乘以一个97x1矩阵。后来我找到了一种避免这种情况的方法,即通过执行X' * (X * theta - y),但我仍然想知道为什么会出现警告以及其后果。也就是说,我不知道是否需要广播,因为我不知道这样做的影响是什么(对我来说,它看起来像是我需要或者想要做的事情,但也许在计算上效率低下之类的?) - user797257
@wvxvw 广播是一种非常高效的方式,可以解决大多数人使用for循环解决的问题(当矩阵的大小允许时,广播会自动执行,并且与调用函数bsxfun()相同。这是一个新功能,令许多人感到惊讶,因此需要警告)。由于您不知道它是什么,所以很可能您不需要它。要了解广播是什么,请运行[1:5].*[1:5]'进行简单示例,并阅读有关该主题的NumPy文档 - carandraug
嗯,这并不让我惊讶,因为这种语言的任何其他特性都不会让我惊讶。尽管我找到了一种重写代码以避免警告的方法,但我的上述代码确实做到了它必须做的事情/我想要它做的事情。我只是从一个角度来看,警告通常意味着它实际上是一个错误,但编译器/调试器未能正确地识别它,所以现在它正在警告一些不应在代码的正确版本中发生的事情。这很奇怪,因为代码似乎没有崩溃或类似的问题 :) - user797257

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接