From 1d943da0cf9154e7ce78ce867cdbb91531c5d78e Mon Sep 17 00:00:00 2001 From: Matthew Sotoudeh Date: Tue, 25 Jul 2023 14:58:33 -0700 Subject: initial dietc commit --- codegen.c | 771 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 771 insertions(+) create mode 100644 codegen.c (limited to 'codegen.c') diff --git a/codegen.c b/codegen.c new file mode 100644 index 0000000..9eaac4a --- /dev/null +++ b/codegen.c @@ -0,0 +1,771 @@ +#include "chibicc.h" +#include + +#define GP_MAX 6 +#define FP_MAX 8 + +int is_bad_number(double x) { + return (x != x) || (x == 1./0.) || (-x == 1./0.); +} + +static FILE *output_file; +static int RETURN_TMP; + +static void printnoln(char *fmt, ...); +static void println(char *fmt, ...); +static void print_tok(Token *tok); + +FILE *BUFFER; +int DO_BUFFER; + +static void flush_buffer() { + rewind(BUFFER); + int c; + while ((c = fgetc(BUFFER)) != EOF) + fputc(c, output_file); + rewind(BUFFER); + ftruncate(fileno(BUFFER), 0); +} + +void print_label(char *label) { + printnoln("_l"); + for (; *label; label++) { + if (*label == '.') printnoln("_"); + else printnoln("%c", *label); + } +} + +static void print_tok(Token *tok) { + if (tok->str) printnoln("%s", tok->str); + else { + assert(tok->loc); + for (int i = 0; i < tok->len; i++) + printnoln("%c", tok->loc[i]); + } +} + +void print_obj(Obj *obj) { + assert(obj->offset >= 0); + if (obj->is_local || strchr(obj->name, '.')) { + printnoln("_"); + for (char *c = obj->name; *c; c++) + if (*c == '.') printnoln("_"); + else printnoln("%c", *c); + printnoln("_%d", obj->offset); + } else { + printnoln("%s", obj->name); + } +} + +const void print_type(Type *type) { + if (type->kind == TY_FUNC) { + assert(type->pointer_type); + print_type(type->pointer_type); + } else { + assert(type->id); + printnoln("Type_%d", type->id); + } +} +static int depth; +static Obj *current_fn; + +static void gen_expr(Node *node, int to_tmp); +static void gen_stmt(Node *node); + +__attribute__((format(printf, 1, 2))) +static void println(char *fmt, ...) { + va_list ap; + va_start(ap, fmt); + if (DO_BUFFER) + vfprintf(BUFFER, fmt, ap); + else + vfprintf(output_file, fmt, ap); + va_end(ap); + if (DO_BUFFER) + fprintf(BUFFER, "\n"); + else + fprintf(output_file, "\n"); +} + +__attribute__((format(printf, 1, 2))) +static void printnoln(char *fmt, ...) { + va_list ap; + va_start(ap, fmt); + if (DO_BUFFER) + vfprintf(BUFFER, fmt, ap); + else + vfprintf(output_file, fmt, ap); + va_end(ap); +} + +static int count(void) { + static int i = 1; + return i++; +} + +// Round up `n` to the nearest multiple of `align`. For instance, +// align_to(5, 8) returns 8 and align_to(11, 8) returns 16. +int align_to(int n, int align) { + return (n + align - 1) / align * align; +} + +void decltmp(Type *type, int c) { + DO_BUFFER = 0; + printnoln("\t"); + print_type(type); + println(" t%d ;", c); + DO_BUFFER = 1; +} + +void decltmpptr(Type *type, int c) { + assert(type->pointer_type); + decltmp(type->pointer_type, c); +} + +// Compute the absolute address of a given node. +// It's an error if a given node does not reside in memory. +static void gen_addr(Node *node, int to_tmp) { + switch (node->kind) { + case ND_VAR: + decltmpptr(node->ty, to_tmp); + printnoln("\tt%d = & ", to_tmp); + print_obj(node->var); + println(" ;"); + return; + case ND_DEREF: + gen_expr(node->lhs, to_tmp); + return; + case ND_COMMA: + gen_expr(node->lhs, count()); + gen_addr(node->rhs, to_tmp); + return; + case ND_MEMBER: { + int s = count(); + gen_addr(node->lhs, s); + decltmpptr(node->ty, to_tmp); + printnoln("\tt%d = FIELDPTR ( t%d, ", to_tmp, s); + if (node->member->name) + print_tok(node->member->name); + else + printnoln("___dietc_f%d", node->member->idx); + println(" ) ;"); + return; + } + case ND_ASSIGN: + case ND_COND: + if (node->ty->kind == TY_STRUCT || node->ty->kind == TY_UNION) { + gen_expr(node, to_tmp); + return; + } + break; + case ND_VLA_PTR: assert(0); + } + + error_tok(node->tok, "not an lvalue"); +} + +// Generate code for a given node. +static void gen_expr(Node *node, int to_tmp) { + // println(" .loc %d %d", node->tok->file->file_no, node->tok->line_no); + + switch (node->kind) { + case ND_NULL_EXPR: + return; + case ND_NUM: { + decltmp(node->ty, to_tmp); + switch (node->ty->kind) { + case TY_FLOAT: + case TY_DOUBLE: + case TY_LDOUBLE: + println("\tt%d = %Lf ;", to_tmp, node->fval); + return; + default: + println("\tt%d = %ld ;", to_tmp, node->val); + return; + } + } + case ND_NEG: { + int c = count(); + gen_expr(node->lhs, c); + decltmp(node->ty, to_tmp); + println("\tt%d = - t%d ;", to_tmp, c); + return; + } + case ND_VAR: { + // we inline the *getaddr(node) + if (node->ty->kind == TY_ARRAY) { + decltmpptr(node->ty->base, to_tmp); + printnoln("\tt%d = ", to_tmp); + print_obj(node->var); + println(" ;"); + } else { + decltmp(node->ty, to_tmp); + printnoln("\tt%d = ", to_tmp); + print_obj(node->var); + println(" ;"); + } + return; + } + case ND_MEMBER: { + int c = count(); + gen_addr(node, c); + // TODO: what to do in this case? + if (node->ty->kind == TY_ARRAY) { + decltmpptr(node->ty->base, to_tmp); + println("\tt%d = * t%d ;", to_tmp, c); + } else { + decltmp(node->ty, to_tmp); + println("\tt%d = * t%d ;", to_tmp, c); + } + + if (node->member->is_bitfield) + fprintf(stderr, "WARNING: bitfields ignored"); + return; + } + case ND_DEREF: { + int c = count(); + gen_expr(node->lhs, c); + // TODO: deref a pointer to an array; is the temporary a pointer, or an + // array?? + if (node->ty->kind == TY_ARRAY) { + decltmpptr(node->ty->base, to_tmp); + println("\tt%d = * t%d ;", to_tmp, c); + } else { + decltmp(node->ty, to_tmp); + println("\tt%d = * t%d ;", to_tmp, c); + } + return; + } + case ND_ADDR: + gen_addr(node->lhs, to_tmp); + return; + case ND_ASSIGN: { + int lhsa = count(); + gen_addr(node->lhs, lhsa); + gen_expr(node->rhs, to_tmp); + + if (node->lhs->kind == ND_MEMBER && node->lhs->member->is_bitfield) + fprintf(stderr, "WARNING: bitfields ignored"); + + println("\t* t%d = t%d ;", lhsa, to_tmp); + return; + } + case ND_STMT_EXPR: + for (Node *n = node->body; n; n = n->next) { + if (n->next || n->kind != ND_EXPR_STMT) + gen_stmt(n); + else + gen_expr(n->lhs, to_tmp); + } + return; + case ND_COMMA: + gen_expr(node->lhs, count()); + gen_expr(node->rhs, to_tmp); + return; + case ND_CAST: { + if (definitely_same_type(node->lhs->ty, node->ty)) + return gen_expr(node->lhs, to_tmp); + int c = count(); + if (node->ty->kind == TY_VOID) { + gen_expr(node->lhs, c); + println("\t( Type_%d ) t%d ;", node->ty->id, c); + } else if (node->lhs->ty->kind == TY_UNION) { + // union *tuptr = &union + int uptr = count(); + gen_addr(node->lhs, uptr); + // T * tc = ( T * ) tuptr + decltmpptr(node->ty, c); + println("\tt%d = ( Type_%d ) t%d ;", + c, node->ty->pointer_type->id, uptr); + // dereference it + decltmp(node->ty, to_tmp); + println("\tt%d = * t%d ;", to_tmp, c); + } else { + gen_expr(node->lhs, c); + decltmp(node->ty, to_tmp); + println("\tt%d = ( Type_%d ) t%d ;", to_tmp, node->ty->id, c); + } + return; + } + case ND_MEMZERO: + printnoln("\tMEMZERO ( "); print_obj(node->var); println(" ) ;"); + return; + case ND_COND: { + int c = count(), cond = count(), condfalse = count(), + tmp1 = count(), tmp2 = count(); + gen_expr(node->cond, cond); + decltmp(ty_int, condfalse); + if (node->ty->kind != TY_VOID) + decltmp(node->ty, to_tmp); + println("\tt%d = ! t%d ;", condfalse, cond); + println("\tif ( t%d ) goto _L_else_%d ;", condfalse, c); + gen_expr(node->then, tmp1); + if (node->ty->kind != TY_VOID) + println("\tt%d = t%d ;", to_tmp, tmp1); + println("\tgoto _L_end_%d ;", c); + println("\t_L_else_%d :", c); + gen_expr(node->els, tmp2); + if (node->ty->kind != TY_VOID) + println("\tt%d = t%d ;", to_tmp, tmp2); + println("\t_L_end_%d :", c); + return; + } + case ND_NOT: { + int c = count(); + gen_expr(node->lhs, c); + decltmp(ty_int, to_tmp); + println("\tt%d = ! t%d ;", to_tmp, c); + return; + } + case ND_BITNOT: { + int c = count(); + gen_expr(node->lhs, c); + decltmp(node->ty, to_tmp); + println("\tt%d = ~ t%d ;", to_tmp, c); + return; + } + case ND_LOGAND: { + int c = count(), lhs = count(), lhsfalse = count(), + rhs = count(), rhsfalse = count(); + decltmp(ty_int, to_tmp); + decltmp(ty_int, lhsfalse); + decltmp(ty_int, rhsfalse); + gen_expr(node->lhs, lhs); + println("\tt%d = 0 ;", to_tmp); + println("\tt%d = ! t%d ;", lhsfalse, lhs); + println("\tif ( t%d ) goto _L_false_%d ;", lhsfalse, c); + gen_expr(node->rhs, rhs); + println("\tt%d = ! t%d ;", rhsfalse, rhs); + println("\tif ( t%d ) goto _L_false_%d ;", rhsfalse, c); + println("\tt%d = 1 ;", to_tmp); + println("\tgoto _L_end_%d;", c); + println("\t_L_false_%d :", c); + println("\t_L_end_%d :", c); + return; + } + case ND_LOGOR: { + int c = count(), lhs = count(), rhs = count(); + decltmp(ty_int, to_tmp); + println("\tt%d = 0 ;", to_tmp); + gen_expr(node->lhs, lhs); + println("\tif ( t%d ) goto _L_true_%d ;", lhs, c); + gen_expr(node->rhs, rhs); + println("\tif ( t%d ) goto _L_true_%d ;", rhs, c); + println("\tgoto _L_end_%d ;", c); + println("\t_L_true_%d :", c); + println("\tt%d = 1 ;", to_tmp); + println("\t_L_end_%d :", c); + return; + } + case ND_FUNCALL: { + int fnc = count(); + gen_expr(node->lhs, fnc); + + int n_args = 0; + for (Node *arg = node->args; arg; arg = arg->next) n_args++; + + int *args = calloc(n_args, sizeof(int)), + arg_i = 0; + for (Node *arg = node->args; arg; arg = arg->next, arg_i++) { + args[arg_i] = count(); + gen_expr(arg, args[arg_i]); + } + if (node->ty->kind != TY_VOID) { + decltmp(node->ty, to_tmp); + printnoln("\tt%d = ", to_tmp); + } else { + printnoln("\t"); + } + printnoln("t%d ( ", fnc); + for (int i = 0; i < n_args; i++) { + if (i) printnoln(", "); + printnoln("t%d ", args[i]); + } + println(") ;"); + return; + } + case ND_LABEL_VAL: assert(0); + case ND_CAS: assert(0); + case ND_EXCH: assert(0); + } + + int rhsc = count(), lhsc = count(); + gen_expr(node->rhs, rhsc); + gen_expr(node->lhs, lhsc); + + char *op = NULL; + switch (node->kind) { + case ND_ADD: op = "+"; break; + case ND_SUB: op = "-"; break; + case ND_MUL: op = "*"; break; + case ND_DIV: op = "/"; break; + case ND_MOD: op = "%"; break; + case ND_BITAND: op = "&"; break; + case ND_BITOR: op = "|"; break; + case ND_BITXOR: op = "^"; break; + case ND_EQ: op = "=="; break; + case ND_NE: op = "!="; break; + case ND_LT: op = "<"; break; + case ND_LE: op = "<="; break; + case ND_SHL: op = "<<"; break; + case ND_SHR: op = ">>"; break; + } + + decltmp(node->ty, to_tmp); + if (node->ty->base) { + println("\tt%d = PTR_BINARY ( t%d , %s , t%d ) ;", to_tmp, lhsc, op, rhsc); + } else { + println("\tt%d = BINARY ( t%d , %s , t%d ) ;", to_tmp, lhsc, op, rhsc); + } +} + +static void gen_stmt(Node *node) { + switch (node->kind) { + case ND_IF: { + int cond = count(), condfalse = count(), c = count(); + gen_expr(node->cond, cond); + decltmp(ty_int, condfalse); + println("\tt%d = ! t%d ;", condfalse, cond); + println("\tif ( t%d ) goto _L_else_%d ;", condfalse, c); + gen_stmt(node->then); + println("\tgoto _L_end_%d ;", c); + println("\t_L_else_%d :", c); + if (node->els) + gen_stmt(node->els); + println("\t_L_end_%d :", c); + return; + } + case ND_FOR: { + int c = count(), cond = count(), condfalse = count(); + if (node->init) + gen_stmt(node->init); + println("\t_L_begin_%d :", c); + if (node->cond) { + gen_expr(node->cond, cond); + decltmp(ty_int, condfalse); + println("\tt%d = ! t%d ;", condfalse, cond); + printnoln("\tif ( t%d ) goto ", condfalse); + print_label(node->brk_label); + println(" ;"); + } + gen_stmt(node->then); + printnoln("\t"); print_label(node->cont_label); println(" :"); + if (node->inc) + gen_expr(node->inc, count()); + println("\tgoto _L_begin_%d ;", c); + printnoln("\t"); print_label(node->brk_label); println(" :"); + return; + } + case ND_DO: { + int c = count(), cond = count(); + println("\t_L_begin_%d :", c); + gen_stmt(node->then); + printnoln("\t"); print_label(node->cont_label); println(" :"); + gen_expr(node->cond, cond); + println("\tif ( t%d ) goto _L_begin_%d ;", cond, c); + printnoln("\t"); print_label(node->brk_label); println(" :"); + return; + } + case ND_SWITCH: { + int cond = count(), eqtmp = count(); + gen_expr(node->cond, cond); + decltmp(ty_int, eqtmp); + + for (Node *n = node->case_next; n; n = n->case_next) { + if (n->begin == n->end) { + println("\tt%d = BINARY ( t%d , == , %ld ) ;", eqtmp, cond, n->begin); + printnoln("\tif ( t%d ) goto ", eqtmp); + print_label(n->label); println(" ;"); + continue; + } + + // [GNU] Case ranges + assert(!"unimplemented"); + } + + if (node->default_case) { + printnoln("\tgoto "); print_label(node->default_case->label); println(" ;"); + } + + printnoln("\tgoto "); print_label(node->brk_label); println(" ;"); + gen_stmt(node->then); + printnoln("\t"); print_label(node->brk_label); println(" :"); + return; + } + case ND_CASE: + printnoln("\t"); print_label(node->label); println(" :"); + gen_stmt(node->lhs); + return; + case ND_BLOCK: + for (Node *n = node->body; n; n = n->next) + gen_stmt(n); + return; + case ND_GOTO: + printnoln("\tgoto "); print_label(node->unique_label); println(" ;"); + return; + case ND_GOTO_EXPR: assert(0); return; + case ND_LABEL: + printnoln("\t"); print_label(node->unique_label); println(" :"); + gen_stmt(node->lhs); + return; + case ND_RETURN: + if (node->lhs) { + int expr = count(); + gen_expr(node->lhs, expr); + println("\tt%d = t%d ;", RETURN_TMP, expr); + } + println("\tgoto _L_RETURN ;"); + return; + case ND_EXPR_STMT: + gen_expr(node->lhs, count()); + return; + case ND_ASM: assert(0); return; + } + + error_tok(node->tok, "invalid statement"); +} + +// Assign offsets to local variables. +static void assign_lvar_offsets(Obj *prog) { + for (Obj *fn = prog; fn; fn = fn->next) { + if (!fn->is_function) + continue; + + // If a function has many parameters, some parameters are + // inevitably passed by stack rather than by register. + // The first passed-by-stack parameter resides at RBP+16. + int c = 1; + + for (Obj *var = fn->params; var; var = var->next) + var->offset = c++; + for (Obj *var = fn->locals; var; var = var->next) + if (!(var->offset)) + var->offset = c++; + } +} + +void emit_constant(int pos, char *data, Relocation **rel, Type *type) { + switch (type->kind) { + case TY_STRUCT: { + int i = 0; + printnoln("{ "); + for (struct Member *m = type->members; m; m = m->next, i++) { + if (i) printnoln(", "); + assert(!m->is_bitfield); + emit_constant(pos + m->offset, data, rel, m->ty); + } + printnoln("} "); + break; + } + case TY_UNION: { + struct Member *biggest = type->members; + for (struct Member *m = type->members; m; m = m->next) + if (m->ty->size > biggest->ty->size) biggest = m; + assert(biggest); + printnoln("{ "); + emit_constant(pos, data, rel, biggest->ty); + printnoln("} "); + break; + } + case TY_ARRAY: { + printnoln("{ "); + for (int i = 0; i < type->array_len; i++) { + if (i) printnoln(", "); + emit_constant(pos, data, rel, type->base); + pos += type->base->size; + } + printnoln("} "); + break; + } + case TY_CHAR: + assert(type->size == 1); + // we don't output ' ' so lexing can be done by splitting on whitespace + if (data[pos] > ' ' && data[pos] <= '~' && data[pos] != '\'' && data[pos] != '\\') + printnoln("'%c' ", data[pos++]); + else + printnoln("%d ", (int)data[pos++]); + break; + case TY_BOOL: + assert(type->size == 1); + printnoln("%d ", *((char*)(data+pos))); + break; + case TY_SHORT: + assert(type->size == sizeof(short)); + printnoln("%d ", *((short*)(data+pos))); + break; + case TY_INT: + assert(type->size == sizeof(int)); + printnoln("%d ", *((int*)(data+pos))); + break; + case TY_LONG: + assert(type->size == sizeof(long)); + printnoln("%ld ", *((long*)(data+pos))); + break; + case TY_FLOAT: { + assert(type->size == sizeof(float)); + float v = *((float*)(data + pos)); + if (v != v) printnoln("0./0. "); + else if (v == 1./0.) printnoln("1./0. "); + else if (-v == 1./0.) printnoln("-1./0. "); + else printnoln("%f ", v); + break; + } case TY_DOUBLE: { + assert(type->size == sizeof(double)); + double v = *((double*)(data + pos)); + if (v != v) printnoln("0./0. "); + else if (v == 1./0.) printnoln("1./0. "); + else if (-v == 1./0.) printnoln("-1./0. "); + else printnoln("%lf ", v); + break; + } case TY_LDOUBLE: { + assert(type->size == sizeof(long double)); + long double v = *((long double*)(data + pos)); + if (v != v) printnoln("0./0. "); + else if (v == 1./0.) printnoln("1./0. "); + else if (-v == 1./0.) printnoln("-1./0. "); + else printnoln("%Lf ", v); + break; + } case TY_PTR: + assert(type->size == 8); + if (*rel && (*rel)->offset == pos) { + // TODO: should addend be divided by (*rel)->ty->size? + printnoln("( void * ) ( & "); + if (strchr(*((*rel)->label), '.')) { + printnoln("_"); + for (char *c = *((*rel)->label); *c; c++) { + if (*c == '.') printnoln("_"); + else printnoln("%c", *c); + } + printnoln("_0 ) + %ld ", (*rel)->addend); + } else { + printnoln("%s ) + %ld ", *((*rel)->label), (*rel)->addend); + } + *rel = (*rel)->next; + } else { + // TODO: actually read the absolute address out + printnoln("( void * ) ( %ld ) + 0 ", *((uint64_t*)(data+pos))); + } + break; + case TY_ENUM: + switch (type->size) { + case sizeof(int): + printnoln("%d ", *((int*)(data+pos))); + break; + case sizeof(long): + printnoln("%ld ", *((long*)(data+pos))); + break; + default: assert(0); + } + break; + default: fprintf(stderr, "Got: %d\n", type->kind); assert(0); + } +} + +static void emit_data(Obj *prog) { + for (Obj *var = prog; var; var = var->next) { + if (var->is_static && var->is_function) { + printnoln("static Type_%d ", var->ty->id); + print_obj(var); println(" ;"); + } else if (var->is_static) { + printnoln("static Type_%d ", var->ty->id); + print_obj(var); println(" ;"); + } else { + printnoln("extern Type_%d ", var->ty->id); + print_obj(var); println(" ;"); + } + } + + for (Obj *var = prog; var; var = var->next) { + if (var->is_function || !var->is_definition) + continue; + + if (var->is_static) printnoln("static "); + print_type(var->ty); printnoln(" "); print_obj(var); + + // Common symbol + assert(!(opt_fcommon && var->is_tentative)); + + // .data or .tdata + if (var->init_data) { + printnoln(" = "); + Relocation *rel = var->rel; + emit_constant(0, var->init_data, &rel, var->ty); + } + println(";"); + } +} + +static void emit_text(Obj *prog) { + for (Obj *fn = prog; fn; fn = fn->next) { + if (!fn->is_function || !fn->is_definition) + continue; + + // No code is emitted for "static inline" functions + // if no one is referencing them. + if (!fn->is_live) + continue; + + current_fn = fn; + + if (fn->is_static) + printnoln("static "); + + print_type(fn->ty->return_ty); printnoln(" "); + printnoln("%s ( ", fn->name); + int p = 0; + for (Obj *param = fn->params; param; param = param->next, p++) { + if (p) printnoln(", "); + print_type(param->ty); printnoln(" "); print_obj(param); + param->is_param = 1; + printnoln(" "); + } + if (fn->va_area) { + if (p) printnoln(", ... "); + // else printnoln("... "); + } + println(") {"); + + RETURN_TMP = 0; + if (fn->ty->return_ty->kind != TY_VOID) { + RETURN_TMP = count(); + decltmp(fn->ty->return_ty, RETURN_TMP); + } + + for (Obj *var = fn->locals; var; var = var->next) { + if (!(var->ty->id)) continue; + if (var->is_param) continue; + if (var == fn->va_area) continue; + if (var == fn->alloca_bottom) continue; + printnoln("\t"); + print_type(var->ty); + printnoln(" "); + print_obj(var); + println(" ;"); + } + + // Emit code + DO_BUFFER = 1; + gen_stmt(fn->body); + assert(depth == 0); + flush_buffer(); + DO_BUFFER = 0; + + println("\t_L_RETURN :"); + if (fn->ty->return_ty->kind != TY_VOID) { + println("\treturn t%d ;", RETURN_TMP); + } else { + println("\treturn ;"); + } + + println("}"); + } +} + +void codegen(Obj *prog, FILE *out) { + output_file = out; + + BUFFER = tmpfile(); + assign_lvar_offsets(prog); + emit_data(prog); + emit_text(prog); +} -- cgit v1.2.3