diff --git a/src/jit/jitter.c b/src/jit/jitter.c new file mode 100644 index 0000000..d623c62 --- /dev/null +++ b/src/jit/jitter.c @@ -0,0 +1,190 @@ +#include "jitter.h" + +#include +#include +#include + +#include "vector.h" + +static void jit_ast_internal(const struct ast_node *ast, struct vector *buf); + +static void jit_num(int num, struct vector *buf) +{ + // movl $num, %eax + // 08 NUM's value + append_vector(buf, 0xb8); + for (int i = 0; i < 4; ++i) + { + const unsigned char c = num & 0xff; + num >>= 8; + append_vector(buf, c); + } +} + +static void handle_division(const struct binop_node *bin_op, struct vector *buf) +{ + const unsigned char handle_rhs[] = { + 0x48, 0x83, 0xec, 0x04, // sub $0x4,%rsp + 0x89, 0x04, 0x24, // mov %eax,(%rsp) + }; + jit_ast_internal(bin_op->rhs, buf); + append_array(buf, handle_rhs, ARR_SIZE(handle_rhs)); + + const unsigned char handle_lhs[] = { + 0x8b, 0x0c, 0x24, // mov (%rsp),%ecx + 0x48, 0x83, 0xc4, 0x04, // add $0x4,%rsp + }; + jit_ast_internal(bin_op->lhs, buf); + append_array(buf, handle_lhs, ARR_SIZE(handle_lhs)); + + const unsigned char handle_div[] = { + 0x99, // cltd + 0xf7, 0xf9, // idiv %ecx + }; + append_array(buf, handle_div, ARR_SIZE(handle_div)); +} + +static void jit_binop(const struct binop_node* bin_op, struct vector *buf) +{ + + if (bin_op->op == DIVIDE) + { + handle_division(bin_op, buf); + return; + } + + const unsigned char handle_lhs[] = { + 0x48, 0x83, 0xec, 0x04, // sub $0x4,%rsp + 0x89, 0x04, 0x24, // mov %eax,(%rsp) + }; + jit_ast_internal(bin_op->lhs, buf); + append_array(buf, handle_lhs, ARR_SIZE(handle_lhs)); + + const unsigned char handle_rhs[] = { + 0x8b, 0x3c, 0x24, // mov (%rsp),%edi + 0x48, 0x83, 0xc4, 0x04, // add $0x4,%rsp + }; + jit_ast_internal(bin_op->rhs, buf); + append_array(buf, handle_rhs, ARR_SIZE(handle_rhs)); + + const unsigned char handle_plus[] = { + 0x01, 0xf8, // add %edi,%eax + }; + const unsigned char handle_minus[] = { + 0x29, 0xf8, // sub %edi,%eax + }; + const unsigned char handle_times[] = { + 0x0f, 0xaf, 0xc7, // imul %edi,%eax + }; + switch (bin_op->op) + { + case PLUS: + append_array(buf, handle_plus, ARR_SIZE(handle_plus)); + break; + case MINUS: + append_array(buf, handle_minus, ARR_SIZE(handle_minus)); + break; + case TIMES: + append_array(buf, handle_times, ARR_SIZE(handle_times)); + break; + case DIVIDE: + default: + /* Not handled */ + break; + } +} + +static void jit_unop(const struct unop_node* un_op, struct vector *buf) +{ + jit_ast_internal(un_op->rhs, buf); + + const unsigned char handle_negate[] = { + 0x6b, 0xc0, 0xff // imul $-1,%eax,%eax + }; + switch (un_op->op) + { + case NEGATE: + append_array(buf, handle_negate, ARR_SIZE(handle_negate)); + break; + case IDENTITY: + default: + /* Nothing to do */ + break; + } +} + +static void jit_ast_internal(const struct ast_node *ast, struct vector *buf) +{ + switch (ast->kind) + { + case BINOP: + jit_binop(&ast->val.bin_op, buf); + break; + case UNOP: + jit_unop(&ast->val.un_op, buf); + break; + case NUM: + jit_num(ast->val.num, buf); + break; + default: + break; + } +} + +static void jit_ast(const struct ast_node *ast, struct vector *buf) +{ + const unsigned char prologue[] = { + 0x55, // push %rbp + 0x48, 0x89, 0xe5, // mov %rsp,%rbp + }; + append_array(buf, prologue, ARR_SIZE(prologue)); + + jit_ast_internal(ast, buf); + + const unsigned char epilogue[] = { + 0x5d, // pop %rbp + 0xc3, // retq + }; + append_array(buf, epilogue, ARR_SIZE(epilogue)); +} + +#define WRITE_PROT (PROT_READ | PROT_WRITE) +#define WRITE_FLAGS (MAP_ANONYMOUS | MAP_PRIVATE) +static int exec_buf(const void *buf, size_t len) +{ + // Copy our JIT-ed instructions into a new buffer + void *ptr = mmap(NULL, len, WRITE_PROT, WRITE_FLAGS, 0, 0); + if (ptr == MAP_FAILED) + err(1, NULL); + memcpy(ptr, buf, len); + + // Make the JIT-ed function executable + if (mprotect(ptr, len, PROT_READ | PROT_EXEC) < 0) + err(1, NULL); + + // Cast our pointer into a function pointer + // Thanks to (-Wpedantic) we need to do this pragma song and dance... +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wpedantic" + int (*func)(void) = ptr; +#pragma GCC diagnostic pop + // Call the JIT-ed function + int ret = func(); + + // Un-map the JIT-ed function + if (munmap(ptr, len) < 0) + err(1, NULL); + + return ret; +} + +int jit_eval_ast(const struct ast_node *ast) +{ + struct vector *buf = create_vector(1); + jit_ast(ast, buf); + + int ret = exec_buf(buf->buf, buf->size); + + destroy_vector(buf); + return ret; +} diff --git a/src/jit/jitter.h b/src/jit/jitter.h new file mode 100644 index 0000000..41477a7 --- /dev/null +++ b/src/jit/jitter.h @@ -0,0 +1,8 @@ +#ifndef JITTER_H +#define JITTER_H + +#include "ast/ast.h" + +int jit_eval_ast(const struct ast_node *ast); + +#endif /* !JITTER_H */ diff --git a/src/jit/local.am b/src/jit/local.am new file mode 100644 index 0000000..6d75bbb --- /dev/null +++ b/src/jit/local.am @@ -0,0 +1,5 @@ +jitters_SOURCES += \ + %D%/jitter.c \ + %D%/jitter.h \ + %D%/vector.h \ + $(NULL) diff --git a/src/jit/vector.h b/src/jit/vector.h new file mode 100644 index 0000000..a3e22cf --- /dev/null +++ b/src/jit/vector.h @@ -0,0 +1,74 @@ +#ifndef VECTOR_H +#define VECTOR_H + +#include +#include +#include + +struct vector +{ + unsigned char *buf; + size_t size; + size_t capacity; +}; + +static inline void *xmalloc(size_t size) +{ + void *ret = malloc(size); + + if (ret == NULL) + err(1, NULL); + + return ret; +} + +static inline void *xrealloc(void *ptr, size_t size) +{ + void *ret = realloc(ptr, size); + + if (ret == NULL) + { + free(ptr); + err(1, NULL); + } + + return ret; +} + +static inline struct vector *create_vector(size_t init_cap) +{ + struct vector *vec = xmalloc(sizeof(*vec)); + + init_cap = (init_cap ? init_cap : 1); + + vec->buf = xmalloc(init_cap); + vec->size = 0; + vec->capacity = init_cap; + + return vec; +} + +static inline void destroy_vector(struct vector *vec) +{ + free(vec->buf); + free(vec); +} + +static inline void append_vector(struct vector *vec, unsigned char c) +{ + if (vec->size >= vec->capacity) + vec->buf = xrealloc(vec->buf, vec->capacity *= 2); + + vec->buf[vec->size++] = c; +} + +#define ARR_SIZE(ARR) (sizeof((ARR)) / sizeof((ARR)[0])) + +static inline void append_array( + struct vector *buf, const unsigned char *arr, size_t n) +{ + for (size_t i = 0; i < n; ++i) + append_vector(buf, arr[i]); +} + +#endif /* !VECTOR_H */ diff --git a/src/jitters.c b/src/jitters.c index 3ef248e..12a5e9e 100644 --- a/src/jitters.c +++ b/src/jitters.c @@ -10,6 +10,7 @@ #include "ast/ast.h" #include "compile/compiler.h" #include "eval/evaluator.h" +#include "jit/jitter.h" #include "parse/parse-jitters.h" #include "print/printer.h" @@ -21,11 +22,12 @@ static char doc[] = static struct argp_option options[] = { - {"evaluate", 'e', 0, 0, "Evaluate input by walking the tree" }, - {"compile", 'c', 0, 0, "Compile input to assembly" }, - {"print", 'p', 0, 0, "Print parsed expression" }, - {"output", 'o', "FILE", 0, "Output to FILE instead of standard output" }, - {"debug", 'd', 0, 0, "Emit debug output from parser" }, + {"evaluate", 'e', 0, 0, "Evaluate input by walking the tree", 0 }, + {"compile", 'c', 0, 0, "Compile input to assembly", 0 }, + {"jit", 'j', 0, 0, "JIT-compile input and evaluate", 0 }, + {"print", 'p', 0, 0, "Print parsed expression", 0 }, + {"output", 'o', "FILE", 0, "Output to FILE instead of standard output", 0 }, + {"debug", 'd', 0, 0, "Emit debug output from parser", 0 }, { 0 } }; @@ -33,6 +35,7 @@ struct arguments { bool debug; // Whether to activate Bison using debug trace bool compile; // Whether to compile the input + bool jit; // Whether to JIT the input bool evaluate; // Whether to evaluate the input bool print; // Whether to print the input const char *output_file; // Where to output @@ -53,6 +56,9 @@ static error_t parse_opt(int key, char *arg, struct argp_state *state) case 'e': arguments->evaluate = true; break; + case 'j': + arguments->jit = true; + break; case 'p': arguments->print = true; break; @@ -70,7 +76,7 @@ static error_t parse_opt(int key, char *arg, struct argp_state *state) return 0; } -static struct argp argp = { options, parse_opt, 0, doc }; +static struct argp argp = { options, parse_opt, 0, doc, 0, 0, 0, }; int main(int argc, char *argv[]) { @@ -91,6 +97,8 @@ int main(int argc, char *argv[]) if ((ret = yyparse(&ast)) == 0) { if (arguments.compile) compile_ast(ast, output); + if (arguments.jit) + fprintf(output, "%d\n", jit_eval_ast(ast)); if (arguments.evaluate) fprintf(output, "%d\n", evaluate_ast(ast)); if (arguments.print) diff --git a/src/local.am b/src/local.am index 51e8d61..6248cf8 100644 --- a/src/local.am +++ b/src/local.am @@ -17,5 +17,6 @@ jitters_LDADD = include %D%/ast/local.am include %D%/compile/local.am include %D%/eval/local.am +include %D%/jit/local.am include %D%/parse/local.am include %D%/print/local.am