jitters: jit: add JIT function

This commit is contained in:
Bruno BELANYI 2020-09-29 18:21:43 +02:00
parent 938dddc43a
commit 1a1148b2b2
6 changed files with 292 additions and 6 deletions

190
src/jit/jitter.c Normal file
View file

@ -0,0 +1,190 @@
#include "jitter.h"
#include <err.h>
#include <string.h>
#include <sys/mman.h>
#include "vector.h"
static void jit_ast_internal(const struct ast_node *ast, struct vector *buf);
static void jit_num(int num, struct vector *buf)
{
// movl $num, %eax
// 08 NUM's value
append_vector(buf, 0xb8);
for (int i = 0; i < 4; ++i)
{
const unsigned char c = num & 0xff;
num >>= 8;
append_vector(buf, c);
}
}
static void handle_division(const struct binop_node *bin_op, struct vector *buf)
{
const unsigned char handle_rhs[] = {
0x48, 0x83, 0xec, 0x04, // sub $0x4,%rsp
0x89, 0x04, 0x24, // mov %eax,(%rsp)
};
jit_ast_internal(bin_op->rhs, buf);
append_array(buf, handle_rhs, ARR_SIZE(handle_rhs));
const unsigned char handle_lhs[] = {
0x8b, 0x0c, 0x24, // mov (%rsp),%ecx
0x48, 0x83, 0xc4, 0x04, // add $0x4,%rsp
};
jit_ast_internal(bin_op->lhs, buf);
append_array(buf, handle_lhs, ARR_SIZE(handle_lhs));
const unsigned char handle_div[] = {
0x99, // cltd
0xf7, 0xf9, // idiv %ecx
};
append_array(buf, handle_div, ARR_SIZE(handle_div));
}
static void jit_binop(const struct binop_node* bin_op, struct vector *buf)
{
if (bin_op->op == DIVIDE)
{
handle_division(bin_op, buf);
return;
}
const unsigned char handle_lhs[] = {
0x48, 0x83, 0xec, 0x04, // sub $0x4,%rsp
0x89, 0x04, 0x24, // mov %eax,(%rsp)
};
jit_ast_internal(bin_op->lhs, buf);
append_array(buf, handle_lhs, ARR_SIZE(handle_lhs));
const unsigned char handle_rhs[] = {
0x8b, 0x3c, 0x24, // mov (%rsp),%edi
0x48, 0x83, 0xc4, 0x04, // add $0x4,%rsp
};
jit_ast_internal(bin_op->rhs, buf);
append_array(buf, handle_rhs, ARR_SIZE(handle_rhs));
const unsigned char handle_plus[] = {
0x01, 0xf8, // add %edi,%eax
};
const unsigned char handle_minus[] = {
0x29, 0xf8, // sub %edi,%eax
};
const unsigned char handle_times[] = {
0x0f, 0xaf, 0xc7, // imul %edi,%eax
};
switch (bin_op->op)
{
case PLUS:
append_array(buf, handle_plus, ARR_SIZE(handle_plus));
break;
case MINUS:
append_array(buf, handle_minus, ARR_SIZE(handle_minus));
break;
case TIMES:
append_array(buf, handle_times, ARR_SIZE(handle_times));
break;
case DIVIDE:
default:
/* Not handled */
break;
}
}
static void jit_unop(const struct unop_node* un_op, struct vector *buf)
{
jit_ast_internal(un_op->rhs, buf);
const unsigned char handle_negate[] = {
0x6b, 0xc0, 0xff // imul $-1,%eax,%eax
};
switch (un_op->op)
{
case NEGATE:
append_array(buf, handle_negate, ARR_SIZE(handle_negate));
break;
case IDENTITY:
default:
/* Nothing to do */
break;
}
}
static void jit_ast_internal(const struct ast_node *ast, struct vector *buf)
{
switch (ast->kind)
{
case BINOP:
jit_binop(&ast->val.bin_op, buf);
break;
case UNOP:
jit_unop(&ast->val.un_op, buf);
break;
case NUM:
jit_num(ast->val.num, buf);
break;
default:
break;
}
}
static void jit_ast(const struct ast_node *ast, struct vector *buf)
{
const unsigned char prologue[] = {
0x55, // push %rbp
0x48, 0x89, 0xe5, // mov %rsp,%rbp
};
append_array(buf, prologue, ARR_SIZE(prologue));
jit_ast_internal(ast, buf);
const unsigned char epilogue[] = {
0x5d, // pop %rbp
0xc3, // retq
};
append_array(buf, epilogue, ARR_SIZE(epilogue));
}
#define WRITE_PROT (PROT_READ | PROT_WRITE)
#define WRITE_FLAGS (MAP_ANONYMOUS | MAP_PRIVATE)
static int exec_buf(const void *buf, size_t len)
{
// Copy our JIT-ed instructions into a new buffer
void *ptr = mmap(NULL, len, WRITE_PROT, WRITE_FLAGS, 0, 0);
if (ptr == MAP_FAILED)
err(1, NULL);
memcpy(ptr, buf, len);
// Make the JIT-ed function executable
if (mprotect(ptr, len, PROT_READ | PROT_EXEC) < 0)
err(1, NULL);
// Cast our pointer into a function pointer
// Thanks to (-Wpedantic) we need to do this pragma song and dance...
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wpedantic"
int (*func)(void) = ptr;
#pragma GCC diagnostic pop
// Call the JIT-ed function
int ret = func();
// Un-map the JIT-ed function
if (munmap(ptr, len) < 0)
err(1, NULL);
return ret;
}
int jit_eval_ast(const struct ast_node *ast)
{
struct vector *buf = create_vector(1);
jit_ast(ast, buf);
int ret = exec_buf(buf->buf, buf->size);
destroy_vector(buf);
return ret;
}

8
src/jit/jitter.h Normal file
View file

@ -0,0 +1,8 @@
#ifndef JITTER_H
#define JITTER_H
#include "ast/ast.h"
int jit_eval_ast(const struct ast_node *ast);
#endif /* !JITTER_H */

5
src/jit/local.am Normal file
View file

@ -0,0 +1,5 @@
jitters_SOURCES += \
%D%/jitter.c \
%D%/jitter.h \
%D%/vector.h \
$(NULL)

74
src/jit/vector.h Normal file
View file

@ -0,0 +1,74 @@
#ifndef VECTOR_H
#define VECTOR_H
#include <err.h>
#include <stddef.h>
#include <stdlib.h>
struct vector
{
unsigned char *buf;
size_t size;
size_t capacity;
};
static inline void *xmalloc(size_t size)
{
void *ret = malloc(size);
if (ret == NULL)
err(1, NULL);
return ret;
}
static inline void *xrealloc(void *ptr, size_t size)
{
void *ret = realloc(ptr, size);
if (ret == NULL)
{
free(ptr);
err(1, NULL);
}
return ret;
}
static inline struct vector *create_vector(size_t init_cap)
{
struct vector *vec = xmalloc(sizeof(*vec));
init_cap = (init_cap ? init_cap : 1);
vec->buf = xmalloc(init_cap);
vec->size = 0;
vec->capacity = init_cap;
return vec;
}
static inline void destroy_vector(struct vector *vec)
{
free(vec->buf);
free(vec);
}
static inline void append_vector(struct vector *vec, unsigned char c)
{
if (vec->size >= vec->capacity)
vec->buf = xrealloc(vec->buf, vec->capacity *= 2);
vec->buf[vec->size++] = c;
}
#define ARR_SIZE(ARR) (sizeof((ARR)) / sizeof((ARR)[0]))
static inline void append_array(
struct vector *buf, const unsigned char *arr, size_t n)
{
for (size_t i = 0; i < n; ++i)
append_vector(buf, arr[i]);
}
#endif /* !VECTOR_H */

View file

@ -10,6 +10,7 @@
#include "ast/ast.h" #include "ast/ast.h"
#include "compile/compiler.h" #include "compile/compiler.h"
#include "eval/evaluator.h" #include "eval/evaluator.h"
#include "jit/jitter.h"
#include "parse/parse-jitters.h" #include "parse/parse-jitters.h"
#include "print/printer.h" #include "print/printer.h"
@ -21,11 +22,12 @@ static char doc[] =
static struct argp_option options[] = { static struct argp_option options[] = {
{"evaluate", 'e', 0, 0, "Evaluate input by walking the tree" }, {"evaluate", 'e', 0, 0, "Evaluate input by walking the tree", 0 },
{"compile", 'c', 0, 0, "Compile input to assembly" }, {"compile", 'c', 0, 0, "Compile input to assembly", 0 },
{"print", 'p', 0, 0, "Print parsed expression" }, {"jit", 'j', 0, 0, "JIT-compile input and evaluate", 0 },
{"output", 'o', "FILE", 0, "Output to FILE instead of standard output" }, {"print", 'p', 0, 0, "Print parsed expression", 0 },
{"debug", 'd', 0, 0, "Emit debug output from parser" }, {"output", 'o', "FILE", 0, "Output to FILE instead of standard output", 0 },
{"debug", 'd', 0, 0, "Emit debug output from parser", 0 },
{ 0 } { 0 }
}; };
@ -33,6 +35,7 @@ struct arguments
{ {
bool debug; // Whether to activate Bison using debug trace bool debug; // Whether to activate Bison using debug trace
bool compile; // Whether to compile the input bool compile; // Whether to compile the input
bool jit; // Whether to JIT the input
bool evaluate; // Whether to evaluate the input bool evaluate; // Whether to evaluate the input
bool print; // Whether to print the input bool print; // Whether to print the input
const char *output_file; // Where to output const char *output_file; // Where to output
@ -53,6 +56,9 @@ static error_t parse_opt(int key, char *arg, struct argp_state *state)
case 'e': case 'e':
arguments->evaluate = true; arguments->evaluate = true;
break; break;
case 'j':
arguments->jit = true;
break;
case 'p': case 'p':
arguments->print = true; arguments->print = true;
break; break;
@ -70,7 +76,7 @@ static error_t parse_opt(int key, char *arg, struct argp_state *state)
return 0; return 0;
} }
static struct argp argp = { options, parse_opt, 0, doc }; static struct argp argp = { options, parse_opt, 0, doc, 0, 0, 0, };
int main(int argc, char *argv[]) int main(int argc, char *argv[])
{ {
@ -91,6 +97,8 @@ int main(int argc, char *argv[])
if ((ret = yyparse(&ast)) == 0) { if ((ret = yyparse(&ast)) == 0) {
if (arguments.compile) if (arguments.compile)
compile_ast(ast, output); compile_ast(ast, output);
if (arguments.jit)
fprintf(output, "%d\n", jit_eval_ast(ast));
if (arguments.evaluate) if (arguments.evaluate)
fprintf(output, "%d\n", evaluate_ast(ast)); fprintf(output, "%d\n", evaluate_ast(ast));
if (arguments.print) if (arguments.print)

View file

@ -17,5 +17,6 @@ jitters_LDADD =
include %D%/ast/local.am include %D%/ast/local.am
include %D%/compile/local.am include %D%/compile/local.am
include %D%/eval/local.am include %D%/eval/local.am
include %D%/jit/local.am
include %D%/parse/local.am include %D%/parse/local.am
include %D%/print/local.am include %D%/print/local.am