/* Recursive descent parser + Pratt parser (for expressions) * TODO: * - DRY code that handle list of tokens, I have like three almost identical functions for that. * - Use an arena for the AST nodes. Nuke all of them with a single call * when we no longer need the AST. */ #include #include "ast.h" #include "pre.h" #include "parse.h" #include "lex.h" #include "state.h" #include "messages.h" #include "libs/stb_ds.h" #define MAX_STMTS_IN_BLOCK 2000 #define MAX_PROC_ARG_COUNT 127 #define EXPR_INIT_PREC 1 /* Consume a token and match it */ #define next_match(lexer, tokt) \ do { LexToken t = lex_scan(lexer); lex_match(lexer, &t, tokt); } while (0) /* Scans a token (mutating `t`), and if its id matches `ttype`, * it executes the code block. Otherwise, the scanned token * gets put back (so a next call to `lex_scan` can pick it up). */ #define matchopt(t, ttype, ps) \ if ((t = lex_scan(ps->lexer)).id != ttype) { \ lex_backup((ps)->lexer, t); \ } else #define token_is_binop(t) (t >= T_PLUS && t <= T_NOTEQUAL) #define token_is_atom(t) (t >= T_IDENT && t <= T_DECNUMBER) #define token_is_unary(t) (t == T_MINUS || t == T_LOGNOT) #define token_is_expr_start(t) (token_is_unary(t) || token_is_atom(t)) #define parse_error(ctx, ...) \ do { error((ctx)->cm, &((ctx)->lexer->cur_loc), __VA_ARGS__); (ctx)->ok = false; } while (0) typedef Optional(AstIdentTypePair) OptAstIdentTypePair; typedef struct { int pred; bool left_assoc; /* false if right assoc... */ } OperatorPrec; /* Operator table specifying the precedence and associativeness * of each operator, used by the expression parser. * The precedence goes from lower to higher. */ const OperatorPrec OperatorTable[] = { [T_LOGOR] = {1, true}, [T_LOGAND] = {2, true}, [T_LESSTHAN] = {3, true}, [T_GREATTHAN] = {3, true}, [T_LOGICEQUAL] = {3, true}, [T_NOTEQUAL] = {3, true}, [T_PLUS] = {4, true}, [T_MINUS] = {4, true}, [T_STAR] = {5, true}, [T_BAR] = {5, true}, }; static Ast * expr(ParserState *ps, int minprec); static Ast * expr_comma_list(ParserState *ps); static Ast * stmt(ParserState *ps, LexToken token); static Ast * stmt_list_until(ParserState *ps, bool putback, const enum LexTokenId *end_markers, isize len); static Ast * make_tree(enum AstType type, Location loc) { Ast *tree = calloc(1, sizeof(Ast)); tree->type = type; tree->loc = loc; return tree; } static Ast * make_binop(enum LexTokenId op, Location loc, Ast *lhs, Ast *rhs) { Ast *tree = make_tree(AST_BINEXPR, loc); tree->bin.op = Str_from_c(TokenIdStr[op]); tree->bin.left = lhs; tree->bin.right = rhs; return tree; } static Ast * make_ident_node(Str ident, Location loc) { Ast *tree = make_tree(AST_IDENT, loc); tree->ident = ident; return tree; } static OptAstIdentTypePair ident_type_pair(ParserState *ps) { AstIdentTypePair itp = {0}; /* ident */ LexToken token = lex_scan(ps->lexer); lex_match(ps->lexer, &token, T_IDENT); itp.ident = token.ident; itp.ident_loc = token.loc; /* type */ next_match(ps->lexer, T_COLON); /* optional qualifier */ token = lex_scan(ps->lexer); if (token.id == T_VAR) { itp.kind = SymVar; } else { itp.kind = SymLet; lex_backup(ps->lexer, token); } token = lex_scan(ps->lexer); if (token.id != T_IDENT) { parse_error(ps, "expected a type, got %s instead", TokenIdStr[token.id]); return None(OptAstIdentTypePair); } itp.dtype_loc = token.loc; itp.dtype = token.ident; return Some(OptAstIdentTypePair, itp); } static Vec(AstIdentTypePair) proc_arglist(ParserState *ps) { Vec(AstIdentTypePair) args = nil; LexToken next; for (;;) { OptAstIdentTypePair oitp = ident_type_pair(ps); if (!oitp.ok) return nil; if (arrlen(args) + 1 > MAX_PROC_ARG_COUNT) { parse_error(ps, "more than %d (implementation limit) proc arguments", MAX_PROC_ARG_COUNT); return nil; } arrput(args, oitp.val); next = lex_scan(ps->lexer); /* do we have a comma? if not, we reached the end of the list */ if (next.id != T_COMMA) break; /* check if we have an expression next to this comma, we do this * to allow a trailling comma */ next = lex_scan(ps->lexer); if (next.id != T_IDENT) break; lex_backup(ps->lexer, next); } trace("token in arglist out: %s\n", TokenIdStr[next.id]); lex_backup(ps->lexer, next); if (arrlen(args) == 0) { arrfree(args); return nil; } return args; } static Ast * proc_decl(ParserState *ps) { LexToken proc_name = lex_scan(ps->lexer); lex_match(ps->lexer, &proc_name, T_IDENT); Ast *proc = make_tree(AST_PROCDEF, ps->lexer->cur_loc); proc->proc.name = proc_name.ident; trace("proc name: %s\n", proc->proc.name.s); LexToken token = lex_scan(ps->lexer); if (token.id == T_STAR) { proc->proc.ispublic = true; token = lex_scan(ps->lexer); } lex_match(ps->lexer, &token, T_LPAREN); token = lex_scan(ps->lexer); if (token.id != T_RPAREN) { lex_backup(ps->lexer, token); proc->proc.args = proc_arglist(ps); token = lex_scan(ps->lexer); } lex_match(ps->lexer, &token, T_RPAREN); /* return type */ token = lex_scan(ps->lexer); if (token.id == T_COLON) { token = lex_scan(ps->lexer); lex_match(ps->lexer, &token, T_IDENT); proc->proc.rettype = make_ident_node(token.ident, ps->lexer->cur_loc); } else { lex_backup(ps->lexer, token); } /* body */ proc->proc.body = stmt_list_until(ps, false, (enum LexTokenId[]){T_RBRACE}, 1); return proc; } static Ast * function_call(ParserState *ps, Str ident, bool ate_lp) { Ast *funcc = make_tree(AST_PROCCALL, ps->lexer->cur_loc); funcc->call = (AstProcCall){ .name = ident }; if (!ate_lp) next_match(ps->lexer, T_LPAREN); LexToken next = lex_scan(ps->lexer); if (token_is_expr_start(next.id)) { lex_backup(ps->lexer, next); funcc->call.args = expr_comma_list(ps); } else { lex_backup(ps->lexer, next); } next_match(ps->lexer, T_RPAREN); trace("function call to: %s\n", ident.s); return funcc; } static Ast * variable_assign(ParserState *ps, Str ident, Location loc) { Ast *tree = make_tree(AST_VARASSIGN, loc); tree->varassgn.name = ident; tree->varassgn.expr = expr(ps, EXPR_INIT_PREC); return tree; } static Ast * funccall_or_assignment(ParserState *ps, Str ident) { LexToken token; matchopt(token, T_EQUAL, ps) { return variable_assign(ps, ident, ps->lexer->cur_loc); } return function_call(ps, ident, false); } static Ast * variable_decl(ParserState *ps, enum LexTokenId decl_kind) { static const enum SymbolKind Token2SemaVarKind[] = { [T_LET] = SymLet, [T_VAR] = SymVar, [T_CONST] = SymConst, }; Assert(decl_kind == T_LET || decl_kind == T_VAR || decl_kind == T_CONST); Ast *decl = make_tree(AST_VARDECL, ps->lexer->cur_loc); LexToken token = lex_scan(ps->lexer); lex_match(ps->lexer, &token, T_IDENT); decl->var = (AstVarDecl) { .name = token.ident, .kind = Token2SemaVarKind[decl_kind], }; /* type */ matchopt(token, T_COLON, ps) { token = lex_scan(ps->lexer); if (token.id != T_IDENT) { parse_error(ps, "expected a type, got %s instead", TokenIdStr[token.id]); return nil; } decl->var.datatype = make_ident_node(token.ident, token.loc); } /* assignment expression */ matchopt(token, T_EQUAL, ps) { trace("assignment of decl here\n"); decl->var.expr = expr(ps, EXPR_INIT_PREC); } trace( "var decl %s %s: %s\n", TokenIdStr[decl_kind], decl->var.name.s, decl->var.datatype != nil ? (char *)decl->var.datatype->ident.s : "(no type)" ); /* if there's no type there must be an expr */ /* TODO: move to semantic analysis phase? */ if (decl->var.datatype == nil && decl->var.expr == nil) { parse_error( ps, "'%s' declaration must have an assignment expression if no type is specified, " "but neither a type nor expression was supplied", TokenIdStr[decl_kind] ); return nil; } return decl; } static Ast * return_stmt(ParserState *ps) { Ast *tree = make_tree(AST_RETURN, ps->lexer->cur_loc); LexToken next = lex_scan(ps->lexer); if (token_is_expr_start(next.id)) { lex_backup(ps->lexer, next); tree->ret = expr(ps, EXPR_INIT_PREC); } else { lex_backup(ps->lexer, next); } return tree; } static Ast * break_stmt(ParserState *ps) { return make_tree(AST_BREAK, ps->lexer->cur_loc); } static Ast * discard_stmt(ParserState *ps) { Ast *tree = make_tree(AST_DISCARD, ps->lexer->cur_loc); tree->discard.expr = expr(ps, EXPR_INIT_PREC); return tree; } static Ast * parse_pragma(ParserState *ps) { Ast *tree = make_tree(AST_PRAGMA, ps->lexer->cur_loc); LexToken next = lex_scan(ps->lexer); lex_match(ps->lexer, &next, T_LBRACKET); next = lex_scan(ps->lexer); lex_match(ps->lexer, &next, T_RBRACKET); return tree; } /* A declaration "decorated" with a pragma */ static Ast * decorated_decl(ParserState *ps) { Ast *attr = parse_pragma(ps); LexToken next = lex_scan(ps->lexer); switch (next.id) { case T_PROC: attr->pragma.node = proc_decl(ps); break; case T_CONST: case T_LET: case T_VAR: attr->pragma.node = variable_decl(ps, next.id); break; default: parse_error(ps, "node of kind '%s' cannot have a pragma", TokenIdStr[next.id]); return nil; } return attr; } static Ast * if_stmt_expr(ParserState *ps) { const enum LexTokenId if_block_ends[] = {T_ELSE, T_ELIF, T_RBRACE}; Ast *tree = make_tree(AST_IF, ps->lexer->cur_loc); /* parse `if` */ tree->ifse.cond = expr(ps, EXPR_INIT_PREC); tree->ifse.true_body = stmt_list_until(ps, true, if_block_ends, countof(if_block_ends)); tree->ifse.false_body = nil; LexToken next = lex_scan(ps->lexer); AstElif elif_tree; /* parse `elif`s and else */ for (;;) { switch (next.id) { case T_END: /* only has true branch */ return tree; case T_ELSE: /* once we see an `else` block, we assume the end of the `if` block, * enforcing that `else` must be the last. */ trace("we got else\n"); tree->ifse.false_body = stmt_list_until(ps, true, (enum LexTokenId[]){T_ELIF, T_RBRACE}, 2); next = lex_scan(ps->lexer); if (next.id == T_ELIF) { parse_error(ps, "'elif' branch after 'else' branch not allowed"); lex_backup(ps->lexer, next); return nil; } return tree; case T_ELIF: trace("we got elif\n"); elif_tree.cond = expr(ps, EXPR_INIT_PREC); elif_tree.body = stmt_list_until(ps, true, if_block_ends, countof(if_block_ends)); next = lex_scan(ps->lexer); arrput(tree->ifse.elifs, elif_tree); /* no more `elif` blocks neither an `else` block next */ if (next.id == T_END) return tree; Assert(next.id == T_ELSE || next.id == T_ELIF); break; default: lex_backup(ps->lexer, next); parse_error(ps, "expected 'elif' or 'else', got '%s'", TokenIdStr[next.id]); return nil; } } return tree; } static Ast * while_stmt(ParserState *ps) { Ast *tree = make_tree(AST_LOOP, ps->lexer->cur_loc); tree->loop.precond = expr(ps, EXPR_INIT_PREC); tree->loop.body = stmt_list_until(ps, false, (enum LexTokenId[]){T_RBRACE}, 1); return tree; } static Ast * atom(ParserState *ps) { Ast *tree = nil; LexToken t = lex_scan(ps->lexer); LexToken next; switch (t.id) { case T_NUMBER: tree = make_tree(AST_NUMBER, ps->lexer->cur_loc); tree->number.n = t.inumber; trace("number in atom: %lu\n", t.inumber); return tree; case T_STRING: tree = make_tree(AST_STRLIT, ps->lexer->cur_loc); tree->strlit = t.str; return tree; case T_IDENT: next = lex_scan(ps->lexer); /* It is a plain symbol or a function call? */ if (next.id == T_LPAREN) { tree = function_call(ps, t.ident, true); } else { lex_backup(ps->lexer, next); tree = make_tree(AST_IDENT, ps->lexer->cur_loc); tree->ident = t.ident; } return tree; case T_LPAREN: tree = expr(ps, EXPR_INIT_PREC); next_match(ps->lexer, T_RPAREN); return tree; default: parse_error(ps, "expected a number, identifier or expression, not '%s'", TokenIdStr[t.id]); } return nil; } static Ast * unary(ParserState *ps) { LexToken next = lex_scan(ps->lexer); if (token_is_unary(next.id)) { Ast *unt = make_tree(AST_UNARY, ps->lexer->cur_loc); unt->unary.op = Str_from_c(TokenIdStr[next.id]); unt->unary.atom = atom(ps); return unt; } lex_backup(ps->lexer, next); return atom(ps); } /* Parse a binary expression or an atom. This implements the Pratt parser algorithm. * See also: * - https://eli.thegreenplace.net/2012/08/02/parsing-expressions-by-precedence-climbing * - https://www.oilshell.org/blog/2016/11/01.html * XXX: Mutate to the shunting yard variation? Since it uses an explicit stack instead of the call * stack, guard against deeply nested expressions. */ static Ast * expr(ParserState *ps, int minprec) { Ast *tree = unary(ps); for (;;) { LexToken t = lex_scan(ps->lexer); if (!token_is_binop(t.id) || t.id == T_END || t.id == T_RPAREN || OperatorTable[t.id].pred < minprec) { lex_backup(ps->lexer, t); break; } const OperatorPrec op = OperatorTable[t.id]; const int next_prec = op.left_assoc ? op.pred + 1 : op.pred; tree = make_binop(t.id, ps->lexer->cur_loc, tree, expr(ps, next_prec)); } return tree; } static Vec(Ast *) sep_list(ParserState *ps, Ast *(*prod_fn)(Compiler *, void *)) { (void)ps, (void)prod_fn; Vec(Ast *) prod = nil; return prod; } static Ast * expr_comma_list(ParserState *ps) { Ast *tree = make_tree(AST_EXPRS, ps->lexer->cur_loc); Vec(Ast *) exprs = nil; LexToken next; for (;;) { arrput(exprs, expr(ps, EXPR_INIT_PREC)); next = lex_scan(ps->lexer); trace("commalist tok: %s\n", TokenIdStr[next.id]); /* do we have a comma? if not, we reached the end of the list */ if (next.id != T_COMMA) break; next = lex_scan(ps->lexer); /* check if we have an expression next to this comma, we do this * to allow a trailling comma */ if (!token_is_expr_start(next.id)) break; lex_backup(ps->lexer, next); } lex_backup(ps->lexer, next); if (arrlen(exprs) == 0) { free(tree); arrfree(exprs); return nil; } tree->exprs = exprs; return tree; } static bool token_id_in_list(enum LexTokenId c, const enum LexTokenId *toks, isize len) { for (isize i = 0; i < len; ++i) if (c == toks[i]) return true; return false; } /* Parses a statement list until the token `end_marker`. Returns `nil` if the statement list * is empty. */ static Ast * stmt_list_until(ParserState *ps, bool putback, const enum LexTokenId *end_markers, isize len) { next_match(ps->lexer, T_LBRACE); LexToken token = lex_scan(ps->lexer); Vec(Ast *) stmts = nil; Ast *body = make_tree(AST_STMTS, ps->lexer->cur_loc); /* stmt* */ while (!token_id_in_list(token.id, end_markers, len)) { trace("stmt list token: %s\n", TokenIdStr[token.id]); if (arrlen(stmts) + 1 > MAX_STMTS_IN_BLOCK) { parse_error(ps, "more than %d (implementation limit) statements in block", MAX_STMTS_IN_BLOCK); return nil; } arrput(stmts, stmt(ps, token)); token = lex_scan(ps->lexer); if (token.id == T_EOF) { parse_error(ps, "unexpected EOF, expected a statement or `end`"); break; } if (token.id == T_SEMICOLON) token = lex_scan(ps->lexer); } //lex_match(ps->lexer, &token, end_marker); trace("token before end next_match: %s\n", TokenIdStr[token.id]); if (putback) lex_backup(ps->lexer, token); /* empty list, just return nil instead of wasting space on a 0-length * vector */ if (arrlen(stmts) == 0) { free(body); arrfree(stmts); return nil; } body->stmts = stmts; return body; } static Ast * stmt(ParserState *ps, LexToken token) { switch (token.id) { case T_IDENT: return funccall_or_assignment(ps, token.ident); case T_CONST: case T_LET: case T_VAR: return variable_decl(ps, token.id); case T_PROC: return proc_decl(ps); case T_HASH: return decorated_decl(ps); case T_RETURN: return return_stmt(ps); case T_BREAK: return break_stmt(ps); case T_DISCARD: return discard_stmt(ps); case T_IF: return if_stmt_expr(ps); case T_ELIF: parse_error(ps, "stray 'elif'"); return nil; case T_WHILE: return while_stmt(ps); case T_ELSE: parse_error(ps, "'else' with no accompanying 'if'"); return nil; case T_END: parse_error(ps, "stray 'end' keyword"); return nil; case T_EOF: parse_error(ps, "unexpected EOF while parsing a statement"); return nil; default: parse_error(ps, "invalid statement '%s'", TokenIdStr[token.id]); exit(1); } return nil; } /* Parse statements until EOF. */ static Ast * stmt_list(ParserState *ps) { Ast *tree = make_tree(AST_STMTS, ps->lexer->cur_loc); for (;;) { const LexToken next = lex_scan(ps->lexer); if (next.id == T_EOF) break; arrput(tree->stmts, stmt(ps, next)); } return tree; } ParserState * parse_new(Compiler *cm, LexState *ls) { ParserState *ps = calloc(1, sizeof(*ps)); ps->cm = cm; ps->lexer = ls; ps->ok = true; return ps; } void parse_destroy(ParserState *ps) { free(ps); } Ast * parse(ParserState *ps) { return stmt_list(ps); }