rutile/compiler/parse.c
2025-01-30 23:37:17 -03:00

670 lines
16 KiB
C

/* Recursive descent parser + Pratt parser (for expressions)
* TODO:
* - DRY code that handle list of tokens, I have like three almost identical functions for that.
* - Use an arena for the AST nodes. Nuke all of them with a single call
* when we no longer need the AST.
*/
#include <stdlib.h>
#include "ast.h"
#include "pre.h"
#include "parse.h"
#include "lex.h"
#include "state.h"
#include "messages.h"
#include "libs/stb_ds.h"
#define MAX_STMTS_IN_BLOCK 2000
#define MAX_PROC_ARG_COUNT 127
#define EXPR_INIT_PREC 1
/* Consume a token and match it */
#define next_match(lexer, tokt) \
do { LexToken t = lex_scan(lexer); lex_match(lexer, &t, tokt); } while (0)
/* Scans a token (mutating `t`), and if its id matches `ttype`,
* it executes the code block. Otherwise, the scanned token
* gets put back (so a next call to `lex_scan` can pick it up).
*/
#define matchopt(t, ttype, ps) \
if ((t = lex_scan(ps->lexer)).id != ttype) { \
lex_backup((ps)->lexer, t); \
} else
#define token_is_binop(t) (t >= T_PLUS && t <= T_NOTEQUAL)
#define token_is_atom(t) (t >= T_IDENT && t <= T_DECNUMBER)
#define token_is_unary(t) (t == T_MINUS || t == T_LOGNOT)
#define token_is_expr_start(t) (token_is_unary(t) || token_is_atom(t))
#define parse_error(ctx, ...) \
do { error((ctx)->cm, &((ctx)->lexer->cur_loc), __VA_ARGS__); (ctx)->ok = false; } while (0)
typedef Optional(AstIdentTypePair) OptAstIdentTypePair;
typedef struct {
int pred;
bool left_assoc; /* false if right assoc... */
} OperatorPrec;
/* Operator table specifying the precedence and associativeness
* of each operator, used by the expression parser.
* The precedence goes from lower to higher.
*/
const OperatorPrec OperatorTable[] = {
[T_LOGOR] = {1, true},
[T_LOGAND] = {2, true},
[T_LESSTHAN] = {3, true},
[T_GREATTHAN] = {3, true},
[T_LOGICEQUAL] = {3, true},
[T_NOTEQUAL] = {3, true},
[T_PLUS] = {4, true},
[T_MINUS] = {4, true},
[T_STAR] = {5, true},
[T_BAR] = {5, true},
};
static Ast *
expr(ParserState *ps, int minprec);
static Ast *
expr_comma_list(ParserState *ps);
static Ast *
stmt(ParserState *ps, LexToken token);
static Ast *
stmt_list_until(ParserState *ps, bool putback, const enum LexTokenId *end_markers, isize len);
static Ast *
make_tree(enum AstType type, Location loc)
{
Ast *tree = calloc(1, sizeof(Ast));
tree->type = type;
tree->loc = loc;
return tree;
}
static Ast *
make_binop(enum LexTokenId op, Location loc, Ast *lhs, Ast *rhs)
{
Ast *tree = make_tree(AST_BINEXPR, loc);
tree->bin.op = Str_from_c(TokenIdStr[op]);
tree->bin.left = lhs;
tree->bin.right = rhs;
return tree;
}
static Ast *
make_ident_node(Str ident, Location loc)
{
Ast *tree = make_tree(AST_IDENT, loc);
tree->ident = ident;
return tree;
}
static OptAstIdentTypePair
ident_type_pair(ParserState *ps)
{
AstIdentTypePair itp = {0};
/* ident */
LexToken token = lex_scan(ps->lexer);
lex_match(ps->lexer, &token, T_IDENT);
itp.ident = token.ident;
itp.ident_loc = token.loc;
/* type */
next_match(ps->lexer, T_COLON);
/* optional qualifier */
token = lex_scan(ps->lexer);
if (token.id == T_VAR) {
itp.kind = SymVar;
} else {
itp.kind = SymLet;
lex_backup(ps->lexer, token);
}
token = lex_scan(ps->lexer);
if (token.id != T_IDENT) {
parse_error(ps, "expected a type, got %s instead", TokenIdStr[token.id]);
return None(OptAstIdentTypePair);
}
itp.dtype_loc = token.loc;
itp.dtype = token.ident;
return Some(OptAstIdentTypePair, itp);
}
static Vec(AstIdentTypePair)
proc_arglist(ParserState *ps)
{
Vec(AstIdentTypePair) args = nil;
LexToken next;
for (;;) {
OptAstIdentTypePair oitp = ident_type_pair(ps);
if (!oitp.ok)
return nil;
if (arrlen(args) + 1 > MAX_PROC_ARG_COUNT) {
parse_error(ps, "more than %d (implementation limit) proc arguments", MAX_PROC_ARG_COUNT);
return nil;
}
arrput(args, oitp.val);
next = lex_scan(ps->lexer);
/* do we have a comma? if not, we reached the end of the list */
if (next.id != T_COMMA)
break;
/* check if we have an expression next to this comma, we do this
* to allow a trailling comma
*/
next = lex_scan(ps->lexer);
if (next.id != T_IDENT)
break;
lex_backup(ps->lexer, next);
}
trace("token in arglist out: %s\n", TokenIdStr[next.id]);
lex_backup(ps->lexer, next);
if (arrlen(args) == 0) {
arrfree(args);
return nil;
}
return args;
}
static Ast *
proc_decl(ParserState *ps)
{
LexToken proc_name = lex_scan(ps->lexer);
lex_match(ps->lexer, &proc_name, T_IDENT);
Ast *proc = make_tree(AST_PROCDEF, ps->lexer->cur_loc);
proc->proc.name = proc_name.ident;
trace("proc name: %s\n", proc->proc.name.s);
LexToken token = lex_scan(ps->lexer);
if (token.id == T_STAR) {
proc->proc.ispublic = true;
token = lex_scan(ps->lexer);
}
lex_match(ps->lexer, &token, T_LPAREN);
token = lex_scan(ps->lexer);
if (token.id != T_RPAREN) {
lex_backup(ps->lexer, token);
proc->proc.args = proc_arglist(ps);
token = lex_scan(ps->lexer);
}
lex_match(ps->lexer, &token, T_RPAREN);
/* return type */
token = lex_scan(ps->lexer);
if (token.id == T_COLON) {
token = lex_scan(ps->lexer);
lex_match(ps->lexer, &token, T_IDENT);
proc->proc.rettype = make_ident_node(token.ident, ps->lexer->cur_loc);
} else {
lex_backup(ps->lexer, token);
}
/* body */
proc->proc.body = stmt_list_until(ps, false, (enum LexTokenId[]){T_RBRACE}, 1);
return proc;
}
static Ast *
function_call(ParserState *ps, Str ident, bool ate_lp)
{
Ast *funcc = make_tree(AST_PROCCALL, ps->lexer->cur_loc);
funcc->call = (AstProcCall){ .name = ident };
if (!ate_lp)
next_match(ps->lexer, T_LPAREN);
LexToken next = lex_scan(ps->lexer);
if (token_is_expr_start(next.id)) {
lex_backup(ps->lexer, next);
funcc->call.args = expr_comma_list(ps);
} else {
lex_backup(ps->lexer, next);
}
next_match(ps->lexer, T_RPAREN);
trace("function call to: %s\n", ident.s);
return funcc;
}
static Ast *
variable_assign(ParserState *ps, Str ident, Location loc)
{
Ast *tree = make_tree(AST_VARASSIGN, loc);
tree->varassgn.name = ident;
tree->varassgn.expr = expr(ps, EXPR_INIT_PREC);
return tree;
}
static Ast *
funccall_or_assignment(ParserState *ps, Str ident)
{
LexToken token;
matchopt(token, T_EQUAL, ps) {
return variable_assign(ps, ident, ps->lexer->cur_loc);
}
return function_call(ps, ident, false);
}
static Ast *
variable_decl(ParserState *ps, enum LexTokenId decl_kind)
{
static const enum SymbolKind Token2SemaVarKind[] = {
[T_LET] = SymLet,
[T_VAR] = SymVar,
[T_CONST] = SymConst,
};
Assert(decl_kind == T_LET || decl_kind == T_VAR || decl_kind == T_CONST);
Ast *decl = make_tree(AST_VARDECL, ps->lexer->cur_loc);
LexToken token = lex_scan(ps->lexer);
lex_match(ps->lexer, &token, T_IDENT);
decl->var = (AstVarDecl) {
.name = token.ident,
.kind = Token2SemaVarKind[decl_kind],
};
/* type */
matchopt(token, T_COLON, ps) {
token = lex_scan(ps->lexer);
if (token.id != T_IDENT) {
parse_error(ps, "expected a type, got %s instead", TokenIdStr[token.id]);
return nil;
}
decl->var.datatype = make_ident_node(token.ident, token.loc);
}
/* assignment expression */
matchopt(token, T_EQUAL, ps) {
trace("assignment of decl here\n");
decl->var.expr = expr(ps, EXPR_INIT_PREC);
}
trace(
"var decl %s %s: %s\n",
TokenIdStr[decl_kind],
decl->var.name.s,
decl->var.datatype != nil ? (char *)decl->var.datatype->ident.s : "(no type)"
);
/* if there's no type there must be an expr */
/* TODO: move to semantic analysis phase? */
if (decl->var.datatype == nil && decl->var.expr == nil) {
parse_error(
ps,
"'%s' declaration must have an assignment expression if no type is specified, "
"but neither a type nor expression was supplied",
TokenIdStr[decl_kind]
);
return nil;
}
return decl;
}
static Ast *
return_stmt(ParserState *ps)
{
Ast *tree = make_tree(AST_RETURN, ps->lexer->cur_loc);
LexToken next = lex_scan(ps->lexer);
if (token_is_expr_start(next.id)) {
lex_backup(ps->lexer, next);
tree->ret = expr(ps, EXPR_INIT_PREC);
} else {
lex_backup(ps->lexer, next);
}
return tree;
}
static Ast *
break_stmt(ParserState *ps)
{
return make_tree(AST_BREAK, ps->lexer->cur_loc);
}
static Ast *
discard_stmt(ParserState *ps)
{
Ast *tree = make_tree(AST_DISCARD, ps->lexer->cur_loc);
tree->discard.expr = expr(ps, EXPR_INIT_PREC);
return tree;
}
static Ast *
parse_pragma(ParserState *ps)
{
Ast *tree = make_tree(AST_PRAGMA, ps->lexer->cur_loc);
LexToken next = lex_scan(ps->lexer);
lex_match(ps->lexer, &next, T_LBRACKET);
next = lex_scan(ps->lexer);
lex_match(ps->lexer, &next, T_RBRACKET);
return tree;
}
/* A declaration "decorated" with a pragma */
static Ast *
decorated_decl(ParserState *ps)
{
Ast *attr = parse_pragma(ps);
LexToken next = lex_scan(ps->lexer);
switch (next.id) {
case T_PROC:
attr->pragma.node = proc_decl(ps);
break;
case T_CONST:
case T_LET:
case T_VAR:
attr->pragma.node = variable_decl(ps, next.id);
break;
default:
parse_error(ps, "node of kind '%s' cannot have a pragma", TokenIdStr[next.id]);
return nil;
}
return attr;
}
static Ast *
if_stmt_expr(ParserState *ps)
{
const enum LexTokenId if_block_ends[] = {T_ELSE, T_ELIF, T_RBRACE};
Ast *tree = make_tree(AST_IF, ps->lexer->cur_loc);
/* parse `if` */
tree->ifse.cond = expr(ps, EXPR_INIT_PREC);
tree->ifse.true_body = stmt_list_until(ps, true, if_block_ends, countof(if_block_ends));
tree->ifse.false_body = nil;
LexToken next = lex_scan(ps->lexer);
AstElif elif_tree;
/* parse `elif`s and else */
for (;;) {
switch (next.id) {
case T_END: /* only has true branch */
return tree;
case T_ELSE:
/* once we see an `else` block, we assume the end of the `if` block,
* enforcing that `else` must be the last. */
trace("we got else\n");
tree->ifse.false_body = stmt_list_until(ps, true, (enum LexTokenId[]){T_ELIF, T_RBRACE}, 2);
next = lex_scan(ps->lexer);
if (next.id == T_ELIF) {
parse_error(ps, "'elif' branch after 'else' branch not allowed");
lex_backup(ps->lexer, next);
return nil;
}
return tree;
case T_ELIF:
trace("we got elif\n");
elif_tree.cond = expr(ps, EXPR_INIT_PREC);
elif_tree.body = stmt_list_until(ps, true, if_block_ends, countof(if_block_ends));
next = lex_scan(ps->lexer);
arrput(tree->ifse.elifs, elif_tree);
/* no more `elif` blocks neither an `else` block next */
if (next.id == T_END)
return tree;
Assert(next.id == T_ELSE || next.id == T_ELIF);
break;
default:
lex_backup(ps->lexer, next);
parse_error(ps, "expected 'elif' or 'else', got '%s'", TokenIdStr[next.id]);
return nil;
}
}
return tree;
}
static Ast *
while_stmt(ParserState *ps)
{
Ast *tree = make_tree(AST_LOOP, ps->lexer->cur_loc);
tree->loop.precond = expr(ps, EXPR_INIT_PREC);
tree->loop.body = stmt_list_until(ps, false, (enum LexTokenId[]){T_RBRACE}, 1);
return tree;
}
static Ast *
atom(ParserState *ps)
{
Ast *tree = nil;
LexToken t = lex_scan(ps->lexer);
LexToken next;
switch (t.id) {
case T_NUMBER:
tree = make_tree(AST_NUMBER, ps->lexer->cur_loc);
tree->number.n = t.inumber;
trace("number in atom: %lu\n", t.inumber);
return tree;
case T_STRING:
tree = make_tree(AST_STRLIT, ps->lexer->cur_loc);
tree->strlit = t.str;
return tree;
case T_IDENT:
next = lex_scan(ps->lexer);
/* It is a plain symbol or a function call? */
if (next.id == T_LPAREN) {
tree = function_call(ps, t.ident, true);
} else {
lex_backup(ps->lexer, next);
tree = make_tree(AST_IDENT, ps->lexer->cur_loc);
tree->ident = t.ident;
}
return tree;
case T_LPAREN:
tree = expr(ps, EXPR_INIT_PREC);
next_match(ps->lexer, T_RPAREN);
return tree;
default:
parse_error(ps, "expected a number, identifier or expression, not '%s'", TokenIdStr[t.id]);
}
return nil;
}
static Ast *
unary(ParserState *ps)
{
LexToken next = lex_scan(ps->lexer);
if (token_is_unary(next.id)) {
Ast *unt = make_tree(AST_UNARY, ps->lexer->cur_loc);
unt->unary.op = Str_from_c(TokenIdStr[next.id]);
unt->unary.atom = atom(ps);
return unt;
}
lex_backup(ps->lexer, next);
return atom(ps);
}
/* Parse a binary expression or an atom. This implements the Pratt parser algorithm.
* See also:
* - https://eli.thegreenplace.net/2012/08/02/parsing-expressions-by-precedence-climbing
* - https://www.oilshell.org/blog/2016/11/01.html
* XXX: Mutate to the shunting yard variation? Since it uses an explicit stack instead of the call
* stack, guard against deeply nested expressions.
*/
static Ast *
expr(ParserState *ps, int minprec)
{
Ast *tree = unary(ps);
for (;;) {
LexToken t = lex_scan(ps->lexer);
if (!token_is_binop(t.id)
|| t.id == T_END
|| t.id == T_RPAREN
|| OperatorTable[t.id].pred < minprec) {
lex_backup(ps->lexer, t);
break;
}
const OperatorPrec op = OperatorTable[t.id];
const int next_prec = op.left_assoc ? op.pred + 1 : op.pred;
tree = make_binop(t.id, ps->lexer->cur_loc, tree, expr(ps, next_prec));
}
return tree;
}
static Vec(Ast *)
sep_list(ParserState *ps, Ast *(*prod_fn)(Compiler *, void *))
{
(void)ps, (void)prod_fn;
Vec(Ast *) prod = nil;
return prod;
}
static Ast *
expr_comma_list(ParserState *ps)
{
Ast *tree = make_tree(AST_EXPRS, ps->lexer->cur_loc);
Vec(Ast *) exprs = nil;
LexToken next;
for (;;) {
arrput(exprs, expr(ps, EXPR_INIT_PREC));
next = lex_scan(ps->lexer);
trace("commalist tok: %s\n", TokenIdStr[next.id]);
/* do we have a comma? if not, we reached the end of the list */
if (next.id != T_COMMA)
break;
next = lex_scan(ps->lexer);
/* check if we have an expression next to this comma, we do this
* to allow a trailling comma
*/
if (!token_is_expr_start(next.id))
break;
lex_backup(ps->lexer, next);
}
lex_backup(ps->lexer, next);
if (arrlen(exprs) == 0) {
free(tree);
arrfree(exprs);
return nil;
}
tree->exprs = exprs;
return tree;
}
static bool
token_id_in_list(enum LexTokenId c, const enum LexTokenId *toks, isize len)
{
for (isize i = 0; i < len; ++i)
if (c == toks[i])
return true;
return false;
}
/* Parses a statement list until the token `end_marker`. Returns `nil` if the statement list
* is empty. */
static Ast *
stmt_list_until(ParserState *ps, bool putback, const enum LexTokenId *end_markers, isize len)
{
next_match(ps->lexer, T_LBRACE);
LexToken token = lex_scan(ps->lexer);
Vec(Ast *) stmts = nil;
Ast *body = make_tree(AST_STMTS, ps->lexer->cur_loc);
/* stmt* */
while (!token_id_in_list(token.id, end_markers, len)) {
trace("stmt list token: %s\n", TokenIdStr[token.id]);
if (arrlen(stmts) + 1 > MAX_STMTS_IN_BLOCK) {
parse_error(ps, "more than %d (implementation limit) statements in block", MAX_STMTS_IN_BLOCK);
return nil;
}
arrput(stmts, stmt(ps, token));
token = lex_scan(ps->lexer);
if (token.id == T_EOF) {
parse_error(ps, "unexpected EOF, expected a statement or `end`");
break;
}
if (token.id == T_SEMICOLON)
token = lex_scan(ps->lexer);
}
//lex_match(ps->lexer, &token, end_marker);
trace("token before end next_match: %s\n", TokenIdStr[token.id]);
if (putback)
lex_backup(ps->lexer, token);
/* empty list, just return nil instead of wasting space on a 0-length
* vector */
if (arrlen(stmts) == 0) {
free(body);
arrfree(stmts);
return nil;
}
body->stmts = stmts;
return body;
}
static Ast *
stmt(ParserState *ps, LexToken token)
{
switch (token.id) {
case T_IDENT:
return funccall_or_assignment(ps, token.ident);
case T_CONST:
case T_LET:
case T_VAR:
return variable_decl(ps, token.id);
case T_PROC:
return proc_decl(ps);
case T_HASH:
return decorated_decl(ps);
case T_RETURN:
return return_stmt(ps);
case T_BREAK:
return break_stmt(ps);
case T_DISCARD:
return discard_stmt(ps);
case T_IF:
return if_stmt_expr(ps);
case T_ELIF:
parse_error(ps, "stray 'elif'");
return nil;
case T_WHILE:
return while_stmt(ps);
case T_ELSE:
parse_error(ps, "'else' with no accompanying 'if'");
return nil;
case T_END:
parse_error(ps, "stray 'end' keyword");
return nil;
case T_EOF:
parse_error(ps, "unexpected EOF while parsing a statement");
return nil;
default:
parse_error(ps, "invalid statement '%s'", TokenIdStr[token.id]);
exit(1);
}
return nil;
}
/* Parse statements until EOF. */
static Ast *
stmt_list(ParserState *ps)
{
Ast *tree = make_tree(AST_STMTS, ps->lexer->cur_loc);
for (;;) {
const LexToken next = lex_scan(ps->lexer);
if (next.id == T_EOF)
break;
arrput(tree->stmts, stmt(ps, next));
}
return tree;
}
ParserState *
parse_new(Compiler *cm, LexState *ls)
{
ParserState *ps = calloc(1, sizeof(*ps));
ps->cm = cm;
ps->lexer = ls;
ps->ok = true;
return ps;
}
void
parse_destroy(ParserState *ps)
{
free(ps);
}
Ast *
parse(ParserState *ps)
{
return stmt_list(ps);
}