#include <jit/jit.h>
#include <ruby.h>

VALUE rb_mJIT;
VALUE rb_cFunction;

/* Holder for a function and its context */
struct Function
{
  jit_context_t context;
  jit_function_t function;
};

/* Abstraction to create a function that takes num_args integer
 * arguments and returns an integer
 */
static struct Function * create_function(unsigned long num_args)
{
  jit_type_t * param_types = ALLOCA_N(jit_type_t, num_args);
  unsigned long j;

  jit_type_t signature;
  jit_context_t context;
  jit_function_t function;
  struct Function * f;

  for(j = 0; j < num_args; ++j)
  {
    param_types[j] = jit_type_int;
  }

  signature = jit_type_create_signature(
      jit_abi_cdecl, /* C calling convention */
      jit_type_int,  /* Return type */
      param_types,   /* Parameter types */
      num_args,      /* Number of parameters */
      1);            /* Increment reference count */

  context = jit_context_create();
  function = jit_function_create(context, signature);

  f = ALLOC_N(struct Function, 1);
  f->context = context;
  f->function = function;

  return f;
}

static void function_free(struct Function * function)
{
  jit_context_destroy(function->context);
}

/* Create a new function */
static VALUE function_s_new(VALUE self, VALUE num_args_v)
{
  unsigned long num_args = NUM2ULONG(num_args_v);
  struct Function * f = create_function(num_args);
  return Data_Wrap_Struct(rb_cFunction, 0, function_free, f);
}

/* Lock the function, yield to the caller and compile the function
 */
static VALUE function_yield_and_compile(VALUE self)
{
  struct Function * f;
  Data_Get_Struct(self, struct Function, f);

  jit_context_build_start(f->context);
  
  rb_yield(self);

  if(!jit_function_compile(f->function))
  {
    rb_raise(rb_eRuntimeError, "Unable to compile function");
  }

  return Qnil;
}

/* Abandon the context if an exception was raised above, then unlock the
 * build context
 */
static VALUE function_abandon_if_exception(VALUE self)
{
  struct Function * f;
  Data_Get_Struct(self, struct Function, f);

  if(RTEST(ruby_errinfo))
  {
    jit_function_abandon(f->function);
  }

  jit_context_build_end(f->context);
  return Qnil;
}

/* Build a function:
 * - Lock the build context
 * - Yield to the caller
 * - Compile the function with jit_function_compile
 * - Unlock the build context
 */
static VALUE function_compile(VALUE self)
{ 
  rb_ensure(
      function_yield_and_compile,
      self,
      function_abandon_if_exception,
      self);

  return self;
}

/* Initialize the extension */
void Init_littlejit()
{
  rb_mJIT = rb_define_module("LittleJIT");
  rb_cFunction = rb_define_class_under(rb_mJIT, "Function", rb_cObject);
  rb_define_singleton_method(rb_cFunction, "new", function_s_new, 1);
  rb_define_method(rb_cFunction, "compile", function_compile, 0);
}

