Module:Grind/fml
From Fallen London Wiki (Staging)
Documentation for this module may be created at Module:Grind/fml/doc
local p = {}
local util = require('Module:Grind/util')
-- == Formula manipulation utilities ==
-- Parses and returns the next formula token, its type and the new string position.
local function next_token(s, initial_pos)
local pos = initial_pos
pos = util.skip_whitespaces(s, pos)
if pos > #s then
return nil, 'end', pos
end
local c = s:sub(pos, pos)
if c == '(' or c == ')' or c == ',' or c == '+' or c == '-' or c == '*' or c == '/' then
return c, 'tok', pos + 1
end
-- a number
if c:match('%d') then
local str = ''
while c:match('%d') do
str = str .. c
pos, c = util.advance(s, pos)
end
if c == '.' then
str = str .. c
pos, c = util.advance(s, pos)
while c:match('%d') do
str = str .. c
pos, c = util.advance(s, pos)
end
end
return tonumber(str), 'num', pos
end
-- a number
if c == '.' then
local str = c
pos, c = util.advance(s, pos)
if c:match('%d') then
while c:match('%d') do
str = str .. c
pos, c = util.advance(s, pos)
end
return tonumber(str), 'num', pos
else
return 'invalid float', 'err', pos
end
end
-- a variable
if c == '$' then
pos, c = util.advance(s, pos)
if c ~= '(' then
return 'invalid variable', 'err', pos
end
local var_end = util.find_match(s, pos)
if var_end > #s then
return 'unmatched `(` in variable', 'err', pos
end
local var_name = s:sub(pos + 1, var_end - 1)
return var_name, 'var', var_end + 1
end
-- function name
if c:match('%a') then
local str = c
pos, c = util.advance(s, pos)
while c:match('[%w_%.]') do
str = str .. c
pos, c = util.advance(s, pos)
end
return str, 'func', pos
end
return 'unexpected token: `' .. c .. '`', 'err', pos + 1
end
local function collect_tokens(s)
local tokens = {}
local pos = 1
while pos <= #s do
local tok_data, tok_type, new_pos = next_token(s, pos)
table.insert(tokens, {tok_data, tok_type})
if new_pos > pos then
pos = new_pos
else
break
end
end
return tokens
end
local function find_match(tokens, initial_pos)
local pos = initial_pos + 1
local depth = 1
while pos <= #tokens do
local tok_data, tok_type = unpack(tokens[pos])
if tok_type == 'tok' then
if tok_data == '(' then
depth = depth + 1
elseif tok_data == ')' then
depth = depth - 1
end
end
if depth == 0 then
break
end
pos = pos + 1
end
return pos
end
local function split_args(tokens)
local args = {}
local arg = {}
local no_args = true
local pos = 1
while pos <= #tokens do
local tok_data, tok_type = unpack(tokens[pos])
if tok_type == 'tok' and tok_data == ',' then
table.insert(args, arg)
arg = {}
no_args = false
pos = pos + 1
elseif tok_type == 'tok' and tok_data == '(' then
local pos_end = find_match(tokens, pos)
if pos_end > #tokens then
table.insert(arg, {'unmatched `(`', 'err'})
pos = pos + 1
else
while pos <= pos_end do
table.insert(arg, tokens[pos])
pos = pos + 1
end
end
no_args = false
elseif tok_type == 'tok' and tok_data == ')' then
table.insert(arg, {'unexpected `)`', 'err'})
no_args = false
pos = pos + 1
else
table.insert(arg, tokens[pos])
no_args = false
pos = pos + 1
end
end
if not no_args then
table.insert(args, arg)
end
return args
end
-- Layer 0: brackets and function calls.
-- Input token types: tok, num, var, func, err.
-- Output token types: tok, num, var, expr, func_call, err.
local function parse_layer0(tokens)
local tree = {}
local pos = 1
while pos <= #tokens do
local tok_data, tok_type = unpack(tokens[pos])
if tok_type == 'tok' and tok_data == '(' then
-- process ()
local end_pos = find_match(tokens, pos)
if end_pos > #tokens then
table.insert(tree, {'unmatched `(`', 'err'})
else
local inner_tokens = {}
pos = pos + 1
while pos < end_pos do
table.insert(inner_tokens, tokens[pos])
pos = pos + 1
end
inner_tokens = parse_layer0(inner_tokens)
table.insert(tree, {inner_tokens, 'expr'})
end
pos = end_pos + 1
elseif tok_type == 'tok' and tok_data == ')' then
table.insert(tree, {'unexpected `(`', 'err'})
pos = pos + 1
elseif tok_type == 'tok' and tok_data == ',' then
table.insert(tree, {'unexpected `,`', 'err'})
pos = pos + 1
elseif tok_type == 'func' then
-- process f(a,b,c)
local func_name = tok_data
pos = pos + 1
if pos > #tokens then
table.insert(tree, {'unexpected function name at the end of an expression: `' .. func_name .. '`', 'err'})
break
end
tok_data, tok_type = unpack(tokens[pos])
if tok_type ~= 'tok' or tok_data ~= '(' then
table.insert(tree, {'invalid function call, `(` expected', 'err'})
pos = pos + 1
else
local end_pos = find_match(tokens, pos)
if end_pos > #tokens then
table.insert(tree, {'unmatched `(`', 'err'})
else
local inner_tokens = {}
pos = pos + 1
while pos < end_pos do
table.insert(inner_tokens, tokens[pos])
pos = pos + 1
end
local func_args = split_args(inner_tokens)
local func_data = {func_name}
for _, arg in ipairs(func_args) do
arg = parse_layer0(arg)
table.insert(func_data, arg)
end
if func_name == 'err' then
if #func_data == 2 then
local arg = func_data[2]
local arg_data, arg_type = unpack(arg[1])
if arg_type == 'var' then
table.insert(tree, {arg_data, 'err'})
else
table.insert(tree, {'err: var expected, ' .. tostring(arg_type) .. ' provided', 'err'})
end
else
table.insert(tree, {'err: 1 argument expected, ' .. (#func_data - 1) .. ' provided', 'err'})
end
else
table.insert(tree, {func_data, 'func_call'})
end
end
pos = end_pos + 1
end
else
table.insert(tree, tokens[pos])
pos = pos + 1
end
end
return tree
end
-- Layer 1: unary operations `+`, `-`.
-- Input token types: tok, num, var, expr, func_call, err.
-- Output token types: tok, num, var, expr, unary, func_call, err.
local function parse_layer1(tokens)
local tree = {}
if #tokens == 0 then
return tree
end
-- An unary operation:
-- * might be applied to `expr`, `var`, `num`, `func_call`
-- * might be the first in an `expr` or follow after a binary operation (`+`, `-`, `*`, `/`)
local first_unary = (tokens[1][2] == 'tok') and (tokens[1][1]:match('[+-]') == tokens[1][1])
local pos = 1
if first_unary then
local op = tokens[1][1]
if #tokens >= 2 then
local tok_data, tok_type = unpack(tokens[2])
if tok_type == 'var' then
table.insert(tree, {{op, tokens[2]}, 'unary'})
pos = 3
elseif tok_type == 'num' then
local val = tok_data
if op == '-' then
val = -val
end
table.insert(tree, {val, 'num'})
pos = 3
elseif tok_type == 'expr' then
local expr = parse_layer1(tok_data)
table.insert(tree, {{op, {expr, 'expr'}}, 'unary'})
pos = 3
elseif tok_type == 'func_call' then
local func_name = tok_data[1]
local func_data = {func_name}
for i = 2, #tok_data do
local arg = parse_layer1(tok_data[i])
table.insert(func_data, arg)
end
table.insert(tree, {{op, {func_data, 'func_call'}}, 'unary'})
pos = 3
else
table.insert(tree, {'unary `' .. op .. '` cannot be applied to ' .. tok_type, 'err'})
pos = 2
end
else
table.insert(tree, {'`' .. op .. '` is not an expression', 'err'})
pos = 2
end
end
local ops = 0 -- operations in row
while pos <= #tokens do
local tok_data, tok_type = unpack(tokens[pos])
if tok_type == 'tok' and (tok_data:match('[+*/-]') == tok_data) then
local op = tok_data
ops = ops + 1
if ops == 2 then
if op:match('[+-]') then
if pos + 1 > #tokens then
table.insert(tree, {'unexpected `' .. op .. '` at the end of an expression', 'err'})
ops = 0
pos = pos + 1
else
local next_data, next_type = unpack(tokens[pos + 1])
-- var, num, expr, func_call
if next_type == 'var' then
table.insert(tree, {{op, tokens[pos + 1]}, 'unary'})
pos = pos + 2
elseif next_type == 'num' then
local val = next_data
if op == '-' then
val = -val
end
table.insert(tree, {val, 'num'})
pos = pos + 2
elseif next_type == 'expr' then
local expr = parse_layer1(next_data)
table.insert(tree, {{op, {expr, 'expr'}}, 'unary'})
pos = pos + 2
elseif next_type == 'func_call' then
local func_name = next_data[1]
local func_data = {func_name}
for i = 2, #next_data do
local arg = parse_layer1(next_data[i])
table.insert(func_data, arg)
end
table.insert(tree, {{op, {func_data, 'func_call'}}, 'unary'})
pos = pos + 2
else
table.insert(tree, {'unary `' .. op .. '` cannot be applied to ' .. next_type, 'err'})
pos = pos + 1
end
ops = 0
end
else
table.insert(tree, {'unexpected `' .. tok_data .. '`', 'err'})
ops = 0
pos = pos + 1
end
else
-- cannot be an unary
table.insert(tree, tokens[pos])
pos = pos + 1
end
elseif tok_type == 'expr' then
ops = 0
table.insert(tree, {parse_layer1(tok_data), 'expr'})
pos = pos + 1
elseif tok_type == 'func_call' then
ops = 0
local func_name = tok_data[1]
local func_data = {func_name}
for i = 2, #tok_data do
local arg = parse_layer1(tok_data[i])
table.insert(func_data, arg)
end
table.insert(tree, {func_data, 'func_call'})
pos = pos + 1
else
ops = 0
table.insert(tree, tokens[pos])
pos = pos + 1
end
end
return tree
end
-- Layer 2: binary operations `*`,`/`.
-- Input token types: tok, num, var, expr, unary, func_call, err.
-- Output token types: tok, num, var, expr, unary, binary, func_call, err.
local function parse_layer2(tokens)
local tree = {}
if #tokens == 0 then
return tree
end
local tokens = tokens
if type(tokens[2]) == 'string' then
tokens = {tokens}
end
local pos = 1
while pos <= #tokens do
local tok_data, tok_type = unpack(tokens[pos])
if tok_type == 'tok' and tok_data:match('[*/]') == tok_data then
local op = tok_data
if #tree == 0 then
table.insert(tree, {'unexpected `' .. op .. '` at the start of an expression', 'err'})
pos = pos + 1
elseif pos == #tokens then
table.insert(tree, {'unexpected `' .. op .. '` at the end of an expression', 'err'})
pos = pos + 1
else
local prev_data, prev_type = unpack(tree[#tree])
-- valid prev types: num, var, expr, unary, binary, func_call
if prev_type == 'num'
or prev_type == 'var'
or prev_type == 'expr'
or prev_type == 'unary'
or prev_type == 'binary'
or prev_type == 'func_call' then
local next_data, next_type = unpack(tokens[pos + 1])
-- valid next types: num, var, expr, unary, func_call
if next_type == 'num' or next_type == 'var' then
local prev = table.remove(tree)
local next = tokens[pos + 1]
table.insert(tree, {{op, prev, next}, 'binary'})
pos = pos + 2
elseif next_type == 'expr' then
local prev = table.remove(tree)
local next = {parse_layer2(next_data), 'expr'}
table.insert(tree, {{op, prev, next}, 'binary'})
pos = pos + 2
elseif next_type == 'unary' then
local prev = table.remove(tree)
local next = {{next_data[1], parse_layer2(next_data[2])}, 'unary'}
table.insert(tree, {{op, prev, next}, 'binary'})
pos = pos + 2
elseif next_type == 'func_call' then
local prev = table.remove(tree)
local func_name = next_data[1]
local func_data = {func_name}
for i = 2, #next_data do
local arg = parse_layer2(next_data[i])
table.insert(func_data, arg)
end
local next = {func_data, 'func_call'}
table.insert(tree, {{op, prev, next}, 'binary'})
pos = pos + 2
else
table.insert(tree, {'`' .. op .. '` cannot use ' .. next_type .. ' as the second argument', 'err'})
pos = pos + 1
end
else
table.insert(tree, {'`' .. op .. '` cannot use ' .. prev_type .. ' as the first argument', 'err'})
pos = pos + 1
end
end
elseif tok_type == 'expr' then
table.insert(tree, {parse_layer2(tok_data), 'expr'})
pos = pos + 1
elseif tok_type == 'unary' then
local op, expr = unpack(tok_data)
table.insert(tree, {{op, parse_layer2(expr)}, 'unary'})
pos = pos + 1
elseif tok_type == 'func_call' then
local func_name = tok_data[1]
local func_data = {func_name}
for i = 2, #tok_data do
local arg = parse_layer2(tok_data[i])
table.insert(func_data, arg)
end
table.insert(tree, {func_data, 'func_call'})
pos = pos + 1
else
table.insert(tree, tokens[pos])
pos = pos + 1
end
end
return tree
end
-- Layer 3: binary operations `+`,`-`.
-- Input token types: tok, num, var, expr, unary, binary, func_call, err.
-- Output token types: num, var, expr, unary, binary, func_call, err.
local function parse_layer3(tokens)
local tree = {}
if #tokens == 0 then
return tree
end
local tokens = tokens
if type(tokens[2]) == 'string' then
tokens = {tokens}
end
local pos = 1
while pos <= #tokens do
local tok_data, tok_type = unpack(tokens[pos])
if tok_type == 'tok' and tok_data:match('[+-]') == tok_data then
local op = tok_data
if #tree == 0 then
table.insert(tree, {'unexpected `' .. op .. '` at the start of an expression', 'err'})
pos = pos + 1
elseif pos == #tokens then
table.insert(tree, {'unexpected `' .. op .. '` at the end of an expression', 'err'})
pos = pos + 1
else
local prev_data, prev_type = unpack(tree[#tree])
-- valid prev types: num, var, expr, unary, binary, func_call
if prev_type == 'num'
or prev_type == 'var'
or prev_type == 'expr'
or prev_type == 'unary'
or prev_type == 'binary'
or prev_type == 'func_call' then
local next_data, next_type = unpack(tokens[pos + 1])
-- valid next types: num, var, expr, unary, binary, func_call
if next_type == 'num' or next_type == 'var' then
local prev = table.remove(tree)
local next = tokens[pos + 1]
table.insert(tree, {{op, prev, next}, 'binary'})
pos = pos + 2
elseif next_type == 'expr' then
local prev = table.remove(tree)
local next = {parse_layer3(next_data), 'expr'}
table.insert(tree, {{op, prev, next}, 'binary'})
pos = pos + 2
elseif next_type == 'unary' then
local prev = table.remove(tree)
local next = {{next_data[1], parse_layer3(next_data[2])}, 'unary'}
table.insert(tree, {{op, prev, next}, 'binary'})
pos = pos + 2
elseif next_type == 'binary' then
local prev = table.remove(tree)
local next_op, next_a, next_b = unpack(next_data)
next_a = parse_layer3(next_a)
next_b = parse_layer3(next_b)
local next = {{next_op, next_a, next_b}, 'binary'}
table.insert(tree, {{op, prev, next}, 'binary'})
pos = pos + 2
elseif next_type == 'func_call' then
local prev = table.remove(tree)
local func_name = next_data[1]
local func_data = {func_name}
for i = 2, #next_data do
local arg = parse_layer3(next_data[i])
table.insert(func_data, arg)
end
local next = {func_data, 'func_call'}
table.insert(tree, {{op, prev, next}, 'binary'})
pos = pos + 2
else
table.insert(tree, {'`' .. op .. '` cannot use ' .. next_type .. ' as the second argument', 'err'})
pos = pos + 1
end
else
table.insert(tree, {'`' .. op .. '` cannot use ' .. prev_type .. ' as the first argument', 'err'})
pos = pos + 1
end
end
elseif tok_type == 'expr' then
table.insert(tree, {parse_layer3(tok_data), 'expr'})
pos = pos + 1
elseif tok_type == 'unary' then
local op, expr = unpack(tok_data)
table.insert(tree, {{op, parse_layer3(expr)}, 'unary'})
pos = pos + 1
elseif tok_type == 'binary' then
local op, a, b = unpack(tok_data)
a = parse_layer3(a)
b = parse_layer3(b)
table.insert(tree, {{op, a, b}, 'binary'})
pos = pos + 1
elseif tok_type == 'func_call' then
local func_name = tok_data[1]
local func_data = {func_name}
for i = 2, #tok_data do
local arg = parse_layer3(tok_data[i])
table.insert(func_data, arg)
end
table.insert(tree, {func_data, 'func_call'})
pos = pos + 1
else
table.insert(tree, tokens[pos])
pos = pos + 1
end
end
return tree
end
-- Removes expr, validates other nodes.
-- Input token types: num, var, expr, unary, binary, func_call, err.
-- Output token types: num, var, unary, binary, func_call, err.
local function parse_postprocess(tokens)
local tokens = tokens
if type(tokens[2]) == 'string' then
tokens = {tokens}
end
if #tokens > 1 then
return {'multiple trees in expression', 'err'}
end
if #tokens == 0 then
return {'empty expression', 'err'}
end
local token = tokens[1]
local tok_data, tok_type = unpack(token)
if tok_type == 'num' then
if type(tok_data) ~= 'number' then
return {'a number is no number', 'err'}
end
return token
elseif tok_type == 'var' then
if type(tok_data) ~= 'string' then
return {'a variable is no variable', 'err'}
end
return token
elseif tok_type == 'expr' then
if type(tok_data) ~= 'table' then
return {'an expression is no expression', 'err'}
end
return parse_postprocess(tok_data)
elseif tok_type == 'unary' then
if type(tok_data) ~= 'table' then
return {'an unary operation is no operation', 'err'}
end
if #tok_data ~= 2 then
return {'an unary operation is not unary', 'err'}
end
local op, inner = unpack(tok_data)
if op:match('[+-]') ~= op then
return {'unknown unary operation `' .. op .. '`', 'err'}
end
inner = parse_postprocess(inner)
return {{op, inner}, 'unary'}
elseif tok_type == 'binary' then
if type(tok_data) ~= 'table' then
return {'a binary operation is no operation', 'err'}
end
if #tok_data ~= 3 then
return {'a binary operation is not binary', 'err'}
end
local op, a, b = unpack(tok_data)
if op:match('[+*/-]') ~= op then
return {'unknown binary operation `' .. op .. '`', 'err'}
end
a = parse_postprocess(a)
b = parse_postprocess(b)
return {{op, a, b}, 'binary'}
elseif tok_type == 'func_call' then
if type(tok_data) ~= 'table' then
return {'a function call is no function call', 'err'}
end
if #tok_data == 0 then
return {'no function name is provided', 'err'}
end
local func_name = tok_data[1]
if type(func_name) ~= 'string' then
return {'a function name is no function name', 'err'}
end
local func_data = {func_name}
for i = 2, #tok_data do
local arg = parse_postprocess(tok_data[i])
table.insert(func_data, arg)
end
return {func_data, 'func_call'}
else
return {'unknown token of type ' .. tostring(tok_type), 'err'}
end
end
-- Finds and returns the first error encountered.
function p.find_error(tree)
local tree = tree
if type(tree[2]) == 'string' then
tree = {tree}
end
for i = 1, #tree do
local tok_data, tok_type = unpack(tree[i])
if tok_type == 'err' then
return tok_data
elseif tok_type == 'expr' then
local err = p.find_error(tok_data)
if err then return err end
elseif tok_type == 'func_call' then
for i = 2, #tok_data do
local err = p.find_error(tok_data[i])
if err then return err end
end
elseif tok_type == 'unary' then
local err = p.find_error({tok_data[2]})
if err then return err end
elseif tok_type == 'binary' then
local err = p.find_error({tok_data[2]})
if err then return err end
err = p.find_error({tok_data[3]})
if err then return err end
end
end
return nil
end
-- Parses the formula.
-- Returns its tree and an error string or nil.
function p.parse(s)
local tokens = collect_tokens(s)
local err = p.find_error(tokens)
if err then return tokens, err end
tokens = parse_layer0(tokens)
err = p.find_error(tokens)
if err then return tokens, err end
tokens = parse_layer1(tokens)
err = p.find_error(tokens)
if err then return tokens, err end
tokens = parse_layer2(tokens)
err = p.find_error(tokens)
if err then return tokens, err end
tokens = parse_layer3(tokens)
err = p.find_error(tokens)
if err then return tokens, err end
tokens = parse_postprocess(tokens)
err = p.find_error(tokens)
if err then return tokens, err end
return tokens, nil
end
-- Returns a table of variables used.
-- Note that the `err($(message))` pattern requires no special handling:
-- it is resolved into a proper error in parse_layer0().
-- Not that you should ever call this function on a formula with errors.
function p.variables(tree)
local vars = {}
if type(tree) ~= 'table' or #tree ~= 2 then
return vars
end
local tok_data, tok_type = unpack(tree)
if tok_type == 'var' then
return {[tok_data]=true}
elseif tok_type == 'unary' then
return p.variables(tok_data[2])
elseif tok_type == 'binary' then
local a, b = tok_data[2], tok_data[3]
for v, _ in pairs(p.variables(a)) do
vars[v] = true
end
for v, _ in pairs(p.variables(b)) do
vars[v] = true
end
elseif tok_type == 'func_call' then
for i = 2, #tok_data do
local arg = tok_data[i]
for v, _ in pairs(p.variables(arg)) do
vars[v] = v
end
end
end
return vars
end
local function iterate_arg_dist(args)
local n = #args
local data = {}
local idx = {}
for i = 1, n do
local d = {}
for delta, prob in pairs(args[i]) do
table.insert(d, {delta, prob})
end
data[i] = d
idx[i] = 1
end
local finished = false
local function advance()
idx[n] = idx[n] + 1
for i = n, 1, -1 do
if idx[i] > #(data[i]) then
idx[i] = 1
if i > 1 then
idx[i - 1] = idx[i - 1] + 1
else
finished = true
end
else
break
end
end
end
return function()
if finished then
return nil
end
local args_item = {}
local prob = 1
for i = 1, n do
table.insert(args_item, data[i][idx[i]][1])
prob = prob * data[i][idx[i]][2]
end
advance()
return args_item, prob
end
end
local function eval_func(func, args)
local arg_n = {
min = 2, max = 2,
exp = 1, ln = 1, pow = 2, sqrt = 1,
sign = 1, abs = 1,
round = 1, floor = 1, ceil = 1,
sin = 1, cos = 1, tan = 1,
pi = 0
}
if arg_n[func] == nil then
return nil, 'unknown function ' .. tostring(func)
end
local n = arg_n[func]
if n ~= #args then
return nil, func .. ': ' .. n .. ' args expected, ' .. #args .. ' provided'
end
if func == 'min' then
return math.min(args[1], args[2])
elseif func == 'max' then
return math.max(args[1], args[2])
elseif func == 'exp' then
return math.exp(args[1])
elseif func == 'ln' then
if args[1] <= 0 then
return nil, 'ln: the argument must be positive'
end
return math.log(args[1])
elseif func == 'pow' then
local val = math.pow(args[1], args[2])
if val ~= val then
return nil, 'pow: result is NaN'
end
return val
elseif func == 'sqrt' then
local val = math.sqrt(args[1])
if val ~= val then
return nil, 'sqrt: result is NaN'
end
return val
elseif func == 'sign' then
if args[1] < 0 then
return -1
elseif args[1] > 0 then
return 1
else
return 0
end
elseif func == 'abs' then
return math.abs(args[1])
elseif func == 'round' then
return math.floor((math.floor(args[1] * 2) + 1) / 2)
elseif func == 'floor' then
return math.floor(args[1])
elseif func == 'ceil' then
return math.cail(args[1])
elseif func == 'sin' then
return math.sin(args[1])
elseif func == 'cos' then
return math.cos(args[1])
elseif func == 'tan' then
return math.tan(args[1])
elseif func == 'pi' then
return math.pi
end
return nil, 'FUNCTION ' .. func .. ' IS DECLARED BUT NOT DEFINED'
end
-- Evaluates the formula, returns a distribution or a tree with an error.
function p.eval(tree, data)
local tok_data, tok_type = unpack(tree)
if tok_type == 'num' then
local val = tok_data
local d = {[val]=1}
return {d, 'dist'}
elseif tok_type == 'var' then
local key = tok_data
if key:sub(1, 6) == 'Input:' then
local _
key, _ = util.normalise_input(key)
end
local val = data[key] or 0
local d = {[val]=1}
return {d, 'dist'}
elseif tok_type == 'unary' then
local op, inner = unpack(tok_data)
inner = p.eval(inner, data)
if inner[2] == 'dist' then
if op == '+' then
return inner
elseif op == '-' then
local d = {}
for delta, prob in pairs(inner[1]) do
d[-delta] = prob
end
return {d, 'dist'}
end
else
return {{op, inner}, 'unary'}
end
elseif tok_type == 'binary' then
local op, a, b = unpack(tok_data)
a = p.eval(a, data)
b = p.eval(b, data)
if a[2] == 'dist' and b[2] == 'dist' then
local d = {}
for delta_a, prob_a in pairs(a[1]) do
for delta_b, prob_b in pairs(b[1]) do
local prob = prob_a * prob_b
local delta
if op == '+' then
delta = delta_a + delta_b
elseif op == '-' then
delta = delta_a - delta_b
elseif op == '*' then
delta = delta_a * delta_b
elseif op == '/' then
delta = delta_a / delta_b
end
if prob > 0 and delta ~= delta then
return {'not a number', 'err'}
end
if prob > 0 and delta == delta + 1 then
return {'division by zero', 'err'}
end
if d[delta] == nil then
d[delta] = 0
end
d[delta] = d[delta] + prob
end
end
return {d, 'dist'}
else
return {{op, a, b}, 'binary'}
end
elseif tok_type == 'func_call' then
local func = tok_data[1]
local args = {}
for i = 2, #tok_data do
local arg = p.eval(tok_data[i], data)
table.insert(args, arg)
end
if #args == 0 then
local val, err = eval_func(func, {})
if err then
return {err, 'err'}
else
local d = {[val]=1}
return {d, 'dist'}
end
else
local no_eval = false
for i = 1, #args do
if args[i][2] ~= 'dist' then
no_eval = true
end
end
if no_eval then
local func_data = {func}
for i = 1, #args do
table.insert(func_data, args[i])
end
return {func_data, 'func_call'}
end
local args_dist = {}
for i = 1, #args do
table.insert(args_dist, args[i][1])
end
local d = {}
for args_item, prob in iterate_arg_dist(args_dist) do
if func == 'random.range' then
if #args_item ~= 2 then
return {'random.range: 2 args expected, ' .. #args_item .. ' provided', 'err'}
end
local a, b = unpack(args_item)
if math.floor(a) ~= a or math.floor(b) ~= b then
return {'random.range requires its arguments to be integer', 'err'}
end
local n = b - a + 1
if n <= 0 then
return {'random.range(' .. a .. ',' .. b .. ') is invalid: ' .. a .. ' > ' .. b, 'err'}
end
if n ~= n or n == n + 1 then
return {'random.range: infinite arguments are not supported', 'err'}
end
for delta = a, b do
if d[delta] == nil then
d[delta] = 0
end
d[delta] = d[delta] + prob / n
end
else
local delta, err = eval_func(func, args_item)
if prob > 0 and err then
return {err, 'err'}
end
if d[delta] == nil then
d[delta] = 0
end
d[delta] = d[delta] + prob
end
end
return {d, 'dist'}
end
elseif tok_type == 'err' then
return tree
else
return {'cannot evaluate ' .. tok_type, 'err'}
end
end
-- Substitutes the specified values into the formula.
-- Also partially evaluates the formula.
-- `assume_unknown`: if true, unknown variables will be assumed to be zero.
function p.substitute(tree, data, assume_unknown)
assume_unknown = assume_unknown or false
local tok_data, tok_type = unpack(tree)
if tok_type == 'num' then
-- there is nothing to do
elseif tok_type == 'var' then
local key = tok_data
if key:sub(1, 6) == 'Input:' then
local _
key, _ = util.normalise_input(key)
end
local val = tonumber(data[key])
if val then
tree = {val, 'num'}
elseif assume_unknown then
tree = {0, 'num'}
else
-- there is nothing to do
end
elseif tok_type == 'unary' then
local op, inner = unpack(tok_data)
inner = p.substitute(inner, data, assume_unknown)
if inner[2] == 'num' then
-- evaluate the operation
local val = inner[1]
if op == '-' then
val = -val
end
tree = {val, 'num'}
else
-- cannot evaluate further
tok_data = {op, inner}
tree = {tok_data, tok_type}
end
elseif tok_type == 'binary' then
local op, a, b = unpack(tok_data)
a = p.substitute(a, data, assume_unknown)
b = p.substitute(b, data, assume_unknown)
if a[2] == 'num' and b[2] == 'num' then
-- evaluate the operation
local val = 0 / 0
if op == '+' then
val = a[1] + b[1]
elseif op == '-' then
val = a[1] - b[1]
elseif op == '*' then
val = a[1] * b[1]
elseif op == '/' then
val = a[1] / b[1]
end
if val ~= val then
tree = {'not a number', 'err'}
elseif val == val + 1 then
tree = {'division by zero', 'err'}
else
tree = {val, 'num'}
end
else
-- cannot evaluate further
tok_data = {op, a, b}
tree = {tok_data, tok_type}
end
elseif tok_type == 'func_call' then
local func_name = tok_data[1]
local args = {}
local all_numbers = true
for i = 2, #tok_data do
local arg = tok_data[i]
arg = p.substitute(arg, data, assume_unknown)
if arg[2] ~= 'num' then
all_numbers = false
end
table.insert(args, arg)
end
if all_numbers and func_name ~= 'random.range' then
-- evaluate the function
local val, f_err = eval_func(func_name, args)
if f_err then
tree = {f_err, 'err'}
else
tree = {val, 'num'}
end
else
-- cannot evaluate further
local func_data = {func_name}
for _, arg in ipairs(args) do
table.insert(func_data, arg)
end
tree = {func_data, 'func_call'}
end
elseif tok_type == 'err' then
-- there is nothing to do
end
return tree
end
-- Returns: the string; whether the last layer has binary +/-.
function p.encode(tree)
local tok_data, tok_type = unpack(tree)
if tok_type == 'num' then
return tostring(tok_data), false
elseif tok_type == 'var' then
return '$(' .. tok_data .. ')', false
elseif tok_type == 'unary' then
local op, inner = unpack(tok_data)
local inner_str, inner_esc = p.encode(inner)
if inner_esc then
return op .. '(' .. inner_str .. ')', false
else
return op .. inner_str, false
end
elseif tok_type == 'binary' then
local op, a, b = unpack(tok_data)
local a_str, a_esc = p.encode(a)
local b_str, b_esc = p.encode(b)
local esc = op:match('[+-]') == op
local s = ''
if op:match('[+-]') == op then
a_esc = false
b_esc = false
end
if a_esc then
s = s .. '(' .. a_str .. ')'
else
s = s .. a_str
end
s = s .. ' ' .. op .. ' '
if b_esc then
s = s .. '(' .. b_str .. ')'
else
s = s .. b_str
end
return s, esc
elseif tok_type == 'func_call' then
local func = tok_data[1]
local s = func .. '('
for i = 2, #tok_data do
if i > 2 then
s = s .. ', '
end
local arg = tok_data[i]
local arg_str, arg_esc = p.encode(arg)
s = s .. arg_str
end
s = s .. ')'
return s, false
elseif tok_type == 'err' then
return 'err($(' .. tostring(tok_data) .. '))'
elseif tok_type == 'dist' then
local d = tok_data
local dist_data = {}
for delta, prob in pairs(d) do
table.insert(dist_data, {delta, prob})
end
if #dist_data == 1 then
local val = dist_data[1][1]
return tostring(val), false
end
return 'err($(cannot encode a disttribution))', false
else
return 'err($(cannot encode ' .. tok_type .. '))', false
end
end
return p