v4k-git-backup/engine/art/scripts/tl.lua

10855 lines
323 KiB
Lua

local _tl_compat; if (tonumber((_VERSION or ''):match('[%d.]*$')) or 0) < 5.3 then local p, m = pcall(require, 'compat53.module'); if p then _tl_compat = m end end; local assert = _tl_compat and _tl_compat.assert or assert; local debug = _tl_compat and _tl_compat.debug or debug; local io = _tl_compat and _tl_compat.io or io; local ipairs = _tl_compat and _tl_compat.ipairs or ipairs; local load = _tl_compat and _tl_compat.load or load; local math = _tl_compat and _tl_compat.math or math; local _tl_math_maxinteger = math.maxinteger or math.pow(2, 53); local os = _tl_compat and _tl_compat.os or os; local package = _tl_compat and _tl_compat.package or package; local pairs = _tl_compat and _tl_compat.pairs or pairs; local string = _tl_compat and _tl_compat.string or string; local table = _tl_compat and _tl_compat.table or table; local _tl_table_unpack = unpack or table.unpack
local VERSION = "0.15.2"
local tl = {TypeCheckOptions = {}, Env = {}, Symbol = {}, Result = {}, Error = {}, TypeInfo = {}, TypeReport = {}, TypeReportEnv = {}, }
tl.version = function()
return VERSION
end
local wk = {
["unknown"] = true,
["unused"] = true,
["redeclaration"] = true,
["branch"] = true,
["hint"] = true,
["debug"] = true,
}
tl.warning_kinds = wk
tl.typecodes = {
NIL = 0x00000001,
NUMBER = 0x00000002,
BOOLEAN = 0x00000004,
STRING = 0x00000008,
TABLE = 0x00000010,
FUNCTION = 0x00000020,
USERDATA = 0x00000040,
THREAD = 0x00000080,
IS_TABLE = 0x00000008,
IS_NUMBER = 0x00000002,
IS_STRING = 0x00000004,
LUA_MASK = 0x00000fff,
INTEGER = 0x00010002,
ARRAY = 0x00010008,
RECORD = 0x00020008,
ARRAYRECORD = 0x00030008,
MAP = 0x00040008,
TUPLE = 0x00080008,
EMPTY_TABLE = 0x00000008,
ENUM = 0x00010004,
IS_ARRAY = 0x00010008,
IS_RECORD = 0x00020008,
NOMINAL = 0x10000000,
TYPE_VARIABLE = 0x08000000,
IS_UNION = 0x40000000,
IS_POLY = 0x20000020,
ANY = 0xffffffff,
UNKNOWN = 0x80008000,
INVALID = 0x80000000,
IS_SPECIAL = 0x80000000,
IS_VALID = 0x00000fff,
}
local TL_DEBUG = os.getenv("TL_DEBUG")
local TL_DEBUG_MAXLINE = _tl_math_maxinteger
if TL_DEBUG then
local max = assert(tonumber(TL_DEBUG), "TL_DEBUG was defined, but not a number")
if max < 0 then
TL_DEBUG_MAXLINE = math.tointeger(-max)
elseif max > 1 then
local count = 0
local skip = nil
debug.sethook(function(event)
if event == "call" or event == "tail call" or event == "return" then
local info = debug.getinfo(2)
if skip then
if info.name == skip and event == "return" then
skip = nil
end
return
elseif (info.name or "?"):match("^tl_debug_") and event == "call" then
skip = info.name
return
end
io.stderr:write(info.name or "<anon>", info.currentline > 0 and "@" .. info.currentline or "", " :: ", event, "\n")
io.stderr:flush()
else
count = count + 100
if count > max then
error("Too many instructions")
end
end
end, "cr", 100)
end
end
do
local last_token_kind = {
["start"] = nil,
["any"] = nil,
["identifier"] = "identifier",
["got -"] = "op",
["got --"] = nil,
["got ."] = ".",
["got .."] = "op",
["got ="] = "op",
["got ~"] = "op",
["got ["] = "[",
["got 0"] = "number",
["got <"] = "op",
["got >"] = "op",
["got /"] = "op",
["got :"] = "op",
["got --["] = nil,
["string single"] = "$ERR invalid_string$",
["string single got \\"] = "$ERR invalid_string$",
["string double"] = "$ERR invalid_string$",
["string double got \\"] = "$ERR invalid_string$",
["string long"] = "$ERR invalid_string$",
["string long got ]"] = "$ERR invalid_string$",
["comment short"] = nil,
["comment long"] = "$ERR unfinished_comment$",
["comment long got ]"] = "$ERR unfinished_comment$",
["number dec"] = "integer",
["number decfloat"] = "number",
["number hex"] = "integer",
["number hexfloat"] = "number",
["number power"] = "number",
["number powersign"] = "$ERR invalid_number$",
}
local keywords = {
["and"] = true,
["break"] = true,
["do"] = true,
["else"] = true,
["elseif"] = true,
["end"] = true,
["false"] = true,
["for"] = true,
["function"] = true,
["goto"] = true,
["if"] = true,
["in"] = true,
["local"] = true,
["nil"] = true,
["not"] = true,
["or"] = true,
["repeat"] = true,
["return"] = true,
["then"] = true,
["true"] = true,
["until"] = true,
["while"] = true,
}
local lex_any_char_states = {
["\""] = "string double",
["'"] = "string single",
["-"] = "got -",
["."] = "got .",
["0"] = "got 0",
["<"] = "got <",
[">"] = "got >",
["/"] = "got /",
[":"] = "got :",
["="] = "got =",
["~"] = "got ~",
["["] = "got [",
}
for c = string.byte("a"), string.byte("z") do
lex_any_char_states[string.char(c)] = "identifier"
end
for c = string.byte("A"), string.byte("Z") do
lex_any_char_states[string.char(c)] = "identifier"
end
lex_any_char_states["_"] = "identifier"
for c = string.byte("1"), string.byte("9") do
lex_any_char_states[string.char(c)] = "number dec"
end
local lex_word = {}
for c = string.byte("a"), string.byte("z") do
lex_word[string.char(c)] = true
end
for c = string.byte("A"), string.byte("Z") do
lex_word[string.char(c)] = true
end
for c = string.byte("0"), string.byte("9") do
lex_word[string.char(c)] = true
end
lex_word["_"] = true
local lex_decimals = {}
for c = string.byte("0"), string.byte("9") do
lex_decimals[string.char(c)] = true
end
local lex_hexadecimals = {}
for c = string.byte("0"), string.byte("9") do
lex_hexadecimals[string.char(c)] = true
end
for c = string.byte("a"), string.byte("f") do
lex_hexadecimals[string.char(c)] = true
end
for c = string.byte("A"), string.byte("F") do
lex_hexadecimals[string.char(c)] = true
end
local lex_any_char_kinds = {}
local single_char_kinds = { "[", "]", "(", ")", "{", "}", ",", "#", ";" }
for _, c in ipairs(single_char_kinds) do
lex_any_char_kinds[c] = c
end
for _, c in ipairs({ "+", "*", "|", "&", "%", "^" }) do
lex_any_char_kinds[c] = "op"
end
local lex_space = {}
for _, c in ipairs({ " ", "\t", "\v", "\n", "\r" }) do
lex_space[c] = true
end
local escapable_characters = {
a = true,
b = true,
f = true,
n = true,
r = true,
t = true,
v = true,
z = true,
["\\"] = true,
["\'"] = true,
["\""] = true,
["\r"] = true,
["\n"] = true,
}
local function lex_string_escape(input, i, c)
if escapable_characters[c] then
return 0, true
elseif c == "x" then
return 2, (
lex_hexadecimals[input:sub(i + 1, i + 1)] and
lex_hexadecimals[input:sub(i + 2, i + 2)])
elseif c == "u" then
if input:sub(i + 1, i + 1) == "{" then
local p = i + 2
if not lex_hexadecimals[input:sub(p, p)] then
return 2, false
end
while true do
p = p + 1
c = input:sub(p, p)
if not lex_hexadecimals[c] then
return p - i, c == "}"
end
end
end
elseif lex_decimals[c] then
local len = lex_decimals[input:sub(i + 1, i + 1)] and
(lex_decimals[input:sub(i + 2, i + 2)] and 2 or 1) or
0
return len, tonumber(input:sub(i, i + len)) < 256
else
return 0, false
end
end
function tl.lex(input, filename)
local tokens = {}
local state = "any"
local fwd = true
local y = 1
local x = 0
local i = 0
local lc_open_lvl = 0
local lc_close_lvl = 0
local ls_open_lvl = 0
local ls_close_lvl = 0
local errs = {}
local nt = 0
local tx
local ty
local ti
local in_token = false
local function begin_token()
tx = x
ty = y
ti = i
in_token = true
end
local function end_token(kind, tk)
nt = nt + 1
tokens[nt] = {
x = tx,
y = ty,
tk = tk,
kind = kind,
}
in_token = false
end
local function end_token_identifier()
local tk = input:sub(ti, i - 1)
nt = nt + 1
tokens[nt] = {
x = tx,
y = ty,
tk = tk,
kind = keywords[tk] and "keyword" or "identifier",
}
in_token = false
end
local function end_token_prev(kind)
local tk = input:sub(ti, i - 1)
nt = nt + 1
tokens[nt] = {
x = tx,
y = ty,
tk = tk,
kind = kind,
}
in_token = false
end
local function end_token_here(kind)
local tk = input:sub(ti, i)
nt = nt + 1
tokens[nt] = {
x = tx,
y = ty,
tk = tk,
kind = kind,
}
in_token = false
end
local function drop_token()
in_token = false
end
local function add_syntax_error()
local t = tokens[nt]
local msg
if t.kind == "$ERR invalid_string$" then
msg = "malformed string"
elseif t.kind == "$ERR invalid_number$" then
msg = "malformed number"
elseif t.kind == "$ERR unfinished_comment$" then
msg = "unfinished long comment"
else
msg = "invalid token '" .. t.tk .. "'"
end
table.insert(errs, {
filename = filename,
y = t.y,
x = t.x,
msg = msg,
})
end
local len = #input
if input:sub(1, 2) == "#!" then
i = input:find("\n")
if not i then
i = len + 1
end
y = 2
x = 0
end
state = "any"
while i <= len do
if fwd then
i = i + 1
if i > len then
break
end
end
local c = input:sub(i, i)
if fwd then
if c == "\n" then
y = y + 1
x = 0
else
x = x + 1
end
else
fwd = true
end
if state == "any" then
local st = lex_any_char_states[c]
if st then
state = st
begin_token()
else
local k = lex_any_char_kinds[c]
if k then
begin_token()
end_token(k, c)
elseif not lex_space[c] then
begin_token()
end_token_here("$ERR invalid$")
add_syntax_error()
end
end
elseif state == "identifier" then
if not lex_word[c] then
end_token_identifier()
fwd = false
state = "any"
end
elseif state == "string double" then
if c == "\\" then
state = "string double got \\"
elseif c == "\"" then
end_token_here("string")
state = "any"
end
elseif state == "comment short" then
if c == "\n" then
state = "any"
end
elseif state == "got =" then
local t
if c == "=" then
t = "=="
else
t = "="
fwd = false
end
end_token("op", t)
state = "any"
elseif state == "got ." then
if c == "." then
state = "got .."
elseif lex_decimals[c] then
state = "number decfloat"
else
end_token(".", ".")
fwd = false
state = "any"
end
elseif state == "got :" then
local t
if c == ":" then
t = "::"
else
t = ":"
fwd = false
end
end_token(t, t)
state = "any"
elseif state == "got [" then
if c == "[" then
state = "string long"
elseif c == "=" then
ls_open_lvl = ls_open_lvl + 1
else
end_token("[", "[")
fwd = false
state = "any"
ls_open_lvl = 0
end
elseif state == "number dec" then
if lex_decimals[c] then
elseif c == "." then
state = "number decfloat"
elseif c == "e" or c == "E" then
state = "number powersign"
else
end_token_prev("integer")
fwd = false
state = "any"
end
elseif state == "got -" then
if c == "-" then
state = "got --"
else
end_token("op", "-")
fwd = false
state = "any"
end
elseif state == "got .." then
if c == "." then
end_token("...", "...")
else
end_token("op", "..")
fwd = false
end
state = "any"
elseif state == "number hex" then
if lex_hexadecimals[c] then
elseif c == "." then
state = "number hexfloat"
elseif c == "p" or c == "P" then
state = "number powersign"
else
end_token_prev("integer")
fwd = false
state = "any"
end
elseif state == "got --" then
if c == "[" then
state = "got --["
else
fwd = false
state = "comment short"
drop_token()
end
elseif state == "got 0" then
if c == "x" or c == "X" then
state = "number hex"
elseif c == "e" or c == "E" then
state = "number powersign"
elseif lex_decimals[c] then
state = "number dec"
elseif c == "." then
state = "number decfloat"
else
end_token_prev("integer")
fwd = false
state = "any"
end
elseif state == "got --[" then
if c == "[" then
state = "comment long"
elseif c == "=" then
lc_open_lvl = lc_open_lvl + 1
else
fwd = false
state = "comment short"
drop_token()
lc_open_lvl = 0
end
elseif state == "comment long" then
if c == "]" then
state = "comment long got ]"
end
elseif state == "comment long got ]" then
if c == "]" and lc_close_lvl == lc_open_lvl then
drop_token()
state = "any"
lc_open_lvl = 0
lc_close_lvl = 0
elseif c == "=" then
lc_close_lvl = lc_close_lvl + 1
else
state = "comment long"
lc_close_lvl = 0
end
elseif state == "string double got \\" then
local skip, valid = lex_string_escape(input, i, c)
i = i + skip
if not valid then
end_token_here("$ERR invalid_string$")
add_syntax_error()
end
x = x + skip
state = "string double"
elseif state == "string single" then
if c == "\\" then
state = "string single got \\"
elseif c == "'" then
end_token_here("string")
state = "any"
end
elseif state == "string single got \\" then
local skip, valid = lex_string_escape(input, i, c)
i = i + skip
if not valid then
end_token_here("$ERR invalid_string$")
add_syntax_error()
end
x = x + skip
state = "string single"
elseif state == "got ~" then
local t
if c == "=" then
t = "~="
else
t = "~"
fwd = false
end
end_token("op", t)
state = "any"
elseif state == "got <" then
local t
if c == "=" then
t = "<="
elseif c == "<" then
t = "<<"
else
t = "<"
fwd = false
end
end_token("op", t)
state = "any"
elseif state == "got >" then
local t
if c == "=" then
t = ">="
elseif c == ">" then
t = ">>"
else
t = ">"
fwd = false
end
end_token("op", t)
state = "any"
elseif state == "got /" then
local t
if c == "/" then
t = "//"
else
t = "/"
fwd = false
end
end_token("op", t)
state = "any"
elseif state == "string long" then
if c == "]" then
state = "string long got ]"
end
elseif state == "string long got ]" then
if c == "]" then
if ls_close_lvl == ls_open_lvl then
end_token_here("string")
state = "any"
ls_open_lvl = 0
ls_close_lvl = 0
end
elseif c == "=" then
ls_close_lvl = ls_close_lvl + 1
else
state = "string long"
ls_close_lvl = 0
end
elseif state == "number hexfloat" then
if c == "p" or c == "P" then
state = "number powersign"
elseif not lex_hexadecimals[c] then
end_token_prev("number")
fwd = false
state = "any"
end
elseif state == "number decfloat" then
if c == "e" or c == "E" then
state = "number powersign"
elseif not lex_decimals[c] then
end_token_prev("number")
fwd = false
state = "any"
end
elseif state == "number powersign" then
if c == "-" or c == "+" then
state = "number power"
elseif lex_decimals[c] then
state = "number power"
else
end_token_here("$ERR invalid_number$")
add_syntax_error()
state = "any"
end
elseif state == "number power" then
if not lex_decimals[c] then
end_token_prev("number")
fwd = false
state = "any"
end
end
end
if in_token then
if last_token_kind[state] then
end_token_prev(last_token_kind[state])
if last_token_kind[state]:sub(1, 4) == "$ERR" then
add_syntax_error()
elseif keywords[tokens[nt].tk] then
tokens[nt].kind = "keyword"
end
else
drop_token()
end
end
table.insert(tokens, { x = x + 1, y = y, i = i, tk = "$EOF$", kind = "$EOF$" })
return tokens, errs
end
end
local function binary_search(list, item, cmp)
local len = #list
local mid
local s, e = 1, len
while s <= e do
mid = math.floor((s + e) / 2)
local val = list[mid]
local res = cmp(val, item)
if res then
if mid == len then
return mid, val
else
if not cmp(list[mid + 1], item) then
return mid, val
end
end
s = mid + 1
else
e = mid - 1
end
end
end
function tl.get_token_at(tks, y, x)
local _, found = binary_search(
tks, nil,
function(tk)
return tk.y < y or
(tk.y == y and tk.x <= x)
end)
if found and
found.y == y and
found.x <= x and x < found.x + #found.tk then
return found.tk
end
end
local last_typeid = 0
local function new_typeid()
last_typeid = last_typeid + 1
return last_typeid
end
local table_types = {
["array"] = true,
["map"] = true,
["arrayrecord"] = true,
["record"] = true,
["emptytable"] = true,
["tupletable"] = true,
["typetype"] = false,
["nestedtype"] = false,
["typevar"] = false,
["typearg"] = false,
["function"] = false,
["enum"] = false,
["boolean"] = false,
["string"] = false,
["nil"] = false,
["thread"] = false,
["number"] = false,
["integer"] = false,
["union"] = false,
["nominal"] = false,
["bad_nominal"] = false,
["table_item"] = false,
["unresolved_emptytable_value"] = false,
["unresolved_typearg"] = false,
["unresolvable_typearg"] = false,
["circular_require"] = false,
["tuple"] = false,
["poly"] = false,
["any"] = false,
["unknown"] = false,
["invalid"] = false,
["unresolved"] = false,
["none"] = false,
}
local Fact = {}
local attributes = {
["const"] = true,
["close"] = true,
["total"] = true,
}
local is_attribute = attributes
local function is_array_type(t)
return t.typename == "array" or t.typename == "arrayrecord"
end
local function is_record_type(t)
return t.typename == "record" or t.typename == "arrayrecord"
end
local function is_number_type(t)
return t.typename == "number" or t.typename == "integer"
end
local function is_typetype(t)
return t.typename == "typetype" or t.typename == "nestedtype"
end
local parse_type_list
local parse_expression
local parse_expression_and_tk
local parse_statements
local parse_argument_list
local parse_argument_type_list
local parse_type
local parse_newtype
local parse_enum_body
local parse_record_body
local function fail(ps, i, msg)
if not ps.tokens[i] then
local eof = ps.tokens[#ps.tokens]
table.insert(ps.errs, { filename = ps.filename, y = eof.y, x = eof.x, msg = msg or "unexpected end of file" })
return #ps.tokens
end
table.insert(ps.errs, { filename = ps.filename, y = ps.tokens[i].y, x = ps.tokens[i].x, msg = assert(msg, "syntax error, but no error message provided") })
return math.min(#ps.tokens, i + 1)
end
local function end_at(node, tk)
node.yend = tk.y
node.xend = tk.x + #tk.tk - 1
end
local function verify_tk(ps, i, tk)
if ps.tokens[i].tk == tk then
return i + 1
end
return fail(ps, i, "syntax error, expected '" .. tk .. "'")
end
local function verify_end(ps, i, istart, node)
if ps.tokens[i].tk == "end" then
local endy, endx = ps.tokens[i].y, ps.tokens[i].x
node.yend = endy
node.xend = endx + 2
if node.kind ~= "function" and endy ~= node.y and endx ~= node.x then
if not ps.end_alignment_hint then
ps.end_alignment_hint = { filename = ps.filename, y = node.y, x = node.x, msg = "syntax error hint: construct starting here is not aligned with its 'end' at " .. ps.filename .. ":" .. endy .. ":" .. endx .. ":" }
end
end
return i + 1
end
end_at(node, ps.tokens[i])
if ps.end_alignment_hint then
table.insert(ps.errs, ps.end_alignment_hint)
ps.end_alignment_hint = nil
end
return fail(ps, i, "syntax error, expected 'end' to close construct started at " .. ps.filename .. ":" .. ps.tokens[istart].y .. ":" .. ps.tokens[istart].x .. ":")
end
local function new_node(tokens, i, kind)
local t = tokens[i]
return { y = t.y, x = t.x, tk = t.tk, kind = kind or t.kind }
end
local function a_type(t)
t.typeid = new_typeid()
return t
end
local function new_type(ps, i, typename)
local token = ps.tokens[i]
return a_type({
typename = assert(typename),
filename = ps.filename,
y = token.y,
x = token.x,
tk = token.tk,
})
end
local function shallow_copy_type(t)
local copy = {}
for k, v in pairs(t) do
copy[k] = v
end
return copy
end
local function verify_kind(ps, i, kind, node_kind)
if ps.tokens[i].kind == kind then
return i + 1, new_node(ps.tokens, i, node_kind)
end
return fail(ps, i, "syntax error, expected " .. kind)
end
local function skip(ps, i, skip_fn)
local err_ps = {
filename = ps.filename,
tokens = ps.tokens,
errs = {},
required_modules = {},
}
return skip_fn(err_ps, i)
end
local function failskip(ps, i, msg, skip_fn, starti)
local skip_i = skip(ps, starti or i, skip_fn)
fail(ps, i, msg)
return skip_i
end
local function skip_record(ps, i)
i = i + 1
return parse_record_body(ps, i, {}, { kind = "function" })
end
local function skip_enum(ps, i)
i = i + 1
return parse_enum_body(ps, i, {}, { kind = "function" })
end
local function parse_table_value(ps, i)
local next_word = ps.tokens[i].tk
local e
if next_word == "record" then
i = failskip(ps, i, "syntax error: this syntax is no longer valid; declare nested record inside a record", skip_record)
elseif next_word == "enum" then
i = failskip(ps, i, "syntax error: this syntax is no longer valid; declare nested enum inside a record", skip_enum)
else
i, e = parse_expression(ps, i)
end
if not e then
e = new_node(ps.tokens, i - 1, "error_node")
end
return i, e
end
local function parse_table_item(ps, i, n)
local node = new_node(ps.tokens, i, "table_item")
if ps.tokens[i].kind == "$EOF$" then
return fail(ps, i, "unexpected eof")
end
if ps.tokens[i].tk == "[" then
node.key_parsed = "long"
i = i + 1
i, node.key = parse_expression_and_tk(ps, i, "]")
i = verify_tk(ps, i, "=")
i, node.value = parse_table_value(ps, i)
return i, node, n
elseif ps.tokens[i].kind == "identifier" then
if ps.tokens[i + 1].tk == "=" then
node.key_parsed = "short"
i, node.key = verify_kind(ps, i, "identifier", "string")
node.key.conststr = node.key.tk
node.key.tk = '"' .. node.key.tk .. '"'
i = verify_tk(ps, i, "=")
i, node.value = parse_table_value(ps, i)
return i, node, n
elseif ps.tokens[i + 1].tk == ":" then
node.key_parsed = "short"
local orig_i = i
local try_ps = {
filename = ps.filename,
tokens = ps.tokens,
errs = {},
required_modules = ps.required_modules,
}
i, node.key = verify_kind(try_ps, i, "identifier", "string")
node.key.conststr = node.key.tk
node.key.tk = '"' .. node.key.tk .. '"'
i = verify_tk(try_ps, i, ":")
i, node.decltype = parse_type(try_ps, i)
if node.decltype and ps.tokens[i].tk == "=" then
i = verify_tk(try_ps, i, "=")
i, node.value = parse_table_value(try_ps, i)
if node.value then
for _, e in ipairs(try_ps.errs) do
table.insert(ps.errs, e)
end
return i, node, n
end
end
node.decltype = nil
i = orig_i
end
end
node.key = new_node(ps.tokens, i, "integer")
node.key_parsed = "implicit"
node.key.constnum = n
node.key.tk = tostring(n)
i, node.value = parse_expression(ps, i)
if not node.value then
return fail(ps, i, "expected an expression")
end
return i, node, n + 1
end
local function parse_list(ps, i, list, close, sep, parse_item)
local n = 1
while ps.tokens[i].kind ~= "$EOF$" do
if close[ps.tokens[i].tk] then
end_at(list, ps.tokens[i])
break
end
local item
local oldn = n
i, item, n = parse_item(ps, i, n)
n = n or oldn
table.insert(list, item)
if ps.tokens[i].tk == "," then
i = i + 1
if sep == "sep" and close[ps.tokens[i].tk] then
fail(ps, i, "unexpected '" .. ps.tokens[i].tk .. "'")
return i, list
end
elseif sep == "term" and ps.tokens[i].tk == ";" then
i = i + 1
elseif not close[ps.tokens[i].tk] then
local options = {}
for k, _ in pairs(close) do
table.insert(options, "'" .. k .. "'")
end
table.sort(options)
local first = options[1]:sub(2, -2)
local msg
if first == ")" and ps.tokens[i].tk == "=" then
msg = "syntax error, cannot perform an assignment here (did you mean '=='?)"
i = failskip(ps, i, msg, parse_expression, i + 1)
else
table.insert(options, "','")
msg = "syntax error, expected one of: " .. table.concat(options, ", ")
fail(ps, i, msg)
end
if first ~= "}" and ps.tokens[i].y ~= ps.tokens[i - 1].y then
table.insert(ps.tokens, i, { tk = first, y = ps.tokens[i - 1].y, x = ps.tokens[i - 1].x + 1, kind = "keyword" })
return i, list
end
end
end
return i, list
end
local function parse_bracket_list(ps, i, list, open, close, sep, parse_item)
i = verify_tk(ps, i, open)
i = parse_list(ps, i, list, { [close] = true }, sep, parse_item)
i = verify_tk(ps, i, close)
return i, list
end
local function parse_table_literal(ps, i)
local node = new_node(ps.tokens, i, "table_literal")
return parse_bracket_list(ps, i, node, "{", "}", "term", parse_table_item)
end
local function parse_trying_list(ps, i, list, parse_item)
local try_ps = {
filename = ps.filename,
tokens = ps.tokens,
errs = {},
required_modules = ps.required_modules,
}
local tryi, item = parse_item(try_ps, i)
if not item then
return i, list
end
for _, e in ipairs(try_ps.errs) do
table.insert(ps.errs, e)
end
i = tryi
table.insert(list, item)
if ps.tokens[i].tk == "," then
while ps.tokens[i].tk == "," do
i = i + 1
i, item = parse_item(ps, i)
table.insert(list, item)
end
end
return i, list
end
local function parse_anglebracket_list(ps, i, parse_item)
if ps.tokens[i + 1].tk == ">" then
return fail(ps, i + 1, "type argument list cannot be empty")
end
local typ = new_type(ps, i, "tuple")
i = verify_tk(ps, i, "<")
i = parse_list(ps, i, typ, { [">"] = true, [">>"] = true }, "sep", parse_item)
if ps.tokens[i].tk == ">" then
i = i + 1
elseif ps.tokens[i].tk == ">>" then
ps.tokens[i].tk = ">"
else
return fail(ps, i, "syntax error, expected '>'")
end
return i, typ
end
local function parse_typearg(ps, i)
i = verify_kind(ps, i, "identifier")
return i, a_type({
y = ps.tokens[i - 2].y,
x = ps.tokens[i - 2].x,
typename = "typearg",
typearg = ps.tokens[i - 1].tk,
})
end
local function parse_return_types(ps, i)
return parse_type_list(ps, i, "rets")
end
local function parse_function_type(ps, i)
local typ = new_type(ps, i, "function")
i = i + 1
if ps.tokens[i].tk == "<" then
i, typ.typeargs = parse_anglebracket_list(ps, i, parse_typearg)
end
if ps.tokens[i].tk == "(" then
i, typ.args = parse_argument_type_list(ps, i)
i, typ.rets = parse_return_types(ps, i)
else
typ.args = a_type({ typename = "tuple", is_va = true, a_type({ typename = "any" }) })
typ.rets = a_type({ typename = "tuple", is_va = true, a_type({ typename = "any" }) })
end
if typ.args[1] and typ.args[1].is_self then
typ.is_method = true
end
return i, typ
end
local NIL = a_type({ typename = "nil" })
local ANY = a_type({ typename = "any" })
local TABLE = a_type({ typename = "map", keys = ANY, values = ANY })
local NUMBER = a_type({ typename = "number" })
local STRING = a_type({ typename = "string" })
local THREAD = a_type({ typename = "thread" })
local BOOLEAN = a_type({ typename = "boolean" })
local INTEGER = a_type({ typename = "integer" })
local simple_types = {
["nil"] = NIL,
["any"] = ANY,
["table"] = TABLE,
["number"] = NUMBER,
["string"] = STRING,
["thread"] = THREAD,
["boolean"] = BOOLEAN,
["integer"] = INTEGER,
}
local function parse_base_type(ps, i)
local tk = ps.tokens[i].tk
if ps.tokens[i].kind == "identifier" then
local st = simple_types[tk]
if st then
return i + 1, st
end
local typ = new_type(ps, i, "nominal")
typ.names = { tk }
i = i + 1
while ps.tokens[i].tk == "." do
i = i + 1
if ps.tokens[i].kind == "identifier" then
table.insert(typ.names, ps.tokens[i].tk)
i = i + 1
else
return fail(ps, i, "syntax error, expected identifier")
end
end
if ps.tokens[i].tk == "<" then
i, typ.typevals = parse_anglebracket_list(ps, i, parse_type)
end
return i, typ
elseif tk == "{" then
i = i + 1
local decl = new_type(ps, i, "array")
local t
i, t = parse_type(ps, i)
if not t then
return i
end
if ps.tokens[i].tk == "}" then
decl.elements = t
end_at(decl, ps.tokens[i])
i = verify_tk(ps, i, "}")
elseif ps.tokens[i].tk == "," then
decl.typename = "tupletable"
decl.types = { t }
local n = 2
repeat
i = i + 1
i, decl.types[n] = parse_type(ps, i)
if not decl.types[n] then
break
end
n = n + 1
until ps.tokens[i].tk ~= ","
end_at(decl, ps.tokens[i])
i = verify_tk(ps, i, "}")
elseif ps.tokens[i].tk == ":" then
decl.typename = "map"
i = i + 1
decl.keys = t
i, decl.values = parse_type(ps, i)
if not decl.values then
return i
end
end_at(decl, ps.tokens[i])
i = verify_tk(ps, i, "}")
else
return fail(ps, i, "syntax error; did you forget a '}'?")
end
return i, decl
elseif tk == "function" then
return parse_function_type(ps, i)
elseif tk == "nil" then
return i + 1, simple_types["nil"]
elseif tk == "table" then
local typ = new_type(ps, i, "map")
typ.keys = a_type({ typename = "any" })
typ.values = a_type({ typename = "any" })
return i + 1, typ
end
return fail(ps, i, "expected a type")
end
parse_type = function(ps, i)
if ps.tokens[i].tk == "(" then
i = i + 1
local t
i, t = parse_type(ps, i)
i = verify_tk(ps, i, ")")
return i, t
end
local bt
local istart = i
i, bt = parse_base_type(ps, i)
if not bt then
return i
end
if ps.tokens[i].tk == "|" then
local u = new_type(ps, istart, "union")
u.types = { bt }
while ps.tokens[i].tk == "|" do
i = i + 1
i, bt = parse_base_type(ps, i)
if not bt then
return i
end
table.insert(u.types, bt)
end
bt = u
end
return i, bt
end
parse_type_list = function(ps, i, mode)
local list = new_type(ps, i, "tuple")
local first_token = ps.tokens[i].tk
if mode == "rets" or mode == "decltype" then
if first_token == ":" then
i = i + 1
else
return i, list
end
end
local optional_paren = false
if ps.tokens[i].tk == "(" then
optional_paren = true
i = i + 1
end
local prev_i = i
i = parse_trying_list(ps, i, list, parse_type)
if i == prev_i and ps.tokens[i].tk ~= ")" then
fail(ps, i - 1, "expected a type list")
end
if mode == "rets" and ps.tokens[i].tk == "..." then
i = i + 1
local nrets = #list
if nrets > 0 then
list.is_va = true
else
fail(ps, i, "unexpected '...'")
end
end
if optional_paren then
i = verify_tk(ps, i, ")")
end
return i, list
end
local function parse_function_args_rets_body(ps, i, node)
local istart = i - 1
if ps.tokens[i].tk == "<" then
i, node.typeargs = parse_anglebracket_list(ps, i, parse_typearg)
end
i, node.args = parse_argument_list(ps, i)
i, node.rets = parse_return_types(ps, i)
i, node.body = parse_statements(ps, i)
end_at(node, ps.tokens[i])
i = verify_end(ps, i, istart, node)
assert(node.rets.typename == "tuple")
return i, node
end
local function parse_function_value(ps, i)
local node = new_node(ps.tokens, i, "function")
i = verify_tk(ps, i, "function")
return parse_function_args_rets_body(ps, i, node)
end
local function unquote(str)
local f = str:sub(1, 1)
if f == '"' or f == "'" then
return str:sub(2, -2), false
end
f = str:match("^%[=*%[")
local l = #f + 1
return str:sub(l, -l), true
end
local function parse_literal(ps, i)
local tk = ps.tokens[i].tk
local kind = ps.tokens[i].kind
if kind == "identifier" then
return verify_kind(ps, i, "identifier", "variable")
elseif kind == "string" then
local node = new_node(ps.tokens, i, "string")
node.conststr, node.is_longstring = unquote(tk)
return i + 1, node
elseif kind == "number" or kind == "integer" then
local n = tonumber(tk)
local node
i, node = verify_kind(ps, i, kind)
node.constnum = n
return i, node
elseif tk == "true" then
return verify_kind(ps, i, "keyword", "boolean")
elseif tk == "false" then
return verify_kind(ps, i, "keyword", "boolean")
elseif tk == "nil" then
return verify_kind(ps, i, "keyword", "nil")
elseif tk == "function" then
return parse_function_value(ps, i)
elseif tk == "{" then
return parse_table_literal(ps, i)
elseif kind == "..." then
return verify_kind(ps, i, "...")
elseif kind == "$ERR invalid_string$" then
return fail(ps, i, "malformed string")
elseif kind == "$ERR invalid_number$" then
return fail(ps, i, "malformed number")
end
return fail(ps, i, "syntax error")
end
local function node_is_require_call(n)
if n.e1 and n.e2 and
n.e1.kind == "variable" and n.e1.tk == "require" and
n.e2.kind == "expression_list" and #n.e2 == 1 and
n.e2[1].kind == "string" then
return n.e2[1].conststr
elseif n.op and n.op.op == "@funcall" and
n.e1 and n.e1.tk == "pcall" and
n.e2 and #n.e2 == 2 and
n.e2[1].kind == "variable" and n.e2[1].tk == "require" and
n.e2[2].kind == "string" and n.e2[2].conststr then
return n.e2[2].conststr
else
return nil
end
end
local an_operator
do
local precedences = {
[1] = {
["not"] = 11,
["#"] = 11,
["-"] = 11,
["~"] = 11,
},
[2] = {
["or"] = 1,
["and"] = 2,
["is"] = 3,
["<"] = 3,
[">"] = 3,
["<="] = 3,
[">="] = 3,
["~="] = 3,
["=="] = 3,
["|"] = 4,
["~"] = 5,
["&"] = 6,
["<<"] = 7,
[">>"] = 7,
[".."] = 8,
["+"] = 9,
["-"] = 9,
["*"] = 10,
["/"] = 10,
["//"] = 10,
["%"] = 10,
["^"] = 12,
["as"] = 50,
["@funcall"] = 100,
["@index"] = 100,
["."] = 100,
[":"] = 100,
},
}
local is_right_assoc = {
["^"] = true,
[".."] = true,
}
local function new_operator(tk, arity, op)
op = op or tk.tk
return { y = tk.y, x = tk.x, arity = arity, op = op, prec = precedences[arity][op] }
end
an_operator = function(node, arity, op)
return { y = node.y, x = node.x, arity = arity, op = op, prec = precedences[arity][op] }
end
local args_starters = {
["("] = true,
["{"] = true,
["string"] = true,
}
local E
local function after_valid_prefixexp(ps, prevnode, i)
return ps.tokens[i - 1].kind == ")" or
(prevnode.kind == "op" and
(prevnode.op.op == "@funcall" or
prevnode.op.op == "@index" or
prevnode.op.op == "." or
prevnode.op.op == ":")) or
prevnode.kind == "identifier" or
prevnode.kind == "variable"
end
local function failstore(tkop, e1)
return { y = tkop.y, x = tkop.x, kind = "paren", e1 = e1, failstore = true }
end
local function P(ps, i)
if ps.tokens[i].kind == "$EOF$" then
return i
end
local e1
local t1 = ps.tokens[i]
if precedences[1][ps.tokens[i].tk] ~= nil then
local op = new_operator(ps.tokens[i], 1)
i = i + 1
local prev_i = i
i, e1 = P(ps, i)
if not e1 then
fail(ps, prev_i, "expected an expression")
return i
end
e1 = { y = t1.y, x = t1.x, kind = "op", op = op, e1 = e1 }
elseif ps.tokens[i].tk == "(" then
i = i + 1
local prev_i = i
i, e1 = parse_expression_and_tk(ps, i, ")")
if not e1 then
fail(ps, prev_i, "expected an expression")
return i
end
e1 = { y = t1.y, x = t1.x, kind = "paren", e1 = e1 }
else
i, e1 = parse_literal(ps, i)
end
if not e1 then
return i
end
while true do
local tkop = ps.tokens[i]
if tkop.kind == "," or tkop.kind == ")" then
break
end
if tkop.tk == "." or tkop.tk == ":" then
local op = new_operator(tkop, 2)
local prev_i = i
local key
i = i + 1
if ps.tokens[i].kind ~= "identifier" then
local skipped = skip(ps, i, parse_type)
if skipped > i + 1 then
fail(ps, i, "syntax error, cannot declare a type here (missing 'local' or 'global'?)")
return skipped, failstore(tkop, e1)
end
end
i, key = verify_kind(ps, i, "identifier")
if not key then
return i, failstore(tkop, e1)
end
if op.op == ":" then
if not args_starters[ps.tokens[i].kind] then
if ps.tokens[i].tk == "=" then
fail(ps, i, "syntax error, cannot perform an assignment here (missing 'local' or 'global'?)")
else
fail(ps, i, "expected a function call for a method")
end
return i, failstore(tkop, e1)
end
if not after_valid_prefixexp(ps, e1, prev_i) then
fail(ps, prev_i, "cannot call a method on this expression")
return i, failstore(tkop, e1)
end
end
e1 = { y = tkop.y, x = tkop.x, kind = "op", op = op, e1 = e1, e2 = key }
elseif tkop.tk == "(" then
local op = new_operator(tkop, 2, "@funcall")
local prev_i = i
local args = new_node(ps.tokens, i, "expression_list")
i, args = parse_bracket_list(ps, i, args, "(", ")", "sep", parse_expression)
if not after_valid_prefixexp(ps, e1, prev_i) then
fail(ps, prev_i, "cannot call this expression")
return i, failstore(tkop, e1)
end
e1 = { y = args.y, x = args.x, kind = "op", op = op, e1 = e1, e2 = args }
table.insert(ps.required_modules, node_is_require_call(e1))
elseif tkop.tk == "[" then
local op = new_operator(tkop, 2, "@index")
local prev_i = i
local idx
i = i + 1
i, idx = parse_expression_and_tk(ps, i, "]")
if not after_valid_prefixexp(ps, e1, prev_i) then
fail(ps, prev_i, "cannot index this expression")
return i, failstore(tkop, e1)
end
e1 = { y = tkop.y, x = tkop.x, kind = "op", op = op, e1 = e1, e2 = idx }
elseif tkop.kind == "string" or tkop.kind == "{" then
local op = new_operator(tkop, 2, "@funcall")
local prev_i = i
local args = new_node(ps.tokens, i, "expression_list")
local argument
if tkop.kind == "string" then
argument = new_node(ps.tokens, i)
argument.conststr = unquote(tkop.tk)
i = i + 1
else
i, argument = parse_table_literal(ps, i)
end
if not after_valid_prefixexp(ps, e1, prev_i) then
if tkop.kind == "string" then
fail(ps, prev_i, "cannot use a string here; if you're trying to call the previous expression, wrap it in parentheses")
else
fail(ps, prev_i, "cannot use a table here; if you're trying to call the previous expression, wrap it in parentheses")
end
return i, failstore(tkop, e1)
end
table.insert(args, argument)
e1 = { y = args.y, x = args.x, kind = "op", op = op, e1 = e1, e2 = args }
table.insert(ps.required_modules, node_is_require_call(e1))
elseif tkop.tk == "as" or tkop.tk == "is" then
local op = new_operator(tkop, 2, tkop.tk)
i = i + 1
local cast = new_node(ps.tokens, i, "cast")
if ps.tokens[i].tk == "(" then
i, cast.casttype = parse_type_list(ps, i, "casttype")
else
i, cast.casttype = parse_type(ps, i)
end
if not cast.casttype then
return i, failstore(tkop, e1)
end
e1 = { y = tkop.y, x = tkop.x, kind = "op", op = op, e1 = e1, e2 = cast, conststr = e1.conststr }
else
break
end
end
return i, e1
end
E = function(ps, i, lhs, min_precedence)
local lookahead = ps.tokens[i].tk
while precedences[2][lookahead] and precedences[2][lookahead] >= min_precedence do
local t1 = ps.tokens[i]
local op = new_operator(t1, 2)
i = i + 1
local rhs
i, rhs = P(ps, i)
if not rhs then
fail(ps, i, "expected an expression")
return i
end
lookahead = ps.tokens[i].tk
while precedences[2][lookahead] and ((precedences[2][lookahead] > (precedences[2][op.op])) or
(is_right_assoc[lookahead] and (precedences[2][lookahead] == precedences[2][op.op]))) do
i, rhs = E(ps, i, rhs, precedences[2][lookahead])
if not rhs then
fail(ps, i, "expected an expression")
return i
end
lookahead = ps.tokens[i].tk
end
lhs = { y = t1.y, x = t1.x, kind = "op", op = op, e1 = lhs, e2 = rhs }
end
return i, lhs
end
parse_expression = function(ps, i)
local lhs
local istart = i
i, lhs = P(ps, i)
if lhs then
i, lhs = E(ps, i, lhs, 0)
end
if lhs then
return i, lhs, 0
end
if i == istart then
i = fail(ps, i, "expected an expression")
end
return i
end
end
parse_expression_and_tk = function(ps, i, tk)
local e
i, e = parse_expression(ps, i)
if not e then
e = new_node(ps.tokens, i - 1, "error_node")
end
if ps.tokens[i].tk == tk then
i = i + 1
else
local msg = "syntax error, expected '" .. tk .. "'"
if ps.tokens[i].tk == "=" then
msg = "syntax error, cannot perform an assignment here (did you mean '=='?)"
end
for n = 0, 19 do
local t = ps.tokens[i + n]
if t.kind == "$EOF$" then
break
end
if t.tk == tk then
fail(ps, i, msg)
return i + n + 1, e
end
end
i = fail(ps, i, msg)
end
return i, e
end
local function parse_variable_name(ps, i)
local node
i, node = verify_kind(ps, i, "identifier")
if not node then
return i
end
if ps.tokens[i].tk == "<" then
i = i + 1
local annotation
i, annotation = verify_kind(ps, i, "identifier")
if annotation then
if not is_attribute[annotation.tk] then
fail(ps, i, "unknown variable annotation: " .. annotation.tk)
end
node.attribute = annotation.tk
else
fail(ps, i, "expected a variable annotation")
end
i = verify_tk(ps, i, ">")
end
return i, node
end
local function parse_argument(ps, i)
local node
if ps.tokens[i].tk == "..." then
i, node = verify_kind(ps, i, "...", "argument")
else
i, node = verify_kind(ps, i, "identifier", "argument")
end
if ps.tokens[i].tk == ":" then
i = i + 1
local decltype
i, decltype = parse_type(ps, i)
if node then
node.decltype = decltype
end
end
return i, node, 0
end
parse_argument_list = function(ps, i)
local node = new_node(ps.tokens, i, "argument_list")
i, node = parse_bracket_list(ps, i, node, "(", ")", "sep", parse_argument)
for a, fnarg in ipairs(node) do
if fnarg.tk == "..." and a ~= #node then
fail(ps, i, "'...' can only be last argument")
end
end
return i, node
end
local function parse_argument_type(ps, i)
local is_va = false
local argument_name = nil
if ps.tokens[i].kind == "identifier" and ps.tokens[i + 1].tk == ":" then
argument_name = ps.tokens[i].tk
i = i + 2
elseif ps.tokens[i].tk == "..." then
if ps.tokens[i + 1].tk == ":" then
i = i + 2
is_va = true
else
return fail(ps, i, "cannot have untyped '...' when declaring the type of an argument")
end
end
local typ; i, typ = parse_type(ps, i)
if typ then
typ.is_va = is_va
end
if argument_name == "self" then
typ.is_self = true
end
return i, typ, 0
end
parse_argument_type_list = function(ps, i)
local list = new_type(ps, i, "tuple")
i = parse_bracket_list(ps, i, list, "(", ")", "sep", parse_argument_type)
if list[#list] and list[#list].is_va then
list[#list].is_va = nil
list.is_va = true
end
return i, list
end
local function parse_identifier(ps, i)
if ps.tokens[i].kind == "identifier" then
return i + 1, new_node(ps.tokens, i, "identifier")
end
i = fail(ps, i, "syntax error, expected identifier")
return i, new_node(ps.tokens, i, "error_node")
end
local function parse_local_function(ps, i)
i = verify_tk(ps, i, "local")
i = verify_tk(ps, i, "function")
local node = new_node(ps.tokens, i - 2, "local_function")
i, node.name = parse_identifier(ps, i)
return parse_function_args_rets_body(ps, i, node)
end
local function parse_function(ps, i, ft)
local orig_i = i
i = verify_tk(ps, i, "function")
local fn = new_node(ps.tokens, i - 1, "global_function")
local names = {}
i, names[1] = parse_identifier(ps, i)
while ps.tokens[i].tk == "." do
i = i + 1
i, names[#names + 1] = parse_identifier(ps, i)
end
if ps.tokens[i].tk == ":" then
i = i + 1
i, names[#names + 1] = parse_identifier(ps, i)
fn.is_method = true
end
if #names > 1 then
fn.kind = "record_function"
local owner = names[1]
owner.kind = "type_identifier"
for i2 = 2, #names - 1 do
local dot = an_operator(names[i2], 2, ".")
names[i2].kind = "identifier"
owner = { y = names[i2].y, x = names[i2].x, kind = "op", op = dot, e1 = owner, e2 = names[i2] }
end
fn.fn_owner = owner
end
fn.name = names[#names]
local selfx, selfy = ps.tokens[i].x, ps.tokens[i].y
i = parse_function_args_rets_body(ps, i, fn)
if fn.is_method then
table.insert(fn.args, 1, { x = selfx, y = selfy, tk = "self", kind = "identifier", is_self = true })
end
if not fn.name then
return orig_i + 1
end
if fn.kind == "record_function" and ft == "global" then
fail(ps, orig_i, "record functions cannot be annotated as 'global'")
elseif fn.kind == "global_function" and ft == "record" then
fn.implicit_global_function = true
end
return i, fn
end
local function parse_if_block(ps, i, n, node, is_else)
local block = new_node(ps.tokens, i, "if_block")
i = i + 1
block.if_parent = node
block.if_block_n = n
if not is_else then
i, block.exp = parse_expression_and_tk(ps, i, "then")
if not block.exp then
return i
end
end
i, block.body = parse_statements(ps, i)
if not block.body then
return i
end
end_at(block.body, ps.tokens[i - 1])
block.yend, block.xend = block.body.yend, block.body.xend
table.insert(node.if_blocks, block)
return i, node
end
local function parse_if(ps, i)
local istart = i
local node = new_node(ps.tokens, i, "if")
node.if_blocks = {}
i, node = parse_if_block(ps, i, 1, node)
if not node then
return i
end
local n = 2
while ps.tokens[i].tk == "elseif" do
i, node = parse_if_block(ps, i, n, node)
if not node then
return i
end
n = n + 1
end
if ps.tokens[i].tk == "else" then
i, node = parse_if_block(ps, i, n, node, true)
if not node then
return i
end
end
i = verify_end(ps, i, istart, node)
return i, node
end
local function parse_while(ps, i)
local istart = i
local node = new_node(ps.tokens, i, "while")
i = verify_tk(ps, i, "while")
i, node.exp = parse_expression_and_tk(ps, i, "do")
i, node.body = parse_statements(ps, i)
i = verify_end(ps, i, istart, node)
return i, node
end
local function parse_fornum(ps, i)
local istart = i
local node = new_node(ps.tokens, i, "fornum")
i = i + 1
i, node.var = parse_identifier(ps, i)
i = verify_tk(ps, i, "=")
i, node.from = parse_expression_and_tk(ps, i, ",")
i, node.to = parse_expression(ps, i)
if ps.tokens[i].tk == "," then
i = i + 1
i, node.step = parse_expression_and_tk(ps, i, "do")
else
i = verify_tk(ps, i, "do")
end
i, node.body = parse_statements(ps, i)
i = verify_end(ps, i, istart, node)
return i, node
end
local function parse_forin(ps, i)
local istart = i
local node = new_node(ps.tokens, i, "forin")
i = i + 1
node.vars = new_node(ps.tokens, i, "variable_list")
i, node.vars = parse_list(ps, i, node.vars, { ["in"] = true }, "sep", parse_variable_name)
i = verify_tk(ps, i, "in")
node.exps = new_node(ps.tokens, i, "expression_list")
i = parse_list(ps, i, node.exps, { ["do"] = true }, "sep", parse_expression)
if #node.exps < 1 then
return fail(ps, i, "missing iterator expression in generic for")
elseif #node.exps > 3 then
return fail(ps, i, "too many expressions in generic for")
end
i = verify_tk(ps, i, "do")
i, node.body = parse_statements(ps, i)
i = verify_end(ps, i, istart, node)
return i, node
end
local function parse_for(ps, i)
if ps.tokens[i + 1].kind == "identifier" and ps.tokens[i + 2].tk == "=" then
return parse_fornum(ps, i)
else
return parse_forin(ps, i)
end
end
local function parse_repeat(ps, i)
local node = new_node(ps.tokens, i, "repeat")
i = verify_tk(ps, i, "repeat")
i, node.body = parse_statements(ps, i)
node.body.is_repeat = true
i = verify_tk(ps, i, "until")
i, node.exp = parse_expression(ps, i)
end_at(node, ps.tokens[i - 1])
return i, node
end
local function parse_do(ps, i)
local istart = i
local node = new_node(ps.tokens, i, "do")
i = verify_tk(ps, i, "do")
i, node.body = parse_statements(ps, i)
i = verify_end(ps, i, istart, node)
return i, node
end
local function parse_break(ps, i)
local node = new_node(ps.tokens, i, "break")
i = verify_tk(ps, i, "break")
return i, node
end
local function parse_goto(ps, i)
local node = new_node(ps.tokens, i, "goto")
i = verify_tk(ps, i, "goto")
node.label = ps.tokens[i].tk
i = verify_kind(ps, i, "identifier")
return i, node
end
local function parse_label(ps, i)
local node = new_node(ps.tokens, i, "label")
i = verify_tk(ps, i, "::")
node.label = ps.tokens[i].tk
i = verify_kind(ps, i, "identifier")
i = verify_tk(ps, i, "::")
return i, node
end
local stop_statement_list = {
["end"] = true,
["else"] = true,
["elseif"] = true,
["until"] = true,
}
local stop_return_list = {
[";"] = true,
["$EOF$"] = true,
}
for k, v in pairs(stop_statement_list) do
stop_return_list[k] = v
end
local function parse_return(ps, i)
local node = new_node(ps.tokens, i, "return")
i = verify_tk(ps, i, "return")
node.exps = new_node(ps.tokens, i, "expression_list")
i = parse_list(ps, i, node.exps, stop_return_list, "sep", parse_expression)
if ps.tokens[i].kind == ";" then
i = i + 1
end
return i, node
end
local function store_field_in_record(ps, i, field_name, t, fields, field_order)
if not fields[field_name] then
fields[field_name] = t
table.insert(field_order, field_name)
else
local prev_t = fields[field_name]
if t.typename == "function" and prev_t.typename == "function" then
fields[field_name] = new_type(ps, i, "poly")
fields[field_name].types = { prev_t, t }
elseif t.typename == "function" and prev_t.typename == "poly" then
table.insert(prev_t.types, t)
else
fail(ps, i, "attempt to redeclare field '" .. field_name .. "' (only functions can be overloaded)")
return false
end
end
return true
end
local function parse_nested_type(ps, i, def, typename, parse_body)
i = i + 1
local iv = i
local v
i, v = verify_kind(ps, i, "identifier", "type_identifier")
if not v then
return fail(ps, i, "expected a variable name")
end
local nt = new_node(ps.tokens, i - 2, "newtype")
nt.newtype = new_type(ps, i, "typetype")
local rdef = new_type(ps, i, typename)
local iok = parse_body(ps, i, rdef, nt, v.tk)
if iok then
i = iok
nt.newtype.def = rdef
end
store_field_in_record(ps, iv, v.tk, nt.newtype, def.fields, def.field_order)
return i
end
parse_enum_body = function(ps, i, def, node)
local istart = i - 1
def.enumset = {}
while ps.tokens[i].tk ~= "$EOF$" and ps.tokens[i].tk ~= "end" do
local item
i, item = verify_kind(ps, i, "string", "enum_item")
if item then
table.insert(node, item)
def.enumset[unquote(item.tk)] = true
end
end
i = verify_end(ps, i, istart, node)
return i, node
end
local metamethod_names = {
["__add"] = true,
["__sub"] = true,
["__mul"] = true,
["__div"] = true,
["__mod"] = true,
["__pow"] = true,
["__unm"] = true,
["__idiv"] = true,
["__band"] = true,
["__bor"] = true,
["__bxor"] = true,
["__bnot"] = true,
["__shl"] = true,
["__shr"] = true,
["__concat"] = true,
["__len"] = true,
["__eq"] = true,
["__lt"] = true,
["__le"] = true,
["__index"] = true,
["__newindex"] = true,
["__call"] = true,
["__tostring"] = true,
["__pairs"] = true,
["__gc"] = true,
["__close"] = true,
}
parse_record_body = function(ps, i, def, node, name)
local istart = i - 1
def.fields = {}
def.field_order = {}
if ps.tokens[i].tk == "<" then
i, def.typeargs = parse_anglebracket_list(ps, i, parse_typearg)
end
while not (ps.tokens[i].kind == "$EOF$" or ps.tokens[i].tk == "end") do
if ps.tokens[i].tk == "userdata" and ps.tokens[i + 1].tk ~= ":" then
if def.is_userdata then
fail(ps, i, "duplicated 'userdata' declaration in record")
else
def.is_userdata = true
end
i = i + 1
elseif ps.tokens[i].tk == "{" then
if def.typename == "arrayrecord" then
i = failskip(ps, i, "duplicated declaration of array element type in record", parse_type)
else
i = i + 1
local t
i, t = parse_type(ps, i)
if ps.tokens[i].tk == "}" then
i = verify_tk(ps, i, "}")
else
return fail(ps, i, "expected an array declaration")
end
def.typename = "arrayrecord"
def.elements = t
end
elseif ps.tokens[i].tk == "type" and ps.tokens[i + 1].tk ~= ":" then
i = i + 1
local iv = i
local v
i, v = verify_kind(ps, i, "identifier", "type_identifier")
if not v then
return fail(ps, i, "expected a variable name")
end
i = verify_tk(ps, i, "=")
local nt
i, nt = parse_newtype(ps, i, v.tk)
if not nt or not nt.newtype then
return fail(ps, i, "expected a type definition")
end
store_field_in_record(ps, iv, v.tk, nt.newtype, def.fields, def.field_order)
elseif ps.tokens[i].tk == "record" and ps.tokens[i + 1].tk ~= ":" then
i = parse_nested_type(ps, i, def, "record", parse_record_body)
elseif ps.tokens[i].tk == "enum" and ps.tokens[i + 1].tk ~= ":" then
i = parse_nested_type(ps, i, def, "enum", parse_enum_body)
else
local is_metamethod = false
if ps.tokens[i].tk == "metamethod" and ps.tokens[i + 1].tk ~= ":" then
is_metamethod = true
i = i + 1
end
local v
if ps.tokens[i].tk == "[" then
i, v = parse_literal(ps, i + 1)
if v and not v.conststr then
return fail(ps, i, "expected a string literal")
end
i = verify_tk(ps, i, "]")
else
i, v = verify_kind(ps, i, "identifier", "variable")
end
local iv = i
if not v then
return fail(ps, i, "expected a variable name")
end
if ps.tokens[i].tk == ":" then
i = i + 1
local t
i, t = parse_type(ps, i)
if not t then
return fail(ps, i, "expected a type")
end
local field_name = v.conststr or v.tk
local fields = def.fields
local field_order = def.field_order
if is_metamethod then
if not def.meta_fields then
def.meta_fields = {}
def.meta_field_order = {}
end
fields = def.meta_fields
field_order = def.meta_field_order
if not metamethod_names[field_name] then
fail(ps, i - 1, "not a valid metamethod: " .. field_name)
end
end
if t.is_method and t.args and t.args[1] and t.args[1].is_self then
local selfarg = t.args[1]
if selfarg.tk ~= name or (def.typeargs and not selfarg.typevals) then
t.is_method = false
selfarg.is_self = false
elseif def.typeargs then
for j = 1, #def.typeargs do
if (not selfarg.typevals[j]) or selfarg.typevals[j].tk ~= def.typeargs[j].typearg then
t.is_method = false
selfarg.is_self = false
break
end
end
end
end
store_field_in_record(ps, iv, field_name, t, fields, field_order)
elseif ps.tokens[i].tk == "=" then
local next_word = ps.tokens[i + 1].tk
if next_word == "record" or next_word == "enum" then
return fail(ps, i, "syntax error: this syntax is no longer valid; use '" .. next_word .. " " .. v.tk .. "'")
elseif next_word == "functiontype" then
return fail(ps, i, "syntax error: this syntax is no longer valid; use 'type " .. v.tk .. " = function('...")
else
return fail(ps, i, "syntax error: this syntax is no longer valid; use 'type " .. v.tk .. " = '...")
end
else
fail(ps, i, "syntax error: expected ':' for an attribute or '=' for a nested type")
end
end
end
i = verify_end(ps, i, istart, node)
return i, node
end
parse_newtype = function(ps, i, name)
local node = new_node(ps.tokens, i, "newtype")
node.newtype = new_type(ps, i, "typetype")
if ps.tokens[i].tk == "record" then
local def = new_type(ps, i, "record")
i = i + 1
i = parse_record_body(ps, i, def, node, name)
node.newtype.def = def
return i, node
elseif ps.tokens[i].tk == "enum" then
local def = new_type(ps, i, "enum")
i = i + 1
i = parse_enum_body(ps, i, def, node)
node.newtype.def = def
return i, node
else
i, node.newtype.def = parse_type(ps, i)
if not node.newtype.def then
return i
end
return i, node
end
return fail(ps, i, "expected a type")
end
local function parse_assignment_expression_list(ps, i, asgn)
asgn.exps = new_node(ps.tokens, i, "expression_list")
repeat
i = i + 1
local val
i, val = parse_expression(ps, i)
if not val then
if #asgn.exps == 0 then
asgn.exps = nil
end
return i
end
table.insert(asgn.exps, val)
until ps.tokens[i].tk ~= ","
return i, asgn
end
local parse_call_or_assignment
do
local function is_lvalue(node)
node.is_lvalue = node.kind == "variable" or
(node.kind == "op" and
(node.op.op == "@index" or node.op.op == "."))
return node.is_lvalue
end
local function parse_variable(ps, i)
local node
i, node = parse_expression(ps, i)
if not (node and is_lvalue(node)) then
return fail(ps, i, "expected a variable")
end
return i, node
end
parse_call_or_assignment = function(ps, i)
local exp
local istart = i
i, exp = parse_expression(ps, i)
if not exp then
return i
end
if (exp.op and exp.op.op == "@funcall") or exp.failstore then
return i, exp
end
if not is_lvalue(exp) then
return fail(ps, i, "syntax error")
end
local asgn = new_node(ps.tokens, istart, "assignment")
asgn.vars = new_node(ps.tokens, istart, "variable_list")
asgn.vars[1] = exp
if ps.tokens[i].tk == "," then
i = i + 1
i = parse_trying_list(ps, i, asgn.vars, parse_variable)
if #asgn.vars < 2 then
return fail(ps, i, "syntax error")
end
end
if ps.tokens[i].tk ~= "=" then
verify_tk(ps, i, "=")
return i
end
i, asgn = parse_assignment_expression_list(ps, i, asgn)
return i, asgn
end
end
local function parse_variable_declarations(ps, i, node_name)
local asgn = new_node(ps.tokens, i, node_name)
asgn.vars = new_node(ps.tokens, i, "variable_list")
i = parse_trying_list(ps, i, asgn.vars, parse_variable_name)
if #asgn.vars == 0 then
return fail(ps, i, "expected a local variable definition")
end
i, asgn.decltype = parse_type_list(ps, i, "decltype")
if ps.tokens[i].tk == "=" then
local next_word = ps.tokens[i + 1].tk
if next_word == "record" then
local scope = node_name == "local_declaration" and "local" or "global"
return failskip(ps, i + 1, "syntax error: this syntax is no longer valid; use '" .. scope .. " record " .. asgn.vars[1].tk .. "'", skip_record)
elseif next_word == "enum" then
local scope = node_name == "local_declaration" and "local" or "global"
return failskip(ps, i + 1, "syntax error: this syntax is no longer valid; use '" .. scope .. " enum " .. asgn.vars[1].tk .. "'", skip_enum)
elseif next_word == "functiontype" then
local scope = node_name == "local_declaration" and "local" or "global"
return failskip(ps, i + 1, "syntax error: this syntax is no longer valid; use '" .. scope .. " type " .. asgn.vars[1].tk .. " = function('...", parse_function_type)
end
i, asgn = parse_assignment_expression_list(ps, i, asgn)
end
return i, asgn
end
local function parse_type_declaration(ps, i, node_name)
i = i + 2
local asgn = new_node(ps.tokens, i, node_name)
i, asgn.var = parse_variable_name(ps, i)
if not asgn.var then
return fail(ps, i, "expected a type name")
end
if node_name == "global_type" and ps.tokens[i].tk ~= "=" then
return i, asgn
end
i = verify_tk(ps, i, "=")
if ps.tokens[i].kind == "identifier" and ps.tokens[i].tk == "require" then
local istart = i
i, asgn.value = parse_call_or_assignment(ps, i)
if asgn.value and not node_is_require_call(asgn.value) then
fail(ps, istart, "require() for type declarations must have a literal argument")
end
return i, asgn
end
i, asgn.value = parse_newtype(ps, i, asgn.var.tk)
if not asgn.value then
return i
end
if not asgn.value.newtype.def.names then
asgn.value.newtype.def.names = { asgn.var.tk }
end
return i, asgn
end
local function parse_type_constructor(ps, i, node_name, type_name, parse_body)
local asgn = new_node(ps.tokens, i, node_name)
local nt = new_node(ps.tokens, i, "newtype")
asgn.value = nt
nt.newtype = new_type(ps, i, "typetype")
local def = new_type(ps, i, type_name)
nt.newtype.def = def
i = i + 2
i, asgn.var = verify_kind(ps, i, "identifier")
if not asgn.var then
return fail(ps, i, "expected a type name")
end
nt.newtype.def.names = { asgn.var.tk }
i = parse_body(ps, i, def, nt, asgn.var.tk)
return i, asgn
end
local function skip_type_declaration(ps, i)
return (parse_type_declaration(ps, i - 1, "local_type"))
end
local function parse_local(ps, i)
local ntk = ps.tokens[i + 1].tk
if ntk == "function" then
return parse_local_function(ps, i)
elseif ntk == "type" and ps.tokens[i + 2].kind == "identifier" then
return parse_type_declaration(ps, i, "local_type")
elseif ntk == "record" and ps.tokens[i + 2].kind == "identifier" then
return parse_type_constructor(ps, i, "local_type", "record", parse_record_body)
elseif ntk == "enum" and ps.tokens[i + 2].kind == "identifier" then
return parse_type_constructor(ps, i, "local_type", "enum", parse_enum_body)
end
return parse_variable_declarations(ps, i + 1, "local_declaration")
end
local function parse_global(ps, i)
local ntk = ps.tokens[i + 1].tk
if ntk == "function" then
return parse_function(ps, i + 1, "global")
elseif ntk == "type" and ps.tokens[i + 2].kind == "identifier" then
return parse_type_declaration(ps, i, "global_type")
elseif ntk == "record" and ps.tokens[i + 2].kind == "identifier" then
return parse_type_constructor(ps, i, "global_type", "record", parse_record_body)
elseif ntk == "enum" and ps.tokens[i + 2].kind == "identifier" then
return parse_type_constructor(ps, i, "global_type", "enum", parse_enum_body)
elseif ps.tokens[i + 1].kind == "identifier" then
return parse_variable_declarations(ps, i + 1, "global_declaration")
end
return parse_call_or_assignment(ps, i)
end
local function parse_record_function(ps, i)
return parse_function(ps, i, "record")
end
local parse_statement_fns = {
["::"] = parse_label,
["do"] = parse_do,
["if"] = parse_if,
["for"] = parse_for,
["goto"] = parse_goto,
["local"] = parse_local,
["while"] = parse_while,
["break"] = parse_break,
["global"] = parse_global,
["repeat"] = parse_repeat,
["return"] = parse_return,
["function"] = parse_record_function,
}
local needs_local_or_global = {
["type"] = function(ps, i)
return failskip(ps, i, "types need to be declared with 'local type' or 'global type'", skip_type_declaration)
end,
["record"] = function(ps, i)
return failskip(ps, i, "records need to be declared with 'local record' or 'global record'", skip_record)
end,
["enum"] = function(ps, i)
return failskip(ps, i, "enums need to be declared with 'local enum' or 'global enum'", skip_enum)
end,
}
parse_statements = function(ps, i, toplevel)
local node = new_node(ps.tokens, i, "statements")
local item
while true do
while ps.tokens[i].kind == ";" do
i = i + 1
if item then
item.semicolon = true
end
end
if ps.tokens[i].kind == "$EOF$" then
break
end
local tk = ps.tokens[i].tk
if (not toplevel) and stop_statement_list[tk] then
break
end
local fn = parse_statement_fns[tk]
if not fn then
local skip_fn = needs_local_or_global[tk]
if skip_fn and ps.tokens[i + 1].kind == "identifier" then
fn = skip_fn
else
fn = parse_call_or_assignment
end
end
i, item = fn(ps, i)
if item then
table.insert(node, item)
elseif i > 1 then
local lasty = ps.tokens[i - 1].y
while ps.tokens[i].kind ~= "$EOF$" and ps.tokens[i].y == lasty do
i = i + 1
end
end
end
end_at(node, ps.tokens[i])
return i, node
end
local function clear_redundant_errors(errors)
local redundant = {}
local lastx, lasty = 0, 0
for i, err in ipairs(errors) do
err.i = i
end
table.sort(errors, function(a, b)
local af = a.filename or ""
local bf = b.filename or ""
return af < bf or
(af == bf and (a.y < b.y or
(a.y == b.y and (a.x < b.x or
(a.x == b.x and (a.i < b.i))))))
end)
for i, err in ipairs(errors) do
err.i = nil
if err.x == lastx and err.y == lasty then
table.insert(redundant, i)
end
lastx, lasty = err.x, err.y
end
for i = #redundant, 1, -1 do
table.remove(errors, redundant[i])
end
end
function tl.parse_program(tokens, errs, filename)
errs = errs or {}
local ps = {
tokens = tokens,
errs = errs,
filename = filename or "",
required_modules = {},
}
local _, node = parse_statements(ps, 1, true)
clear_redundant_errors(errs)
return node, ps.required_modules
end
function tl.parse(input, filename)
local tokens, errs = tl.lex(input, filename)
local node, required_modules = tl.parse_program(tokens, errs, filename)
return node, errs, required_modules
end
local function fields_of(t, meta)
local i = 1
local field_order, fields
if meta then
field_order, fields = t.meta_field_order, t.meta_fields
else
field_order, fields = t.field_order, t.fields
end
if not fields then
return function()
end
end
return function()
local name = field_order[i]
if not name then
return nil
end
i = i + 1
return name, fields[name]
end
end
local show_type
local tl_debug_indent = 0
local tl_debug_entry = nil
local tl_debug_y = 1
local function tl_debug_loc(y, x)
return (tostring(y) or "?") .. ":" .. (tostring(x) or "?")
end
local function tl_debug_indent_push(mark, y, x, fmt, ...)
if tl_debug_entry then
if tl_debug_entry.y and (tl_debug_entry.y > tl_debug_y) then
io.stderr:write("\n")
tl_debug_y = tl_debug_entry.y
end
io.stderr:write((" "):rep(tl_debug_indent) .. tl_debug_entry.mark .. " " ..
tl_debug_loc(tl_debug_entry.y, tl_debug_entry.x) .. " " ..
tl_debug_entry.msg .. "\n")
io.stderr:flush()
tl_debug_entry = nil
tl_debug_indent = tl_debug_indent + 1
end
tl_debug_entry = {
mark = mark,
y = y,
x = x,
msg = fmt:format(...),
}
end
local function tl_debug_indent_pop(mark, single, y, x, fmt, ...)
if tl_debug_entry then
local msg = tl_debug_entry.msg
if fmt then
msg = fmt:format(...)
end
if y and (y > tl_debug_y) then
io.stderr:write("\n")
tl_debug_y = y
end
io.stderr:write((" "):rep(tl_debug_indent) .. single .. " " .. tl_debug_loc(y, x) .. " " .. msg .. "\n")
io.stderr:flush()
tl_debug_entry = nil
else
tl_debug_indent = tl_debug_indent - 1
if fmt then
io.stderr:write((" "):rep(tl_debug_indent) .. mark .. " " .. fmt:format(...) .. "\n")
io.stderr:flush()
end
end
end
local function recurse_type(ast, visit)
local kind = ast.typename
if TL_DEBUG then
tl_debug_indent_push("---", ast.y, ast.x, "[%s] = %s", kind, show_type(ast))
end
local cbs = visit.cbs
local cbkind = cbs and cbs[kind]
if cbkind then
local cbkind_before = cbkind.before
if cbkind_before then
cbkind_before(ast)
end
end
local xs = {}
if ast.typeargs then
for _, child in ipairs(ast.typeargs) do
table.insert(xs, recurse_type(child, visit))
end
end
for i, child in ipairs(ast) do
xs[i] = recurse_type(child, visit)
end
if ast.types then
for _, child in ipairs(ast.types) do
table.insert(xs, recurse_type(child, visit))
end
end
if ast.def then
table.insert(xs, recurse_type(ast.def, visit))
end
if ast.keys then
table.insert(xs, recurse_type(ast.keys, visit))
end
if ast.values then
table.insert(xs, recurse_type(ast.values, visit))
end
if ast.elements then
table.insert(xs, recurse_type(ast.elements, visit))
end
if ast.fields then
for _, child in fields_of(ast) do
table.insert(xs, recurse_type(child, visit))
end
end
if ast.meta_fields then
for _, child in fields_of(ast, "meta") do
table.insert(xs, recurse_type(child, visit))
end
end
if ast.args then
for i, child in ipairs(ast.args) do
if i > 1 or not ast.is_method or child.is_self then
table.insert(xs, recurse_type(child, visit))
end
end
end
if ast.rets then
for _, child in ipairs(ast.rets) do
table.insert(xs, recurse_type(child, visit))
end
end
if ast.typevals then
for _, child in ipairs(ast.typevals) do
table.insert(xs, recurse_type(child, visit))
end
end
if ast.ktype then
table.insert(xs, recurse_type(ast.ktype, visit))
end
if ast.vtype then
table.insert(xs, recurse_type(ast.vtype, visit))
end
local ret
local cbkind_after = cbkind and cbkind.after
if cbkind_after then
ret = cbkind_after(ast, xs)
end
local visit_after = visit.after
if visit_after then
ret = visit_after(ast, xs, ret)
end
if TL_DEBUG then
tl_debug_indent_pop("---", "---", ast.y, ast.x)
end
return ret
end
local function recurse_typeargs(ast, visit_type)
if ast.typeargs then
for _, typearg in ipairs(ast.typeargs) do
recurse_type(typearg, visit_type)
end
end
end
local function extra_callback(name,
ast,
xs,
visit_node)
local cbs = visit_node.cbs
if not cbs then return end
local nbs = cbs[ast.kind]
if not nbs then return end
local bs = nbs[name]
if not bs then return end
bs(ast, xs)
end
local no_recurse_node = {
["..."] = true,
["nil"] = true,
["cast"] = true,
["goto"] = true,
["break"] = true,
["label"] = true,
["number"] = true,
["string"] = true,
["boolean"] = true,
["integer"] = true,
["variable"] = true,
["error_node"] = true,
["identifier"] = true,
["type_identifier"] = true,
}
local function recurse_node(root,
visit_node,
visit_type)
if not root then
return
end
local recurse
local function walk_children(ast, xs)
for i, child in ipairs(ast) do
xs[i] = recurse(child)
end
end
local function walk_vars_exps(ast, xs)
xs[1] = recurse(ast.vars)
if ast.decltype then
xs[2] = recurse_type(ast.decltype, visit_type)
end
extra_callback("before_expressions", ast, xs, visit_node)
if ast.exps then
xs[3] = recurse(ast.exps)
end
end
local function walk_var_value(ast, xs)
xs[1] = recurse(ast.var)
xs[2] = recurse(ast.value)
end
local function walk_named_function(ast, xs)
recurse_typeargs(ast, visit_type)
xs[1] = recurse(ast.name)
xs[2] = recurse(ast.args)
xs[3] = recurse_type(ast.rets, visit_type)
extra_callback("before_statements", ast, xs, visit_node)
xs[4] = recurse(ast.body)
end
local walkers = {
["op"] = function(ast, xs)
xs[1] = recurse(ast.e1)
local p1 = ast.e1.op and ast.e1.op.prec or nil
if ast.op.op == ":" and ast.e1.kind == "string" then
p1 = -999
end
xs[2] = p1
if ast.op.arity == 2 then
extra_callback("before_e2", ast, xs, visit_node)
if ast.op.op == "is" or ast.op.op == "as" then
xs[3] = recurse_type(ast.e2.casttype, visit_type)
else
xs[3] = recurse(ast.e2)
end
xs[4] = (ast.e2.op and ast.e2.op.prec)
end
end,
["statements"] = walk_children,
["argument_list"] = walk_children,
["table_literal"] = walk_children,
["variable_list"] = walk_children,
["expression_list"] = walk_children,
["table_item"] = function(ast, xs)
xs[1] = recurse(ast.key)
xs[2] = recurse(ast.value)
if ast.decltype then
xs[3] = recurse_type(ast.decltype, visit_type)
end
end,
["assignment"] = walk_vars_exps,
["local_declaration"] = walk_vars_exps,
["global_declaration"] = walk_vars_exps,
["local_type"] = walk_var_value,
["global_type"] = function(ast, xs)
xs[1] = recurse(ast.var)
if ast.value then
xs[2] = recurse(ast.value)
end
end,
["if"] = function(ast, xs)
for _, e in ipairs(ast.if_blocks) do
table.insert(xs, recurse(e))
end
end,
["if_block"] = function(ast, xs)
if ast.exp then
xs[1] = recurse(ast.exp)
end
extra_callback("before_statements", ast, xs, visit_node)
xs[2] = recurse(ast.body)
end,
["while"] = function(ast, xs)
xs[1] = recurse(ast.exp)
extra_callback("before_statements", ast, xs, visit_node)
xs[2] = recurse(ast.body)
end,
["repeat"] = function(ast, xs)
xs[1] = recurse(ast.body)
xs[2] = recurse(ast.exp)
end,
["function"] = function(ast, xs)
recurse_typeargs(ast, visit_type)
xs[1] = recurse(ast.args)
xs[2] = recurse_type(ast.rets, visit_type)
extra_callback("before_statements", ast, xs, visit_node)
xs[3] = recurse(ast.body)
end,
["local_function"] = walk_named_function,
["global_function"] = walk_named_function,
["record_function"] = function(ast, xs)
recurse_typeargs(ast, visit_type)
xs[1] = recurse(ast.fn_owner)
xs[2] = recurse(ast.name)
xs[3] = recurse(ast.args)
xs[4] = recurse_type(ast.rets, visit_type)
extra_callback("before_statements", ast, xs, visit_node)
xs[5] = recurse(ast.body)
end,
["forin"] = function(ast, xs)
xs[1] = recurse(ast.vars)
xs[2] = recurse(ast.exps)
extra_callback("before_statements", ast, xs, visit_node)
xs[3] = recurse(ast.body)
end,
["fornum"] = function(ast, xs)
xs[1] = recurse(ast.var)
xs[2] = recurse(ast.from)
xs[3] = recurse(ast.to)
xs[4] = ast.step and recurse(ast.step)
extra_callback("before_statements", ast, xs, visit_node)
xs[5] = recurse(ast.body)
end,
["return"] = function(ast, xs)
xs[1] = recurse(ast.exps)
end,
["do"] = function(ast, xs)
xs[1] = recurse(ast.body)
end,
["paren"] = function(ast, xs)
xs[1] = recurse(ast.e1)
end,
["newtype"] = function(ast, xs)
xs[1] = recurse_type(ast.newtype, visit_type)
end,
["argument"] = function(ast, xs)
if ast.decltype then
xs[1] = recurse_type(ast.decltype, visit_type)
end
end,
}
if not visit_node.allow_missing_cbs and not visit_node.cbs then
error("missing cbs in visit_node")
end
local visit_after = visit_node.after
recurse = function(ast)
local xs = {}
local kind = assert(ast.kind)
local cbs = visit_node.cbs
local cbkind = cbs and cbs[kind]
if cbkind then
if cbkind.before then
cbkind.before(ast)
end
end
if TL_DEBUG then
if ast.y > TL_DEBUG_MAXLINE then
error("Halting execution at input line " .. ast.y)
end
local k = kind == "op" and "op " .. ast.op.op or kind
tl_debug_indent_push("{{{", ast.y, ast.x, "[%s]", k)
end
local fn = walkers[kind]
if fn then
fn(ast, xs)
else
assert(no_recurse_node[kind])
end
local ret
local cbkind_after = cbkind and cbkind.after
if cbkind_after then
ret = cbkind_after(ast, xs)
end
if visit_after then
ret = visit_after(ast, xs, ret)
end
if TL_DEBUG then
local k = kind == "op" and "op " .. ast.op.op or kind
tl_debug_indent_pop("}}}", "***", ast.y, ast.x, "[%s] = %s", k, ast.type and show_type(ast.type))
end
return ret
end
return recurse(root)
end
local tight_op = {
[1] = {
["-"] = true,
["~"] = true,
["#"] = true,
},
[2] = {
["."] = true,
[":"] = true,
},
}
local spaced_op = {
[1] = {
["not"] = true,
},
[2] = {
["or"] = true,
["and"] = true,
["<"] = true,
[">"] = true,
["<="] = true,
[">="] = true,
["~="] = true,
["=="] = true,
["|"] = true,
["~"] = true,
["&"] = true,
["<<"] = true,
[">>"] = true,
[".."] = true,
["+"] = true,
["-"] = true,
["*"] = true,
["/"] = true,
["//"] = true,
["%"] = true,
["^"] = true,
},
}
local default_pretty_print_ast_opts = {
preserve_indent = true,
preserve_newlines = true,
}
local fast_pretty_print_ast_opts = {
preserve_indent = false,
preserve_newlines = true,
}
local primitive = {
["function"] = "function",
["enum"] = "string",
["boolean"] = "boolean",
["string"] = "string",
["nil"] = "nil",
["number"] = "number",
["integer"] = "number",
["thread"] = "thread",
}
function tl.pretty_print_ast(ast, gen_target, mode)
local err
local indent = 0
local opts
if type(mode) == "table" then
opts = mode
elseif mode == true then
opts = fast_pretty_print_ast_opts
else
opts = default_pretty_print_ast_opts
end
local save_indent = {}
local function increment_indent(node)
local child = node.body or node[1]
if not child then
return
end
if child.y ~= node.y then
if indent == 0 and #save_indent > 0 then
indent = save_indent[#save_indent] + 1
else
indent = indent + 1
end
else
table.insert(save_indent, indent)
indent = 0
end
end
local function decrement_indent(node, child)
if child.y ~= node.y then
indent = indent - 1
else
indent = table.remove(save_indent)
end
end
if not opts.preserve_indent then
increment_indent = nil
decrement_indent = function() end
end
local function add_string(out, s)
table.insert(out, s)
if string.find(s, "\n", 1, true) then
for _nl in s:gmatch("\n") do
out.h = out.h + 1
end
end
end
local function add_child(out, child, space, current_indent)
if #child == 0 then
return
end
if child.y ~= -1 and child.y < out.y then
out.y = child.y
end
if child.y > out.y + out.h and opts.preserve_newlines then
local delta = child.y - (out.y + out.h)
out.h = out.h + delta
table.insert(out, ("\n"):rep(delta))
else
if space then
if space ~= "" then
table.insert(out, space)
end
current_indent = nil
end
end
if current_indent and opts.preserve_indent then
table.insert(out, (" "):rep(current_indent))
end
table.insert(out, child)
out.h = out.h + child.h
end
local function concat_output(out)
for i, s in ipairs(out) do
if type(s) == "table" then
out[i] = concat_output(s)
end
end
return table.concat(out)
end
local function print_record_def(typ)
local out = { "{" }
for _, name in ipairs(typ.field_order) do
if is_typetype(typ.fields[name]) and is_record_type(typ.fields[name].def) then
table.insert(out, name)
table.insert(out, " = ")
table.insert(out, print_record_def(typ.fields[name].def))
table.insert(out, ", ")
end
end
table.insert(out, "}")
return table.concat(out)
end
local visit_node = {}
local lua_54_attribute = {
["const"] = " <const>",
["close"] = " <close>",
["total"] = " <const>",
}
visit_node.cbs = {
["statements"] = {
after = function(node, children)
local out = { y = node.y, h = 0 }
local space
for i, child in ipairs(children) do
add_child(out, child, space, indent)
if node[i].semicolon then
table.insert(out, ";")
space = " "
else
space = "; "
end
end
return out
end,
},
["local_declaration"] = {
after = function(node, children)
local out = { y = node.y, h = 0 }
table.insert(out, "local ")
for i, var in ipairs(node.vars) do
if i > 1 then
add_string(out, ", ")
end
add_string(out, var.tk)
if var.attribute then
if gen_target ~= "5.4" and var.attribute == "close" then
err = "attempt to emit a <close> attribute for a non 5.4 target"
end
if gen_target == "5.4" then
add_string(out, lua_54_attribute[var.attribute])
end
end
end
if children[3] then
table.insert(out, " =")
add_child(out, children[3], " ")
end
return out
end,
},
["local_type"] = {
after = function(node, children)
local out = { y = node.y, h = 0 }
if not node.var.elide_type then
table.insert(out, "local")
add_child(out, children[1], " ")
table.insert(out, " =")
add_child(out, children[2], " ")
end
return out
end,
},
["global_type"] = {
after = function(node, children)
local out = { y = node.y, h = 0 }
if children[2] then
add_child(out, children[1])
table.insert(out, " =")
add_child(out, children[2], " ")
end
return out
end,
},
["global_declaration"] = {
after = function(node, children)
local out = { y = node.y, h = 0 }
if children[3] then
add_child(out, children[1])
table.insert(out, " =")
add_child(out, children[3], " ")
end
return out
end,
},
["assignment"] = {
after = function(node, children)
local out = { y = node.y, h = 0 }
add_child(out, children[1])
table.insert(out, " =")
add_child(out, children[3], " ")
return out
end,
},
["if"] = {
after = function(node, children)
local out = { y = node.y, h = 0 }
for i, child in ipairs(children) do
add_child(out, child, i > 1 and " ", child.y ~= node.y and indent)
end
add_child(out, { y = node.yend, h = 0, [1] = "end" }, " ", indent)
return out
end,
},
["if_block"] = {
before = increment_indent,
after = function(node, children)
local out = { y = node.y, h = 0 }
if node.if_block_n == 1 then
table.insert(out, "if")
elseif not node.exp then
table.insert(out, "else")
else
table.insert(out, "elseif")
end
if node.exp then
add_child(out, children[1], " ")
table.insert(out, " then")
end
add_child(out, children[2], " ")
decrement_indent(node, node.body)
return out
end,
},
["while"] = {
before = increment_indent,
after = function(node, children)
local out = { y = node.y, h = 0 }
table.insert(out, "while")
add_child(out, children[1], " ")
table.insert(out, " do")
add_child(out, children[2], " ")
decrement_indent(node, node.body)
add_child(out, { y = node.yend, h = 0, [1] = "end" }, " ", indent)
return out
end,
},
["repeat"] = {
before = increment_indent,
after = function(node, children)
local out = { y = node.y, h = 0 }
table.insert(out, "repeat")
add_child(out, children[1], " ")
decrement_indent(node, node.body)
add_child(out, { y = node.yend, h = 0, [1] = "until " }, " ", indent)
add_child(out, children[2])
return out
end,
},
["do"] = {
before = increment_indent,
after = function(node, children)
local out = { y = node.y, h = 0 }
table.insert(out, "do")
add_child(out, children[1], " ")
decrement_indent(node, node.body)
add_child(out, { y = node.yend, h = 0, [1] = "end" }, " ", indent)
return out
end,
},
["forin"] = {
before = increment_indent,
after = function(node, children)
local out = { y = node.y, h = 0 }
table.insert(out, "for")
add_child(out, children[1], " ")
table.insert(out, " in")
add_child(out, children[2], " ")
table.insert(out, " do")
add_child(out, children[3], " ")
decrement_indent(node, node.body)
add_child(out, { y = node.yend, h = 0, [1] = "end" }, " ", indent)
return out
end,
},
["fornum"] = {
before = increment_indent,
after = function(node, children)
local out = { y = node.y, h = 0 }
table.insert(out, "for")
add_child(out, children[1], " ")
table.insert(out, " =")
add_child(out, children[2], " ")
table.insert(out, ",")
add_child(out, children[3], " ")
if children[4] then
table.insert(out, ",")
add_child(out, children[4], " ")
end
table.insert(out, " do")
add_child(out, children[5], " ")
decrement_indent(node, node.body)
add_child(out, { y = node.yend, h = 0, [1] = "end" }, " ", indent)
return out
end,
},
["return"] = {
after = function(node, children)
local out = { y = node.y, h = 0 }
table.insert(out, "return")
if #children[1] > 0 then
add_child(out, children[1], " ")
end
return out
end,
},
["break"] = {
after = function(node, _children)
local out = { y = node.y, h = 0 }
table.insert(out, "break")
return out
end,
},
["variable_list"] = {
after = function(node, children)
local out = { y = node.y, h = 0 }
local space
for i, child in ipairs(children) do
if i > 1 then
table.insert(out, ",")
space = " "
end
add_child(out, child, space, child.y ~= node.y and indent)
end
return out
end,
},
["table_literal"] = {
before = increment_indent,
after = function(node, children)
local out = { y = node.y, h = 0 }
if #children == 0 then
table.insert(out, "{}")
return out
end
table.insert(out, "{")
local n = #children
for i, child in ipairs(children) do
add_child(out, child, " ", child.y ~= node.y and indent)
if i < n or node.yend ~= node.y then
table.insert(out, ",")
end
end
decrement_indent(node, node[1])
add_child(out, { y = node.yend, h = 0, [1] = "}" }, " ", indent)
return out
end,
},
["table_item"] = {
after = function(node, children)
local out = { y = node.y, h = 0 }
if node.key_parsed ~= "implicit" then
if node.key_parsed == "short" then
children[1][1] = children[1][1]:sub(2, -2)
add_child(out, children[1])
table.insert(out, " = ")
else
table.insert(out, "[")
if node.key_parsed == "long" and node.key.is_longstring then
table.insert(children[1], 1, " ")
table.insert(children[1], " ")
end
add_child(out, children[1])
table.insert(out, "] = ")
end
end
add_child(out, children[2])
return out
end,
},
["local_function"] = {
before = increment_indent,
after = function(node, children)
local out = { y = node.y, h = 0 }
table.insert(out, "local function")
add_child(out, children[1], " ")
table.insert(out, "(")
add_child(out, children[2])
table.insert(out, ")")
add_child(out, children[4], " ")
decrement_indent(node, node.body)
add_child(out, { y = node.yend, h = 0, [1] = "end" }, " ", indent)
return out
end,
},
["global_function"] = {
before = increment_indent,
after = function(node, children)
local out = { y = node.y, h = 0 }
table.insert(out, "function")
add_child(out, children[1], " ")
table.insert(out, "(")
add_child(out, children[2])
table.insert(out, ")")
add_child(out, children[4], " ")
decrement_indent(node, node.body)
add_child(out, { y = node.yend, h = 0, [1] = "end" }, " ", indent)
return out
end,
},
["record_function"] = {
before = increment_indent,
after = function(node, children)
local out = { y = node.y, h = 0 }
table.insert(out, "function")
add_child(out, children[1], " ")
table.insert(out, node.is_method and ":" or ".")
add_child(out, children[2])
table.insert(out, "(")
if node.is_method then
table.remove(children[3], 1)
if children[3][1] == "," then
table.remove(children[3], 1)
table.remove(children[3], 1)
end
end
add_child(out, children[3])
table.insert(out, ")")
add_child(out, children[5], " ")
decrement_indent(node, node.body)
add_child(out, { y = node.yend, h = 0, [1] = "end" }, " ", indent)
return out
end,
},
["function"] = {
before = increment_indent,
after = function(node, children)
local out = { y = node.y, h = 0 }
table.insert(out, "function(")
add_child(out, children[1])
table.insert(out, ")")
add_child(out, children[3], " ")
decrement_indent(node, node.body)
add_child(out, { y = node.yend, h = 0, [1] = "end" }, " ", indent)
return out
end,
},
["cast"] = {},
["paren"] = {
after = function(node, children)
local out = { y = node.y, h = 0 }
table.insert(out, "(")
add_child(out, children[1], "", indent)
table.insert(out, ")")
return out
end,
},
["op"] = {
after = function(node, children)
local out = { y = node.y, h = 0 }
if node.op.op == "@funcall" then
add_child(out, children[1], "", indent)
table.insert(out, "(")
add_child(out, children[3], "", indent)
table.insert(out, ")")
elseif node.op.op == "@index" then
add_child(out, children[1], "", indent)
table.insert(out, "[")
if node.e2.is_longstring then
table.insert(children[3], 1, " ")
table.insert(children[3], " ")
end
add_child(out, children[3], "", indent)
table.insert(out, "]")
elseif node.op.op == "as" then
add_child(out, children[1], "", indent)
elseif node.op.op == "is" then
if node.e2.casttype.typename == "integer" then
table.insert(out, "math.type(")
add_child(out, children[1], "", indent)
table.insert(out, ") == \"integer\"")
elseif node.e2.casttype.typename == "nil" then
add_child(out, children[1], "", indent)
table.insert(out, " == nil")
else
table.insert(out, "type(")
add_child(out, children[1], "", indent)
table.insert(out, ") == \"")
add_child(out, children[3], "", indent)
table.insert(out, "\"")
end
elseif spaced_op[node.op.arity][node.op.op] or tight_op[node.op.arity][node.op.op] then
local space = spaced_op[node.op.arity][node.op.op] and " " or ""
if children[2] and node.op.prec > tonumber(children[2]) then
table.insert(children[1], 1, "(")
table.insert(children[1], ")")
end
if node.op.arity == 1 then
table.insert(out, node.op.op)
add_child(out, children[1], space, indent)
elseif node.op.arity == 2 then
add_child(out, children[1], "", indent)
if space == " " then
table.insert(out, " ")
end
table.insert(out, node.op.op)
if children[4] and node.op.prec > tonumber(children[4]) then
table.insert(children[3], 1, "(")
table.insert(children[3], ")")
end
add_child(out, children[3], space, indent)
end
else
error("unknown node op " .. node.op.op)
end
return out
end,
},
["variable"] = {
after = function(node, _children)
local out = { y = node.y, h = 0 }
add_string(out, node.tk)
return out
end,
},
["newtype"] = {
after = function(node, _children)
local out = { y = node.y, h = 0 }
if node.is_alias then
table.insert(out, table.concat(node.newtype.def.names, "."))
elseif is_record_type(node.newtype.def) then
table.insert(out, print_record_def(node.newtype.def))
else
table.insert(out, "{}")
end
return out
end,
},
["goto"] = {
after = function(node, _children)
local out = { y = node.y, h = 0 }
table.insert(out, "goto ")
table.insert(out, node.label)
return out
end,
},
["label"] = {
after = function(node, _children)
local out = { y = node.y, h = 0 }
table.insert(out, "::")
table.insert(out, node.label)
table.insert(out, "::")
return out
end,
},
}
local visit_type = {}
visit_type.cbs = {
["string"] = {
after = function(typ, _children)
local out = { y = typ.y or -1, h = 0 }
local r = typ.resolved or typ
local lua_type = primitive[r.typename] or
(r.is_userdata and "userdata") or
"table"
table.insert(out, lua_type)
return out
end,
},
}
visit_type.cbs["typetype"] = visit_type.cbs["string"]
visit_type.cbs["typevar"] = visit_type.cbs["string"]
visit_type.cbs["typearg"] = visit_type.cbs["string"]
visit_type.cbs["function"] = visit_type.cbs["string"]
visit_type.cbs["thread"] = visit_type.cbs["string"]
visit_type.cbs["array"] = visit_type.cbs["string"]
visit_type.cbs["map"] = visit_type.cbs["string"]
visit_type.cbs["tupletable"] = visit_type.cbs["string"]
visit_type.cbs["arrayrecord"] = visit_type.cbs["string"]
visit_type.cbs["record"] = visit_type.cbs["string"]
visit_type.cbs["enum"] = visit_type.cbs["string"]
visit_type.cbs["boolean"] = visit_type.cbs["string"]
visit_type.cbs["nil"] = visit_type.cbs["string"]
visit_type.cbs["number"] = visit_type.cbs["string"]
visit_type.cbs["integer"] = visit_type.cbs["string"]
visit_type.cbs["union"] = visit_type.cbs["string"]
visit_type.cbs["nominal"] = visit_type.cbs["string"]
visit_type.cbs["bad_nominal"] = visit_type.cbs["string"]
visit_type.cbs["emptytable"] = visit_type.cbs["string"]
visit_type.cbs["table_item"] = visit_type.cbs["string"]
visit_type.cbs["unresolved_emptytable_value"] = visit_type.cbs["string"]
visit_type.cbs["tuple"] = visit_type.cbs["string"]
visit_type.cbs["poly"] = visit_type.cbs["string"]
visit_type.cbs["any"] = visit_type.cbs["string"]
visit_type.cbs["unknown"] = visit_type.cbs["string"]
visit_type.cbs["invalid"] = visit_type.cbs["string"]
visit_type.cbs["unresolved"] = visit_type.cbs["string"]
visit_type.cbs["none"] = visit_type.cbs["string"]
visit_node.cbs["expression_list"] = visit_node.cbs["variable_list"]
visit_node.cbs["argument_list"] = visit_node.cbs["variable_list"]
visit_node.cbs["identifier"] = visit_node.cbs["variable"]
visit_node.cbs["number"] = visit_node.cbs["variable"]
visit_node.cbs["integer"] = visit_node.cbs["variable"]
visit_node.cbs["string"] = visit_node.cbs["variable"]
visit_node.cbs["nil"] = visit_node.cbs["variable"]
visit_node.cbs["boolean"] = visit_node.cbs["variable"]
visit_node.cbs["..."] = visit_node.cbs["variable"]
visit_node.cbs["argument"] = visit_node.cbs["variable"]
visit_node.cbs["type_identifier"] = visit_node.cbs["variable"]
local out = recurse_node(ast, visit_node, visit_type)
if err then
return nil, err
end
local code
if opts.preserve_newlines then
code = { y = 1, h = 0 }
add_child(code, out)
else
code = out
end
return concat_output(code)
end
local function VARARG(t)
local tuple = t
tuple.typename = "tuple"
tuple.is_va = true
return a_type(t)
end
local function TUPLE(t)
local tuple = t
tuple.typename = "tuple"
return a_type(t)
end
local function UNION(t)
return a_type({ typename = "union", types = t })
end
local NONE = a_type({ typename = "none" })
local INVALID = a_type({ typename = "invalid" })
local UNKNOWN = a_type({ typename = "unknown" })
local CIRCULAR_REQUIRE = a_type({ typename = "circular_require" })
local FUNCTION = a_type({ typename = "function", args = VARARG({ ANY }), rets = VARARG({ ANY }) })
local NOMINAL_FILE = a_type({ typename = "nominal", names = { "FILE" } })
local XPCALL_MSGH_FUNCTION = a_type({ typename = "function", args = TUPLE({ ANY }), rets = TUPLE({}) })
local USERDATA = ANY
local numeric_binop = {
["number"] = {
["number"] = NUMBER,
["integer"] = NUMBER,
},
["integer"] = {
["integer"] = INTEGER,
["number"] = NUMBER,
},
}
local float_binop = {
["number"] = {
["number"] = NUMBER,
["integer"] = NUMBER,
},
["integer"] = {
["integer"] = NUMBER,
["number"] = NUMBER,
},
}
local integer_binop = {
["number"] = {
["number"] = INTEGER,
["integer"] = INTEGER,
},
["integer"] = {
["integer"] = INTEGER,
["number"] = INTEGER,
},
}
local relational_binop = {
["number"] = {
["integer"] = BOOLEAN,
["number"] = BOOLEAN,
},
["integer"] = {
["number"] = BOOLEAN,
["integer"] = BOOLEAN,
},
["string"] = {
["string"] = BOOLEAN,
},
["boolean"] = {
["boolean"] = BOOLEAN,
},
}
local equality_binop = {
["number"] = {
["number"] = BOOLEAN,
["integer"] = BOOLEAN,
["nil"] = BOOLEAN,
},
["integer"] = {
["number"] = BOOLEAN,
["integer"] = BOOLEAN,
["nil"] = BOOLEAN,
},
["string"] = {
["string"] = BOOLEAN,
["nil"] = BOOLEAN,
},
["boolean"] = {
["boolean"] = BOOLEAN,
["nil"] = BOOLEAN,
},
["record"] = {
["emptytable"] = BOOLEAN,
["arrayrecord"] = BOOLEAN,
["record"] = BOOLEAN,
["nil"] = BOOLEAN,
},
["array"] = {
["emptytable"] = BOOLEAN,
["arrayrecord"] = BOOLEAN,
["array"] = BOOLEAN,
["nil"] = BOOLEAN,
},
["arrayrecord"] = {
["emptytable"] = BOOLEAN,
["arrayrecord"] = BOOLEAN,
["record"] = BOOLEAN,
["array"] = BOOLEAN,
["nil"] = BOOLEAN,
},
["map"] = {
["emptytable"] = BOOLEAN,
["map"] = BOOLEAN,
["nil"] = BOOLEAN,
},
["thread"] = {
["thread"] = BOOLEAN,
["nil"] = BOOLEAN,
},
}
local unop_types = {
["#"] = {
["arrayrecord"] = INTEGER,
["string"] = INTEGER,
["array"] = INTEGER,
["tupletable"] = INTEGER,
["map"] = INTEGER,
["emptytable"] = INTEGER,
},
["-"] = {
["number"] = NUMBER,
["integer"] = INTEGER,
},
["~"] = {
["number"] = INTEGER,
["integer"] = INTEGER,
},
["not"] = {
["string"] = BOOLEAN,
["number"] = BOOLEAN,
["integer"] = BOOLEAN,
["boolean"] = BOOLEAN,
["record"] = BOOLEAN,
["arrayrecord"] = BOOLEAN,
["array"] = BOOLEAN,
["tupletable"] = BOOLEAN,
["map"] = BOOLEAN,
["emptytable"] = BOOLEAN,
["thread"] = BOOLEAN,
},
}
local unop_to_metamethod = {
["#"] = "__len",
["-"] = "__unm",
["~"] = "__bnot",
}
local binop_types = {
["+"] = numeric_binop,
["-"] = numeric_binop,
["*"] = numeric_binop,
["%"] = numeric_binop,
["/"] = float_binop,
["//"] = numeric_binop,
["^"] = float_binop,
["&"] = integer_binop,
["|"] = integer_binop,
["<<"] = integer_binop,
[">>"] = integer_binop,
["~"] = integer_binop,
["=="] = equality_binop,
["~="] = equality_binop,
["<="] = relational_binop,
[">="] = relational_binop,
["<"] = relational_binop,
[">"] = relational_binop,
["or"] = {
["boolean"] = {
["boolean"] = BOOLEAN,
["function"] = FUNCTION,
},
["number"] = {
["integer"] = NUMBER,
["number"] = NUMBER,
["boolean"] = BOOLEAN,
},
["integer"] = {
["integer"] = INTEGER,
["number"] = NUMBER,
["boolean"] = BOOLEAN,
},
["string"] = {
["string"] = STRING,
["boolean"] = BOOLEAN,
["enum"] = STRING,
},
["function"] = {
["boolean"] = BOOLEAN,
},
["array"] = {
["boolean"] = BOOLEAN,
},
["record"] = {
["boolean"] = BOOLEAN,
},
["arrayrecord"] = {
["boolean"] = BOOLEAN,
},
["map"] = {
["boolean"] = BOOLEAN,
},
["enum"] = {
["string"] = STRING,
},
["thread"] = {
["boolean"] = BOOLEAN,
},
},
[".."] = {
["string"] = {
["string"] = STRING,
["enum"] = STRING,
["number"] = STRING,
["integer"] = STRING,
},
["number"] = {
["integer"] = STRING,
["number"] = STRING,
["string"] = STRING,
["enum"] = STRING,
},
["integer"] = {
["integer"] = STRING,
["number"] = STRING,
["string"] = STRING,
["enum"] = STRING,
},
["enum"] = {
["number"] = STRING,
["integer"] = STRING,
["string"] = STRING,
["enum"] = STRING,
},
},
}
local binop_to_metamethod = {
["+"] = "__add",
["-"] = "__sub",
["*"] = "__mul",
["/"] = "__div",
["%"] = "__mod",
["^"] = "__pow",
["//"] = "__idiv",
["&"] = "__band",
["|"] = "__bor",
["~"] = "__bxor",
["<<"] = "__shl",
[">>"] = "__shr",
[".."] = "__concat",
["=="] = "__eq",
["<"] = "__lt",
["<="] = "__le",
["@index"] = "__index",
}
local function is_unknown(t)
return t.typename == "unknown" or
t.typename == "unresolved_emptytable_value"
end
local function show_type_base(t, short, seen)
if seen[t] then
return seen[t]
end
seen[t] = "..."
local function show(typ)
return show_type(typ, short, seen)
end
if t.typename == "nominal" then
if t.typevals then
local out = { table.concat(t.names, "."), "<" }
local vals = {}
for _, v in ipairs(t.typevals) do
table.insert(vals, show(v))
end
table.insert(out, table.concat(vals, ", "))
table.insert(out, ">")
return table.concat(out)
else
return table.concat(t.names, ".")
end
elseif t.typename == "tuple" then
local out = {}
for _, v in ipairs(t) do
table.insert(out, show(v))
end
return "(" .. table.concat(out, ", ") .. ")"
elseif t.typename == "tupletable" then
local out = {}
for _, v in ipairs(t.types) do
table.insert(out, show(v))
end
return "{" .. table.concat(out, ", ") .. "}"
elseif t.typename == "poly" then
local out = {}
for _, v in ipairs(t.types) do
table.insert(out, show(v))
end
return "polymorphic function (with types " .. table.concat(out, " and ") .. ")"
elseif t.typename == "union" then
local out = {}
for _, v in ipairs(t.types) do
table.insert(out, show(v))
end
return table.concat(out, " | ")
elseif t.typename == "emptytable" then
return "{}"
elseif t.typename == "map" then
return "{" .. show(t.keys) .. " : " .. show(t.values) .. "}"
elseif t.typename == "array" then
return "{" .. show(t.elements) .. "}"
elseif t.typename == "enum" then
return t.names and table.concat(t.names, ".") or "enum"
elseif is_record_type(t) then
if short then
return "record"
else
local out = { "record" }
if t.typeargs then
table.insert(out, "<")
local typeargs = {}
for _, v in ipairs(t.typeargs) do
table.insert(typeargs, show(v))
end
table.insert(out, table.concat(typeargs, ", "))
table.insert(out, ">")
end
table.insert(out, " (")
if t.elements then
table.insert(out, "{" .. show(t.elements) .. "}")
end
local fs = {}
for _, k in ipairs(t.field_order) do
local v = t.fields[k]
table.insert(fs, k .. ": " .. show(v))
end
table.insert(out, table.concat(fs, "; "))
table.insert(out, ")")
return table.concat(out)
end
elseif t.typename == "function" then
local out = { "function" }
if t.typeargs then
table.insert(out, "<")
local typeargs = {}
for _, v in ipairs(t.typeargs) do
table.insert(typeargs, show(v))
end
table.insert(out, table.concat(typeargs, ", "))
table.insert(out, ">")
end
table.insert(out, "(")
local args = {}
if t.is_method then
table.insert(args, "self")
end
for i, v in ipairs(t.args) do
if not t.is_method or i > 1 then
table.insert(args, (i == #t.args and t.args.is_va and "...: " or "") .. show(v))
end
end
table.insert(out, table.concat(args, ", "))
table.insert(out, ")")
if #t.rets > 0 then
table.insert(out, ": ")
local rets = {}
for i, v in ipairs(t.rets) do
table.insert(rets, show(v) .. (i == #t.rets and t.rets.is_va and "..." or ""))
end
table.insert(out, table.concat(rets, ", "))
end
return table.concat(out)
elseif t.typename == "number" or
t.typename == "integer" or
t.typename == "boolean" or
t.typename == "thread" then
return t.typename
elseif t.typename == "string" then
if short then
return "string"
else
return t.typename ..
(t.tk and " " .. t.tk or "")
end
elseif t.typename == "typevar" then
return TL_DEBUG and t.typevar or (t.typevar:gsub("@.*", ""))
elseif t.typename == "typearg" then
return TL_DEBUG and t.typearg or (t.typearg:gsub("@.*", ""))
elseif t.typename == "unresolvable_typearg" then
return (TL_DEBUG and t.typearg or (t.typearg:gsub("@.*", ""))) .. " (unresolved generic)"
elseif is_unknown(t) then
return "<unknown type>"
elseif t.typename == "invalid" then
return "<invalid type>"
elseif t.typename == "any" then
return "<any type>"
elseif t.typename == "nil" then
return "nil"
elseif t.typename == "none" then
return ""
elseif is_typetype(t) then
return "type " .. show(t.def) .. (t.is_alias and " (alias)" or "")
elseif t.typename == "bad_nominal" then
return table.concat(t.names, ".") .. " (an unknown type)"
else
return "<" .. t.typename .. " " .. tostring(t) .. ">"
end
end
local function inferred_msg(t)
return " (inferred at " .. t.inferred_at_file .. ":" .. t.inferred_at.y .. ":" .. t.inferred_at.x .. ")"
end
show_type = function(t, short, seen)
seen = seen or {}
if seen[t] then
return seen[t]
end
local ret = show_type_base(t, short, seen)
if t.inferred_at then
ret = ret .. inferred_msg(t)
end
seen[t] = ret
return ret
end
local function search_for(module_name, suffix, path, tried)
for entry in path:gmatch("[^;]+") do
local slash_name = module_name:gsub("%.", "/")
local filename = entry:gsub("?", slash_name)
local tl_filename = filename:gsub("%.lua$", suffix)
local fd = io.open(tl_filename, "rb")
if fd then
return tl_filename, fd, tried
end
table.insert(tried, "no file '" .. tl_filename .. "'")
end
return nil, nil, tried
end
local function filename_to_module_name(filename)
local path = os.getenv("TL_PATH") or package.path
for entry in path:gmatch("[^;]+") do
entry = entry:gsub("%.", "%%.")
local lua_pat = "^" .. entry:gsub("%?", ".+") .. "$"
local d_tl_pat = lua_pat:gsub("%%.lua%$", "%%.d%%.tl$")
local tl_pat = lua_pat:gsub("%%.lua%$", "%%.tl$")
for _, pat in ipairs({ tl_pat, d_tl_pat, lua_pat }) do
local cap = filename:match(pat)
if cap then
return (cap:gsub("[/\\]", "."))
end
end
end
return (filename:gsub("%.lua$", ""):gsub("%.d%.tl$", ""):gsub("%.tl$", ""):gsub("[/\\]", "."))
end
function tl.search_module(module_name, search_dtl)
local found
local fd
local tried = {}
local path = os.getenv("TL_PATH") or package.path
if search_dtl then
found, fd, tried = search_for(module_name, ".d.tl", path, tried)
if found then
return found, fd
end
end
found, fd, tried = search_for(module_name, ".tl", path, tried)
if found then
return found, fd
end
found, fd, tried = search_for(module_name, ".lua", path, tried)
if found then
return found, fd
end
return nil, nil, tried
end
local function sorted_keys(m)
local keys = {}
for k, _ in pairs(m) do
table.insert(keys, k)
end
table.sort(keys)
return keys
end
local function fill_field_order(t)
if t.typename == "record" then
t.field_order = sorted_keys(t.fields)
end
end
local function require_module(module_name, lax, env)
local mod = env.modules[module_name]
if mod then
return mod, true
end
local found, fd = tl.search_module(module_name, true)
if found and (lax or found:match("tl$")) then
local found_result, err = tl.process(found, env, module_name, fd)
assert(found_result, err)
return found_result.type, true
elseif fd then
fd:close()
end
return INVALID, found ~= nil
end
local compat_code_cache = {}
local function add_compat_entries(program, used_set, gen_compat)
local tl_debug = TL_DEBUG
TL_DEBUG = nil
if gen_compat == "off" or not next(used_set) then
return
end
local used_list = sorted_keys(used_set)
local compat_loaded = false
local n = 1
local function load_code(name, text)
local code = compat_code_cache[name]
if not code then
code = tl.parse(text, "@internal")
tl.type_check(code, { filename = "<internal>", lax = false, gen_compat = "off" })
compat_code_cache[name] = code
end
for _, c in ipairs(code) do
table.insert(program, n, c)
n = n + 1
end
end
local function req(m)
return (gen_compat == "optional") and
"pcall(require, '" .. m .. "')" or
"true, require('" .. m .. "')"
end
for _, name in ipairs(used_list) do
if name == "table.unpack" then
load_code(name, "local _tl_table_unpack = unpack or table.unpack")
elseif name == "bit32" then
load_code(name, "local bit32 = bit32; if not bit32 then local p, m = " .. req("bit32") .. "; if p then bit32 = m end")
elseif name == "mt" then
load_code(name, "local _tl_mt = function(m, s, a, b) return (getmetatable(s == 1 and a or b)[m](a, b) end")
elseif name == "math.maxinteger" then
load_code(name, "local _tl_math_maxinteger = math.maxinteger or math.pow(2,53)")
elseif name == "math.mininteger" then
load_code(name, "local _tl_math_mininteger = math.mininteger or -math.pow(2,53) - 1")
else
if not compat_loaded then
load_code("compat", "local _tl_compat; if (tonumber((_VERSION or ''):match('[%d.]*$')) or 0) < 5.3 then local p, m = " .. req("compat53.module") .. "; if p then _tl_compat = m end")
compat_loaded = true
end
load_code(name, (("local $NAME = _tl_compat and _tl_compat.$NAME or $NAME"):gsub("$NAME", name)))
end
end
program.y = 1
TL_DEBUG = tl_debug
end
local function get_stdlib_compat(lax)
if lax then
return {
["utf8"] = true,
}
else
return {
["io"] = true,
["math"] = true,
["string"] = true,
["table"] = true,
["utf8"] = true,
["coroutine"] = true,
["os"] = true,
["package"] = true,
["debug"] = true,
["load"] = true,
["loadfile"] = true,
["assert"] = true,
["pairs"] = true,
["ipairs"] = true,
["pcall"] = true,
["xpcall"] = true,
["rawlen"] = true,
}
end
end
local bit_operators = {
["&"] = "band",
["|"] = "bor",
["~"] = "bxor",
[">>"] = "rshift",
["<<"] = "lshift",
}
local function convert_node_to_compat_call(node, mod_name, fn_name, e1, e2)
node.op.op = "@funcall"
node.op.arity = 2
node.op.prec = 100
node.e1 = { y = node.y, x = node.x, kind = "op", op = an_operator(node, 2, ".") }
node.e1.e1 = { y = node.y, x = node.x, kind = "identifier", tk = mod_name }
node.e1.e2 = { y = node.y, x = node.x, kind = "identifier", tk = fn_name }
node.e2 = { y = node.y, x = node.x, kind = "expression_list" }
node.e2[1] = e1
node.e2[2] = e2
end
local function convert_node_to_compat_mt_call(node, mt_name, which_self, e1, e2)
node.op.op = "@funcall"
node.op.arity = 2
node.op.prec = 100
node.e1 = { y = node.y, x = node.x, kind = "identifier", tk = "_tl_mt" }
node.e2 = { y = node.y, x = node.x, kind = "expression_list" }
node.e2[1] = { y = node.y, x = node.x, kind = "string", tk = "\"" .. mt_name .. "\"" }
node.e2[2] = { y = node.y, x = node.x, kind = "integer", tk = tostring(which_self) }
node.e2[3] = e1
node.e2[4] = e2
end
local globals_typeid
local fresh_typevar_ctr = 1
local function init_globals(lax)
local globals = {}
local stdlib_compat = get_stdlib_compat(lax)
local is_first_init = globals_typeid == nil
local save_typeid = last_typeid
if is_first_init then
globals_typeid = new_typeid()
else
last_typeid = globals_typeid
end
local function a_gfunction(n, f)
local typevars = {}
local typeargs = {}
local c = string.byte("A") - 1
fresh_typevar_ctr = fresh_typevar_ctr + 1
for i = 1, n do
local name = string.char(c + i) .. "@" .. fresh_typevar_ctr
typevars[i] = a_type({ typename = "typevar", typevar = name })
typeargs[i] = a_type({ typename = "typearg", typearg = name })
end
local t = f(_tl_table_unpack(typevars))
t.typename = "function"
t.typeargs = typeargs
return a_type(t)
end
local function a_grecord(n, f)
local t = a_gfunction(n, f)
t.typename = "record"
return t
end
local LOAD_FUNCTION = a_type({ typename = "function", args = {}, rets = TUPLE({ STRING }) })
local OS_DATE_TABLE = a_type({
typename = "record",
fields = {
["year"] = INTEGER,
["month"] = INTEGER,
["day"] = INTEGER,
["hour"] = INTEGER,
["min"] = INTEGER,
["sec"] = INTEGER,
["wday"] = INTEGER,
["yday"] = INTEGER,
["isdst"] = BOOLEAN,
},
})
local OS_DATE_TABLE_FORMAT = a_type({ typename = "enum", enumset = { ["!*t"] = true, ["*t"] = true } })
local DEBUG_GETINFO_TABLE = a_type({
typename = "record",
fields = {
["name"] = STRING,
["namewhat"] = STRING,
["source"] = STRING,
["short_src"] = STRING,
["linedefined"] = INTEGER,
["lastlinedefined"] = INTEGER,
["what"] = STRING,
["currentline"] = INTEGER,
["istailcall"] = BOOLEAN,
["nups"] = INTEGER,
["nparams"] = INTEGER,
["isvararg"] = BOOLEAN,
["func"] = ANY,
["activelines"] = a_type({ typename = "map", keys = INTEGER, values = BOOLEAN }),
},
})
local DEBUG_HOOK_EVENT = a_type({
typename = "enum",
enumset = {
["call"] = true,
["tail call"] = true,
["return"] = true,
["line"] = true,
["count"] = true,
},
})
local DEBUG_HOOK_FUNCTION = a_type({
typename = "function",
args = TUPLE({ DEBUG_HOOK_EVENT, INTEGER }),
rets = TUPLE({}),
})
local TABLE_SORT_FUNCTION = a_gfunction(1, function(a) return { args = TUPLE({ a, a }), rets = TUPLE({ BOOLEAN }) } end)
local metatable_nominals = {}
local function METATABLE(a)
local t = a_type({ typename = "nominal", names = { "metatable" }, typevals = { a } })
table.insert(metatable_nominals, t)
return t
end
local function ARRAY(t)
return a_type({
typename = "array",
elements = t,
})
end
local function MAP(k, v)
return a_type({
typename = "map",
keys = k,
values = v,
})
end
local function OPT(x)
return x
end
local standard_library = {
["..."] = VARARG({ STRING }),
["any"] = a_type({ typename = "typetype", def = ANY }),
["arg"] = ARRAY(STRING),
["assert"] = a_gfunction(2, function(a, b) return { args = TUPLE({ a, OPT(b) }), rets = TUPLE({ a }) } end),
["collectgarbage"] = a_type({
typename = "poly",
types = {
a_type({ typename = "function", args = TUPLE({ a_type({ typename = "enum", enumset = { ["collect"] = true, ["count"] = true, ["stop"] = true, ["restart"] = true } }) }), rets = TUPLE({ NUMBER }) }),
a_type({ typename = "function", args = TUPLE({ a_type({ typename = "enum", enumset = { ["step"] = true, ["setpause"] = true, ["setstepmul"] = true } }), NUMBER }), rets = TUPLE({ NUMBER }) }),
a_type({ typename = "function", args = TUPLE({ a_type({ typename = "enum", enumset = { ["isrunning"] = true } }) }), rets = TUPLE({ BOOLEAN }) }),
a_type({ typename = "function", args = TUPLE({ STRING, OPT(NUMBER) }), rets = TUPLE({ a_type({ typename = "union", types = { BOOLEAN, NUMBER } }) }) }),
},
}),
["dofile"] = a_type({ typename = "function", args = TUPLE({ OPT(STRING) }), rets = VARARG({ ANY }) }),
["error"] = a_type({ typename = "function", args = TUPLE({ ANY, NUMBER }), rets = TUPLE({}) }),
["getmetatable"] = a_gfunction(1, function(a) return { args = TUPLE({ a }), rets = TUPLE({ METATABLE(a) }) } end),
["ipairs"] = a_gfunction(1, function(a) return { args = TUPLE({ ARRAY(a) }), rets = TUPLE({
a_type({ typename = "function", args = TUPLE({}), rets = TUPLE({ INTEGER, a }) }),
}), } end),
["load"] = a_type({ typename = "function", args = TUPLE({ UNION({ STRING, LOAD_FUNCTION }), OPT(STRING), OPT(STRING), OPT(TABLE) }), rets = TUPLE({ FUNCTION, STRING }) }),
["loadfile"] = a_type({ typename = "function", args = TUPLE({ OPT(STRING), OPT(STRING), OPT(TABLE) }), rets = TUPLE({ FUNCTION, STRING }) }),
["next"] = a_type({
typename = "poly",
types = {
a_gfunction(2, function(a, b) return { args = TUPLE({ MAP(a, b), OPT(a) }), rets = TUPLE({ a, b }) } end),
a_gfunction(1, function(a) return { args = TUPLE({ ARRAY(a), OPT(a) }), rets = TUPLE({ INTEGER, a }) } end),
},
}),
["pairs"] = a_gfunction(2, function(a, b) return { args = TUPLE({ a_type({ typename = "map", keys = a, values = b }) }), rets = TUPLE({
a_type({ typename = "function", args = TUPLE({}), rets = TUPLE({ a, b }) }),
}), } end),
["pcall"] = a_type({ typename = "function", args = VARARG({ FUNCTION, ANY }), rets = VARARG({ BOOLEAN, ANY }) }),
["xpcall"] = a_type({ typename = "function", args = VARARG({ FUNCTION, XPCALL_MSGH_FUNCTION, ANY }), rets = VARARG({ BOOLEAN, ANY }) }),
["print"] = a_type({ typename = "function", args = VARARG({ ANY }), rets = TUPLE({}) }),
["rawequal"] = a_type({ typename = "function", args = TUPLE({ ANY, ANY }), rets = TUPLE({ BOOLEAN }) }),
["rawget"] = a_type({ typename = "function", args = TUPLE({ TABLE, ANY }), rets = TUPLE({ ANY }) }),
["rawlen"] = a_type({ typename = "function", args = TUPLE({ UNION({ TABLE, STRING }) }), rets = TUPLE({ INTEGER }) }),
["rawset"] = a_type({
typename = "poly",
types = {
a_gfunction(2, function(a, b) return { args = TUPLE({ MAP(a, b), a, b }), rets = TUPLE({}) } end),
a_gfunction(1, function(a) return { args = TUPLE({ ARRAY(a), NUMBER, a }), rets = TUPLE({}) } end),
a_type({ typename = "function", args = TUPLE({ TABLE, ANY, ANY }), rets = TUPLE({}) }),
},
}),
["require"] = a_type({ typename = "function", args = TUPLE({ STRING }), rets = TUPLE({}) }),
["select"] = a_type({
typename = "poly",
types = {
a_gfunction(1, function(a) return { args = VARARG({ NUMBER, a }), rets = TUPLE({ a }) } end),
a_type({ typename = "function", args = VARARG({ NUMBER, ANY }), rets = TUPLE({ ANY }) }),
a_type({ typename = "function", args = VARARG({ STRING, ANY }), rets = TUPLE({ INTEGER }) }),
},
}),
["setmetatable"] = a_gfunction(1, function(a) return { args = TUPLE({ a, METATABLE(a) }), rets = TUPLE({ a }) } end),
["tonumber"] = a_type({
typename = "poly",
types = {
a_type({ typename = "function", args = TUPLE({ ANY }), rets = TUPLE({ NUMBER }) }),
a_type({ typename = "function", args = TUPLE({ ANY, NUMBER }), rets = TUPLE({ INTEGER }) }),
},
}),
["tostring"] = a_type({ typename = "function", args = TUPLE({ ANY }), rets = TUPLE({ STRING }) }),
["type"] = a_type({ typename = "function", args = TUPLE({ ANY }), rets = TUPLE({ STRING }) }),
["FILE"] = a_type({
typename = "typetype",
def = a_type({
typename = "record",
is_userdata = true,
fields = {
["close"] = a_type({ typename = "function", args = TUPLE({ NOMINAL_FILE }), rets = TUPLE({ BOOLEAN, STRING }) }),
["flush"] = a_type({ typename = "function", args = TUPLE({ NOMINAL_FILE }), rets = TUPLE({}) }),
["lines"] = a_type({ typename = "function", args = VARARG({ NOMINAL_FILE, a_type({ typename = "union", types = { STRING, NUMBER } }) }), rets = TUPLE({
a_type({ typename = "function", args = TUPLE({}), rets = VARARG({ STRING }) }),
}), }),
["read"] = a_type({ typename = "function", args = TUPLE({ NOMINAL_FILE, UNION({ STRING, NUMBER }) }), rets = TUPLE({ STRING, STRING }) }),
["seek"] = a_type({ typename = "function", args = TUPLE({ NOMINAL_FILE, OPT(STRING), OPT(NUMBER) }), rets = TUPLE({ INTEGER, STRING }) }),
["setvbuf"] = a_type({ typename = "function", args = TUPLE({ NOMINAL_FILE, STRING, OPT(NUMBER) }), rets = TUPLE({}) }),
["write"] = a_type({ typename = "function", args = VARARG({ NOMINAL_FILE, STRING }), rets = TUPLE({ NOMINAL_FILE, STRING }) }),
},
meta_fields = { ["__close"] = FUNCTION },
meta_field_order = { "__close" },
}),
}),
["metatable"] = a_type({
typename = "typetype",
def = a_grecord(1, function(a) return {
fields = {
["__call"] = a_type({ typename = "function", args = VARARG({ a, ANY }), rets = VARARG({ ANY }) }),
["__gc"] = a_type({ typename = "function", args = TUPLE({ a }), rets = TUPLE({}) }),
["__index"] = ANY,
["__len"] = a_type({ typename = "function", args = TUPLE({ a }), rets = TUPLE({ ANY }) }),
["__mode"] = a_type({ typename = "enum", enumset = { ["k"] = true, ["v"] = true, ["kv"] = true } }),
["__newindex"] = ANY,
["__pairs"] = a_gfunction(2, function(k, v)
return {
args = TUPLE({ a }),
rets = TUPLE({ a_type({ typename = "function", args = TUPLE({}), rets = TUPLE({ k, v }) }) }),
}
end),
["__tostring"] = a_type({ typename = "function", args = TUPLE({ a }), rets = TUPLE({ STRING }) }),
["__name"] = STRING,
["__add"] = a_type({ typename = "function", args = TUPLE({ ANY, ANY }), rets = TUPLE({ ANY }) }),
["__sub"] = a_type({ typename = "function", args = TUPLE({ ANY, ANY }), rets = TUPLE({ ANY }) }),
["__mul"] = a_type({ typename = "function", args = TUPLE({ ANY, ANY }), rets = TUPLE({ ANY }) }),
["__div"] = a_type({ typename = "function", args = TUPLE({ ANY, ANY }), rets = TUPLE({ ANY }) }),
["__idiv"] = a_type({ typename = "function", args = TUPLE({ ANY, ANY }), rets = TUPLE({ ANY }) }),
["__mod"] = a_type({ typename = "function", args = TUPLE({ ANY, ANY }), rets = TUPLE({ ANY }) }),
["__pow"] = a_type({ typename = "function", args = TUPLE({ ANY, ANY }), rets = TUPLE({ ANY }) }),
["__unm"] = a_type({ typename = "function", args = TUPLE({ ANY }), rets = TUPLE({ ANY }) }),
["__band"] = a_type({ typename = "function", args = TUPLE({ ANY, ANY }), rets = TUPLE({ ANY }) }),
["__bor"] = a_type({ typename = "function", args = TUPLE({ ANY, ANY }), rets = TUPLE({ ANY }) }),
["__bxor"] = a_type({ typename = "function", args = TUPLE({ ANY, ANY }), rets = TUPLE({ ANY }) }),
["__bnot"] = a_type({ typename = "function", args = TUPLE({ ANY }), rets = TUPLE({ ANY }) }),
["__shl"] = a_type({ typename = "function", args = TUPLE({ ANY, ANY }), rets = TUPLE({ ANY }) }),
["__shr"] = a_type({ typename = "function", args = TUPLE({ ANY, ANY }), rets = TUPLE({ ANY }) }),
["__concat"] = a_type({ typename = "function", args = TUPLE({ ANY, ANY }), rets = TUPLE({ ANY }) }),
["__eq"] = a_type({ typename = "function", args = TUPLE({ ANY, ANY }), rets = TUPLE({ BOOLEAN }) }),
["__lt"] = a_type({ typename = "function", args = TUPLE({ ANY, ANY }), rets = TUPLE({ BOOLEAN }) }),
["__le"] = a_type({ typename = "function", args = TUPLE({ ANY, ANY }), rets = TUPLE({ BOOLEAN }) }),
["__close"] = a_type({ typename = "function", args = TUPLE({ a }), rets = TUPLE({}) }),
},
} end),
}),
["coroutine"] = a_type({
typename = "record",
fields = {
["create"] = a_type({ typename = "function", args = TUPLE({ FUNCTION }), rets = TUPLE({ THREAD }) }),
["close"] = a_type({ typename = "function", args = TUPLE({ THREAD }), rets = TUPLE({ BOOLEAN, STRING }) }),
["isyieldable"] = a_type({ typename = "function", args = TUPLE({}), rets = TUPLE({ BOOLEAN }) }),
["resume"] = a_type({ typename = "function", args = VARARG({ THREAD, ANY }), rets = VARARG({ BOOLEAN, ANY }) }),
["running"] = a_type({ typename = "function", args = TUPLE({}), rets = TUPLE({ THREAD, BOOLEAN }) }),
["status"] = a_type({ typename = "function", args = TUPLE({ THREAD }), rets = TUPLE({ STRING }) }),
["wrap"] = a_type({ typename = "function", args = TUPLE({ FUNCTION }), rets = TUPLE({ FUNCTION }) }),
["yield"] = a_type({ typename = "function", args = VARARG({ ANY }), rets = VARARG({ ANY }) }),
},
}),
["debug"] = a_type({
typename = "record",
fields = {
["Info"] = a_type({
typename = "typetype",
def = DEBUG_GETINFO_TABLE,
}),
["Hook"] = a_type({
typename = "typetype",
def = DEBUG_HOOK_FUNCTION,
}),
["HookEvent"] = a_type({
typename = "typetype",
def = DEBUG_HOOK_EVENT,
}),
["debug"] = a_type({ typename = "function", args = TUPLE({}), rets = TUPLE({}) }),
["gethook"] = a_type({ typename = "function", args = TUPLE({ OPT(THREAD) }), rets = TUPLE({ DEBUG_HOOK_FUNCTION, INTEGER }) }),
["getlocal"] = a_type({
typename = "poly",
types = {
a_type({ typename = "function", args = TUPLE({ THREAD, FUNCTION, NUMBER }), rets = TUPLE({}) }),
a_type({ typename = "function", args = TUPLE({ FUNCTION, NUMBER }), rets = TUPLE({}) }),
},
}),
["getmetatable"] = a_gfunction(1, function(a) return { args = TUPLE({ a }), rets = TUPLE({ METATABLE(a) }) } end),
["getregistry"] = a_type({ typename = "function", args = TUPLE({}), rets = TUPLE({ TABLE }) }),
["getupvalue"] = a_type({ typename = "function", args = TUPLE({ FUNCTION, NUMBER }), rets = TUPLE({ ANY }) }),
["getuservalue"] = a_type({ typename = "function", args = TUPLE({ USERDATA, NUMBER }), rets = TUPLE({ ANY }) }),
["sethook"] = a_type({
typename = "poly",
types = {
a_type({ typename = "function", args = TUPLE({ THREAD, DEBUG_HOOK_FUNCTION, STRING, NUMBER }), rets = TUPLE({}) }),
a_type({ typename = "function", args = TUPLE({ DEBUG_HOOK_FUNCTION, STRING, NUMBER }), rets = TUPLE({}) }),
},
}),
["setlocal"] = a_type({
typename = "poly",
types = {
a_type({ typename = "function", args = TUPLE({ THREAD, NUMBER, NUMBER, ANY }), rets = TUPLE({ STRING }) }),
a_type({ typename = "function", args = TUPLE({ NUMBER, NUMBER, ANY }), rets = TUPLE({ STRING }) }),
},
}),
["setmetatable"] = a_gfunction(1, function(a) return { args = TUPLE({ a, METATABLE(a) }), rets = TUPLE({ a }) } end),
["setupvalue"] = a_type({ typename = "function", args = TUPLE({ FUNCTION, NUMBER, ANY }), rets = TUPLE({ STRING }) }),
["setuservalue"] = a_type({ typename = "function", args = TUPLE({ USERDATA, ANY, NUMBER }), rets = TUPLE({ USERDATA }) }),
["traceback"] = a_type({
typename = "poly",
types = {
a_type({ typename = "function", args = TUPLE({ THREAD, STRING, NUMBER }), rets = TUPLE({ STRING }) }),
a_type({ typename = "function", args = TUPLE({ STRING, NUMBER }), rets = TUPLE({ STRING }) }),
},
}),
["upvalueid"] = a_type({ typename = "function", args = TUPLE({ FUNCTION, NUMBER }), rets = TUPLE({ USERDATA }) }),
["upvaluejoin"] = a_type({ typename = "function", args = TUPLE({ FUNCTION, NUMBER, FUNCTION, NUMBER }), rets = TUPLE({}) }),
["getinfo"] = a_type({
typename = "poly",
types = {
a_type({ typename = "function", args = TUPLE({ ANY }), rets = TUPLE({ DEBUG_GETINFO_TABLE }) }),
a_type({ typename = "function", args = TUPLE({ ANY, STRING }), rets = TUPLE({ DEBUG_GETINFO_TABLE }) }),
a_type({ typename = "function", args = TUPLE({ ANY, ANY, STRING }), rets = TUPLE({ DEBUG_GETINFO_TABLE }) }),
},
}),
},
}),
["io"] = a_type({
typename = "record",
fields = {
["close"] = a_type({ typename = "function", args = TUPLE({ OPT(NOMINAL_FILE) }), rets = TUPLE({ BOOLEAN, STRING }) }),
["flush"] = a_type({ typename = "function", args = TUPLE({}), rets = TUPLE({}) }),
["input"] = a_type({ typename = "function", args = TUPLE({ OPT(UNION({ STRING, NOMINAL_FILE })) }), rets = TUPLE({ NOMINAL_FILE }) }),
["lines"] = a_type({ typename = "function", args = VARARG({ OPT(STRING), a_type({ typename = "union", types = { STRING, NUMBER } }) }), rets = TUPLE({
a_type({ typename = "function", args = TUPLE({}), rets = VARARG({ STRING }) }),
}), }),
["open"] = a_type({ typename = "function", args = TUPLE({ STRING, STRING }), rets = TUPLE({ NOMINAL_FILE, STRING }) }),
["output"] = a_type({ typename = "function", args = TUPLE({ OPT(UNION({ STRING, NOMINAL_FILE })) }), rets = TUPLE({ NOMINAL_FILE }) }),
["popen"] = a_type({ typename = "function", args = TUPLE({ STRING, STRING }), rets = TUPLE({ NOMINAL_FILE, STRING }) }),
["read"] = a_type({ typename = "function", args = TUPLE({ UNION({ STRING, NUMBER }) }), rets = TUPLE({ STRING, STRING }) }),
["stderr"] = NOMINAL_FILE,
["stdin"] = NOMINAL_FILE,
["stdout"] = NOMINAL_FILE,
["tmpfile"] = a_type({ typename = "function", args = TUPLE({}), rets = TUPLE({ NOMINAL_FILE }) }),
["type"] = a_type({ typename = "function", args = TUPLE({ ANY }), rets = TUPLE({ STRING }) }),
["write"] = a_type({ typename = "function", args = VARARG({ STRING }), rets = TUPLE({ NOMINAL_FILE, STRING }) }),
},
}),
["math"] = a_type({
typename = "record",
fields = {
["abs"] = a_type({
typename = "poly",
types = {
a_type({ typename = "function", args = TUPLE({ INTEGER }), rets = TUPLE({ INTEGER }) }),
a_type({ typename = "function", args = TUPLE({ NUMBER }), rets = TUPLE({ NUMBER }) }),
},
}),
["acos"] = a_type({ typename = "function", args = TUPLE({ NUMBER }), rets = TUPLE({ NUMBER }) }),
["asin"] = a_type({ typename = "function", args = TUPLE({ NUMBER }), rets = TUPLE({ NUMBER }) }),
["atan"] = a_type({ typename = "function", args = TUPLE({ NUMBER, OPT(NUMBER) }), rets = TUPLE({ NUMBER }) }),
["atan2"] = a_type({ typename = "function", args = TUPLE({ NUMBER, NUMBER }), rets = TUPLE({ NUMBER }) }),
["ceil"] = a_type({ typename = "function", args = TUPLE({ NUMBER }), rets = TUPLE({ INTEGER }) }),
["cos"] = a_type({ typename = "function", args = TUPLE({ NUMBER }), rets = TUPLE({ NUMBER }) }),
["cosh"] = a_type({ typename = "function", args = TUPLE({ NUMBER }), rets = TUPLE({ NUMBER }) }),
["deg"] = a_type({ typename = "function", args = TUPLE({ NUMBER }), rets = TUPLE({ NUMBER }) }),
["exp"] = a_type({ typename = "function", args = TUPLE({ NUMBER }), rets = TUPLE({ NUMBER }) }),
["floor"] = a_type({ typename = "function", args = TUPLE({ NUMBER }), rets = TUPLE({ INTEGER }) }),
["fmod"] = a_type({
typename = "poly",
types = {
a_type({ typename = "function", args = TUPLE({ INTEGER, INTEGER }), rets = TUPLE({ INTEGER }) }),
a_type({ typename = "function", args = TUPLE({ NUMBER, NUMBER }), rets = TUPLE({ NUMBER }) }),
},
}),
["frexp"] = a_type({ typename = "function", args = TUPLE({ NUMBER }), rets = TUPLE({ NUMBER, NUMBER }) }),
["huge"] = NUMBER,
["ldexp"] = a_type({ typename = "function", args = TUPLE({ NUMBER, NUMBER }), rets = TUPLE({ NUMBER }) }),
["log"] = a_type({ typename = "function", args = TUPLE({ NUMBER, NUMBER }), rets = TUPLE({ NUMBER }) }),
["log10"] = a_type({ typename = "function", args = TUPLE({ NUMBER }), rets = TUPLE({ NUMBER }) }),
["max"] = a_type({
typename = "poly",
types = {
a_type({ typename = "function", args = VARARG({ INTEGER }), rets = TUPLE({ INTEGER }) }),
a_gfunction(1, function(a) return { args = VARARG({ a }), rets = TUPLE({ a }) } end),
a_type({ typename = "function", args = VARARG({ a_type({ typename = "union", types = { NUMBER, INTEGER } }) }), rets = TUPLE({ NUMBER }) }),
a_type({ typename = "function", args = VARARG({ ANY }), rets = TUPLE({ ANY }) }),
},
}),
["maxinteger"] = a_type({ typename = "integer", needs_compat = true }),
["min"] = a_type({
typename = "poly",
types = {
a_type({ typename = "function", args = VARARG({ INTEGER }), rets = TUPLE({ INTEGER }) }),
a_gfunction(1, function(a) return { args = VARARG({ a }), rets = TUPLE({ a }) } end),
a_type({ typename = "function", args = VARARG({ a_type({ typename = "union", types = { NUMBER, INTEGER } }) }), rets = TUPLE({ NUMBER }) }),
a_type({ typename = "function", args = VARARG({ ANY }), rets = TUPLE({ ANY }) }),
},
}),
["mininteger"] = a_type({ typename = "integer", needs_compat = true }),
["modf"] = a_type({ typename = "function", args = TUPLE({ NUMBER }), rets = TUPLE({ INTEGER, NUMBER }) }),
["pi"] = NUMBER,
["pow"] = a_type({ typename = "function", args = TUPLE({ NUMBER, NUMBER }), rets = TUPLE({ NUMBER }) }),
["rad"] = a_type({ typename = "function", args = TUPLE({ NUMBER }), rets = TUPLE({ NUMBER }) }),
["random"] = a_type({
typename = "poly",
types = {
a_type({ typename = "function", args = TUPLE({ NUMBER, NUMBER }), rets = TUPLE({ INTEGER }) }),
a_type({ typename = "function", args = TUPLE({}), rets = TUPLE({ NUMBER }) }),
},
}),
["randomseed"] = a_type({ typename = "function", args = TUPLE({ NUMBER, NUMBER }), rets = TUPLE({ INTEGER, INTEGER }) }),
["sin"] = a_type({ typename = "function", args = TUPLE({ NUMBER }), rets = TUPLE({ NUMBER }) }),
["sinh"] = a_type({ typename = "function", args = TUPLE({ NUMBER }), rets = TUPLE({ NUMBER }) }),
["sqrt"] = a_type({ typename = "function", args = TUPLE({ NUMBER }), rets = TUPLE({ NUMBER }) }),
["tan"] = a_type({ typename = "function", args = TUPLE({ NUMBER }), rets = TUPLE({ NUMBER }) }),
["tanh"] = a_type({ typename = "function", args = TUPLE({ NUMBER }), rets = TUPLE({ NUMBER }) }),
["tointeger"] = a_type({ typename = "function", args = TUPLE({ ANY }), rets = TUPLE({ INTEGER }) }),
["type"] = a_type({ typename = "function", args = TUPLE({ ANY }), rets = TUPLE({ STRING }) }),
["ult"] = a_type({ typename = "function", args = TUPLE({ NUMBER, NUMBER }), rets = TUPLE({ BOOLEAN }) }),
},
}),
["os"] = a_type({
typename = "record",
fields = {
["clock"] = a_type({ typename = "function", args = TUPLE({}), rets = TUPLE({ NUMBER }) }),
["date"] = a_type({
typename = "poly",
types = {
a_type({ typename = "function", args = TUPLE({}), rets = TUPLE({ STRING }) }),
a_type({ typename = "function", args = TUPLE({ OS_DATE_TABLE_FORMAT, NUMBER }), rets = TUPLE({ OS_DATE_TABLE }) }),
a_type({ typename = "function", args = TUPLE({ OPT(STRING), OPT(NUMBER) }), rets = TUPLE({ STRING }) }),
},
}),
["difftime"] = a_type({ typename = "function", args = TUPLE({ NUMBER, NUMBER }), rets = TUPLE({ NUMBER }) }),
["execute"] = a_type({ typename = "function", args = TUPLE({ STRING }), rets = TUPLE({ BOOLEAN, STRING, INTEGER }) }),
["exit"] = a_type({ typename = "function", args = TUPLE({ UNION({ NUMBER, BOOLEAN }), BOOLEAN }), rets = TUPLE({}) }),
["getenv"] = a_type({ typename = "function", args = TUPLE({ STRING }), rets = TUPLE({ STRING }) }),
["remove"] = a_type({ typename = "function", args = TUPLE({ STRING }), rets = TUPLE({ BOOLEAN, STRING }) }),
["rename"] = a_type({ typename = "function", args = TUPLE({ STRING, STRING }), rets = TUPLE({ BOOLEAN, STRING }) }),
["setlocale"] = a_type({ typename = "function", args = TUPLE({ STRING, OPT(STRING) }), rets = TUPLE({ STRING }) }),
["time"] = a_type({ typename = "function", args = TUPLE({ OPT(OS_DATE_TABLE) }), rets = TUPLE({ INTEGER }) }),
["tmpname"] = a_type({ typename = "function", args = TUPLE({}), rets = TUPLE({ STRING }) }),
},
}),
["package"] = a_type({
typename = "record",
fields = {
["config"] = STRING,
["cpath"] = STRING,
["loaded"] = a_type({
typename = "map",
keys = STRING,
values = ANY,
}),
["loaders"] = a_type({
typename = "array",
elements = a_type({ typename = "function", args = TUPLE({ STRING }), rets = TUPLE({ ANY }) }),
}),
["loadlib"] = a_type({ typename = "function", args = TUPLE({ STRING, STRING }), rets = TUPLE({ FUNCTION }) }),
["path"] = STRING,
["preload"] = TABLE,
["searchers"] = a_type({
typename = "array",
elements = a_type({ typename = "function", args = TUPLE({ STRING }), rets = TUPLE({ ANY }) }),
}),
["searchpath"] = a_type({ typename = "function", args = TUPLE({ STRING, STRING, OPT(STRING), OPT(STRING) }), rets = TUPLE({ STRING, STRING }) }),
},
}),
["string"] = a_type({
typename = "record",
fields = {
["byte"] = a_type({
typename = "poly",
types = {
a_type({ typename = "function", args = TUPLE({ STRING, OPT(NUMBER) }), rets = TUPLE({ INTEGER }) }),
a_type({ typename = "function", args = TUPLE({ STRING, NUMBER, NUMBER }), rets = VARARG({ INTEGER }) }),
},
}),
["char"] = a_type({ typename = "function", args = VARARG({ NUMBER }), rets = TUPLE({ STRING }) }),
["dump"] = a_type({ typename = "function", args = TUPLE({ FUNCTION, OPT(BOOLEAN) }), rets = TUPLE({ STRING }) }),
["find"] = a_type({ typename = "function", args = TUPLE({ STRING, STRING, OPT(NUMBER), OPT(BOOLEAN) }), rets = VARARG({ INTEGER, INTEGER, STRING }) }),
["format"] = a_type({ typename = "function", args = VARARG({ STRING, ANY }), rets = TUPLE({ STRING }) }),
["gmatch"] = a_type({ typename = "function", args = TUPLE({ STRING, STRING }), rets = TUPLE({
a_type({ typename = "function", args = TUPLE({}), rets = VARARG({ STRING }) }),
}), }),
["gsub"] = a_type({
typename = "poly",
types = {
a_type({ typename = "function", args = TUPLE({ STRING, STRING, STRING, NUMBER }), rets = TUPLE({ STRING, INTEGER }) }),
a_type({ typename = "function", args = TUPLE({ STRING, STRING, a_type({ typename = "map", keys = STRING, values = STRING }), NUMBER }), rets = TUPLE({ STRING, INTEGER }) }),
a_type({ typename = "function", args = TUPLE({ STRING, STRING, a_type({ typename = "function", args = VARARG({ STRING }), rets = TUPLE({ STRING }) }) }), rets = TUPLE({ STRING, INTEGER }) }),
a_type({ typename = "function", args = TUPLE({ STRING, STRING, a_type({ typename = "function", args = VARARG({ STRING }), rets = TUPLE({ NUMBER }) }) }), rets = TUPLE({ STRING, INTEGER }) }),
a_type({ typename = "function", args = TUPLE({ STRING, STRING, a_type({ typename = "function", args = VARARG({ STRING }), rets = TUPLE({ BOOLEAN }) }) }), rets = TUPLE({ STRING, INTEGER }) }),
a_type({ typename = "function", args = TUPLE({ STRING, STRING, a_type({ typename = "function", args = VARARG({ STRING }), rets = TUPLE({}) }) }), rets = TUPLE({ STRING, INTEGER }) }),
},
}),
["len"] = a_type({ typename = "function", args = TUPLE({ STRING }), rets = TUPLE({ INTEGER }) }),
["lower"] = a_type({ typename = "function", args = TUPLE({ STRING }), rets = TUPLE({ STRING }) }),
["match"] = a_type({ typename = "function", args = TUPLE({ STRING, STRING, NUMBER }), rets = VARARG({ STRING }) }),
["pack"] = a_type({ typename = "function", args = VARARG({ STRING, ANY }), rets = TUPLE({ STRING }) }),
["packsize"] = a_type({ typename = "function", args = TUPLE({ STRING }), rets = TUPLE({ INTEGER }) }),
["rep"] = a_type({ typename = "function", args = TUPLE({ STRING, NUMBER, OPT(STRING) }), rets = TUPLE({ STRING }) }),
["reverse"] = a_type({ typename = "function", args = TUPLE({ STRING }), rets = TUPLE({ STRING }) }),
["sub"] = a_type({ typename = "function", args = TUPLE({ STRING, NUMBER, NUMBER }), rets = TUPLE({ STRING }) }),
["unpack"] = a_type({ typename = "function", args = TUPLE({ STRING, STRING, OPT(NUMBER) }), rets = VARARG({ ANY }) }),
["upper"] = a_type({ typename = "function", args = TUPLE({ STRING }), rets = TUPLE({ STRING }) }),
},
}),
["table"] = a_type({
typename = "record",
fields = {
["concat"] = a_type({ typename = "function", args = TUPLE({ ARRAY(UNION({ STRING, NUMBER })), OPT(STRING), OPT(NUMBER), OPT(NUMBER) }), rets = TUPLE({ STRING }) }),
["insert"] = a_type({
typename = "poly",
types = {
a_gfunction(1, function(a) return { args = TUPLE({ ARRAY(a), NUMBER, a }), rets = TUPLE({}) } end),
a_gfunction(1, function(a) return { args = TUPLE({ ARRAY(a), a }), rets = TUPLE({}) } end),
},
}),
["move"] = a_type({
typename = "poly",
types = {
a_gfunction(1, function(a) return { args = TUPLE({ ARRAY(a), NUMBER, NUMBER, NUMBER }), rets = TUPLE({ ARRAY(a) }) } end),
a_gfunction(1, function(a) return { args = TUPLE({ ARRAY(a), NUMBER, NUMBER, NUMBER, ARRAY(a) }), rets = TUPLE({ ARRAY(a) }) } end),
},
}),
["pack"] = a_type({ typename = "function", args = VARARG({ ANY }), rets = TUPLE({ TABLE }) }),
["remove"] = a_gfunction(1, function(a) return { args = TUPLE({ ARRAY(a), OPT(NUMBER) }), rets = TUPLE({ a }) } end),
["sort"] = a_gfunction(1, function(a) return { args = TUPLE({ ARRAY(a), OPT(TABLE_SORT_FUNCTION) }), rets = TUPLE({}) } end),
["unpack"] = a_gfunction(1, function(a) return { needs_compat = true, args = TUPLE({ ARRAY(a), NUMBER, NUMBER }), rets = VARARG({ a }) } end),
},
}),
["utf8"] = a_type({
typename = "record",
fields = {
["char"] = a_type({ typename = "function", args = VARARG({ NUMBER }), rets = TUPLE({ STRING }) }),
["charpattern"] = STRING,
["codepoint"] = a_type({ typename = "function", args = TUPLE({ STRING, OPT(NUMBER), OPT(NUMBER) }), rets = VARARG({ INTEGER }) }),
["codes"] = a_type({ typename = "function", args = TUPLE({ STRING }), rets = TUPLE({
a_type({ typename = "function", args = TUPLE({}), rets = TUPLE({ NUMBER, STRING }) }),
}), }),
["len"] = a_type({ typename = "function", args = TUPLE({ STRING, NUMBER, NUMBER }), rets = TUPLE({ INTEGER }) }),
["offset"] = a_type({ typename = "function", args = TUPLE({ STRING, NUMBER, NUMBER }), rets = TUPLE({ INTEGER }) }),
},
}),
["_VERSION"] = STRING,
}
for _, t in pairs(standard_library) do
fill_field_order(t)
if is_typetype(t) then
fill_field_order(t.def)
end
end
fill_field_order(OS_DATE_TABLE)
fill_field_order(DEBUG_GETINFO_TABLE)
NOMINAL_FILE.found = standard_library["FILE"]
for _, m in ipairs(metatable_nominals) do
m.found = standard_library["metatable"]
end
for name, typ in pairs(standard_library) do
globals[name] = { t = typ, needs_compat = stdlib_compat[name], attribute = "const" }
end
globals["@is_va"] = { t = ANY }
if not is_first_init then
last_typeid = save_typeid
end
return globals, standard_library
end
tl.init_env = function(lax, gen_compat, gen_target, predefined)
if gen_compat == true or gen_compat == nil then
gen_compat = "optional"
elseif gen_compat == false then
gen_compat = "off"
end
gen_compat = gen_compat
if not gen_target then
if _VERSION == "Lua 5.1" or _VERSION == "Lua 5.2" then
gen_target = "5.1"
else
gen_target = "5.3"
end
end
if gen_target == "5.4" and gen_compat ~= "off" then
return nil, "gen-compat must be explicitly 'off' when gen-target is '5.4'"
end
local globals, standard_library = init_globals(lax)
local env = {
ok = true,
modules = {},
loaded = {},
loaded_order = {},
globals = globals,
gen_compat = gen_compat,
gen_target = gen_target,
}
for name, var in pairs(standard_library) do
if var.typename == "record" then
env.modules[name] = var
end
end
if predefined then
for _, name in ipairs(predefined) do
local module_type = require_module(name, lax, env)
if module_type == INVALID then
return nil, string.format("Error: could not predefine module '%s'", name)
end
end
end
return env
end
tl.type_check = function(ast, opts)
opts = opts or {}
local env = opts.env
if not env then
local err
env, err = tl.init_env(opts.lax, opts.gen_compat, opts.gen_target)
if err then
return nil, err
end
end
if opts.module_name then
env.modules[opts.module_name] = a_type({ typename = "typetype", def = CIRCULAR_REQUIRE })
end
local lax = opts.lax
local filename = opts.filename
local st = { env.globals }
local symbol_list = {}
local symbol_list_n = 0
local all_needs_compat = {}
local dependencies = {}
local warnings = {}
local errors = {}
local module_type
local function find_var(name, use)
for i = #st, 1, -1 do
local scope = st[i]
local var = scope[name]
if var then
if use == "lvalue" and var.is_narrowed then
if var.narrowed_from then
var.used = true
return { t = var.narrowed_from, attribute = var.attribute }, i, var.attribute
end
else
if i == 1 and var.needs_compat then
all_needs_compat[name] = true
end
if use == "use_type" then
var.used_as_type = true
elseif use ~= "check_only" then
var.used = true
end
return var, i, var.attribute
end
end
end
end
local function simulate_g()
local globals = {}
for k, v in pairs(st[1]) do
if k:sub(1, 1) ~= "@" then
globals[k] = v.t
end
end
return {
typeid = globals_typeid,
typename = "record",
field_order = sorted_keys(globals),
fields = globals,
}, nil
end
local resolve_typevars
local function fresh_typevar(t)
return a_type({
typename = "typevar",
typevar = (t.typevar:gsub("@.*", "")) .. "@" .. fresh_typevar_ctr,
})
end
local function fresh_typearg(t)
return a_type({
typename = "typearg",
typearg = (t.typearg:gsub("@.*", "")) .. "@" .. fresh_typevar_ctr,
})
end
local function ensure_fresh_typeargs(t)
if not t.typeargs then
return t
end
fresh_typevar_ctr = fresh_typevar_ctr + 1
local ok
ok, t = resolve_typevars(t, fresh_typevar, fresh_typearg)
assert(ok, "Internal Compiler Error: error creating fresh type variables")
return t
end
local function find_var_type(name, use)
local var = find_var(name, use)
if var then
local t = var.t
if t.typename == "unresolved_typearg" then
return nil
end
t = ensure_fresh_typeargs(t)
return t, var.attribute
end
end
local function error_in_type(where, msg, ...)
local n = select("#", ...)
if n > 0 then
local showt = {}
for i = 1, n do
local t = select(i, ...)
if t then
if t.typename == "invalid" then
return nil
end
showt[i] = show_type(t)
end
end
msg = msg:format(_tl_table_unpack(showt))
end
return {
y = where.y,
x = where.x,
msg = msg,
filename = where.filename or filename,
}
end
local function type_error(t, msg, ...)
local e = error_in_type(t, msg, ...)
if e then
table.insert(errors, e)
return true
else
return false
end
end
local function find_type(names, accept_typearg)
local typ = find_var_type(names[1], "use_type")
if not typ then
return nil
end
if typ.found then
typ = typ.found
end
for i = 2, #names do
local fields = typ.fields or (typ.def and typ.def.fields)
if fields then
typ = fields[names[i]]
if typ == nil then
return nil
end
typ = ensure_fresh_typeargs(typ)
if typ.found then
typ = typ.found
end
else
return nil
end
end
if is_typetype(typ) or (accept_typearg and typ.typename == "typearg") then
return typ
end
end
local function union_type(t)
if is_typetype(t) then
return union_type(t.def)
elseif t.typename == "tuple" then
return union_type(t[1])
elseif t.typename == "nominal" then
local typetype = t.found or find_type(t.names)
if not typetype then
return "invalid"
end
return union_type(typetype)
elseif t.typename == "record" then
if t.is_userdata then
return "userdata"
end
return "table"
elseif table_types[t.typename] then
return "table"
else
return t.typename
end
end
local function is_valid_union(typ)
local n_table_types = 0
local n_function_types = 0
local n_userdata_types = 0
local n_string_enum = 0
local has_primitive_string_type = false
for _, t in ipairs(typ.types) do
local ut = union_type(t)
if ut == "userdata" then
n_userdata_types = n_userdata_types + 1
if n_userdata_types > 1 then
return false, "cannot discriminate a union between multiple userdata types: %s"
end
elseif ut == "table" then
n_table_types = n_table_types + 1
if n_table_types > 1 then
return false, "cannot discriminate a union between multiple table types: %s"
end
elseif ut == "function" then
n_function_types = n_function_types + 1
if n_function_types > 1 then
return false, "cannot discriminate a union between multiple function types: %s"
end
elseif ut == "enum" or (ut == "string" and not has_primitive_string_type) then
n_string_enum = n_string_enum + 1
if n_string_enum > 1 then
return false, "cannot discriminate a union between multiple string/enum types: %s"
end
if ut == "string" then
has_primitive_string_type = true
end
elseif ut == "invalid" then
return false, nil
end
end
return true
end
local function validate_union(where, u, store_errs, errs)
local valid, err = is_valid_union(u)
if err then
if store_errs then
errs = errs or {}
else
errs = errors
end
table.insert(errs, error_in_type(where, err, u))
end
if not valid then
u = INVALID
end
return u, store_errs and errs
end
local function resolve_typetype(t)
if is_typetype(t) then
return t.def
else
return t
end
end
local no_nested_types = {
["string"] = true,
["number"] = true,
["integer"] = true,
["boolean"] = true,
["thread"] = true,
["any"] = true,
["enum"] = true,
["nil"] = true,
["unknown"] = true,
}
local function default_resolve_typevars_callback(t)
local rt = find_var_type(t.typevar)
if not rt then
return nil
elseif rt.typename == "string" then
return STRING
end
return rt
end
resolve_typevars = function(typ, fn_var, fn_arg)
local errs
local seen = {}
local resolved = {}
fn_var = fn_var or default_resolve_typevars_callback
local function resolve(t, all_same)
local same = true
if no_nested_types[t.typename] or (t.typename == "nominal" and not t.typevals) then
return t, all_same
end
if seen[t] then
return seen[t], all_same
end
local orig_t = t
if t.typename == "typevar" then
local rt = fn_var(t)
if rt then
resolved[orig_t.typevar] = true
if no_nested_types[rt.typename] or (rt.typename == "nominal" and not rt.typevals) then
seen[orig_t] = rt
return rt, false
end
same = false
t = rt
end
end
local copy = {}
seen[orig_t] = copy
copy.is_userdata = t.is_userdata
copy.typename = t.typename
copy.filename = t.filename
copy.x = t.x
copy.y = t.y
copy.yend = t.yend
copy.xend = t.xend
copy.names = t.names
for i, tf in ipairs(t) do
copy[i], same = resolve(tf, same)
end
if t.typename == "array" then
copy.elements, same = resolve(t.elements, same)
elseif t.typename == "typearg" then
if fn_arg then
copy = fn_arg(t)
else
copy.typearg = t.typearg
end
elseif t.typename == "unresolvable_typearg" then
copy.typearg = t.typearg
elseif t.typename == "typevar" then
copy.typevar = t.typevar
elseif is_typetype(t) then
copy.def, same = resolve(t.def, same)
elseif t.typename == "nominal" then
copy.typevals, same = resolve(t.typevals, same)
copy.found = t.found
elseif t.typename == "function" then
if t.typeargs then
copy.typeargs = {}
for i, tf in ipairs(t.typeargs) do
copy.typeargs[i], same = resolve(tf, same)
end
end
copy.is_method = t.is_method
copy.args, same = resolve(t.args, same)
copy.rets, same = resolve(t.rets, same)
elseif t.typename == "record" or t.typename == "arrayrecord" then
if t.typeargs then
copy.typeargs = {}
for i, tf in ipairs(t.typeargs) do
copy.typeargs[i], same = resolve(tf, same)
end
end
if t.elements then
copy.elements, same = resolve(t.elements, same)
end
copy.fields = {}
copy.field_order = {}
for i, k in ipairs(t.field_order) do
copy.field_order[i] = k
copy.fields[k], same = resolve(t.fields[k], same)
end
if t.meta_fields then
copy.meta_fields = {}
copy.meta_field_order = {}
for i, k in ipairs(t.meta_field_order) do
copy.meta_field_order[i] = k
copy.meta_fields[k], same = resolve(t.meta_fields[k], same)
end
end
elseif t.typename == "map" then
copy.keys, same = resolve(t.keys, same)
copy.values, same = resolve(t.values, same)
elseif t.typename == "union" then
copy.types = {}
for i, tf in ipairs(t.types) do
copy.types[i], same = resolve(tf, same)
end
copy, errs = validate_union(t, copy, true, errs)
elseif t.typename == "poly" or t.typename == "tupletable" then
copy.types = {}
for i, tf in ipairs(t.types) do
copy.types[i], same = resolve(tf, same)
end
elseif t.typename == "tuple" then
copy.is_va = t.is_va
end
copy.typeid = same and orig_t.typeid or new_typeid()
return copy, same and all_same
end
local copy, same = resolve(typ, true)
if errs then
return false, INVALID, errs
end
if copy.typeargs and not same then
for i = #copy.typeargs, 1, -1 do
if resolved[copy.typeargs[i].typearg] then
table.remove(copy.typeargs, i)
end
end
if not copy.typeargs[1] then
copy.typeargs = nil
end
end
return true, copy
end
local function infer_emptytable(emptytable, fresh_t)
local is_global = (emptytable.declared_at and emptytable.declared_at.kind == "global_declaration")
local nst = is_global and 1 or #st
for i = nst, 1, -1 do
local scope = st[i]
if scope[emptytable.assigned_to] then
scope[emptytable.assigned_to] = { t = fresh_t }
end
end
end
local function resolve_tuple(t)
if t.typename == "tuple" then
t = t[1]
end
if t == nil then
return NIL
end
return t
end
local function node_warning(tag, node, fmt, ...)
table.insert(warnings, {
y = node.y,
x = node.x,
msg = fmt:format(...),
filename = filename,
tag = tag,
})
end
local function node_error(node, msg, ...)
type_error(node, msg, ...)
node.type = INVALID
return node.type
end
local function terr(t, s, ...)
return { error_in_type(t, s, ...) }
end
local function add_unknown(node, name)
node_warning("unknown", node, "unknown variable: %s", name)
end
local function redeclaration_warning(node, old_var)
if node.tk:sub(1, 1) == "_" then return end
local var_kind = "variable"
local var_name = node.tk
if node.kind == "local_function" or node.kind == "record_function" then
var_kind = "function"
var_name = node.name.tk
end
local short_error = "redeclaration of " .. var_kind .. " '%s'"
if old_var and old_var.declared_at then
node_warning("redeclaration", node, short_error .. " (originally declared at %d:%d)", var_name, old_var.declared_at.y, old_var.declared_at.x)
else
node_warning("redeclaration", node, short_error, var_name)
end
end
local function check_if_redeclaration(new_name, at)
local old = find_var(new_name, "check_only")
if old then
redeclaration_warning(at, old)
end
end
local function unused_warning(name, var)
local prefix = name:sub(1, 1)
if var.declared_at and
not (var.is_narrowed == "is") and
prefix ~= "_" and
prefix ~= "@" then
if name:sub(1, 2) == "::" then
node_warning("unused", var.declared_at, "unused label %s", name)
else
node_warning(
"unused",
var.declared_at,
"unused %s %s: %s",
var.is_func_arg and "argument" or
var.t.typename == "function" and "function" or
is_typetype(var.t) and "type" or
"variable",
name,
show_type(var.t))
end
end
end
local function add_errs_prefixing(where, src, dst, prefix)
if not src then
return
end
for _, err in ipairs(src) do
err.msg = prefix .. err.msg
if where and where.y and (
(err.filename ~= filename) or
(not err.y) or
(where.y > err.y or (where.y == err.y and where.x > err.x))) then
err.y = where.y
err.x = where.x
err.filename = filename
end
table.insert(dst, err)
end
end
local function resolve_typevars_at(where, t)
assert(where)
local ok, typ, errs = resolve_typevars(t)
if not ok then
assert(where.y)
add_errs_prefixing(where, errs, errors, "")
end
return typ
end
local function infer_at(where, t)
local ret = resolve_typevars_at(where, t)
ret = (ret ~= t) and ret or shallow_copy_type(t)
ret.inferred_at = where
ret.inferred_at_file = filename
return ret
end
local function drop_constant_value(t)
if not t.tk then
return t
end
local ret = shallow_copy_type(t)
ret.tk = nil
return ret
end
local function reserve_symbol_list_slot(node)
symbol_list_n = symbol_list_n + 1
node.symbol_list_slot = symbol_list_n
end
local get_unresolved
local function add_to_scope(node, name, t, attribute, narrow, dont_check_redeclaration)
local scope = st[#st]
local var = scope[name]
if narrow then
if var then
if var.is_narrowed then
var.t = t
return var
end
var.is_narrowed = narrow
var.narrowed_from = var.t
var.t = t
else
var = { t = t, attribute = attribute, is_narrowed = narrow, declared_at = node }
scope[name] = var
end
local unresolved = get_unresolved(scope)
unresolved.narrows[name] = true
return var
end
if not dont_check_redeclaration and
node and
name ~= "self" and
name ~= "..." and
name:sub(1, 1) ~= "@" then
check_if_redeclaration(name, node)
end
if var and not var.used then
unused_warning(name, var)
end
var = { t = t, attribute = attribute, is_narrowed = nil, declared_at = node }
scope[name] = var
return var
end
local function add_var(node, name, t, attribute, narrow, dont_check_redeclaration)
if lax and node and is_unknown(t) and (name ~= "self" and name ~= "...") and not narrow then
add_unknown(node, name)
end
if not attribute then
t = drop_constant_value(t)
end
local var = add_to_scope(node, name, t, attribute, narrow, dont_check_redeclaration)
if node and t.typename ~= "unresolved" and t.typename ~= "none" then
node.type = node.type or t
local slot
if node.symbol_list_slot then
slot = node.symbol_list_slot
else
symbol_list_n = symbol_list_n + 1
slot = symbol_list_n
end
symbol_list[slot] = { y = node.y, x = node.x, name = name, typ = t }
end
return var
end
local function compare_and_infer_typevars(t1, t2, comp)
if t1.typevar == t2.typevar then
return true
end
local typevar = t2.typevar or t1.typevar
local vt = find_var_type(typevar)
if vt then
if t2.typevar then
return comp(t1, vt)
else
return comp(vt, t2)
end
else
local other = t2.typevar and t1 or t2
local ok, resolved, errs = resolve_typevars(other)
if not ok then
return false, errs
end
if resolved.typename ~= "unknown" then
resolved = resolve_typetype(resolved)
add_var(nil, typevar, resolved)
end
return true
end
end
local same_type
local is_a
local function match_record_fields(rec1, t2, invariant)
local fielderrs = {}
for _, k in ipairs(rec1.field_order) do
local f = rec1.fields[k]
local t2k = t2(k)
if t2k == nil then
if (not lax) and invariant then
table.insert(fielderrs, error_in_type(f, "unknown field " .. k))
end
else
local ok, errs
if invariant then
ok, errs = same_type(f, t2k)
else
ok, errs = is_a(f, t2k)
end
if not ok then
add_errs_prefixing(nil, errs, fielderrs, "record field doesn't match: " .. k .. ": ")
end
end
end
if #fielderrs > 0 then
return false, fielderrs
end
return true
end
local function match_fields_to_record(rec1, rec2, invariant)
if rec1.is_userdata ~= rec2.is_userdata then
return false, { error_in_type(rec1, "userdata record doesn't match: %s", rec2) }
end
local ok, fielderrs = match_record_fields(rec1, function(k) return rec2.fields[k] end, invariant)
if not ok then
local errs = {}
add_errs_prefixing(nil, fielderrs, errs, show_type(rec1) .. " is not a " .. show_type(rec2) .. ": ")
return false, errs
end
return true
end
local function match_fields_to_map(rec1, map)
if not match_record_fields(rec1, function(_) return map.values end) then
return false, { error_in_type(rec1, "record is not a valid map; not all fields have the same type") }
end
return true
end
local function arg_check(where, cmp, a, b, n, errs, ctx)
local matches, match_errs = cmp(a, b)
if not matches then
add_errs_prefixing(where, match_errs, errs, ctx .. (n and " " .. n or "") .. ": ")
return false
end
return true
end
local function has_all_types_of(t1s, t2s, cmp)
for _, t1 in ipairs(t1s) do
local found = false
for _, t2 in ipairs(t2s) do
if cmp(t2, t1) then
found = true
break
end
end
if not found then
return false
end
end
return true
end
local function any_errors(all_errs)
if #all_errs == 0 then
return true
else
return false, all_errs
end
end
local function close_nested_records(t)
for _, ft in pairs(t.fields) do
if is_typetype(ft) then
ft.closed = true
if is_record_type(ft.def) then
close_nested_records(ft.def)
end
end
end
end
local function close_types(vars)
for _, var in pairs(vars) do
if is_typetype(var.t) then
var.t.closed = true
if is_record_type(var.t.def) then
close_nested_records(var.t.def)
end
end
end
end
local function check_for_unused_vars(vars)
if not next(vars) then
return
end
local list = {}
for name, var in pairs(vars) do
if var.declared_at and not var.used then
if var.used_as_type then
var.declared_at.elide_type = true
else
table.insert(list, { y = var.declared_at.y, x = var.declared_at.x, name = name, var = var })
end
elseif var.used and is_typetype(var.t) and var.aliasing then
var.aliasing.used = true
var.aliasing.declared_at.elide_type = false
end
end
if list[1] then
table.sort(list, function(a, b)
return a.y < b.y or (a.y == b.y and a.x < b.x)
end)
for _, u in ipairs(list) do
unused_warning(u.name, u.var)
end
end
end
get_unresolved = function(scope)
local unresolved
if scope then
local unr = scope["@unresolved"]
unresolved = unr and unr.t
else
unresolved = find_var_type("@unresolved")
end
if not unresolved then
unresolved = {
typename = "unresolved",
labels = {},
nominals = {},
global_types = {},
narrows = {},
}
add_var(nil, "@unresolved", unresolved)
end
return unresolved
end
local function begin_scope(node)
table.insert(st, {})
if node then
symbol_list_n = symbol_list_n + 1
symbol_list[symbol_list_n] = { y = node.y, x = node.x, name = "@{" }
end
end
local function end_scope(node)
local scope = st[#st]
local unresolved = scope["@unresolved"]
if unresolved then
local next_scope = st[#st - 1]
local upper = next_scope["@unresolved"]
if upper then
for name, nodes in pairs(unresolved.t.labels) do
for _, n in ipairs(nodes) do
upper.t.labels[name] = upper.t.labels[name] or {}
table.insert(upper.t.labels[name], n)
end
end
for name, types in pairs(unresolved.t.nominals) do
for _, typ in ipairs(types) do
upper.t.nominals[name] = upper.t.nominals[name] or {}
table.insert(upper.t.nominals[name], typ)
end
end
for name, _ in pairs(unresolved.t.global_types) do
upper.t.global_types[name] = true
end
else
next_scope["@unresolved"] = unresolved
unresolved.t.narrows = {}
end
end
close_types(scope)
check_for_unused_vars(scope)
table.remove(st)
if node then
if symbol_list[symbol_list_n].name == "@{" then
symbol_list[symbol_list_n] = nil
symbol_list_n = symbol_list_n - 1
else
symbol_list_n = symbol_list_n + 1
symbol_list[symbol_list_n] = { y = assert(node.yend), x = assert(node.xend), name = "@}" }
end
end
end
local end_scope_and_none_type = function(node, _children)
end_scope(node)
node.type = NONE
return node.type
end
local resolve_nominal
do
local function match_typevals(t, def)
if t.typevals and def.typeargs then
if #t.typevals ~= #def.typeargs then
type_error(t, "mismatch in number of type arguments")
return nil
end
begin_scope()
for i, tt in ipairs(t.typevals) do
add_var(nil, def.typeargs[i].typearg, tt)
end
local ret = resolve_typevars_at(t, def)
end_scope()
return ret
elseif t.typevals then
type_error(t, "spurious type arguments")
return nil
elseif def.typeargs then
type_error(t, "missing type arguments in %s", def)
return nil
else
return def
end
end
resolve_nominal = function(t)
if t.resolved then
return t.resolved
end
local resolved
local typetype = t.found or find_type(t.names)
if not typetype then
type_error(t, "unknown type %s", t)
return INVALID
elseif is_typetype(typetype) then
if typetype.is_alias then
typetype = typetype.def.found
assert(is_typetype(typetype))
end
if typetype.def.typename == "circular_require" then
return typetype.def
end
if typetype.def.typename == "nominal" then
typetype = typetype.def.found
assert(is_typetype(typetype))
end
assert(typetype.def.typename ~= "nominal")
resolved = match_typevals(t, typetype.def)
else
type_error(t, table.concat(t.names, ".") .. " is not a type")
end
if not resolved then
resolved = a_type({ typename = "bad_nominal", names = t.names })
end
if not t.filename then
t.filename = resolved.filename
if t.x == nil and t.y == nil then
t.x = resolved.x
t.y = resolved.y
end
end
t.found = typetype
t.resolved = resolved
return resolved
end
end
local function are_same_unresolved_global_type(t1, t2)
if #t1.names == 1 and #t2.names == 1 and
t1.names[1] == t2.names[1] then
local unresolved = get_unresolved()
if unresolved.global_types[t1.names[1]] then
return true
end
end
return false
end
local function are_same_nominals(t1, t2)
local same_names
if t1.found and t2.found then
same_names = t1.found.typeid == t2.found.typeid
else
local ft1 = t1.found or find_type(t1.names)
local ft2 = t2.found or find_type(t2.names)
if ft1 and ft2 then
same_names = ft1.typeid == ft2.typeid
else
if are_same_unresolved_global_type(t1, t2) then
return true
end
if not ft1 then
type_error(t1, "unknown type %s", t1)
end
if not ft2 then
type_error(t2, "unknown type %s", t2)
end
return false, {}
end
end
if same_names then
if t1.typevals == nil and t2.typevals == nil then
return true
elseif t1.typevals and t2.typevals and #t1.typevals == #t2.typevals then
local all_errs = {}
for i = 1, #t1.typevals do
local _, errs = same_type(t1.typevals[i], t2.typevals[i])
add_errs_prefixing(t1, errs, all_errs, "type parameter <" .. show_type(t2.typevals[i]) .. ">: ")
end
if #all_errs == 0 then
return true
else
return false, all_errs
end
end
else
local t1name = show_type(t1)
local t2name = show_type(t2)
if t1name == t2name then
local t1r = resolve_nominal(t1)
if t1r.filename then
t1name = t1name .. " (defined in " .. t1r.filename .. ":" .. t1r.y .. ")"
end
local t2r = resolve_nominal(t2)
if t2r.filename then
t2name = t2name .. " (defined in " .. t2r.filename .. ":" .. t2r.y .. ")"
end
end
return false, terr(t1, t1name .. " is not a " .. t2name)
end
end
local is_lua_table_type
local resolve_tuple_and_nominal = nil
local function invariant_match_fields_to_record(t1, t2)
local ok, errs = match_fields_to_record(t1, t2, true)
if not ok then
return ok, errs
end
ok, errs = match_fields_to_record(t2, t1, true)
if not ok then
return ok, errs
end
return true
end
same_type = function(t1, t2)
assert(type(t1) == "table")
assert(type(t2) == "table")
if t1.typeid == t2.typeid then
if TL_DEBUG then
local st1, st2 = show_type_base(t1, false, {}), show_type_base(t2, false, {})
assert(st1 == st2, st1 .. " ~= " .. st2)
end
return true
end
if t1.typename == "typevar" or t2.typename == "typevar" then
return compare_and_infer_typevars(t1, t2, same_type)
end
if t1.typename == "emptytable" and is_lua_table_type(resolve_tuple_and_nominal(t2)) then
return true
end
if t2.typename == "emptytable" and is_lua_table_type(resolve_tuple_and_nominal(t1)) then
return true
end
if t1.typename ~= t2.typename then
return false, terr(t1, "got %s, expected %s", t1, t2)
end
if t1.typename == "array" then
return same_type(t1.elements, t2.elements)
elseif t1.typename == "tupletable" then
local all_errs = {}
for i = 1, math.min(#t1.types, #t2.types) do
local ok, err = same_type(t1.types[i], t2.types[i])
if not ok then
add_errs_prefixing(t1, err, all_errs, "values")
end
end
return any_errors(all_errs)
elseif t1.typename == "map" then
local all_errs = {}
local k_ok, k_errs = same_type(t1.keys, t2.keys)
if not k_ok then
add_errs_prefixing(t1, k_errs, all_errs, "keys")
end
local v_ok, v_errs = same_type(t1.values, t2.values)
if not v_ok then
add_errs_prefixing(t1, v_errs, all_errs, "values")
end
return any_errors(all_errs)
elseif t1.typename == "union" then
if has_all_types_of(t1.types, t2.types, same_type) and
has_all_types_of(t2.types, t1.types, same_type) then
return true
else
return false, terr(t1, "got %s, expected %s", t1, t2)
end
elseif t1.typename == "nominal" then
return are_same_nominals(t1, t2)
elseif t1.typename == "record" then
return invariant_match_fields_to_record(t1, t2)
elseif t1.typename == "function" then
local argdelta = t1.is_method and 1 or 0
if #t1.args ~= #t2.args then
if t1.is_method ~= t2.is_method then
return false, terr(t1, "different number of input arguments: method and non-method are not the same type")
end
return false, terr(t1, "different number of input arguments: got " .. #t1.args - argdelta .. ", expected " .. #t2.args - argdelta)
end
if #t1.rets ~= #t2.rets then
return false, terr(t1, "different number of return values: got " .. #t1.rets .. ", expected " .. #t2.rets)
end
local all_errs = {}
for i = 1, #t1.args do
arg_check(t1, same_type, t1.args[i], t2.args[i], i - argdelta, all_errs, "argument")
end
for i = 1, #t1.rets do
local _, errs = same_type(t1.rets[i], t2.rets[i])
add_errs_prefixing(t1, errs, all_errs, "return " .. i)
end
return any_errors(all_errs)
elseif t1.typename == "arrayrecord" then
local ok, errs = same_type(t1.elements, t2.elements)
if not ok then
return ok, errs
end
return invariant_match_fields_to_record(t1, t2)
end
return true
end
local function unite(types, flatten_constants)
if #types == 1 then
return types[1]
end
local ts = {}
local stack = {}
local types_seen = {}
types_seen[NIL.typeid] = true
types_seen["nil"] = true
local i = 1
while types[i] or stack[1] do
local t
if stack[1] then
t = table.remove(stack)
else
t = types[i]
i = i + 1
end
t = resolve_tuple(t)
if t.typename == "union" then
for _, s in ipairs(t.types) do
table.insert(stack, s)
end
else
if primitive[t.typename] and (flatten_constants or not t.tk) then
if not types_seen[t.typename] then
types_seen[t.typename] = true
table.insert(ts, t)
end
else
local typeid = t.typeid
if t.typename == "nominal" then
typeid = resolve_nominal(t).typeid
end
if not types_seen[typeid] then
types_seen[typeid] = true
table.insert(ts, t)
end
end
end
end
if types_seen[INVALID.typeid] then
return INVALID
end
if #ts == 1 then
return ts[1]
else
return a_type({
typename = "union",
types = ts,
})
end
end
local function add_map_errors(errs, ctx, ctx_errs)
if ctx_errs then
for _, err in ipairs(ctx_errs) do
err.msg = ctx .. err.msg
table.insert(errs, err)
end
end
end
local function combine_map_errs(key_errs, value_errs)
if not key_errs and not value_errs then
return true
end
local errs = {}
add_map_errors(errs, "in map key: ", key_errs)
add_map_errors(errs, "in map value: ", value_errs)
return false, errs
end
local known_table_types = {
array = true,
map = true,
record = true,
arrayrecord = true,
tupletable = true,
}
is_lua_table_type = function(t)
return known_table_types[t.typename] and not t.is_userdata
end
local expand_type
local function arraytype_from_tuple(where, tupletype)
local element_type = unite(tupletype.types, true)
local valid = element_type.typename ~= "union" and true or is_valid_union(element_type)
if valid then
return a_type({
elements = element_type,
typename = "array",
})
end
local arr_type = a_type({
elements = tupletype.types[1],
typename = "array",
})
for i = 2, #tupletype.types do
arr_type = expand_type(where, arr_type, a_type({ elements = tupletype.types[i], typename = "array" }))
if not arr_type or not arr_type.elements then
return nil, terr(tupletype, "unable to convert tuple %s to array", tupletype)
end
end
return arr_type
end
is_a = function(t1, t2, for_equality)
assert(type(t1) == "table")
assert(type(t2) == "table")
if lax and (is_unknown(t1) or is_unknown(t2)) then
return true
end
if t1.typeid == t2.typeid then
if TL_DEBUG then
local st1, st2 = show_type_base(t1, false, {}), show_type_base(t2, false, {})
assert(st1 == st2, st1 .. " ~= " .. st2)
end
return true
end
if t1.typename == "bad_nominal" or t2.typename == "bad_nominal" then
return false
end
if t2.typename ~= "tuple" then
t1 = resolve_tuple(t1)
end
if t2.typename == "tuple" and t1.typename ~= "tuple" then
t1 = a_type({
typename = "tuple",
[1] = t1,
})
end
if t1.typename == "typevar" or t2.typename == "typevar" then
return compare_and_infer_typevars(t1, t2, is_a)
end
if t1.typename == "nil" then
return true
end
if t2.typename == "any" then
return true
elseif t1.typename == "union" then
if t2.typename == "union" then
local used = {}
for _, t in ipairs(t1.types) do
local ok = false
begin_scope()
for _, u in ipairs(t2.types) do
if not used[u] then
if is_a(t, u, for_equality) then
used[u] = t
ok = true
break
end
end
end
end_scope()
if not ok then
return false, terr(t1, "got %s, expected %s", t1, t2)
end
end
for u, t in pairs(used) do
is_a(t, u, for_equality)
end
return true
else
for _, t in ipairs(t1.types) do
if not is_a(t, t2, for_equality) then
return false, terr(t1, "got %s, expected %s", t1, t2)
end
end
return true
end
elseif t2.typename == "union" then
for _, t in ipairs(t2.types) do
if is_a(t1, t, for_equality) then
return true
end
end
elseif t2.typename == "poly" then
for _, t in ipairs(t2.types) do
if not is_a(t1, t, for_equality) then
return false, terr(t1, "cannot match against all alternatives of the polymorphic type")
end
end
return true
elseif t1.typename == "poly" then
for _, t in ipairs(t1.types) do
if is_a(t, t2, for_equality) then
return true
end
end
return false, terr(t1, "cannot match against any alternatives of the polymorphic type")
elseif t1.typename == "nominal" and t2.typename == "nominal" then
local same, err = are_same_nominals(t1, t2)
if same then
return true
end
local t1r = resolve_tuple_and_nominal(t1)
local t2r = resolve_tuple_and_nominal(t2)
if is_record_type(t1r) and is_record_type(t2r) then
return same, err
else
return is_a(t1r, t2r, for_equality)
end
elseif t1.typename == "enum" and t2.typename == "string" then
local ok
if for_equality then
ok = t2.tk and t1.enumset[unquote(t2.tk)]
else
ok = true
end
if ok then
return true
else
return false, terr(t1, "enum is incompatible with %s", t2)
end
elseif t1.typename == "integer" and t2.typename == "number" then
return true
elseif t1.typename == "string" and t2.typename == "enum" then
local ok = t1.tk and t2.enumset[unquote(t1.tk)]
if ok then
return true
else
if t1.tk then
return false, terr(t1, "%s is not a member of %s", t1, t2)
else
return false, terr(t1, "string is not a %s", t2)
end
end
elseif t1.typename == "nominal" or t2.typename == "nominal" then
local t1r = resolve_tuple_and_nominal(t1)
local t2r = resolve_tuple_and_nominal(t2)
local ok, errs = is_a(t1r, t2r, for_equality)
if errs and #errs == 1 then
if errs[1].msg:match("^got ") then
errs = terr(t1, "got %s, expected %s", t1, t2)
end
end
return ok, errs
elseif t1.typename == "emptytable" and is_lua_table_type(t2) then
return true
elseif t2.typename == "array" then
if is_array_type(t1) then
if is_a(t1.elements, t2.elements) then
local t1e = resolve_tuple_and_nominal(t1.elements)
local t2e = resolve_tuple_and_nominal(t2.elements)
if t2e.typename == "enum" and t1e.typename == "string" and #t1.types > 1 then
for i = 2, #t1.types do
local t = t1.types[i]
if not is_a(t, t2e) then
return false, terr(t, "%s is not a member of %s", t, t2e)
end
end
end
return true
end
elseif t1.typename == "tupletable" then
if t2.inferred_len and t2.inferred_len > #t1.types then
return false, terr(t1, "incompatible length, expected maximum length of " .. tostring(#t1.types) .. ", got " .. tostring(t2.inferred_len))
end
local t1a, err = arraytype_from_tuple(t1.inferred_at, t1)
if not t1a then
return false, err
end
if not is_a(t1a, t2) then
return false, terr(t2, "got %s (from %s), expected %s", t1a, t1, t2)
end
return true
elseif t1.typename == "map" then
local _, errs_keys, errs_values
_, errs_keys = is_a(t1.keys, INTEGER)
_, errs_values = is_a(t1.values, t2.elements)
return combine_map_errs(errs_keys, errs_values)
end
elseif t2.typename == "record" then
if is_record_type(t1) then
return match_fields_to_record(t1, t2)
elseif is_typetype(t1) and is_record_type(t1.def) then
return is_a(t1.def, t2, for_equality)
end
elseif t2.typename == "arrayrecord" then
if t1.typename == "array" then
return is_a(t1.elements, t2.elements)
elseif t1.typename == "tupletable" then
if t2.inferred_len and t2.inferred_len > #t1.types then
return false, terr(t1, "incompatible length, expected maximum length of " .. tostring(#t1.types) .. ", got " .. tostring(t2.inferred_len))
end
local t1a, err = arraytype_from_tuple(t1.inferred_at, t1)
if not t1a then
return false, err
end
if not is_a(t1a, t2) then
return false, terr(t2, "got %s (from %s), expected %s", t1a, t1, t2)
end
return true
elseif t1.typename == "record" then
return match_fields_to_record(t1, t2)
elseif t1.typename == "arrayrecord" then
if not is_a(t1.elements, t2.elements) then
return false, terr(t1, "array parts have incompatible element types")
end
return match_fields_to_record(t1, t2)
elseif is_typetype(t1) and is_record_type(t1.def) then
return is_a(t1.def, t2, for_equality)
end
elseif t2.typename == "map" then
if t1.typename == "map" then
local _, errs_keys, errs_values
if t2.keys.typename ~= "any" then
_, errs_keys = same_type(t2.keys, t1.keys)
end
if t2.values.typename ~= "any" then
_, errs_values = same_type(t1.values, t2.values)
end
return combine_map_errs(errs_keys, errs_values)
elseif t1.typename == "array" or t1.typename == "tupletable" then
local elements
if t1.typename == "tupletable" then
local arr_type = arraytype_from_tuple(t1.inferred_at, t1)
if not arr_type then
return false, terr(t1, "Unable to convert tuple %s to map", t1)
end
elements = arr_type.elements
else
elements = t1.elements
end
local _, errs_keys, errs_values
_, errs_keys = is_a(INTEGER, t2.keys)
_, errs_values = is_a(elements, t2.values)
return combine_map_errs(errs_keys, errs_values)
elseif is_record_type(t1) then
if not is_a(t2.keys, STRING) then
return false, terr(t1, "can't match a record to a map with non-string keys")
end
if t2.keys.typename == "enum" then
for _, k in ipairs(t1.field_order) do
if not t2.keys.enumset[k] then
return false, terr(t1, "key is not an enum value: " .. k)
end
end
end
return match_fields_to_map(t1, t2)
end
elseif t2.typename == "tupletable" then
if t1.typename == "tupletable" then
for i = 1, math.min(#t1.types, #t2.types) do
if not is_a(t1.types[i], t2.types[i], for_equality) then
return false, terr(t1, "in tuple entry " .. tostring(i) .. ": got %s, expected %s", t1.types[i], t2.types[i])
end
end
if for_equality and #t1.types ~= #t2.types then
return false, terr(t1, "tuples are not the same size")
end
if #t1.types > #t2.types then
return false, terr(t1, "tuple %s is too big for tuple %s", t1, t2)
end
return true
elseif is_array_type(t1) then
if t1.inferred_len and t1.inferred_len > #t2.types then
return false, terr(t1, "incompatible length, expected maximum length of " .. tostring(#t2.types) .. ", got " .. tostring(t1.inferred_len))
end
local len = (t1.inferred_len and t1.inferred_len > 0) and
t1.inferred_len or
#t2.types
for i = 1, len do
if not is_a(t1.elements, t2.types[i], for_equality) then
return false, terr(t1, "tuple entry " .. tostring(i) .. " of type %s does not match type of array elements, which is %s", t2.types[i], t1.elements)
end
end
return true
end
elseif t1.typename == "function" and t2.typename == "function" then
local all_errs = {}
if (not t2.args.is_va) and #t1.args > #t2.args then
table.insert(all_errs, error_in_type(t1, "incompatible number of arguments: got " .. #t1.args .. " %s, expected " .. #t2.args .. " %s", t1.args, t2.args))
else
for i = ((t1.is_method or t2.is_method) and 2 or 1), #t1.args do
arg_check(nil, is_a, t1.args[i], t2.args[i] or ANY, i, all_errs, "argument")
end
end
local diff_by_va = #t2.rets - #t1.rets == 1 and t2.rets.is_va
if #t1.rets < #t2.rets and not diff_by_va then
table.insert(all_errs, error_in_type(t1, "incompatible number of returns: got " .. #t1.rets .. " %s, expected " .. #t2.rets .. " %s", t1.rets, t2.rets))
else
local nrets = #t2.rets
if diff_by_va then
nrets = nrets - 1
end
for i = 1, nrets do
local _, errs = is_a(t1.rets[i], t2.rets[i])
add_errs_prefixing(nil, errs, all_errs, "return " .. i .. ": ")
end
end
if #all_errs == 0 then
return true
else
return false, all_errs
end
elseif lax and ((not for_equality) and t2.typename == "boolean") then
return true
elseif t1.typename == t2.typename then
return true
end
return false, terr(t1, "got %s, expected %s", t1, t2)
end
local function assert_is_a(node, t1, t2, context, name)
t1 = resolve_tuple(t1)
t2 = resolve_tuple(t2)
if lax and (is_unknown(t1) or is_unknown(t2)) then
return true
end
if t1.typename == "nil" then
return true
elseif t2.typename == "unresolved_emptytable_value" then
if is_number_type(t2.emptytable_type.keys) then
infer_emptytable(t2.emptytable_type, infer_at(node, a_type({ typename = "array", elements = t1 })))
else
infer_emptytable(t2.emptytable_type, infer_at(node, a_type({ typename = "map", keys = t2.emptytable_type.keys, values = t1 })))
end
return true
elseif t2.typename == "emptytable" then
if is_lua_table_type(t1) then
infer_emptytable(t2, infer_at(node, t1))
elseif t1.typename ~= "emptytable" then
node_error(node, context .. ": " .. (name and (name .. ": ") or "") .. "assigning %s to a variable declared with {}", t1)
return false
end
return true
end
local ok, match_errs = is_a(t1, t2)
add_errs_prefixing(node, match_errs, errors, context .. ": " .. (name and (name .. ": ") or ""))
return ok
end
local function type_is_closable(t)
if t.typename == "invalid" then
return false
end
if same_type(t, NIL) then
return true
end
if t.typename ~= "function" then
t = resolve_nominal(t)
end
return t.meta_fields and t.meta_fields["__close"] ~= nil
end
local definitely_not_closable_exprs = {
["string"] = true,
["number"] = true,
["integer"] = true,
["boolean"] = true,
["table_literal"] = true,
}
local function expr_is_definitely_not_closable(e)
return definitely_not_closable_exprs[e.kind]
end
local unknown_dots = {}
local function add_unknown_dot(node, name)
if not unknown_dots[name] then
unknown_dots[name] = true
add_unknown(node, name)
end
end
local function resolve_for_call(func, args, is_method)
if lax and is_unknown(func) then
func = a_type({ typename = "function", args = VARARG({ UNKNOWN }), rets = VARARG({ UNKNOWN }) })
end
func = resolve_tuple_and_nominal(func)
if func.typename ~= "function" and func.typename ~= "poly" then
if is_typetype(func) and func.def.typename == "record" then
func = func.def
end
if func.meta_fields and func.meta_fields["__call"] then
table.insert(args, 1, func)
func = func.meta_fields["__call"]
func = resolve_tuple_and_nominal(func)
is_method = true
end
end
return func, is_method
end
local type_check_function_call
do
local function mark_invalid_typeargs(f)
if f.typeargs then
for _, a in ipairs(f.typeargs) do
if not find_var_type(a.typearg) then
add_var(nil, a.typearg, lax and UNKNOWN or { typename = "unresolvable_typearg", typearg = a.typearg })
end
end
end
end
local function infer_emptytables(where, wheres, xs, ys, delta)
assert(xs.typename == "tuple")
assert(ys.typename == "tuple")
local n_xs = #xs
local n_ys = #ys
for i = 1, n_xs do
local x = xs[i]
if x.typename == "emptytable" or x.typename == "unresolved_emptytable_value" then
local y = ys[i] or (ys.is_va and ys[n_ys])
if y then
local w = wheres and wheres[i + delta] or where
local inferred_y = infer_at(w, y)
infer_emptytable(x, inferred_y)
xs[i] = inferred_y
end
end
end
end
local check_args_rets
do
local function check_func_type_list(where, wheres, xs, ys, from, delta, mode)
assert(xs.typename == "tuple", xs.typename)
assert(ys.typename == "tuple", ys.typename)
local errs = {}
local n_xs = #xs
local n_ys = #ys
for i = from, math.max(n_xs, n_ys) do
local pos = i + delta
local x = xs[i] or (xs.is_va and xs[n_xs]) or NIL
local y = ys[i] or (ys.is_va and ys[n_ys])
if y then
local w = wheres and wheres[pos] or where
if not arg_check(w, is_a, x, y, pos, errs, mode) then
return nil, errs
end
end
end
return true
end
check_args_rets = function(where, where_args, f, args, rets, argdelta)
local rets_ok = true
local rets_errs
local args_ok
local args_errs
local from = 1
if argdelta == -1 then
from = 2
local errs = {}
if not arg_check(where, is_a, args[1], f.args[1], nil, errs, "self") then
return nil, errs
end
end
if rets then
rets = infer_at(where, rets)
infer_emptytables(where, nil, rets, f.rets, 0)
rets_ok, rets_errs = check_func_type_list(where, nil, f.rets, rets, 1, 0, "return")
end
args_ok, args_errs = check_func_type_list(where, where_args, args, f.args, 1, argdelta, "argument")
if (not args_ok) or (not rets_ok) then
return nil, args_errs or {}
end
infer_emptytables(where, where_args, args, f.args, argdelta)
mark_invalid_typeargs(f)
return resolve_typevars_at(where, f.rets)
end
end
local function push_typeargs(func)
if func.typeargs then
for _, fnarg in ipairs(func.typeargs) do
add_var(nil, fnarg.typearg, { typename = "unresolved_typearg" })
end
end
end
local function pop_typeargs(func)
if func.typeargs then
for _, fnarg in ipairs(func.typeargs) do
if st[#st][fnarg.typearg] then
st[#st][fnarg.typearg] = nil
end
end
end
end
local function fail_call(node, func, nargs, errs)
if errs then
for _, err in ipairs(errs) do
table.insert(errors, err)
end
else
local expects = {}
if func.typename == "poly" then
for _, f in ipairs(func.types) do
table.insert(expects, tostring(#f.args or 0))
end
table.sort(expects)
for i = #expects, 1, -1 do
if expects[i] == expects[i + 1] then
table.remove(expects, i)
end
end
else
table.insert(expects, tostring(#func.args or 0))
end
node_error(node, "wrong number of arguments (given " .. nargs .. ", expects " .. table.concat(expects, " or ") .. ")")
end
local f = func.typename == "poly" and func.types[1] or func
mark_invalid_typeargs(f)
return resolve_typevars_at(node, f.rets)
end
local function check_call(where, where_args, func, args, is_method, argdelta)
assert(type(func) == "table")
assert(type(args) == "table")
if not (func.typename == "function" or func.typename == "poly") then
func, is_method = resolve_for_call(func, args, is_method)
end
argdelta = is_method and -1 or argdelta or 0
local is_func = func.typename == "function"
local is_poly = func.typename == "poly"
if not (is_func or is_poly) then
return node_error(where, "not a function: %s", func)
end
local passes, n = 1, 1
if is_poly then
passes, n = 3, #func.types
end
local given = #args
local tried
local first_errs
for pass = 1, passes do
for i = 1, n do
if (not tried) or not tried[i] then
local f = is_func and func or func.types[i]
if f.is_method and not is_method then
if args[1] and is_a(args[1], f.args[1]) then
local receiver_is_typetype = where.e1.e1 and where.e1.e1.type and where.e1.e1.type.resolved and where.e1.e1.type.resolved.typename == "typetype"
if not receiver_is_typetype then
node_warning("hint", where, "invoked method as a regular function: consider using ':' instead of '.'")
end
else
return node_error(where, "invoked method as a regular function: use ':' instead of '.'")
end
end
local expected = #f.args
if (is_func and (given <= expected or (f.args.is_va and given > expected))) or
(is_poly and ((pass == 1 and given == expected) or
(pass == 2 and given < expected) or
(pass == 3 and f.args.is_va and given > expected))) then
push_typeargs(f)
local matched, errs = check_args_rets(where, where_args, f, args, where.expected, argdelta)
if matched then
return matched, f
end
first_errs = first_errs or errs
if where.expected then
infer_emptytables(where, where_args, f.rets, f.rets, argdelta)
end
if is_poly then
tried = tried or {}
tried[i] = true
pop_typeargs(f)
end
end
end
end
end
return fail_call(where, func, given, first_errs)
end
type_check_function_call = function(where, where_args, func, args, e1, is_method, argdelta)
if where.expected and where.expected.typename ~= "tuple" then
where.expected = a_type({ typename = "tuple", where.expected })
end
begin_scope()
local ret, f = check_call(where, where_args, func, args, is_method, argdelta)
ret = resolve_typevars_at(where, ret)
end_scope()
if e1 then
e1.type = f
end
return ret
end
end
local function match_record_key(tbl, rec, key)
assert(type(tbl) == "table")
assert(type(rec) == "table")
assert(type(key) == "string")
tbl = resolve_tuple_and_nominal(tbl)
if tbl.typename == "string" or tbl.typename == "enum" then
tbl = find_var_type("string")
end
if tbl.is_alias then
return nil, "cannot use a nested type alias as a concrete value"
end
tbl = resolve_typetype(tbl)
if is_record_type(tbl) then
assert(tbl.fields, "record has no fields!?")
if tbl.fields[key] then
return tbl.fields[key]
end
if rec.kind == "variable" then
return nil, "invalid key '" .. key .. "' in record '" .. rec.tk .. "' of type %s"
else
return nil, "invalid key '" .. key .. "' in type %s"
end
elseif tbl.typename == "emptytable" or is_unknown(tbl) then
if lax then
return INVALID
end
return nil, "cannot index a value of unknown type"
end
if rec.kind == "variable" then
return nil, "cannot index key '" .. key .. "' in " .. tbl.typename .. " '" .. rec.tk .. "' of type %s"
else
return nil, "cannot index key '" .. key .. "' in type %s"
end
end
local function widen_in_scope(scope, var)
assert(scope[var], "no " .. var .. " in scope")
if scope[var].is_narrowed then
if scope[var].narrowed_from then
scope[var].t = scope[var].narrowed_from
scope[var].narrowed_from = nil
scope[var].is_narrowed = nil
else
scope[var] = nil
end
local unresolved = get_unresolved(scope)
unresolved.narrows[var] = nil
return true
end
return false
end
local function widen_back_var(name)
local widened = false
for i = #st, 1, -1 do
if st[i][name] then
if widen_in_scope(st[i], name) then
widened = true
else
break
end
end
end
return widened
end
local function assigned_anywhere(name, root)
local visit_node = {
cbs = {
["assignment"] = {
after = function(node, _children)
for _, v in ipairs(node.vars) do
if v.kind == "variable" and v.tk == name then
return true
end
end
return false
end,
},
},
after = function(_node, children, ret)
ret = ret or false
for _, c in ipairs(children) do
local ca = c
if type(ca) == "boolean" then
ret = ret or c
end
end
return ret
end,
}
local visit_type = {
after = function()
return false
end,
}
return recurse_node(root, visit_node, visit_type)
end
local function widen_all_unions(node)
for i = #st, 1, -1 do
local scope = st[i]
local unr = scope["@unresolved"]
if unr and unr.t.narrows then
for name, _ in pairs(unr.t.narrows) do
if not node or assigned_anywhere(name, node) then
widen_in_scope(scope, name)
end
end
end
end
end
local function add_global(node, var, valtype, is_assigning)
if lax and is_unknown(valtype) and (var ~= "self" and var ~= "...") then
add_unknown(node, var)
end
local existing, scope, existing_attr = find_var(var)
if existing and scope > 1 then
node_error(node, "cannot define a global when a local with the same name is in scope")
return nil
end
local is_const = node.attribute ~= nil
if existing then
if is_assigning and existing_attr then
node_error(node, "cannot reassign to <" .. node.attribute .. "> global: " .. var)
end
if existing_attr and not is_const then
node_error(node, "global was previously declared as <" .. existing_attr .. ">: " .. var)
end
if (not existing_attr) and is_const then
node_error(node, "global was previously declared as not <" .. node.attribute .. ">: " .. var)
end
if valtype and not same_type(existing.t, valtype) then
node_error(node, "cannot redeclare global with a different type: previous type of " .. var .. " is %s", existing.t)
end
return nil
end
st[1][var] = { t = valtype, attribute = is_const and "const" or nil }
if node then
node.type = node.type or valtype
end
return st[1][var]
end
local function get_rets(rets)
if lax and (#rets == 0) then
return VARARG({ UNKNOWN })
end
local t = rets
if not t.typename then
t = TUPLE(t)
end
assert(t.typeid)
return t
end
local function add_internal_function_variables(node)
add_var(nil, "@is_va", node.args.type.is_va and ANY or NIL)
add_var(nil, "@return", node.rets or a_type({ typename = "tuple" }))
if node.typeargs then
for _, t in ipairs(node.typeargs) do
local v = find_var(t.typearg, "check_only")
if not v or not v.used_as_type then
type_error(t, "type argument '%s' is not used in function signature", t)
end
end
end
end
local function add_function_definition_for_recursion(node)
local args = a_type({ typename = "tuple" })
for _, fnarg in ipairs(node.args) do
table.insert(args, fnarg.type)
end
add_var(nil, node.name.tk, a_type({
typename = "function",
args = args,
rets = get_rets(node.rets),
}))
end
local function fail_unresolved()
local unresolved = st[#st]["@unresolved"]
if unresolved then
st[#st]["@unresolved"] = nil
for name, nodes in pairs(unresolved.t.labels) do
for _, node in ipairs(nodes) do
node_error(node, "no visible label '" .. name .. "' for goto")
end
end
for name, types in pairs(unresolved.t.nominals) do
if not unresolved.t.global_types[name] then
for _, typ in ipairs(types) do
assert(typ.x)
assert(typ.y)
type_error(typ, "unknown type %s", typ)
end
end
end
end
end
local function end_function_scope(node)
fail_unresolved()
end_scope(node)
end
resolve_tuple_and_nominal = function(t)
t = resolve_tuple(t)
if t.typename == "nominal" then
t = resolve_nominal(t)
end
assert(t.typename ~= "nominal")
return t
end
local function flatten_list(list)
local exps = {}
for i = 1, #list - 1 do
table.insert(exps, resolve_tuple_and_nominal(list[i]))
end
if #list > 0 then
local last = list[#list]
if last.typename == "tuple" then
for _, val in ipairs(last) do
table.insert(exps, val)
end
else
table.insert(exps, last)
end
end
return exps
end
local function get_assignment_values(vals, wanted)
local ret = {}
if vals == nil then
return ret
end
local is_va = vals.is_va
for i = 1, #vals - 1 do
ret[i] = resolve_tuple(vals[i])
end
local last = vals[#vals]
if last.typename == "tuple" then
is_va = last.is_va
for _, v in ipairs(last) do
table.insert(ret, v)
end
else
table.insert(ret, last)
end
if is_va and last and #ret < wanted then
while #ret < wanted do
table.insert(ret, last)
end
end
return ret
end
local function match_all_record_field_names(node, a, field_names, errmsg)
local t
for _, k in ipairs(field_names) do
local f = a.fields[k]
if not t then
t = f
else
if not same_type(f, t) then
errmsg = errmsg .. string.format(" (types of fields '%s' and '%s' do not match)", field_names[1], k)
t = nil
break
end
end
end
if t then
return t
else
return node_error(node, errmsg)
end
end
local function check_metamethod(node, op, a, b)
local method_name
local where_args
local args
local meta_on_operator = 1
if a and b then
method_name = binop_to_metamethod[op]
where_args = { node.e1, node.e2 }
args = { typename = "tuple", a, b }
else
method_name = unop_to_metamethod[op]
where_args = { node.e1 }
args = { typename = "tuple", a }
end
local metamethod = a.meta_fields and a.meta_fields[method_name or ""]
if (not metamethod) and b and op ~= "@index" then
metamethod = b.meta_fields and b.meta_fields[method_name or ""]
meta_on_operator = 2
end
if metamethod then
return resolve_tuple_and_nominal(type_check_function_call(node, where_args, metamethod, args, nil, false, 0)), meta_on_operator
elseif lax and ((a and is_unknown(a)) or (b and is_unknown(b))) then
return UNKNOWN, nil
else
return nil, nil
end
end
local function type_check_index(anode, bnode, a, b)
local orig_a = a
local orig_b = b
a = resolve_typetype(resolve_tuple_and_nominal(a))
b = resolve_tuple_and_nominal(b)
if lax and is_unknown(a) then
return UNKNOWN
end
local errm
local erra
local errb
if a.typename == "tupletable" and is_a(b, INTEGER) then
if bnode.constnum then
if bnode.constnum >= 1 and bnode.constnum <= #a.types and bnode.constnum == math.floor(bnode.constnum) then
return a.types[bnode.constnum]
end
errm, erra = "index " .. tostring(bnode.constnum) .. " out of range for tuple %s", a
else
local array_type = arraytype_from_tuple(bnode, a)
if array_type then
return array_type.elements
end
errm = "cannot index this tuple with a variable because it would produce a union type that cannot be discriminated at runtime"
end
elseif is_array_type(a) and is_a(b, INTEGER) then
return a.elements
elseif a.typename == "emptytable" then
if a.keys == nil then
a.keys = resolve_tuple(orig_b)
a.keys_inferred_at = assert(anode)
a.keys_inferred_at_file = filename
end
if is_a(orig_b, a.keys) then
return a_type({ y = anode.y, x = anode.x, typename = "unresolved_emptytable_value", emptytable_type = a })
end
errm, erra, errb = "inconsistent index type: got %s, expected %s (type of keys inferred at " ..
a.keys_inferred_at_file .. ":" ..
a.keys_inferred_at.y .. ":" ..
a.keys_inferred_at.x .. ": )", orig_b, a.keys
elseif a.typename == "map" then
if is_a(orig_b, a.keys) then
return a.values
end
errm, erra, errb = "wrong index type: got %s, expected %s", orig_b, a.keys
elseif bnode.kind == "string" or bnode.kind == "enum_item" then
local t, e = match_record_key(orig_a, anode, bnode.conststr)
if t then
return t
end
errm, erra = e, orig_a
elseif is_record_type(a) then
if b.typename == "enum" then
local field_names = sorted_keys(b.enumset)
for _, k in ipairs(field_names) do
if not a.fields[k] then
errm, erra = "enum value '" .. k .. "' is not a field in %s", a
break
end
end
if not errm then
return match_all_record_field_names(bnode, a, field_names,
"cannot index, not all enum values map to record fields of the same type")
end
elseif is_a(b, STRING) then
errm, erra = "cannot index object of type %s with a string, consider using an enum", orig_a
else
errm, erra, errb = "cannot index object of type %s with %s", orig_a, orig_b
end
else
errm, erra, errb = "cannot index object of type %s with %s", orig_a, orig_b
end
local meta_t = check_metamethod(anode, "@index", a, orig_b)
if meta_t then
return meta_t
end
return node_error(bnode, errm, erra, errb)
end
expand_type = function(where, old, new)
if not old or old.typename == "nil" then
return new
else
if not is_a(new, old) then
if old.typename == "map" and is_record_type(new) then
if old.keys.typename == "string" then
for _, ftype in fields_of(new) do
old.values = expand_type(where, old.values, ftype)
end
else
node_error(where, "cannot determine table literal type")
end
elseif is_record_type(old) and is_record_type(new) then
old.typename = "map"
old.keys = STRING
for _, ftype in fields_of(old) do
if not old.values then
old.values = ftype
else
old.values = expand_type(where, old.values, ftype)
end
end
for _, ftype in fields_of(new) do
if not old.values then
new.values = ftype
else
new.values = expand_type(where, old.values, ftype)
end
end
old.fields = nil
old.field_order = nil
elseif old.typename == "union" then
new.tk = nil
table.insert(old.types, new)
else
old.tk = nil
new.tk = nil
return unite({ old, new })
end
end
end
return old
end
local function find_record_to_extend(exp)
if exp.kind == "type_identifier" then
local v = find_var(exp.tk)
if not v then
return nil, nil, exp.tk
end
local t = v.t
if t.closed then
return nil, nil, exp.tk
end
return t.def or t, v, exp.tk
elseif exp.kind == "op" then
local t, v, rname = find_record_to_extend(exp.e1)
local fname = exp.e2.tk
local dname = rname .. "." .. fname
if not t then
return nil, nil, dname
end
t = t and t.fields and t.fields[fname]
return t.def or t, v, dname
end
end
local function get_self_type(exp)
if exp.kind == "type_identifier" then
local t = find_var_type(exp.tk)
if not t then
return nil
end
if t.typename == "typetype" then
return a_type({
y = exp.y,
x = exp.x,
typename = "nominal",
names = { exp.tk },
found = t,
})
else
return t
end
elseif exp.kind == "op" then
local t = get_self_type(exp.e1)
if not t then
return nil
end
if t.typename == "nominal" then
if t.found and t.found.def and t.found.def.fields and t.found.def.fields[exp.e2.tk] then
table.insert(t.names, exp.e2.tk)
t.found = t.found.def.fields[exp.e2.tk]
end
else
return t.fields and t.fields[exp.e2.tk]
end
return t
end
end
local facts_and
local facts_or
local facts_not
local apply_facts
local FACT_TRUTHY
do
setmetatable(Fact, {
__call = function(_, fact)
return setmetatable(fact, {
__tostring = function(f)
if f.fact == "is" then
return ("(%s is %s)"):format(f.var, show_type(f.typ))
elseif f.fact == "==" then
return ("(%s == %s)"):format(f.var, show_type(f.typ))
elseif f.fact == "truthy" then
return "*"
elseif f.fact == "not" then
return ("(not %s)"):format(tostring(f.f1))
elseif f.fact == "or" then
return ("(%s or %s)"):format(tostring(f.f1), tostring(f.f2))
elseif f.fact == "and" then
return ("(%s and %s)"):format(tostring(f.f1), tostring(f.f2))
end
end,
})
end,
})
FACT_TRUTHY = Fact({ fact = "truthy" })
facts_and = function(where, f1, f2)
return Fact({ fact = "and", f1 = f1, f2 = f2, where = where })
end
facts_or = function(where, f1, f2)
if f1 and f2 then
return Fact({ fact = "or", f1 = f1, f2 = f2, where = where })
else
return nil
end
end
facts_not = function(where, f1)
if f1 then
return Fact({ fact = "not", f1 = f1, where = where })
else
return nil
end
end
local function unite_types(t1, t2)
return unite({ t2, t1 })
end
local function intersect_types(t1, t2)
if t2.typename == "union" then
t1, t2 = t2, t1
end
if t1.typename == "union" then
local out = {}
for _, t in ipairs(t1.types) do
if is_a(t, t2) then
table.insert(out, t)
end
end
return unite(out)
else
if is_a(t1, t2) then
return t1
elseif is_a(t2, t1) then
return t2
else
return INVALID
end
end
end
local function resolve_if_union(t)
local rt = resolve_tuple_and_nominal(t)
if rt.typename == "union" then
return rt
end
return t
end
local function subtract_types(t1, t2)
local types = {}
t1 = resolve_if_union(t1)
if t1.typename ~= "union" then
return t1
end
t2 = resolve_if_union(t2)
local t2types = t2.types or { t2 }
for _, at in ipairs(t1.types) do
local not_present = true
for _, bt in ipairs(t2types) do
if same_type(at, bt) then
not_present = false
break
end
end
if not_present then
table.insert(types, at)
end
end
if #types == 0 then
return INVALID
end
return unite(types)
end
local eval_not
local not_facts
local or_facts
local and_facts
local eval_fact
local function invalid_from(f)
return Fact({ fact = "is", var = f.var, typ = INVALID, where = f.where })
end
not_facts = function(fs)
local ret = {}
for var, f in pairs(fs) do
local typ = find_var_type(f.var, "check_only")
local fact = "=="
local where = f.where
if not typ then
typ = INVALID
else
if f.fact == "is" then
if typ.typename == "typevar" then
where = nil
elseif not is_a(f.typ, typ) then
node_warning("branch", f.where, f.var .. " (of type %s) can never be a %s", show_type(typ), show_type(f.typ))
typ = INVALID
else
fact = "is"
typ = subtract_types(typ, f.typ)
end
elseif f.fact == "==" then
where = nil
end
end
ret[var] = Fact({ fact = fact, var = var, typ = typ, where = where })
end
return ret
end
eval_not = function(f)
if not f then
return {}
elseif f.fact == "is" then
return not_facts({ [f.var] = f })
elseif f.fact == "not" then
return eval_fact(f.f1)
elseif f.fact == "and" and f.f2 and f.f2.fact == "truthy" then
return eval_not(f.f1)
elseif f.fact == "or" and f.f2 and f.f2.fact == "truthy" then
return eval_fact(f.f1)
elseif f.fact == "and" then
return or_facts(not_facts(eval_fact(f.f1)), not_facts(eval_fact(f.f2)))
elseif f.fact == "or" then
return and_facts(not_facts(eval_fact(f.f1)), not_facts(eval_fact(f.f2)))
else
return not_facts(eval_fact(f))
end
end
or_facts = function(fs1, fs2)
local ret = {}
for var, f in pairs(fs2) do
if fs1[var] then
local fact = (fs1[var].fact == "is" and f.fact == "is") and
"is" or "=="
ret[var] = Fact({ fact = fact, var = var, typ = unite_types(f.typ, fs1[var].typ), where = f.where })
end
end
return ret
end
and_facts = function(fs1, fs2)
local ret = {}
local has = {}
for var, f in pairs(fs1) do
local rt
local fact
if fs2[var] then
fact = (fs2[var].fact == "is" and f.fact == "is") and "is" or "=="
rt = intersect_types(f.typ, fs2[var].typ)
else
fact = "=="
rt = f.typ
end
ret[var] = Fact({ fact = fact, var = var, typ = rt, where = f.where })
has[fact] = true
end
for var, f in pairs(fs2) do
if not fs1[var] then
ret[var] = Fact({ fact = "==", var = var, typ = f.typ, where = f.where })
has["=="] = true
end
end
if has["is"] and has["=="] then
for _, f in pairs(ret) do
f.fact = "=="
end
end
return ret
end
eval_fact = function(f)
if not f then
return {}
elseif f.fact == "is" then
local typ = find_var_type(f.var, "check_only")
if not typ then
return { [f.var] = invalid_from(f) }
end
if typ.typename ~= "typevar" then
if is_a(typ, f.typ) then
node_warning("branch", f.where, f.var .. " (of type %s) is always a %s", show_type(typ), show_type(f.typ))
return { [f.var] = f }
elseif not is_a(f.typ, typ) then
node_error(f.where, f.var .. " (of type %s) can never be a %s", typ, f.typ)
return { [f.var] = invalid_from(f) }
end
end
return { [f.var] = f }
elseif f.fact == "==" then
return { [f.var] = f }
elseif f.fact == "not" then
return eval_not(f.f1)
elseif f.fact == "truthy" then
return {}
elseif f.fact == "and" and f.f2 and f.f2.fact == "truthy" then
return eval_fact(f.f1)
elseif f.fact == "or" and f.f2 and f.f2.fact == "truthy" then
return eval_not(f.f1)
elseif f.fact == "and" then
return and_facts(eval_fact(f.f1), eval_fact(f.f2))
elseif f.fact == "or" then
return or_facts(eval_fact(f.f1), eval_fact(f.f2))
end
end
apply_facts = function(where, known)
if not known then
return
end
local facts = eval_fact(known)
for v, f in pairs(facts) do
if f.typ.typename == "invalid" then
node_error(where, "cannot resolve a type for " .. v .. " here")
end
local t = infer_at(where, f.typ)
if not f.where then
t.inferred_at = nil
end
add_var(nil, v, t, "const", "is")
end
end
end
local function dismiss_unresolved(name)
for i = #st, 1, -1 do
local unresolved = st[i]["@unresolved"]
if unresolved then
if unresolved.t.nominals[name] then
for _, t in ipairs(unresolved.t.nominals[name]) do
resolve_nominal(t)
end
unresolved.t.nominals[name] = nil
return
end
end
end
end
local type_check_funcall
local function special_pcall_xpcall(node, _a, b, argdelta)
local base_nargs = (node.e1.tk == "xpcall") and 2 or 1
if #node.e2 < base_nargs then
node_error(node, "wrong number of arguments (given " .. #node.e2 .. ", expects at least " .. base_nargs .. ")")
return TUPLE({ BOOLEAN })
end
local ftype = table.remove(b, 1)
local fe2 = {}
if node.e1.tk == "xpcall" then
base_nargs = 2
local msgh = table.remove(b, 1)
assert_is_a(node.e2[2], msgh, XPCALL_MSGH_FUNCTION, "in message handler")
end
for i = base_nargs + 1, #node.e2 do
table.insert(fe2, node.e2[i])
end
local fnode = {
y = node.y,
x = node.x,
kind = "op",
op = { op = "@funcall" },
e1 = node.e2[1],
e2 = fe2,
}
local rets = type_check_funcall(fnode, ftype, b, argdelta + base_nargs)
if rets.typename ~= "tuple" then
rets = a_type({ typename = "tuple", rets })
end
table.insert(rets, 1, BOOLEAN)
return rets
end
local special_functions = {
["rawget"] = function(node, _a, b, _argdelta)
if #b == 2 then
return type_check_index(node.e2[1], node.e2[2], b[1], b[2])
else
return node_error(node, "rawget expects two arguments")
end
end,
["print_type"] = function(node, _a, b, _argdelta)
if #b == 0 then
print("-----------------------------------------")
for i, scope in ipairs(st) do
for s, v in pairs(scope) do
print(("%2d %-14s %-11s %s"):format(i, s, v.t.typename, show_type(v.t):sub(1, 50)))
end
end
print("-----------------------------------------")
return NONE
else
local t = show_type(b[1])
print(t)
node_warning("debug", node.e2[1], "type is: %s", t)
return b
end
end,
["require"] = function(node, _a, b, _argdelta)
if #b ~= 1 then
return node_error(node, "require expects one literal argument")
end
if node.e2[1].kind ~= "string" then
return node_error(node, "don't know how to resolve a dynamic require")
end
local module_name = assert(node.e2[1].conststr)
local t, found = require_module(module_name, lax, env)
if not found then
return node_error(node, "module not found: '" .. module_name .. "'")
end
if t.typename == "invalid" then
if lax then
return UNKNOWN
end
return node_error(node, "no type information for required module: '" .. module_name .. "'")
end
dependencies[module_name] = t.filename
return t
end,
["pcall"] = special_pcall_xpcall,
["xpcall"] = special_pcall_xpcall,
["assert"] = function(node, a, b, argdelta)
node.known = FACT_TRUTHY
return type_check_function_call(node, node.e2, a, b, node, false, argdelta)
end,
}
type_check_funcall = function(node, a, b, argdelta)
argdelta = argdelta or 0
if node.e1.kind == "variable" then
local special = special_functions[node.e1.tk]
if special then
return special(node, a, b, argdelta)
else
return type_check_function_call(node, node.e2, a, b, node.e1, false, argdelta)
end
elseif node.e1.op and node.e1.op.op == ":" then
table.insert(b, 1, node.e1.e1.type)
return type_check_function_call(node, node.e2, a, b, node.e1, true)
else
return type_check_function_call(node, node.e2, a, b, node.e1, false, argdelta)
end
end
local function is_localizing_a_variable(node, i)
return node.exps and
node.exps[i] and
node.exps[i].kind == "variable" and
node.exps[i].tk == node.vars[i].tk
end
local function resolve_nominal_typetype(typetype)
if typetype.def.typename == "nominal" then
local names = typetype.def.names
local aliasing = find_var(names[1], "use_type")
local resolved = typetype
if typetype.def.typevals then
typetype.def = resolve_nominal(typetype.def)
else
resolved = find_type(names)
if (not resolved) or (not is_typetype(resolved)) then
type_error(typetype, "%s is not a type", typetype)
resolved = a_type({ typename = "bad_nominal", names = names })
end
end
return resolved, aliasing
end
return typetype, nil
end
local function missing_initializer(node, i, name)
if lax then
return UNKNOWN
else
if node.exps then
return node_error(node.vars[i], "assignment in declaration did not produce an initial value for variable '" .. name .. "'")
else
return node_error(node.vars[i], "variable '" .. name .. "' has no type or initial value")
end
end
end
local function set_expected_types_to_decltypes(node, children)
local decls = node.kind == "assignment" and children[1] or node.decltype
if decls and node.exps then
local ndecl = #decls
local nexps = #node.exps
for i = 1, nexps do
local typ
typ = decls[i]
if typ then
if i == nexps and ndecl > nexps then
typ = a_type({ y = node.y, x = node.x, filename = filename, typename = "tuple", types = {} })
for a = i, ndecl do
table.insert(typ, decls[a])
end
end
node.exps[i].expected = typ
node.exps[i].expected_context = { kind = node.kind, name = node.vars[i].tk }
end
end
end
end
local function is_positive_int(n)
return n and n >= 1 and math.floor(n) == n
end
local context_name = {
["local_declaration"] = "in local declaration",
["global_declaration"] = "in global declaration",
["assignment"] = "in assignment",
}
local function in_context(ctx, msg)
if not ctx then
return msg
end
local where = context_name[ctx.kind]
if where then
return where .. ": " .. (ctx.name and ctx.name .. ": " or "") .. msg
else
return msg
end
end
local function check_redeclared_key(node, ctx, seen_keys, key)
if key ~= nil then
local s = seen_keys[key]
if s then
node_error(node, in_context(ctx, "redeclared key " .. tostring(key) .. " (previously declared at " .. filename .. ":" .. s.y .. ":" .. s.x .. ")"))
else
seen_keys[key] = node
end
end
end
local function infer_table_literal(node, children)
local typ = a_type({
filename = filename,
y = node.y,
x = node.x,
typename = "emptytable",
})
local is_record = false
local is_array = false
local is_map = false
local is_tuple = false
local is_not_tuple = false
local last_array_idx = 1
local largest_array_idx = -1
local seen_keys = {}
for i, child in ipairs(children) do
assert(child.typename == "table_item")
local ck = child.kname
local n = node[i].key.constnum
local b = nil
if child.ktype.typename == "boolean" then
b = (node[i].key.tk == "true")
end
local key = ck or n or b
check_redeclared_key(node[i], nil, seen_keys, key)
local uvtype = resolve_tuple(child.vtype)
if ck then
is_record = true
if not typ.fields then
typ.fields = {}
typ.field_order = {}
end
typ.fields[ck] = uvtype
table.insert(typ.field_order, ck)
elseif is_number_type(child.ktype) then
is_array = true
if not is_not_tuple then
is_tuple = true
end
if not typ.types then
typ.types = {}
end
if node[i].key_parsed == "implicit" then
if i == #children and child.vtype.typename == "tuple" then
for _, c in ipairs(child.vtype) do
typ.elements = expand_type(node, typ.elements, c)
typ.types[last_array_idx] = resolve_tuple(c)
last_array_idx = last_array_idx + 1
end
else
typ.types[last_array_idx] = uvtype
last_array_idx = last_array_idx + 1
typ.elements = expand_type(node, typ.elements, uvtype)
end
else
if not is_positive_int(n) then
typ.elements = expand_type(node, typ.elements, uvtype)
is_not_tuple = true
elseif n then
typ.types[n] = uvtype
if n > largest_array_idx then
largest_array_idx = n
end
typ.elements = expand_type(node, typ.elements, uvtype)
end
end
if last_array_idx > largest_array_idx then
largest_array_idx = last_array_idx
end
if not typ.elements then
is_array = false
end
else
is_map = true
child.ktype.tk = nil
typ.keys = expand_type(node, typ.keys, child.ktype)
typ.values = expand_type(node, typ.values, uvtype)
end
end
if is_array and is_map then
typ.typename = "map"
typ.keys = expand_type(node, typ.keys, INTEGER)
typ.values = expand_type(node, typ.values, typ.elements)
typ.elements = nil
node_error(node, "cannot determine type of table literal")
elseif is_record and is_array then
typ.typename = "arrayrecord"
elseif is_record and is_map then
if typ.keys.typename == "string" then
typ.typename = "map"
for _, ftype in fields_of(typ) do
typ.values = expand_type(node, typ.values, ftype)
end
typ.fields = nil
typ.field_order = nil
else
node_error(node, "cannot determine type of table literal")
end
elseif is_array then
if is_not_tuple then
typ.typename = "array"
typ.inferred_len = largest_array_idx - 1
else
local pure_array = true
local last_t
for _, current_t in pairs(typ.types) do
if last_t then
if not same_type(last_t, current_t) then
pure_array = false
break
end
end
last_t = current_t
end
if not pure_array then
typ.typename = "tupletable"
else
typ.typename = "array"
typ.inferred_len = largest_array_idx - 1
end
end
elseif is_record then
typ.typename = "record"
elseif is_map then
typ.typename = "map"
elseif is_tuple then
typ.typename = "tupletable"
if not typ.types or #typ.types == 0 then
node_error(node, "cannot determine type of tuple elements")
end
end
return typ
end
local function infer_negation_of_if_blocks(where, ifnode, n)
local f = facts_not(where, ifnode.if_blocks[1].exp.known)
for e = 2, n do
local b = ifnode.if_blocks[e]
if b.exp then
f = facts_and(where, f, facts_not(where, b.exp.known))
end
end
apply_facts(where, f)
end
local function determine_declaration_type(var, node, infertypes, i)
local ok = true
local name = var.tk
local infertype = infertypes and infertypes[i]
if lax and infertype and infertype.typename == "nil" then
infertype = nil
end
local decltype = node.decltype and node.decltype[i]
if decltype then
if resolve_tuple_and_nominal(decltype) == INVALID then
decltype = INVALID
end
if infertype then
ok = assert_is_a(node.vars[i], infertype, decltype, context_name[node.kind], name)
end
else
if infertype and infertype.typename == "unresolvable_typearg" then
node_error(node.vars[i], "cannot infer declaration type; an explicit type annotation is necessary")
ok = false
infertype = INVALID
elseif infertype and infertype.is_method then
infertype = shallow_copy_type(infertype)
infertype.typeid = new_typeid()
infertype.is_method = false
end
end
if var.attribute == "total" then
local rd = decltype and resolve_tuple_and_nominal(decltype)
if rd and (rd.typename ~= "map" and rd.typename ~= "record") then
node_error(var, "attribute <total> only applies to maps and records")
ok = false
elseif not infertype then
node_error(var, "variable declared <total> does not declare an initialization value")
ok = false
elseif not (node.exps[i] and node.exps[i].attribute == "total") then
local ri = resolve_tuple_and_nominal(infertype)
if ri.typename ~= "map" and ri.typename ~= "record" then
node_error(var, "attribute <total> only applies to maps and records")
ok = false
elseif not infertype.is_total then
local missing = ""
if infertype.missing then
missing = " (missing: " .. table.concat(infertype.missing, ", ") .. ")"
end
if ri.typename == "map" then
node_error(var, "map variable declared <total> does not declare values for all possible keys" .. missing)
ok = false
elseif ri.typename == "record" then
node_error(var, "record variable declared <total> does not declare values for all fields" .. missing)
ok = false
end
end
infertype.is_total = nil
end
end
local t = decltype or infertype
if t == nil then
t = missing_initializer(node, i, name)
elseif t.typename == "emptytable" then
t.declared_at = node
t.assigned_to = name
end
t.inferred_len = nil
return ok, t, infertype ~= nil
end
local function get_type_declaration(node)
if node.value.kind == "op" and node.value.op.op == "@funcall" then
return special_functions["require"](node.value, find_var_type("require"), { STRING }, 0)
else
return resolve_nominal_typetype(node.value.newtype)
end
end
local function total_check_key(key, seen_keys, is_total, missing)
if not seen_keys[key] then
missing = missing or {}
table.insert(missing, tostring(key))
return false, missing
end
return is_total, missing
end
local function total_record_check(t, seen_keys)
if t.meta_field_order then
return false
end
local is_total = true
local missing
for _, key in ipairs(t.field_order) do
if not is_typetype(t.fields[key]) then
is_total, missing = total_check_key(key, seen_keys, is_total, missing)
end
end
return is_total, missing
end
local function total_map_check(t, seen_keys)
local k = resolve_tuple_and_nominal(t.keys)
local is_total = true
local missing
if k.typename == "enum" then
for _, key in ipairs(sorted_keys(k.enumset)) do
is_total, missing = total_check_key(key, seen_keys, is_total, missing)
end
elseif k.typename == "boolean" then
for _, key in ipairs({ true, false }) do
is_total, missing = total_check_key(key, seen_keys, is_total, missing)
end
else
is_total = false
end
return is_total, missing
end
local visit_node = {}
visit_node.cbs = {
["statements"] = {
before = function(node)
begin_scope(node)
end,
after = function(node, _children)
if #st == 2 then
fail_unresolved()
end
if not node.is_repeat then
end_scope(node)
end
node.type = NONE
return node.type
end,
},
["local_type"] = {
before = function(node)
local name = node.var.tk
local resolved, aliasing = get_type_declaration(node)
local var = add_var(node.var, name, resolved, node.var.attribute)
node.value.type = resolved
if aliasing then
var.aliasing = aliasing
node.value.is_alias = true
end
end,
after = function(node, _children)
dismiss_unresolved(node.var.tk)
node.type = NONE
return node.type
end,
},
["global_type"] = {
before = function(node)
local name = node.var.tk
local unresolved = get_unresolved()
if node.value then
local resolved, aliasing = get_type_declaration(node)
local added = add_global(node.var, name, resolved)
node.value.newtype = resolved
if aliasing then
added.aliasing = aliasing
node.value.is_alias = true
end
if added and unresolved.global_types[name] then
unresolved.global_types[name] = nil
end
else
if not st[1][name] then
unresolved.global_types[name] = true
end
end
end,
after = function(node, _children)
dismiss_unresolved(node.var.tk)
node.type = NONE
return node.type
end,
},
["local_declaration"] = {
before = function(node)
for _, var in ipairs(node.vars) do
reserve_symbol_list_slot(var)
end
end,
before_expressions = set_expected_types_to_decltypes,
after = function(node, children)
local encountered_close = false
local infertypes = get_assignment_values(children[3], #node.vars)
for i, var in ipairs(node.vars) do
if var.attribute == "close" then
if opts.gen_target == "5.4" then
if encountered_close then
node_error(var, "only one <close> per declaration is allowed")
else
encountered_close = true
end
else
node_error(var, "<close> attribute is only valid for Lua 5.4 (current target is " .. tostring(opts.gen_target) .. ")")
end
end
local ok, t = determine_declaration_type(var, node, infertypes, i)
if var.attribute == "close" then
if not type_is_closable(t) then
node_error(var, "to-be-closed variable " .. var.tk .. " has a non-closable type %s", t)
elseif node.exps and node.exps[i] and expr_is_definitely_not_closable(node.exps[i]) then
node_error(var, "to-be-closed variable " .. var.tk .. " assigned a non-closable value")
end
end
assert(var)
add_var(var, var.tk, t, var.attribute, is_localizing_a_variable(node, i) and "declaration")
if ok and infertypes and infertypes[i] then
local where = node.exps[i] or node.exps
local infertype = infertypes[i]
local rt = resolve_tuple_and_nominal(t)
if rt.typename ~= "enum" and not same_type(t, infertype) then
add_var(where, var.tk, infer_at(where, infertype), "const", "declaration")
end
end
dismiss_unresolved(var.tk)
end
node.type = NONE
return node.type
end,
},
["global_declaration"] = {
before_expressions = set_expected_types_to_decltypes,
after = function(node, children)
local infertypes = get_assignment_values(children[3], #node.vars)
for i, var in ipairs(node.vars) do
local _, t, is_inferred = determine_declaration_type(var, node, infertypes, i)
if var.attribute == "close" then
node_error(var, "globals may not be <close>")
end
add_global(var, var.tk, t, is_inferred)
var.type = t
dismiss_unresolved(var.tk)
end
node.type = NONE
return node.type
end,
},
["assignment"] = {
before_expressions = set_expected_types_to_decltypes,
after = function(node, children)
local valtypes = get_assignment_values(children[3], #children[1])
local exps = flatten_list(valtypes)
for i, vartype in ipairs(children[1]) do
local varnode = node.vars[i]
local attr = varnode.attribute
if varnode.kind == "variable" then
if widen_back_var(varnode.tk) then
vartype, attr = find_var_type(varnode.tk)
end
end
if attr then
node_error(varnode, "cannot assign to <" .. attr .. "> variable")
end
if vartype then
local val = exps[i]
if is_typetype(resolve_tuple_and_nominal(vartype)) then
node_error(varnode, "cannot reassign a type")
elseif val then
assert_is_a(varnode, val, vartype, "in assignment")
local rval = resolve_tuple_and_nominal(val)
if rval.typename == "function" then
widen_all_unions()
end
if varnode.kind == "variable" and vartype.typename == "union" then
add_var(varnode, varnode.tk, val, nil, "is")
end
else
node_error(varnode, "variable is not being assigned a value")
if #node.exps == 1 and node.exps[1].kind == "op" and node.exps[1].op.op == "@funcall" then
local rets = node.exps[1].type
if rets.typename == "tuple" then
local msg = #rets == 1 and
"only 1 value is returned by the function" or
("only " .. #rets .. " values are returned by the function")
node_warning("hint", varnode, msg)
end
end
end
else
node_error(varnode, "unknown variable")
end
end
node.type = NONE
return node.type
end,
},
["if"] = {
after = function(node, _children)
local all_return = true
for _, b in ipairs(node.if_blocks) do
if not b.block_returns then
all_return = false
break
end
end
if all_return then
node.block_returns = true
infer_negation_of_if_blocks(node, node, #node.if_blocks)
end
node.type = NONE
return node.type
end,
},
["if_block"] = {
before = function(node)
begin_scope(node)
if node.if_block_n > 1 then
infer_negation_of_if_blocks(node, node.if_parent, node.if_block_n - 1)
end
end,
before_statements = function(node)
if node.exp then
apply_facts(node.exp, node.exp.known)
end
end,
after = function(node, _children)
end_scope(node)
if #node.body > 0 and node.body[#node.body].block_returns then
node.block_returns = true
end
node.type = NONE
return node.type
end,
},
["while"] = {
before = function(node)
widen_all_unions(node)
end,
before_statements = function(node)
begin_scope(node)
apply_facts(node.exp, node.exp.known)
end,
after = end_scope_and_none_type,
},
["label"] = {
before = function(node)
widen_all_unions()
local label_id = "::" .. node.label .. "::"
if st[#st][label_id] then
node_error(node, "label '" .. node.label .. "' already defined at " .. filename)
end
local unresolved = st[#st]["@unresolved"]
node.type = a_type({ y = node.y, x = node.x, typename = "none" })
local var = add_var(node, label_id, node.type)
if unresolved then
if unresolved.t.labels[node.label] then
var.used = true
end
unresolved.t.labels[node.label] = nil
end
end,
},
["goto"] = {
after = function(node, _children)
if not find_var_type("::" .. node.label .. "::") then
local unresolved = get_unresolved(st[#st])
unresolved.labels[node.label] = unresolved.labels[node.label] or {}
table.insert(unresolved.labels[node.label], node)
end
node.type = NONE
return node.type
end,
},
["repeat"] = {
before = function(node)
widen_all_unions(node)
end,
after = end_scope_and_none_type,
},
["forin"] = {
before = function(node)
begin_scope(node)
end,
before_statements = function(node)
widen_all_unions(node)
local exp1 = node.exps[1]
local args = {
typename = "tuple",
node.exps[2] and node.exps[2].type,
node.exps[3] and node.exps[3].type,
}
local exp1type = resolve_for_call(exp1.type, args)
if exp1type.typename == "poly" then
type_check_function_call(exp1, { node.exps[2], node.exps[3] }, exp1type, args, exp1, false, 0)
exp1type = exp1.type or exp1type
end
if exp1type.typename == "function" then
if exp1.op and exp1.op.op == "@funcall" then
local t = resolve_tuple_and_nominal(exp1.e2.type)
if exp1.e1.tk == "pairs" and is_array_type(t) then
node_warning("hint", exp1, "hint: applying pairs on an array: did you intend to apply ipairs?")
end
if exp1.e1.tk == "pairs" and t.typename ~= "map" then
if not (lax and is_unknown(t)) then
if is_record_type(t) then
match_all_record_field_names(exp1.e2, t, t.field_order,
"attempting pairs loop on a record with attributes of different types")
local ct = t.typename == "record" and "{string:any}" or "{any:any}"
node_warning("hint", exp1.e2, "hint: if you want to iterate over fields of a record, cast it to " .. ct)
else
node_error(exp1.e2, "cannot apply pairs on values of type: %s", exp1.e2.type)
end
end
elseif exp1.e1.tk == "ipairs" then
if t.typename == "tupletable" then
local arr_type = arraytype_from_tuple(exp1.e2, t)
if not arr_type then
node_error(exp1.e2, "attempting ipairs loop on tuple that's not a valid array: %s", exp1.e2.type)
end
elseif not is_array_type(t) then
if not (lax and (is_unknown(t) or t.typename == "emptytable")) then
node_error(exp1.e2, "attempting ipairs loop on something that's not an array: %s", exp1.e2.type)
end
end
end
end
local last
local rets = exp1type.rets
for i, v in ipairs(node.vars) do
local r = rets[i]
if not r then
if rets.is_va then
r = last
else
r = lax and UNKNOWN or INVALID
end
end
add_var(v, v.tk, r)
last = r
end
if (not lax) and (not rets.is_va and #node.vars > #rets) then
local nrets = #rets
local at = node.vars[nrets + 1]
local n_values = nrets == 1 and "1 value" or tostring(nrets) .. " values"
node_error(at, "too many variables for this iterator; it produces " .. n_values)
end
else
if not (lax and is_unknown(exp1type)) then
node_error(exp1, "expression in for loop does not return an iterator")
end
end
end,
after = end_scope_and_none_type,
},
["fornum"] = {
before_statements = function(node, children)
widen_all_unions(node)
begin_scope(node)
local from_t = resolve_tuple_and_nominal(children[2])
local to_t = resolve_tuple_and_nominal(children[3])
local step_t = children[4] and resolve_tuple_and_nominal(children[4])
local t = (from_t.typename == "integer" and
to_t.typename == "integer" and
(not step_t or step_t.typename == "integer")) and
INTEGER or
NUMBER
add_var(node.var, node.var.tk, t)
end,
after = end_scope_and_none_type,
},
["return"] = {
before = function(node)
local rets = find_var_type("@return")
if rets then
for i, exp in ipairs(node.exps) do
exp.expected = rets[i]
end
end
end,
after = function(node, children)
node.block_returns = true
local rets = find_var_type("@return")
if not rets then
rets = infer_at(node, children[1])
module_type = resolve_tuple_and_nominal(rets)
module_type.tk = nil
st[2]["@return"] = { t = rets }
end
local what = "in return value"
if rets.inferred_at then
what = what .. inferred_msg(rets)
end
local nrets = #rets
local vatype
if nrets > 0 then
vatype = rets.is_va and rets[nrets]
end
if #children[1] > nrets and (not lax) and not vatype then
node_error(node, "in " .. what .. ": excess return values, expected " .. #rets .. " %s, got " .. #children[1] .. " %s", rets, children[1])
end
if nrets > 1 and
#node.exps == 1 and
node.exps[1].kind == "op" and
(node.exps[1].op.op == "and" or node.exps[1].op.op == "or") and
#node.exps[1].e2.type > 1 then
node_warning("hint", node.exps[1].e2, "additional return values are being discarded due to '" .. node.exps[1].op.op .. "' expression; suggest parentheses if intentional")
end
for i = 1, #children[1] do
local expected = rets[i] or vatype
if expected then
expected = resolve_tuple(expected)
local where = (node.exps[i] and node.exps[i].x) and
node.exps[i] or
node.exps
assert(where and where.x)
assert_is_a(where, children[1][i], expected, what)
end
end
node.type = NONE
return node.type
end,
},
["variable_list"] = {
after = function(node, children)
node.type = TUPLE(children)
local n = #children
if n > 0 and children[n].typename == "tuple" then
if children[n].is_va then
node.type.is_va = true
end
local tuple = children[n]
for i, c in ipairs(tuple) do
children[n + i - 1] = c
end
end
return node.type
end,
},
["table_literal"] = {
before = function(node)
if node.expected then
local decltype = resolve_tuple_and_nominal(node.expected)
if decltype.typename == "tupletable" then
for _, child in ipairs(node) do
local n = child.key.constnum
if n and is_positive_int(n) then
child.value.expected = decltype.types[n]
end
end
elseif is_array_type(decltype) then
for _, child in ipairs(node) do
if child.key.constnum then
child.value.expected = decltype.elements
end
end
elseif decltype.typename == "map" then
for _, child in ipairs(node) do
child.key.expected = decltype.keys
child.value.expected = decltype.values
end
end
if is_record_type(decltype) then
for _, child in ipairs(node) do
if child.key.conststr then
child.value.expected = decltype.fields[child.key.conststr]
end
end
end
end
end,
after = function(node, children)
node.known = FACT_TRUTHY
if node.expected then
local decltype = resolve_tuple_and_nominal(node.expected)
if decltype.typename == "union" then
for _, t in ipairs(decltype.types) do
local rt = resolve_tuple_and_nominal(t)
if is_lua_table_type(rt) then
node.expected = t
decltype = rt
break
end
end
end
if not is_lua_table_type(decltype) then
node.type = infer_table_literal(node, children)
return node.type
end
local is_record = is_record_type(decltype)
local is_array = is_array_type(decltype)
local is_tupletable = decltype.typename == "tupletable"
local is_map = decltype.typename == "map"
local force_array = nil
local seen_keys = {}
for i, child in ipairs(children) do
assert(child.typename == "table_item")
local cvtype = resolve_tuple(child.vtype)
local ck = child.kname
local n = node[i].key.constnum
local b = nil
if child.ktype.typename == "boolean" then
b = (node[i].key.tk == "true")
end
check_redeclared_key(node[i], node.expected_context, seen_keys, ck or n or b)
if is_record and ck then
local df = decltype.fields[ck]
if not df then
node_error(node[i], in_context(node.expected_context, "unknown field " .. ck))
else
assert_is_a(node[i], cvtype, df, "in record field", ck)
end
elseif is_tupletable and is_number_type(child.ktype) then
local dt = decltype.types[n]
if not n then
node_error(node[i], in_context(node.expected_context, "unknown index in tuple %s"), decltype)
elseif not dt then
node_error(node[i], in_context(node.expected_context, "unexpected index " .. n .. " in tuple %s"), decltype)
else
assert_is_a(node[i], cvtype, dt, in_context(node.expected_context, "in tuple"), "at index " .. tostring(n))
end
elseif is_array and is_number_type(child.ktype) then
if child.vtype.typename == "tuple" and i == #children and node[i].key_parsed == "implicit" then
for ti, tt in ipairs(child.vtype) do
assert_is_a(node[i], tt, decltype.elements, in_context(node.expected_context, "expected an array"), "at index " .. tostring(i + ti - 1))
end
else
assert_is_a(node[i], cvtype, decltype.elements, in_context(node.expected_context, "expected an array"), "at index " .. tostring(n))
end
elseif node[i].key_parsed == "implicit" then
force_array = expand_type(node[i], force_array, child.vtype)
elseif is_map then
assert_is_a(node[i], child.ktype, decltype.keys, in_context(node.expected_context, "in map key"))
assert_is_a(node[i], cvtype, decltype.values, in_context(node.expected_context, "in map value"))
else
node_error(node[i], in_context(node.expected_context, "unexpected key of type %s in table of type %s"), child.ktype, decltype)
end
end
if force_array then
node.type = a_type({
inferred_at = node,
inferred_at_file = filename,
typename = "array",
elements = force_array,
})
else
node.type = resolve_typevars_at(node, node.expected)
if node.expected == node.type and node.type.typename == "nominal" then
node.type = {
typeid = node.type.typeid,
typename = "nominal",
names = node.type.names,
found = node.type.found,
resolved = node.type.resolved,
}
end
end
if decltype.typename == "record" then
node.type.is_total, node.type.missing = total_record_check(decltype, seen_keys)
elseif decltype.typename == "map" then
node.type.is_total, node.type.missing = total_map_check(decltype, seen_keys)
end
else
node.type = infer_table_literal(node, children)
end
return node.type
end,
},
["table_item"] = {
after = function(node, children)
local kname = node.key.conststr
local ktype = children[1]
local vtype = children[2]
if node.decltype then
vtype = node.decltype
assert_is_a(node.value, children[2], node.decltype, "in table item")
end
if vtype.is_method then
vtype = shallow_copy_type(vtype)
vtype.typeid = new_typeid()
vtype.is_method = false
end
node.type = a_type({
y = node.y,
x = node.x,
typename = "table_item",
kname = kname,
ktype = ktype,
vtype = vtype,
})
return node.type
end,
},
["local_function"] = {
before = function(node)
widen_all_unions()
reserve_symbol_list_slot(node)
begin_scope(node)
end,
before_statements = function(node)
add_internal_function_variables(node)
add_function_definition_for_recursion(node)
end,
after = function(node, children)
end_function_scope(node)
local rets = get_rets(children[3])
add_var(node, node.name.tk, ensure_fresh_typeargs(a_type({
y = node.y,
x = node.x,
typename = "function",
typeargs = node.typeargs,
args = children[2],
rets = rets,
filename = filename,
})))
return node.type
end,
},
["global_function"] = {
before = function(node)
widen_all_unions()
begin_scope(node)
if node.implicit_global_function then
local typ = find_var_type(node.name.tk)
if typ then
if typ.typename == "function" then
node.is_predeclared_local_function = true
elseif not lax then
node_error(node, "cannot declare function: type of " .. node.name.tk .. " is %s", typ)
end
elseif not lax then
node_error(node, "functions need an explicit 'local' or 'global' annotation")
end
end
end,
before_statements = function(node)
add_internal_function_variables(node)
add_function_definition_for_recursion(node)
end,
after = function(node, children)
end_function_scope(node)
if node.is_predeclared_local_function then
return node.type
end
add_global(node, node.name.tk, ensure_fresh_typeargs(a_type({
y = node.y,
x = node.x,
typename = "function",
typeargs = node.typeargs,
args = children[2],
rets = get_rets(children[3]),
filename = filename,
})))
return node.type
end,
},
["record_function"] = {
before = function(node)
widen_all_unions()
begin_scope(node)
end,
before_statements = function(node, children)
add_internal_function_variables(node)
local rtype = resolve_tuple_and_nominal(resolve_typetype(children[1]))
if rtype.typename == "emptytable" then
rtype.typename = "record"
rtype.fields = {}
rtype.field_order = {}
end
if lax and rtype.typename == "unknown" then
return
end
if not is_record_type(rtype) then
node_error(node, "not a module: %s", rtype)
return
end
if node.is_method then
local selftype = get_self_type(node.fn_owner)
if not selftype then
node_error(node, "could not resolve type of self")
return
end
children[3][1] = selftype
add_var(nil, "self", selftype)
end
local fn_type = ensure_fresh_typeargs(a_type({
y = node.y,
x = node.x,
typename = "function",
is_method = node.is_method,
typeargs = node.typeargs,
args = children[3],
rets = get_rets(children[4]),
filename = filename,
}))
local open_t, open_v, owner_name = find_record_to_extend(node.fn_owner)
local open_k = owner_name .. "." .. node.name.tk
local rfieldtype = rtype.fields[node.name.tk]
if rfieldtype then
rfieldtype = resolve_tuple_and_nominal(rfieldtype)
if open_v and open_v.implemented and open_v.implemented[open_k] then
redeclaration_warning(node)
end
local ok, err = same_type(fn_type, rfieldtype)
if not ok then
if rfieldtype.typename == "poly" then
add_errs_prefixing(node, err, errors, "type signature does not match declaration: field has multiple function definitions (such polymorphic declarations are intended for Lua module interoperability)")
return
end
local shortname = node.fn_owner.type.typename == "nominal" and
show_type(node.fn_owner.type) or
owner_name
local msg = "type signature of '" .. node.name.tk .. "' does not match its declaration in " .. shortname .. ": "
add_errs_prefixing(node, err, errors, msg)
return
end
else
if lax or rtype == open_t then
rtype.fields[node.name.tk] = fn_type
table.insert(rtype.field_order, node.name.tk)
else
node_error(node, "cannot add undeclared function '" .. node.name.tk .. "' outside of the scope where '" .. owner_name .. "' was originally declared")
return
end
end
if open_v then
if not open_v.implemented then
open_v.implemented = {}
end
open_v.implemented[open_k] = true
end
node.name.type = fn_type
end,
after = function(node, _children)
end_function_scope(node)
node.type = NONE
return node.type
end,
},
["function"] = {
before = function(node)
widen_all_unions(node)
begin_scope(node)
end,
before_statements = function(node)
add_internal_function_variables(node)
end,
after = function(node, children)
end_function_scope(node)
node.type = ensure_fresh_typeargs(a_type({
y = node.y,
x = node.x,
typename = "function",
typeargs = node.typeargs,
args = children[1],
rets = children[2],
filename = filename,
}))
return node.type
end,
},
["cast"] = {
after = function(node, _children)
node.type = node.casttype
return node.type
end,
},
["paren"] = {
before = function(node)
node.e1.expected = node.expected
end,
after = function(node, children)
node.known = node.e1 and node.e1.known
node.type = resolve_tuple(children[1])
return node.type
end,
},
["op"] = {
before = function(node)
begin_scope()
if node.expected then
if node.op.op == "and" then
node.e2.expected = node.expected
elseif node.op.op == "or" then
node.e1.expected = node.expected
if not (node.e2.kind == "table_literal" and #node.e2 == 0) then
node.e2.expected = node.expected
end
end
end
end,
before_e2 = function(node)
if node.op.op == "and" then
apply_facts(node, node.e1.known)
elseif node.op.op == "or" then
apply_facts(node, facts_not(node, node.e1.known))
elseif node.op.op == "@funcall" then
if node.e1.type.typename == "function" then
local argdelta = (node.e1.op and node.e1.op.op == ":") and -1 or 0
if node.expected then
is_a(node.e1.type.rets, node.expected)
end
for i, typ in ipairs(node.e1.type.args) do
if node.e2[i + argdelta] then
node.e2[i + argdelta].expected = typ
end
end
end
elseif node.op.op == "@index" then
if node.e1.type.typename == "map" then
node.e2.expected = node.e1.type.keys
end
end
end,
after = function(node, children)
end_scope()
local a = children[1]
local b = children[3]
local orig_a = a
local orig_b = b
local ra = a and resolve_tuple_and_nominal(a)
local rb = b and resolve_tuple_and_nominal(b)
local expected = node.expected and resolve_tuple_and_nominal(node.expected)
if ra.typename == "circular_require" or (ra.def and ra.def.typename == "circular_require") then
node_error(node, "cannot dereference a type from a circular require")
node.type = INVALID
return node.type
end
if is_typetype(ra) and ra.def.typename == "record" then
ra = ra.def
end
if rb and is_typetype(rb) and rb.def.typename == "record" then
rb = rb.def
end
if node.op.op == "@funcall" then
if lax and is_unknown(a) then
if node.e1.op and node.e1.op.op == ":" and node.e1.e1.kind == "variable" then
add_unknown_dot(node, node.e1.e1.tk .. "." .. node.e1.e2.tk)
end
end
node.type = type_check_funcall(node, a, b)
elseif node.op.op == "." then
assert(node.e2.kind == "identifier")
local bnode = {
y = node.e2.y,
x = node.e2.x,
tk = node.e2.tk,
kind = "string",
conststr = node.e2.tk,
}
local btype = a_type({
y = node.e2.y,
x = node.e2.x,
tk = '"' .. node.e2.tk .. '"',
typename = "string",
})
node.type = type_check_index(node.e1, bnode, orig_a, btype)
if node.type.needs_compat and opts.gen_compat ~= "off" then
if node.e1.kind == "variable" and node.e2.kind == "identifier" then
local key = node.e1.tk .. "." .. node.e2.tk
node.kind = "variable"
node.tk = "_tl_" .. node.e1.tk .. "_" .. node.e2.tk
all_needs_compat[key] = true
end
end
elseif node.op.op == "@index" then
node.type = type_check_index(node.e1, node.e2, a, b)
elseif node.op.op == "as" then
node.type = b
elseif node.op.op == "is" then
if rb.typename == "integer" then
all_needs_compat["math"] = true
end
if ra.typename == "typetype" then
node_error(node, "can only use 'is' on variables, not types")
elseif node.e1.kind == "variable" then
node.known = Fact({ fact = "is", var = node.e1.tk, typ = b, where = node })
else
node_error(node, "can only use 'is' on variables")
end
node.type = BOOLEAN
elseif node.op.op == ":" then
if lax and (is_unknown(a) or a.typename == "typevar") then
if node.e1.kind == "variable" then
add_unknown_dot(node.e1, node.e1.tk .. "." .. node.e2.tk)
end
node.type = UNKNOWN
else
local t, e = match_record_key(a, node.e1, node.e2.conststr or node.e2.tk)
if not t then
node.type = INVALID
return node_error(node.e2, e, a == INVALID and a or resolve_tuple(orig_a))
end
node.type = t
end
elseif node.op.op == "not" then
node.known = facts_not(node, node.e1.known)
node.type = BOOLEAN
elseif node.op.op == "and" then
node.known = facts_and(node, node.e1.known, node.e2.known)
node.type = resolve_tuple(b)
elseif node.op.op == "or" and b.typename == "nil" then
node.known = nil
node.type = resolve_tuple(a)
elseif node.op.op == "or" and is_lua_table_type(ra) and b.typename == "emptytable" then
node.known = nil
node.type = resolve_tuple(a)
elseif node.op.op == "or" and
((ra.typename == "enum" and rb.typename == "string" and is_a(rb, ra)) or
(ra.typename == "string" and rb.typename == "enum" and is_a(ra, rb))) then
node.known = nil
node.type = (ra.typename == "enum" and ra or rb)
elseif node.op.op == "or" and expected and expected.typename == "union" then
node.known = facts_or(node, node.e1.known, node.e2.known)
local u = unite({ ra, rb }, true)
if u.typename == "union" then
u = validate_union(node, u)
end
node.type = u
elseif node.op.op == "or" and is_a(rb, ra) then
node.known = facts_or(node, node.e1.known, node.e2.known)
if expected then
local a_is = is_a(a, node.expected)
local b_is = is_a(b, node.expected)
if a_is and b_is then
node.type = node.expected
elseif a_is then
node.type = resolve_tuple(b)
else
node.type = resolve_tuple(a)
end
else
node.type = resolve_tuple(a)
end
node.type.tk = nil
elseif node.op.op == "==" or node.op.op == "~=" then
node.type = BOOLEAN
if is_a(b, a, true) or a.typename == "typevar" then
if node.op.op == "==" and node.e1.kind == "variable" then
node.known = Fact({ fact = "==", var = node.e1.tk, typ = b, where = node })
end
elseif is_a(a, b, true) or b.typename == "typevar" then
if node.op.op == "==" and node.e2.kind == "variable" then
node.known = Fact({ fact = "==", var = node.e2.tk, typ = a, where = node })
end
elseif lax and (is_unknown(a) or is_unknown(b)) then
node.type = UNKNOWN
else
return node_error(node, "types are not comparable for equality: %s and %s", a, b)
end
elseif node.op.arity == 1 and unop_types[node.op.op] then
a = ra
if a.typename == "union" then
a = unite(a.types, true)
end
local types_op = unop_types[node.op.op]
node.type = types_op[a.typename]
local meta_on_operator
if not node.type then
node.type, meta_on_operator = check_metamethod(node, node.op.op, a)
if not node.type then
return node_error(node, "cannot use operator '" .. node.op.op:gsub("%%", "%%%%") .. "' on type %s", resolve_tuple(orig_a))
end
end
if a.typename == "map" then
if a.keys.typename == "number" or a.keys.typename == "integer" then
node_warning("hint", node, "using the '#' operator on a map with numeric key type may produce unexpected results")
else
return node_error(node, "using the '#' operator on this map will always return 0")
end
end
if node.type.typename ~= "boolean" and not is_unknown(node.type) then
node.known = FACT_TRUTHY
end
if node.op.op == "~" and env.gen_target == "5.1" then
if meta_on_operator then
all_needs_compat["mt"] = true
convert_node_to_compat_mt_call(node, unop_to_metamethod[node.op.op], 1, node.e1)
else
all_needs_compat["bit32"] = true
convert_node_to_compat_call(node, "bit32", "bnot", node.e1)
end
end
elseif node.op.arity == 2 and binop_types[node.op.op] then
if node.op.op == "or" then
node.known = facts_or(node, node.e1.known, node.e2.known)
end
a = ra
b = rb
if a.typename == "union" then
a = unite(a.types, true)
end
if b.typename == "union" then
b = unite(b.types, true)
end
local types_op = binop_types[node.op.op]
node.type = types_op[a.typename] and types_op[a.typename][b.typename]
local meta_on_operator
if not node.type then
node.type, meta_on_operator = check_metamethod(node, node.op.op, a, b)
if not node.type then
return node_error(node, "cannot use operator '" .. node.op.op:gsub("%%", "%%%%") .. "' for types %s and %s", resolve_tuple(orig_a), resolve_tuple(orig_b))
end
end
if types_op == numeric_binop or node.op.op == ".." then
node.known = FACT_TRUTHY
end
if node.op.op == "//" and env.gen_target == "5.1" then
if meta_on_operator then
all_needs_compat["mt"] = true
convert_node_to_compat_mt_call(node, "__idiv", meta_on_operator, node.e1, node.e2)
else
local div = { y = node.y, x = node.x, kind = "op", op = an_operator(node, 2, "/"), e1 = node.e1, e2 = node.e2 }
convert_node_to_compat_call(node, "math", "floor", div)
end
elseif bit_operators[node.op.op] and env.gen_target == "5.1" then
if meta_on_operator then
all_needs_compat["mt"] = true
convert_node_to_compat_mt_call(node, binop_to_metamethod[node.op.op], meta_on_operator, node.e1, node.e2)
else
all_needs_compat["bit32"] = true
convert_node_to_compat_call(node, "bit32", bit_operators[node.op.op], node.e1, node.e2)
end
end
else
error("unknown node op " .. node.op.op)
end
return node.type
end,
},
["variable"] = {
after = function(node, _children)
if node.tk == "..." then
local va_sentinel = find_var_type("@is_va")
if not va_sentinel or va_sentinel.typename == "nil" then
return node_error(node, "cannot use '...' outside a vararg function")
end
end
if node.tk == "_G" then
node.type, node.attribute = simulate_g()
else
local use = node.is_lvalue and "lvalue" or "use"
node.type, node.attribute = find_var_type(node.tk, use)
end
if node.type and is_typetype(node.type) then
node.type = a_type({
y = node.y,
x = node.x,
typename = "nominal",
names = { node.tk },
found = node.type,
resolved = node.type,
})
end
if node.type == nil then
node.type = a_type({ typename = "unknown" })
if lax then
add_unknown(node, node.tk)
else
return node_error(node, "unknown variable: " .. node.tk)
end
end
return node.type
end,
},
["type_identifier"] = {
after = function(node, _children)
node.type, node.attribute = find_var_type(node.tk)
if node.type == nil then
if lax then
node.type = UNKNOWN
add_unknown(node, node.tk)
else
return node_error(node, "unknown variable: " .. node.tk)
end
end
return node.type
end,
},
["argument"] = {
after = function(node, children)
local t = children[1]
if not t then
t = UNKNOWN
end
if node.tk == "..." then
t = a_type({ typename = "tuple", is_va = true, t })
end
add_var(node, node.tk, t).is_func_arg = true
return node.type
end,
},
["identifier"] = {
after = function(node, _children)
node.type = node.type or NONE
return node.type
end,
},
["newtype"] = {
after = function(node, _children)
node.type = node.type or node.newtype
return node.type
end,
},
["error_node"] = {
after = function(node, _children)
node.type = INVALID
return node.type
end,
},
}
visit_node.cbs["break"] = {
after = function(node, _children)
node.type = NONE
return node.type
end,
}
visit_node.cbs["do"] = visit_node.cbs["break"]
local function after_literal(node)
node.type = a_type({
y = node.y,
x = node.x,
typename = node.kind,
tk = node.tk,
})
node.known = FACT_TRUTHY
return node.type
end
visit_node.cbs["string"] = {
after = function(node, _children)
after_literal(node)
if node.expected then
if node.expected.typename == "enum" and is_a(node.type, node.expected) then
node.type = node.expected
end
end
return node.type
end,
}
visit_node.cbs["number"] = { after = after_literal }
visit_node.cbs["integer"] = { after = after_literal }
visit_node.cbs["boolean"] = {
after = function(node, _children)
after_literal(node)
node.known = (node.tk == "true") and FACT_TRUTHY or nil
return node.type
end,
}
visit_node.cbs["nil"] = visit_node.cbs["boolean"]
visit_node.cbs["..."] = visit_node.cbs["variable"]
visit_node.cbs["argument_list"] = visit_node.cbs["variable_list"]
visit_node.cbs["expression_list"] = visit_node.cbs["variable_list"]
visit_node.after = function(node, _children)
if type(node.type) ~= "table" then
error(node.kind .. " did not produce a type")
end
if type(node.type.typename) ~= "string" then
error(node.kind .. " type does not have a typename")
end
return node.type
end
local visit_type = {
cbs = {
["string"] = {
after = function(typ, _children)
return typ
end,
},
["function"] = {
before = function(_typ, _children)
begin_scope()
end,
after = function(typ, _children)
end_scope()
return ensure_fresh_typeargs(typ)
end,
},
["record"] = {
before = function(typ, _children)
begin_scope()
for name, typ2 in fields_of(typ) do
if typ2.typename == "typetype" then
typ2.typename = "nestedtype"
local resolved, is_alias = resolve_nominal_typetype(typ2)
if is_alias then
typ2.is_alias = true
typ2.def.resolved = resolved
end
add_var(nil, name, resolved)
end
end
end,
after = function(typ, children)
end_scope()
local i = 1
if typ.typeargs then
for _, _ in ipairs(typ.typeargs) do
typ.typeargs[i] = children[i]
i = i + 1
end
end
if typ.elements then
typ.elements = children[i]
i = i + 1
end
for name, _ in fields_of(typ) do
local ftype = children[i]
if ftype.typename == "nestedtype" then
ftype.typename = "typetype"
end
typ.fields[name] = ftype
i = i + 1
end
for name, _ in fields_of(typ, "meta") do
local ftype = children[i]
if ftype.typename == "nestedtype" then
ftype.typename = "typetype"
end
typ.meta_fields[name] = ftype
i = i + 1
end
return typ
end,
},
["typearg"] = {
after = function(typ, _children)
add_var(nil, typ.typearg, a_type({
y = typ.y,
x = typ.x,
typename = "typearg",
typearg = typ.typearg,
}))
return typ
end,
},
["typevar"] = {
after = function(typ, _children)
if not find_var_type(typ.typevar) then
type_error(typ, "undefined type variable " .. typ.typevar)
end
return typ
end,
},
["nominal"] = {
after = function(typ, _children)
if typ.found then
return typ
end
local t = find_type(typ.names, true)
if t then
if t.typename == "typearg" then
typ.names = nil
typ.typename = "typevar"
typ.typevar = t.typearg
else
if t.is_alias then
t = t.def.resolved
end
if not (t.def and t.def.typename == "circular_require") then
typ.found = t
end
end
else
local name = typ.names[1]
local unresolved = get_unresolved()
unresolved.nominals[name] = unresolved.nominals[name] or {}
table.insert(unresolved.nominals[name], typ)
end
return typ
end,
},
["union"] = {
after = function(typ, _children)
return (validate_union(typ, typ))
end,
},
},
after = function(typ, _children, ret)
if type(ret) ~= "table" then
error(typ.typename .. " did not produce a type")
end
if type(ret.typename) ~= "string" then
error("type node does not have a typename")
end
return ret
end,
}
if not opts.run_internal_compiler_checks then
visit_node.after = nil
visit_type.after = nil
end
visit_type.cbs["tupletable"] = visit_type.cbs["string"]
visit_type.cbs["typetype"] = visit_type.cbs["string"]
visit_type.cbs["nestedtype"] = visit_type.cbs["string"]
visit_type.cbs["array"] = visit_type.cbs["string"]
visit_type.cbs["map"] = visit_type.cbs["string"]
visit_type.cbs["arrayrecord"] = visit_type.cbs["record"]
visit_type.cbs["enum"] = visit_type.cbs["string"]
visit_type.cbs["boolean"] = visit_type.cbs["string"]
visit_type.cbs["nil"] = visit_type.cbs["string"]
visit_type.cbs["number"] = visit_type.cbs["string"]
visit_type.cbs["integer"] = visit_type.cbs["string"]
visit_type.cbs["thread"] = visit_type.cbs["string"]
visit_type.cbs["bad_nominal"] = visit_type.cbs["string"]
visit_type.cbs["emptytable"] = visit_type.cbs["string"]
visit_type.cbs["table_item"] = visit_type.cbs["string"]
visit_type.cbs["unresolved_emptytable_value"] = visit_type.cbs["string"]
visit_type.cbs["tuple"] = visit_type.cbs["string"]
visit_type.cbs["poly"] = visit_type.cbs["string"]
visit_type.cbs["any"] = visit_type.cbs["string"]
visit_type.cbs["unknown"] = visit_type.cbs["string"]
visit_type.cbs["invalid"] = visit_type.cbs["string"]
visit_type.cbs["unresolved"] = visit_type.cbs["string"]
visit_type.cbs["none"] = visit_type.cbs["string"]
assert(ast.kind == "statements")
recurse_node(ast, visit_node, visit_type)
close_types(st[1])
check_for_unused_vars(st[1])
clear_redundant_errors(errors)
add_compat_entries(ast, all_needs_compat, env.gen_compat)
local result = {
ast = ast,
env = env,
type = module_type or BOOLEAN,
filename = filename,
warnings = warnings,
type_errors = errors,
symbol_list = symbol_list,
dependencies = dependencies,
}
env.loaded[filename] = result
table.insert(env.loaded_order, filename)
if opts.module_name then
env.modules[opts.module_name] = result.type
end
return result
end
local typename_to_typecode = {
["typevar"] = tl.typecodes.TYPE_VARIABLE,
["typearg"] = tl.typecodes.TYPE_VARIABLE,
["unresolved_typearg"] = tl.typecodes.TYPE_VARIABLE,
["unresolvable_typearg"] = tl.typecodes.TYPE_VARIABLE,
["function"] = tl.typecodes.FUNCTION,
["array"] = tl.typecodes.ARRAY,
["map"] = tl.typecodes.MAP,
["tupletable"] = tl.typecodes.TUPLE,
["arrayrecord"] = tl.typecodes.ARRAYRECORD,
["record"] = tl.typecodes.RECORD,
["enum"] = tl.typecodes.ENUM,
["boolean"] = tl.typecodes.BOOLEAN,
["string"] = tl.typecodes.STRING,
["nil"] = tl.typecodes.NIL,
["thread"] = tl.typecodes.THREAD,
["number"] = tl.typecodes.NUMBER,
["integer"] = tl.typecodes.INTEGER,
["union"] = tl.typecodes.IS_UNION,
["nominal"] = tl.typecodes.NOMINAL,
["bad_nominal"] = tl.typecodes.NOMINAL,
["circular_require"] = tl.typecodes.NOMINAL,
["emptytable"] = tl.typecodes.EMPTY_TABLE,
["unresolved_emptytable_value"] = tl.typecodes.EMPTY_TABLE,
["poly"] = tl.typecodes.IS_POLY,
["any"] = tl.typecodes.ANY,
["unknown"] = tl.typecodes.UNKNOWN,
["invalid"] = tl.typecodes.INVALID,
["none"] = tl.typecodes.UNKNOWN,
["tuple"] = tl.typecodes.UNKNOWN,
["table_item"] = tl.typecodes.UNKNOWN,
["unresolved"] = tl.typecodes.UNKNOWN,
["typetype"] = tl.typecodes.UNKNOWN,
["nestedtype"] = tl.typecodes.UNKNOWN,
}
local skip_types = {
["none"] = true,
["tuple"] = true,
["table_item"] = true,
["unresolved"] = true,
["typetype"] = true,
["nestedtype"] = true,
}
function tl.get_types(result, trenv)
local filename = result.filename or "?"
local function mark_array(x)
local arr = x
arr[0] = false
return x
end
if not trenv then
trenv = {
next_num = 1,
typeid_to_num = {},
tr = {
by_pos = {},
types = {},
symbols = mark_array({}),
globals = {},
},
}
end
local tr = trenv.tr
local typeid_to_num = trenv.typeid_to_num
local get_typenum
local function store_function(ti, rt)
local args = {}
for _, fnarg in ipairs(rt.args) do
table.insert(args, mark_array({ get_typenum(fnarg), nil }))
end
ti.args = mark_array(args)
local rets = {}
for _, fnarg in ipairs(rt.rets) do
table.insert(rets, mark_array({ get_typenum(fnarg), nil }))
end
ti.rets = mark_array(rets)
ti.vararg = not not rt.is_va
end
get_typenum = function(t)
assert(t.typeid)
local n = typeid_to_num[t.typeid]
if n then
return n
end
n = trenv.next_num
local rt = t
if is_typetype(rt) then
rt = rt.def
elseif rt.typename == "tuple" and #rt == 1 then
rt = rt[1]
end
local ti = {
t = assert(typename_to_typecode[rt.typename]),
str = show_type(t, true),
file = t.filename,
y = t.y,
x = t.x,
}
tr.types[n] = ti
typeid_to_num[t.typeid] = n
trenv.next_num = trenv.next_num + 1
if t.found then
ti.ref = get_typenum(t.found)
end
if t.resolved then
rt = t
end
assert(not is_typetype(rt))
if is_record_type(rt) then
local r = {}
for _, k in ipairs(rt.field_order) do
local v = rt.fields[k]
r[k] = get_typenum(v)
end
ti.fields = r
end
if is_array_type(rt) then
ti.elements = get_typenum(rt.elements)
end
if rt.typename == "map" then
ti.keys = get_typenum(rt.keys)
ti.values = get_typenum(rt.values)
elseif rt.typename == "enum" then
ti.enums = mark_array(sorted_keys(rt.enumset))
elseif rt.typename == "function" then
store_function(ti, rt)
elseif rt.typename == "poly" or rt.typename == "union" or rt.typename == "tupletable" then
local tis = {}
for _, pt in ipairs(rt.types) do
table.insert(tis, get_typenum(pt))
end
ti.types = mark_array(tis)
end
return n
end
local visit_node = { allow_missing_cbs = true }
local visit_type = { allow_missing_cbs = true }
local ft = {}
tr.by_pos[filename] = ft
local function store(y, x, typ)
if not typ or skip_types[typ.typename] then
return
end
local yt = ft[y]
if not yt then
yt = {}
ft[y] = yt
end
yt[x] = get_typenum(typ)
end
visit_node.after = function(node)
store(node.y, node.x, node.type)
end
visit_type.after = function(typ)
store(typ.y or 0, typ.x or 0, typ)
end
recurse_node(result.ast, visit_node, visit_type)
tr.by_pos[filename][0] = nil
do
local n = 0
local p = 0
local n_stack, p_stack = {}, {}
local level = 0
for i, s in ipairs(result.symbol_list) do
if s.typ then
n = n + 1
elseif s.name == "@{" then
level = level + 1
n_stack[level], p_stack[level] = n, p
n, p = 0, i
else
if n == 0 then
result.symbol_list[p].skip = true
s.skip = true
end
n, p = n_stack[level], p_stack[level]
level = level - 1
end
end
end
do
local stack = {}
local level = 0
local i = 0
for _, s in ipairs(result.symbol_list) do
if not s.skip then
i = i + 1
local id
if s.typ then
id = get_typenum(s.typ)
elseif s.name == "@{" then
level = level + 1
stack[level] = i
id = -1
else
local other = stack[level]
level = level - 1
tr.symbols[other][4] = i
id = other - 1
end
local sym = mark_array({ s.y, s.x, s.name, id })
table.insert(tr.symbols, sym)
end
end
end
local gkeys = sorted_keys(result.env.globals)
for _, name in ipairs(gkeys) do
if name:sub(1, 1) ~= "@" then
local var = result.env.globals[name]
tr.globals[name] = get_typenum(var.t)
end
end
return tr, trenv
end
function tl.symbols_in_scope(tr, y, x)
local function find(symbols, at_y, at_x)
local function le(a, b)
return a[1] < b[1] or
(a[1] == b[1] and a[2] <= b[2])
end
return binary_search(symbols, { at_y, at_x }, le) or 0
end
local ret = {}
local n = find(tr.symbols, y, x)
local symbols = tr.symbols
while n >= 1 do
local s = symbols[n]
if s[3] == "@{" then
n = n - 1
elseif s[3] == "@}" then
n = s[4]
else
ret[s[3]] = s[4]
n = n - 1
end
end
return ret
end
local function read_full_file(fd)
local bom = "\xEF\xBB\xBF"
local content, err = fd:read("*a")
if content:sub(1, bom:len()) == bom then
content = content:sub(bom:len() + 1)
end
return content, err
end
tl.process = function(filename, env, module_name, fd)
if env and env.loaded and env.loaded[filename] then
return env.loaded[filename]
end
local input, err
if not fd then
fd, err = io.open(filename, "rb")
if not fd then
return nil, "could not open " .. filename .. ": " .. err
end
end
input, err = read_full_file(fd)
fd:close()
if not input then
return nil, "could not read " .. filename .. ": " .. err
end
local _, extension = filename:match("(.*)%.([a-z]+)$")
extension = extension and extension:lower()
local is_lua
if extension == "tl" then
is_lua = false
elseif extension == "lua" then
is_lua = true
else
is_lua = input:match("^#![^\n]*lua[^\n]*\n")
end
return tl.process_string(input, is_lua, env, filename, module_name)
end
function tl.process_string(input, is_lua, env, filename, module_name)
if filename and not module_name then
module_name = filename_to_module_name(filename)
end
env = env or tl.init_env(is_lua)
if env.loaded and env.loaded[filename] then
return env.loaded[filename]
end
filename = filename or ""
local program, syntax_errors = tl.parse(input, filename)
if (not env.keep_going) and #syntax_errors > 0 then
local result = {
ok = false,
filename = filename,
module_name = module_name,
type = BOOLEAN,
type_errors = {},
syntax_errors = syntax_errors,
env = env,
}
env.loaded[filename] = result
table.insert(env.loaded_order, filename)
return result
end
local opts = {
filename = filename,
module_name = module_name,
lax = is_lua,
gen_compat = env.gen_compat,
gen_target = env.gen_target,
env = env,
}
local result = tl.type_check(program, opts)
result.syntax_errors = syntax_errors
return result
end
tl.gen = function(input, env)
env = env or assert(tl.init_env(), "Default environment initialization failed")
local result = tl.process_string(input, false, env)
if (not result.ast) or #result.syntax_errors > 0 then
return nil, result
end
local code
code, result.gen_error = tl.pretty_print_ast(result.ast, env.gen_target)
return code, result
end
local function tl_package_loader(module_name)
local found_filename, fd, tried = tl.search_module(module_name, false)
if found_filename then
local input = read_full_file(fd)
if not input then
return table.concat(tried, "\n\t")
end
fd:close()
local program, errs = tl.parse(input, found_filename)
if #errs > 0 then
error(found_filename .. ":" .. errs[1].y .. ":" .. errs[1].x .. ": " .. errs[1].msg)
end
local lax = not not found_filename:match("lua$")
local env = tl.package_loader_env
if not env then
tl.package_loader_env = tl.init_env(lax)
env = tl.package_loader_env
end
tl.type_check(program, {
lax = lax,
filename = found_filename,
module_name = module_name,
env = env,
run_internal_compiler_checks = false,
})
local code = assert(tl.pretty_print_ast(program, env.gen_target, true))
local chunk, err = load(code, "@" .. found_filename, "t")
if chunk then
return function()
local ret = chunk()
package.loaded[module_name] = ret
return ret
end
else
error("Internal Compiler Error: Teal generator produced invalid Lua. Please report a bug at https://github.com/teal-language/tl\n\n" .. err)
end
end
return table.concat(tried, "\n\t")
end
function tl.loader()
if package.searchers then
table.insert(package.searchers, 2, tl_package_loader)
else
table.insert(package.loaders, 2, tl_package_loader)
end
end
function tl.target_from_lua_version(str)
if str == "Lua 5.1" or
str == "Lua 5.2" then
return "5.1"
elseif str == "Lua 5.3" then
return "5.3"
elseif str == "Lua 5.4" then
return "5.4"
end
end
local function env_for(lax, env_tbl)
if not env_tbl then
if not tl.package_loader_env then
tl.package_loader_env = tl.init_env(lax)
end
return tl.package_loader_env
end
if not tl.load_envs then
tl.load_envs = setmetatable({}, { __mode = "k" })
end
tl.load_envs[env_tbl] = tl.load_envs[env_tbl] or tl.init_env(lax)
return tl.load_envs[env_tbl]
end
tl.load = function(input, chunkname, mode, ...)
local program, errs = tl.parse(input, chunkname)
if #errs > 0 then
return nil, (chunkname or "") .. ":" .. errs[1].y .. ":" .. errs[1].x .. ": " .. errs[1].msg
end
local lax = chunkname and not not chunkname:match("lua$")
if not tl.package_loader_env then
tl.package_loader_env = tl.init_env(lax)
end
local result = tl.type_check(program, {
lax = lax,
filename = chunkname or ("string \"" .. input:sub(45) .. (#input > 45 and "..." or "") .. "\""),
env = env_for(lax, ...),
run_internal_compiler_checks = false,
})
if mode and mode:match("c") then
if #result.type_errors > 0 then
local errout = {}
for _, err in ipairs(result.type_errors) do
table.insert(errout, err.filename .. ":" .. err.y .. ":" .. err.x .. ": " .. (err.msg or ""))
end
return nil, table.concat(errout, "\n")
end
mode = mode:gsub("c", "")
end
local code, err = tl.pretty_print_ast(program, tl.target_from_lua_version(_VERSION), true)
if not code then
return nil, err
end
return load(code, chunkname, mode, ...)
end
return tl