/*
 * Copyright (c) 2004-2005 Sean Chittenden <sean@chittenden.org>
 *
 * Permission is hereby granted, free of charge, to any person
 * obtaining a copy of this software and associated documentation
 * files (the "Software"), to deal in the Software without
 * restriction, including without limitation the rights to use, copy,
 * modify, merge, publish, distribute, sublicense, and/or sell copies
 * of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be
 * included in all copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
 * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
 * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
 * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
 * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
 * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
 * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
 * SOFTWARE.
 *
 * $RubyForge$
 */

#include <string.h>

#include "rbmemcache.h"

VALUE cMemcache;
VALUE cMemcacheReq;
VALUE cMemcacheRes;
VALUE cMemcacheStats;
VALUE cMemcacheServer;
VALUE eMemcacheServerConn;

static ID s_to_f, s_to_i;

/* Internal flags to change the type of a function performed when
 * setting data. */
#define RBMC_SET_CMD_TYPE_ADD		1
#define RBMC_SET_CMD_TYPE_REPLACE	2
#define RBMC_SET_CMD_TYPE_SET		3

/* This set of #defines determe the data serialization type */
#define RBMC_DATA_STRING		0x0000
#define RBMC_DATA_NIL			0x0001
#define RBMC_DATA_FLOAT			0x0002
#define RBMC_DATA_FIXNUM		0x0003
#define RBMC_DATA_BOOL			0x0004
#define RBMC_DATA_MARSHAL		0x0005

/* This is reserved for flags */
#define RBMC_FLAGS_GZIP			0x1000

static VALUE	 rb_memcache_add(const int argc, VALUE *argv, VALUE self);
static VALUE	 rb_memcache_alloc(VALUE klass);
static VALUE	 rb_memcache_decr(const int argc, const VALUE *argv, VALUE self);
static VALUE	 rb_memcache_delete(const int argc, const VALUE *argv, VALUE self);
static VALUE	 rb_memcache_flush(VALUE self, VALUE server);
static VALUE	 rb_memcache_flush_all(VALUE self);
static VALUE	 rb_memcache_get(VALUE self, const VALUE key);
static VALUE	 rb_memcache_get_array(const int argc, const VALUE *argv, VALUE self);
static VALUE	 rb_memcache_get_hash(const int argc, const VALUE *argv, VALUE self);
static VALUE	 rb_memcache_incr(const int argc, const VALUE *argv, VALUE self);
static VALUE	 rb_memcache_init(const int argc, const VALUE *argv, VALUE self);
static VALUE	 rb_memcache_hash(VALUE self, const VALUE str);
static VALUE	 rb_memcache_replace(const int argc, VALUE *argv, VALUE self);
static VALUE	 rb_memcache_restore_data(const u_int16_t flags, const void *val, const size_t bytes);
static VALUE	 rb_memcache_server_add(const int argc, VALUE *argv, VALUE self);
static VALUE	 rb_memcache_server_alloc(VALUE klass);
static VALUE	 rb_memcache_server_init(const int argc, VALUE *argv, VALUE self);
static VALUE	 rb_memcache_server_hostname(VALUE self);
static VALUE	 rb_memcache_server_hostname_eq(VALUE self, VALUE hostname);
static VALUE	 rb_memcache_server_port(VALUE self);
static VALUE	 rb_memcache_server_port_eq(VALUE self, VALUE port);
static VALUE	 rb_memcache_set(const int argc, VALUE *argv, VALUE self);
static VALUE	 rb_memcache_set_cmd(const int argc, VALUE *argv, VALUE self, const int type);
#ifndef HAVE_RB_DEFINE_ALLOC_FUNC
static VALUE	 rb_memcache_server_new(const int argc, const VALUE *argv, VALUE self);
static VALUE	 rb_memcache_new(const int argc, const VALUE *argv, VALUE klass);
#endif


static VALUE
rb_memcache_add(const int argc, VALUE *argv, VALUE self) {
  return rb_memcache_set_cmd(argc, argv, self, RBMC_SET_CMD_TYPE_ADD);
}


static VALUE
rb_memcache_alloc(VALUE klass) {
  return Data_Wrap_Struct(klass, NULL, mc_free, NULL);
}


static VALUE
rb_memcache_decr(const int argc, const VALUE *argv, VALUE self) {
  struct memcache *mc;
  u_int32_t decr;

  Check_Type(self, T_DATA);
  Data_Get_Struct(self, struct memcache, mc);

  switch(argc) {
  case 1:
    decr = 1;
    break;
  case 2:
    decr = NUM2INT(argv[1]);
    break;
  default:
    rb_raise(rb_eArgError, "wrong number of arguments (1-2 args required)");
  }

  return UINT2NUM(mc_decr(mc, RSTRING(argv[0])->ptr, RSTRING(argv[0])->len, 1));
}


static VALUE
rb_memcache_delete(const int argc, const VALUE *argv, VALUE self) {
  struct memcache *mc;
  time_t hold;

  Check_Type(self, T_DATA);
  Data_Get_Struct(self, struct memcache, mc);

  switch(argc) {
  case 1:
    hold = 0;
    break;
  case 2:
    hold = NUM2INT(argv[1]);
    break;
  default:
    rb_raise(rb_eArgError, "wrong number of arguments (1-2 args required)");
  }

  return UINT2NUM(mc_delete(mc, RSTRING(argv[0])->ptr, RSTRING(argv[0])->len, hold));
}


static VALUE
rb_memcache_flush(VALUE self, VALUE server) {
  struct memcache *mc;
  struct memcache_server *ms;

  Check_Type(self, T_DATA);
  Check_Type(server, T_DATA);

  if (!rb_obj_is_instance_of(server, cMemcacheServer))
    rb_raise(rb_eTypeError, "wrong argument type %s (expected Memcache::Server)", rb_class2name(CLASS_OF(self)));

  Data_Get_Struct(self, struct memcache, mc);
  Data_Get_Struct(self, struct memcache_server, ms);

  return INT2FIX(mc_flush(mc, ms));
}


static VALUE
rb_memcache_flush_all(VALUE self) {
  struct memcache *mc;

  Check_Type(self, T_DATA);

  Data_Get_Struct(self, struct memcache, mc);

  return INT2FIX(mc_flush_all(mc));
}


static VALUE
rb_memcache_get(VALUE self, const VALUE key) {
  VALUE ret;
  struct memcache_req *req;
  struct memcache_res *res;
  struct memcache *mc;

  Check_Type(self, T_DATA);

  Data_Get_Struct(self, struct memcache, mc);

  req = mc_req_new();
  res = mc_req_add(req, RSTRING(key)->ptr, RSTRING(key)->len);
  mc_res_free_on_delete(res, 0);
  mc_get(mc, req);

  if (mc_res_found(res) == 1)
    ret = rb_memcache_restore_data(res->flags, res->val, res->bytes);
  else
    ret = Qnil;

  mc_req_free(req);
  return ret;
}


/* Need two types of get multi's.  One returns a parallel array for
 * the args passed in.  In the place of values not found, set Qnil.
 * The other get multi needs to return a hash.  keys passed in matched
 * with values, or Qnil if not found. */
static VALUE
rb_memcache_get_array(const int argc, const VALUE *argv, VALUE self) {
  struct memcache_req *req;
  struct memcache_res *res;
  struct memcache *mc;
  u_int32_t i;
  VALUE ret;

  if (argc < 1)
    rb_raise(rb_eArgError, "wrong number of arguments (one or more args required)");

  Check_Type(self, T_DATA);
  Data_Get_Struct(self, struct memcache, mc);

  ret = rb_ary_new2(argc);
  req = mc_req_new();
  for (i = 0; i < (u_int32_t)argc; i++) {
    res = mc_req_add(req, RSTRING(argv[i])->ptr, RSTRING(argv[0])->len);
    mc_res_free_on_delete(res, 0);
  }

  mc_get(mc, req);

  for (res = req->query.tqh_first; res != NULL; res = res->entries.tqe_next) {
    if (mc_res_found(res) == 1) {
      rb_ary_push(ret, rb_memcache_restore_data(res->flags, res->val, res->bytes));
    } else {
      rb_ary_push(ret, Qnil);
    }
  }

  mc_req_free(req);
  return ret;
}


static VALUE
rb_memcache_get_hash(const int argc, const VALUE *argv, VALUE self) {
  struct memcache_req *req;
  struct memcache_res *res;
  struct memcache *mc;
  u_int32_t i;
  VALUE ret;

  if (argc < 1)
    rb_raise(rb_eArgError, "wrong number of arguments (one or more args required)");

  Check_Type(self, T_DATA);
  Data_Get_Struct(self, struct memcache, mc);

  ret = rb_hash_new();
  req = mc_req_new();
  for (i = 0; i < (u_int32_t)argc; i++) {
    res = mc_req_add(req, RSTRING(argv[i])->ptr, RSTRING(argv[0])->len);
    mc_res_free_on_delete(res, 0);
  }

  mc_get(mc, req);

  for (res = req->query.tqh_first; res != NULL; res = res->entries.tqe_next) {
    if (mc_res_found(res) == 1) {
      rb_hash_aset(ret, rb_str_new(res->key, res->len), rb_memcache_restore_data(res->flags, res->val, res->bytes));
    } else {
      rb_hash_aset(ret, rb_str_new(res->key, res->len), Qnil);
    }
  }

  mc_req_free(req);
  return ret;
}


static VALUE
rb_memcache_hash(VALUE self, const VALUE str) {
  return UINT2NUM(mc_hash_key(RSTRING(str)->ptr, RSTRING(str)->len));
}


static VALUE
rb_memcache_incr(const int argc, const VALUE *argv, VALUE self) {
  struct memcache *mc;
  u_int32_t incr;

  Check_Type(self, T_DATA);
  Data_Get_Struct(self, struct memcache, mc);

  switch(argc) {
  case 1:
    incr = 1;
    break;
  case 2:
    incr = NUM2INT(argv[1]);
    break;
  default:
    rb_raise(rb_eArgError, "wrong number of arguments (1-2 args required)");
  }

  return UINT2NUM(mc_incr(mc, RSTRING(argv[0])->ptr, RSTRING(argv[0])->len, 1));
}


static VALUE
rb_memcache_init(const int argc, const VALUE *argv, VALUE self) {
  Check_Type(self, T_DATA);
  if (!rb_obj_is_instance_of(self, cMemcache))
    rb_raise(rb_eTypeError, "wrong argument type %s (expected Memcache)", rb_class2name(CLASS_OF(self)));

  if (DATA_PTR(self) != NULL)
    rb_raise(rb_eArgError, "Cannot re-initialize Memcache object.");

  DATA_PTR(self) = mc_new();

  return self;
}


#ifndef HAVE_RB_DEFINE_ALLOC_FUNC
static VALUE
rb_memcache_new(const int argc, const VALUE *argv, VALUE klass) {
  VALUE obj;

  obj = rb_memcache_alloc(klass);
  rb_memcache_init(argc, argv, obj);
  return obj;
}
#endif


static VALUE
rb_memcache_replace(const int argc, VALUE *argv, VALUE self) {
  return rb_memcache_set_cmd(argc, argv, self, RBMC_SET_CMD_TYPE_REPLACE);
}


static VALUE
rb_memcache_restore_data(const u_int16_t flags, const void *val, const size_t bytes) {
  VALUE ret;

  switch (flags) {
  case RBMC_DATA_STRING:
    ret = rb_tainted_str_new(val, bytes);
    break;
  case RBMC_DATA_NIL:
    ret = rb_tainted_str_new(NULL, 0);
    break;
  case RBMC_DATA_FLOAT:
    ret = rb_funcall(rb_tainted_str_new(val, bytes), s_to_f, 0, NULL);
    break;
  case RBMC_DATA_FIXNUM:
    ret = rb_funcall(rb_tainted_str_new(val, bytes), s_to_i, 0, NULL);
    break;
  case RBMC_DATA_BOOL:
    if (bytes == 1) {
      if (val == (void *)0x1) {
	ret = Qtrue;
	break;
      } else if (val == (void *)0x0) {
	ret = Qfalse;
	break;
      }
    }

    rb_raise(rb_eRangeError, "invalid boolean value 0x%x", val);
    break;
  case RBMC_DATA_MARSHAL:
    ret = rb_marshal_load(rb_tainted_str_new(val, bytes));
    break;
  default:
    rb_raise(rb_eTypeError, "unable to handle client flags value 0x%x", flags);
  }

  return ret;
}


static VALUE
rb_memcache_server_add(const int argc, VALUE *argv, VALUE self) {
  struct memcache *mc;
  struct memcache_server *ms;
  VALUE host, port;

  Check_Type(self, T_DATA);
  Data_Get_Struct(self, struct memcache, mc);

  switch (argc) {
  case 1:
    switch(TYPE(argv[0])) {
    case T_DATA:
      if (!rb_obj_is_instance_of(argv[0], cMemcacheServer))
	rb_raise(rb_eTypeError, "wrong argument type %s (expected Memcache::Server)",
		 rb_class2name(CLASS_OF(argv[0])));

      Data_Get_Struct(self, struct memcache_server, ms);

      return INT2FIX(mc_server_add3(mc, ms));
    default:
      host = StringValue(argv[0]);
      return INT2FIX(mc_server_add4(mc, RSTRING(host)->ptr));
    }

  case 2:
    host = StringValue(argv[0]);
    port = StringValue(argv[1]);
    return INT2FIX(mc_server_add2(mc, RSTRING(host)->ptr, RSTRING(host)->len,
				  RSTRING(port)->ptr, RSTRING(port)->len));
  default:
    rb_raise(rb_eArgError, "wrong number of arguments (1 or 2 args required)");
  }
}


static VALUE
rb_memcache_server_alloc(VALUE klass) {
  return Data_Wrap_Struct(klass, NULL, mc_server_free, NULL);
}


static VALUE
rb_memcache_server_init(const int argc, VALUE *argv, VALUE self) {
  struct memcache_server *ms;

  Check_Type(self, T_DATA);
  if (!rb_obj_is_instance_of(self, cMemcacheServer))
    rb_raise(rb_eTypeError, "wrong argument type %s (expected Memcache::Server)", rb_class2name(CLASS_OF(self)));

  if (DATA_PTR(self) != NULL)
    rb_raise(rb_eArgError, "Cannot re-initialize Memcache object.");

  DATA_PTR(self) = mc_server_new();
  Data_Get_Struct(self, struct memcache_server, ms);

  switch (argc) {
  case 1:
    ms->hostname = mc_strdup(StringValueCStr(argv[0]));
    break;
  case 2:
    ms->hostname = mc_strdup(StringValueCStr(argv[0]));
    ms->port = mc_strdup(StringValueCStr(argv[1]));
    break;
  default:
    rb_raise(rb_eArgError, "wrong number of arguments (1 or 2 args required)");
  }

  return self;
}


#ifndef HAVE_RB_DEFINE_ALLOC_FUNC
static VALUE
rb_memcache_server_new(const int argc, const VALUE *argv, VALUE klass) {
  VALUE obj;

  obj = rb_memcache_server_alloc(klass);
  rb_memcache_server_init(argc, argv, obj);
  return obj;
}
#endif


static VALUE
rb_memcache_server_hostname(VALUE self) {
  struct memcache_server *ms;

  Data_Get_Struct(self, struct memcache_server, ms);

  return ms->hostname == NULL ? Qnil : rb_str_new2(ms->hostname);
}


static VALUE
rb_memcache_server_hostname_eq(VALUE self, VALUE hostname) {
  struct memcache_server *ms;

  Data_Get_Struct(self, struct memcache_server, ms);

  if (ms->fd != -1)
    rb_raise(eMemcacheServerConn, "already connected: unable to change the hostname");

  if (ms->hostname != NULL)
      xfree(ms->hostname);
  ms->hostname = strdup(StringValueCStr(hostname));

  return ms->hostname == NULL ? Qnil : rb_str_new2(ms->hostname);
}


static VALUE
rb_memcache_server_port(VALUE self) {
  struct memcache_server *ms;

  Data_Get_Struct(self, struct memcache_server, ms);

  return ms->port == NULL ? Qnil : rb_str_new2(ms->port);
}


static VALUE
rb_memcache_server_port_eq(VALUE self, VALUE port) {
  struct memcache_server *ms;

  Data_Get_Struct(self, struct memcache_server, ms);

  if (ms->fd != -1)
    rb_raise(eMemcacheServerConn, "already connected: unable to change the port");

  switch(TYPE(port)) {
  case T_STRING:
  case T_FIXNUM:
    if (ms->port != NULL)
      xfree(ms->port);
    ms->port = mc_strdup(StringValueCStr(port));
    break;
  default:
    rb_raise(rb_eArgError, "port number must be a FixNum");
  }

  return ms->port == NULL ? Qnil : INT2FIX(ms->port);
}


static VALUE
rb_memcache_set(const int argc, VALUE *argv, VALUE self) {
  return rb_memcache_set_cmd(argc, argv, self, RBMC_SET_CMD_TYPE_SET);
}


static VALUE
rb_memcache_set_ary(VALUE self, VALUE key, VALUE val) {
  VALUE argv[2];
  argv[0] = key;
  argv[1] = val;
  return rb_memcache_set_cmd(2, argv, self, RBMC_SET_CMD_TYPE_SET);
}


static VALUE
rb_memcache_set_cmd(const int argc, VALUE *argv, VALUE self, const int type) {
  struct memcache *mc;
  char *key, *val;
  size_t key_len, val_len;
  time_t expire;
  u_int16_t flags;
  VALUE tmp;

  Check_Type(self, T_DATA);
  Data_Get_Struct(self, struct memcache, mc);

  expire = 0;
  flags = 0;
  key_len = val_len = 0;
  key = val = NULL;

  switch (argc) {
  case 3:
    /* key, value, expiration */
    if (TYPE(argv[2]) != T_NIL)
      expire = NUM2INT(rb_funcall(argv[2], s_to_i, 0, NULL));
  case 2:
    /* key, value */

    /* Most value's get sent as strings, but have a specific client
     * flag set. */
    switch(TYPE(argv[1])) {
    case T_STRING:
      flags |= RBMC_DATA_STRING;
      tmp = StringValue(argv[1]);
      val_len = RSTRING(tmp)->len;
      val = RSTRING(tmp)->ptr;
      break;
    case T_NIL:
      flags |= RBMC_DATA_NIL;
      val_len = 0;
      val = NULL;
      break;
    case T_FLOAT:
      flags |= RBMC_DATA_FLOAT;
      tmp = StringValue(argv[1]);
      val_len = RSTRING(tmp)->len;
      val = RSTRING(tmp)->ptr;
      break;
    case T_FIXNUM:
      flags |= RBMC_DATA_FIXNUM;
      tmp = StringValue(argv[1]);
      val_len = RSTRING(tmp)->len;
      val = RSTRING(tmp)->ptr;
      break;
    case T_TRUE:
      flags |= RBMC_DATA_BOOL;
      val = (void *)0x1;
      val_len = 1;
      break;
    case T_FALSE:
      flags |= RBMC_DATA_BOOL;
      val = (void *)0x0;
      val_len = 1;
      break;
    default:
      /* Everything else get's marshaled and sent as
       * RBMC_DATA_MARSHAL */
      flags |= RBMC_DATA_MARSHAL;
      tmp = rb_marshal_dump(argv[1], Qnil);
      val_len = RSTRING(tmp)->len;
      val = RSTRING(tmp)->ptr;
    }
  case 1:
    /* key */
    tmp = StringValue(argv[0]);
    key_len = RSTRING(tmp)->len;
    key = RSTRING(tmp)->ptr;
    break;
  default:
    rb_raise(rb_eArgError, "wrong number of arguments (1-3 args required)");
  }

  switch (type) {
  case RBMC_SET_CMD_TYPE_ADD:
    return mc_add(mc, key, key_len, val, val_len, expire, flags);
  case RBMC_SET_CMD_TYPE_REPLACE:
    return mc_replace(mc, key, key_len, val, val_len, expire, flags);
  case RBMC_SET_CMD_TYPE_SET:
    return mc_set(mc, key, key_len, val, val_len, expire, flags);
  default:
    rb_fatal("invalid cmd type");
  }
}


void
Init_memcache(void) {
  s_to_f = rb_intern("to_f");
  s_to_i = rb_intern("to_i");

  /* Init memory context: this isn't any less screwed up than
   * ruby(1)'s memory context handling. */
  mcMemSetup(xfree, (mcMallocFunc)xmalloc, NULL, (mcReallocFunc)xrealloc);

  /* Classes */
  cMemcache = rb_define_class("Memcache", rb_cObject);
  cMemcacheReq = rb_define_class_under(cMemcache, "Request", rb_cObject);
  cMemcacheRes = rb_define_class_under(cMemcache, "Response", rb_cObject);
  cMemcacheStats = rb_define_class_under(cMemcache, "Stats", rb_cObject);
  cMemcacheServer = rb_define_class_under(cMemcache, "Server", rb_cObject);

  /* Exceptions */
  eMemcacheServerConn = rb_define_class_under(cMemcacheServer, "Conn", rb_eException);

  /* Class initialization foo.  When did this start to suck so hard? */
  /* BEGIN: Memcache */
#ifdef HAVE_RB_DEFINE_ALLOC_FUNC
  rb_define_alloc_func(cMemcache, rb_memcache_alloc);
#else
  rb_define_singleton_method(cMemcache, "allocate", rb_memcache_alloc, 0);
  rb_define_singleton_method(cMemcache, "new", rb_memcache_new, -1);
#endif
  rb_define_method(cMemcache, "initialize", rb_memcache_init, -1);
  /* END: Memcache */

  /* BEGIN: Memcache::Server */
#ifdef HAVE_RB_DEFINE_ALLOC_FUNC
  rb_define_alloc_func(cMemcacheServer, rb_memcache_server_alloc);
#else
  rb_define_singleton_method(cMemcacheServer, "allocate", rb_memcache_server_alloc, 0);
  rb_define_singleton_method(cMemcacheServer, "new", rb_memcache_server_new, -1);
#endif
  rb_define_method(cMemcacheServer, "initialize", rb_memcache_server_init, -1);
  /* END: Memcache::Server */


  /* Memcache methods */
  rb_define_singleton_method(cMemcache, "hash", rb_memcache_hash, 1);
  rb_define_method(cMemcache, "add", rb_memcache_add, -1);
  rb_define_method(cMemcache, "add_server", rb_memcache_server_add, -1);
  rb_define_method(cMemcache, "decr", rb_memcache_decr, -1);
  rb_define_method(cMemcache, "delete", rb_memcache_delete, -1);
  rb_define_method(cMemcache, "flush", rb_memcache_flush, 1);
  rb_define_method(cMemcache, "flush_all", rb_memcache_flush_all, 0);
  rb_define_method(cMemcache, "get", rb_memcache_get, 1);
  rb_define_method(cMemcache, "get_a", rb_memcache_get_array, -1);
  rb_define_method(cMemcache, "get_h", rb_memcache_get_hash, -1);
  rb_define_method(cMemcache, "incr", rb_memcache_incr, -1);
  rb_define_method(cMemcache, "replace", rb_memcache_replace, -1);
  rb_define_method(cMemcache, "set", rb_memcache_set, -1);
  rb_define_method(cMemcache, "[]", rb_memcache_get, 1);
  rb_define_method(cMemcache, "[]=", rb_memcache_set_ary, 2);


  /* Memcache::Server Instance methods */
  rb_define_method(cMemcacheServer, "hostname", rb_memcache_server_hostname, 0);
  rb_define_method(cMemcacheServer, "hostname=", rb_memcache_server_hostname_eq, 1);
  rb_define_method(cMemcacheServer, "port", rb_memcache_server_port, 0);
  rb_define_method(cMemcacheServer, "port=", rb_memcache_server_port_eq, 1);
}
