#include <jit/jit.h>
#include <ruby.h>

static jit_function_t compile_add_function(jit_context_t context)
{
  /* JIT-compile a function to add two numbers */
  jit_type_t param_types[2];
  jit_type_t signature;
  jit_function_t function;
  jit_value_t arg1, arg2, result;

  param_types[0] = jit_type_int;
  param_types[1] = jit_type_int;
  signature = jit_type_create_signature(
      jit_abi_cdecl, /* C calling convention */
      jit_type_int,  /* Return type */
      param_types,   /* Parameter types */
      2,             /* Number of parameters */
      1);            /* Increment reference count */

  /* Lock the build context */
  jit_context_build_start(context);

  /* Create the function object */
  function = jit_function_create(context, signature);

  /* Construct the function body */
  arg1 = jit_value_get_param(function, 0);
  arg2 = jit_value_get_param(function, 1);
  result = jit_insn_add(function, arg1, arg2);
  jit_insn_return(function, result);

  /* Compile the function */
  jit_function_compile(function);

  /* Unlock and destroy the context */
  jit_context_build_end(context);

  return function;
}

/* Function to add two Ruby Fixnums using libjit and return the result
 */
static VALUE add(VALUE self, VALUE v1, VALUE v2)
{
  jit_context_t context;
  jit_function_t function;
  void * args[2];
  jit_int i1, i2;
  jit_int result;

  /* Compile a function */
  context = jit_context_create();
  function = compile_add_function(context);

  /* Convert our arguments */
  i1 = NUM2INT(v1);
  i2 = NUM2INT(v2);

  /* Call the newly compiled function */
  args[0] = &i1;
  args[1] = &i2;
  jit_function_apply(function, args, &result);

  /* Destroy the build context */
  jit_context_destroy(context);

  return INT2NUM(result);
}

/* Initialize the extension */
void Init_add()
{
  rb_define_global_function("add", add, 2);
}

