Lua中逐个元素比较

3
我正在尝试使用标准的<运算符在Lua中逐个比较元素。例如,这是我想要做的事情:
a = {5, 7, 10}
b = {6, 4, 15}
c = a < b -- should return {true, false, true}

我已经有关于加法(和减法、乘法等)的可用代码了。我的问题是Lua强制要求比较的结果为布尔类型。但我不想要一个布尔类型的结果,而是想要一个表格作为比较的结果。

以下是我的代码,加法能够正常运行,但小于号比较不能正常工作:

m = {}
m['__add'] = function (a, b)
    -- Add two tables together
    -- Works fine
    c = {}
    for i = 1, #a do
        c[i] = a[i] + b[i]
    end
    return c
end
m['__lt'] = function (a, b)
    -- Should do a less-than operator on each element
    -- Doesn't work, Lua forces result to boolean
    c = {}
    for i = 1, #a do
        c[i] = a[i] < b[i]
    end
    return c
end


a = {5, 7, 10}
b = {6, 4, 15}

setmetatable(a, m)

c = a + b -- Expecting {11, 11, 25}
print(c[1], c[2], c[3]) -- Works great!

c = a < b -- Expecting {true, false, true}
print(c[1], c[2], c[3]) -- Error, lua makes c into boolean

Lua编程手册指出,__lt元方法调用结果总是转换为布尔值。我的问题是,我该如何解决这个问题?我听说Lua在DSL方面很擅长,我真的需要语法在这里起作用。我认为使用MetaLua应该是可行的,但我不太确定从哪里开始。
一位同事建议我改用<<__shl元方法。我试了一下,它可以工作,但我真的想使用小于号<,而不是使用错误符号的hack。
谢谢!

2
为什么你需要它作为一个操作符?难道你不能使用一个普通函数吗? - Piglet
1
@Piglet 我正在制作一个DSL。我希望能够像 a*(b+5)/c<d 这样做,其中 abcd 都是向量(数组)。将其重写为前缀函数调用非常冗长:less_than(divide(multiply(a, add(b, 5),c),d)。我不能要求我的用户编写这样的语句。 - jenny
4个回答

4

您只有两个选择可以让您的语法正常工作:

选项1:修补Lua核心。

这可能会非常困难,而且将来会是一个维护噩梦。最大的问题是Lua在非常低的级别上假定比较运算符 <>==~= 返回布尔值。

Lua生成的字节码实际上对任何比较都进行跳转。例如,像 c = 4 < 5 这样的东西被编译成字节码,看起来更像是 if (4 < 5) then c = true else c = false end

您可以使用 luac -l file.lua 查看字节码的样子。如果您将 c=4<5 的字节码与 c=4+5 的字节码进行比较,您就会明白我的意思了。加法代码更短更简单。Lua假设您将使用比较进行分支,而不是赋值。

选项2:解析您的代码,修改它,然后运行它

这是我认为您应该做的。这将非常困难,但您可以使用类似于 LuaMinify 的工具使大部分工作更加容易。

首先,编写一个用于比较任何内容的函数。这里的想法是如果比较对象是一个表,则使用您的特殊比较,否则使用 < 进行比较。

my_less = function(a, b)
   if (type(a) == 'table') then
     c = {}
     for i = 1, #a do
       c[i] = a[i] < b[i]
     end
     return c
    else
      return a < b
    end
end

现在我们只需要将所有小于操作符a<b替换为my_less(a,b)。让我们使用来自LuaMinify的解析器。我们将使用以下代码调用它:
local parse = require('ParseLua').ParseLua
local ident = require('FormatIdentity')

local code = "c=a*b<c+d"
local ret, ast = parse(code)
local _, f = ident(ast)
print(f)

这样做的作用仅仅只是将代码解析成语法树,然后再输出。我们将修改FormatIdentity.lua文件,使其进行替换操作。请使用以下代码替换第138行附近的部分:

    elseif expr.AstType == 'BinopExpr' then --line 138
        if (expr.Op == '<') then
            tok_it = tok_it + 1
            out:appendStr('my_less(')
            formatExpr(expr.Lhs)
            out:appendStr(',')
            formatExpr(expr.Rhs)
            out:appendStr(')')
        else
            formatExpr(expr.Lhs)
            appendStr( expr.Op )
            formatExpr(expr.Rhs)
        end

就是这样。它会将像 c=a*b<c+d 这样的内容替换为 my_less(a*b,c+d)。只需在运行时将所有代码推送即可。


谢谢!我会尝试运行你的代码,看看它的效果如何。我喜欢这个想法,因为它让我得到了我想要的语法,这非常重要。 - jenny

3
Lua中的比较运算会返回一个布尔值。 除非更改Lua的核心,否则您无法对此进行任何操作。

谢谢,我在手册里看到了。我猜我希望我能预处理代码或者做些什么。我可以用MetaLua做到吗? - jenny
有关如何开始修补Lua的建议。看起来比较运算符在内部不返回任何内容,它们都被解释为跳转。即使是a = 1 < 2的情况下,Lua也使用条件跳转将truefalse分配给a。 - jenny
@jenny,没错。这不是一件简单的事情。 - lhf

1

你能忍受有点冗长的 v() 表示法吗:
v(a < b) 而不是 a < b

local vec_mt = {}

local operations = {
   copy     = function (a, b) return a     end,
   lt       = function (a, b) return a < b end,
   add      = function (a, b) return a + b end,
   tostring = tostring,
}

local function create_vector_instance(operand1, operation, operand2)
   local func, vec = operations[operation], {}
   for k, elem1 in ipairs(operand1) do
      local elem2 = operand2 and operand2[k]
      vec[k] = func(elem1, elem2)
   end
   return setmetatable(vec, vec_mt)
end

local saved_result

function v(...)  -- constructor for class "vector"
   local result = ...
   local tp = type(result)
   if tp == 'boolean' and saved_result then
      result, saved_result = saved_result
   elseif tp ~= 'table' then
      result = create_vector_instance({...}, 'copy')
   end
   return result
end

function vec_mt.__add(v1, v2)
   return create_vector_instance(v1, 'add', v2)
end

function vec_mt.__lt(v1, v2)
   saved_result = create_vector_instance(v1, 'lt', v2)
end

function vec_mt.__tostring(vec)
   return 
      'Vector ('
      ..table.concat(create_vector_instance(vec, 'tostring'), ', ')
      ..')'
end

用法:

a = v(5, 7, 10); print(a)
b = v(6, 4, 15); print(b)

c =   a + b ; print(c)  -- result is v(11, 11, 25)
c = v(a + b); print(c)  -- result is v(11, 11, 25)
c = v(a < b); print(c)  -- result is v(true, false, true)

我只想说你的想法很棒。这不是我能想出来的东西。非常有创意!但我不确定我能够使用它。我正在制作一种DSL,而且我认为这种语法会让人们感到困惑,例如有人会尝试: c = v((a < b) == false),结果就会失败。 - jenny
在你的DSL中,(a < b) == false是什么意思?你正在比较一个向量和一个布尔值。 - Egor Skriptunoff
它将对每个元素进行逐个比较,例如{false, true, false} == false将返回{true, false, true}。但是,我没有问题进行计算,只需要帮助使语法可接受。 - jenny
要实现自己的相等性,例如 vector == boolean,您将不得不定义特殊值(对象)TRUEFALSE,因为无法使用__eq元方法来比较两种不同类型。您可以使用方括号语法进行操作:v[a < b]v[a + b]。然后错误地编写v[[a < b] == FALSE]将生成语法错误。您还将能够检测到像这样的错误:(a < b) == FALSE - Egor Skriptunoff

0

正如其他人已经提到的,这个问题没有直接的解决方案。然而,通过使用类似于Python的通用zip()函数,例如下面所示的函数,您可以简化问题,如下所示:

--------------------------------------------------------------------------------
-- Python-like zip() iterator
--------------------------------------------------------------------------------

function zip(...)
  local arrays, ans = {...}, {}
  local index = 0
  return
    function()
      index = index + 1
      for i,t in ipairs(arrays) do
        if type(t) == 'function' then ans[i] = t() else ans[i] = t[index] end
        if ans[i] == nil then return end
      end
      return table.unpack(ans)
    end
end

--------------------------------------------------------------------------------

a = {5, 7, 10}
b = {6, 4, 15}
c = {}

for a,b in zip(a,b) do
  c[#c+1] = a < b -- should return {true, false, true}
end

-- display answer
for _,v in ipairs(c) do print(v) end

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接