#include <jit/jit.h>
#include <ruby.h>

VALUE rb_mJIT;
VALUE rb_cFunction;
VALUE rb_cValue;

/* Holder for a function and its context */
struct Function
{
  jit_context_t context;
  jit_function_t function;
};

/* Abstraction to create a function that takes num_args integer
 * arguments and returns an integer
 */
static struct Function * create_function(unsigned long num_args)
{
  jit_type_t * param_types = ALLOCA_N(jit_type_t, num_args);
  unsigned long j;

  jit_type_t signature;
  jit_context_t context;
  jit_function_t function;
  struct Function * f;

  for(j = 0; j < num_args; ++j)
  {
    param_types[j] = jit_type_int;
  }

  signature = jit_type_create_signature(
      jit_abi_cdecl, /* C calling convention */
      jit_type_int,  /* Return type */
      param_types,   /* Parameter types */
      num_args,      /* Number of parameters */
      1);            /* Increment reference count */

  context = jit_context_create();
  function = jit_function_create(context, signature);

  f = ALLOC_N(struct Function, 1);
  f->context = context;
  f->function = function;

  return f;
}

static void function_free(struct Function * function)
{
  jit_context_destroy(function->context);
}

/* Create a new function */
static VALUE function_s_new(VALUE self, VALUE num_args_v)
{
  unsigned long num_args = NUM2ULONG(num_args_v);
  struct Function * f = create_function(num_args);
  return Data_Wrap_Struct(rb_cFunction, 0, function_free, f);
}

/* Hash mapping a value to a function
 * There's a more efficient way to do this, but it's more complex and
 * relies on the internals of libjit
 */
static VALUE value_to_function = Qnil;

static void mark_value(jit_value_t value)
{
  /* Mark the function who owns the value */
  VALUE function = rb_hash_aref(
      value_to_function,
      ULONG2NUM((unsigned long)value));
  rb_gc_mark(function);
}

static void free_value(jit_value_t value)
{
  /* Remove the value from our mapping */
  rb_hash_delete(
      value_to_function,
      ULONG2NUM((unsigned long)value));
}

static VALUE wrap_value(jit_value_t value, VALUE function)
{
  VALUE value_v = Data_Wrap_Struct(
      rb_cValue,
      mark_value,
      free_value,
      value);
  rb_hash_aset(
      value_to_function,
      ULONG2NUM((unsigned long)value),
      value_v);
  return value_v;
}

/* Lock the function, yield to the caller and compile the function
 */
static VALUE function_yield_and_compile(VALUE self)
{
  struct Function * f;
  Data_Get_Struct(self, struct Function, f);

  jit_context_build_start(f->context);
  
  rb_yield(self);

  if(!jit_function_compile(f->function))
  {
    rb_raise(rb_eRuntimeError, "Unable to compile function");
  }

  return Qnil;
}

/* Abandon the context if an exception was raised above, then unlock the
 * build context
 */
static VALUE function_abandon_if_exception(VALUE self)
{
  struct Function * f;
  Data_Get_Struct(self, struct Function, f);

  if(RTEST(ruby_errinfo))
  {
    jit_function_abandon(f->function);
  }

  jit_context_build_end(f->context);
  return Qnil;
}

/* Build a function:
 * - Lock the build context
 * - Yield to the caller
 * - Compile the function with jit_function_compile
 * - Unlock the build context
 */
static VALUE function_compile(VALUE self)
{ 
  rb_ensure(
      function_yield_and_compile,
      self,
      function_abandon_if_exception,
      self);

  return self;
}

/* Return a placeholder for the nth parameter to the function */
static VALUE function_get_param(VALUE self, VALUE n)
{
  struct Function * f;
  jit_value_t value;

  Data_Get_Struct(self, struct Function, f);
  value = jit_value_get_param(f->function, NUM2INT(n));

  return wrap_value(value, self);
}

/* Append an instruction to add v1 and v2 and return a placeholder for
 * the result of the operation
 */
static VALUE function_insn_add(VALUE self, VALUE v1, VALUE v2)
{
  struct Function * f;
  jit_value_t value1, value2, result;

  Data_Get_Struct(self, struct Function, f);
  Data_Get_Struct(v1, struct _jit_value, value1);
  Data_Get_Struct(v2, struct _jit_value, value2);
  result = jit_insn_add(f->function, value1, value2);

  return wrap_value(result, self);
}

/* Append an instruction to return a value from a function */
static VALUE function_insn_return(VALUE self, VALUE retval_v)
{
  struct Function * f;
  jit_value_t retval;

  Data_Get_Struct(self, struct Function, f);
  Data_Get_Struct(retval_v, struct _jit_value, retval);
  jit_insn_return(f->function, retval);

  return Qnil;
}

/* Call an already compiled function */
static VALUE function_apply(int argc, VALUE * argv, VALUE self)
{
  struct Function * f;
  jit_int * arg_values = ALLOCA_N(jit_int, argc);
  void * * args = ALLOCA_N(void *, argc);
  int j;
  jit_int result;

  Data_Get_Struct(self, struct Function, f);
  
  for(j = 0; j < argc; ++j)
  {
    arg_values[j] = NUM2INT(argv[j]);
    args[j] = &arg_values[j];
  }

  jit_function_apply(f->function, args, &result);

  return INT2NUM(result);
}

/* Initialize the extension */
void Init_littlejit()
{
  rb_mJIT = rb_define_module("LittleJIT");
  rb_cFunction = rb_define_class_under(rb_mJIT, "Function", rb_cObject);
  rb_cValue = rb_define_class_under(rb_mJIT, "Value", rb_cObject);
  rb_define_singleton_method(rb_cFunction, "new", function_s_new, 1);
  rb_define_method(rb_cFunction, "compile", function_compile, 0);
  rb_define_method(rb_cFunction, "get_param", function_get_param, 1);
  rb_define_method(rb_cFunction, "insn_add", function_insn_add, 2);
  rb_define_method(rb_cFunction, "insn_return", function_insn_return, 1);
  rb_define_method(rb_cFunction, "apply", function_apply, -1);

  value_to_function = rb_hash_new();
}

