/*
 * nasd_linux_sock.c
 *
 * Linux in-kernel sockets
 *
 * Authors: Jim Zelenka, Nat Lanza
 */
/*
 * Copyright (c) of Carnegie Mellon University, 1999, 2000.
 *
 * Permission to reproduce, use, and prepare derivative works of
 * this software for internal use is granted provided the copyright
 * and "No Warranty" statements are included with all reproductions
 * and derivative works. This software may also be redistributed
 * without charge provided that the copyright and "No Warranty"
 * statements are included in all redistributions.
 *
 * NO WARRANTY. THIS SOFTWARE IS FURNISHED ON AN "AS IS" BASIS.
 * CARNEGIE MELLON UNIVERSITY MAKES NO WARRANTIES OF ANY KIND, EITHER
 * EXPRESSED OR IMPLIED AS TO THE MATTER INCLUDING, BUT NOT LIMITED
 * TO: WARRANTY OF FITNESS FOR PURPOSE OR MERCHANTABILITY, EXCLUSIVITY
 * OF RESULTS OR RESULTS OBTAINED FROM USE OF THIS SOFTWARE. CARNEGIE
 * MELLON UNIVERSITY DOES NOT MAKE ANY WARRANTY OF ANY KIND WITH RESPECT
 * TO FREEDOM FROM PATENT, TRADEMARK, OR COPYRIGHT INFRINGEMENT.
 */


#include <nasd/nasd_options.h>
#include <nasd/nasd_threadstuff.h>
#include <nasd/nasd_shutdown.h>
#include <nasd/nasd_types.h>
#include <nasd/nasd_common.h>
#include <nasd/nasd_srpc.h>

/*
 * Define NASD_LS_MAX_IOV to
 * MIN (NASD_LS_MAX_MAX_IOV, UIO_MAXIOV)
 */

#define NASD_LS_MAX_MAX_IOV 16

#if NASD_UIO_MAXIOV > NASD_LS_MAX_MAX_IOV
#define NASD_LS_MAX_IOV NASD_LS_MAX_MAX_IOV
#else /* NASD_UIO_MAXIOV > NASD_LS_MAX_MAX_IOV */
#define NASD_LS_MAX_IOV NASD_UIO_MAXIOV
#endif /* NASD_UIO_MAXIOV > NASD_LS_MAX_MAX_IOV */

#define MAX_EAGAIN_IN_A_ROW_ALLOWED 2

nasd_status_t
nasd_srpc_sys_sock_setnodelay(
  nasd_srpc_sock_t           *sock,
  int                         nodelay)
{
  int ret, val;

  val = nodelay;

  ret = sock->sock.linux_sock->ops->setsockopt(sock->sock.linux_sock,
    IPPROTO_TCP, TCP_NODELAY, (char *)&val, sizeof(val));
  if (ret)
    return(NASD_FAIL);

  sock->sock.nodelay = nodelay;

  return(NASD_SUCCESS);
}

nasd_status_t
nasd_srpc_sys_sock_set_options(
  nasd_srpc_sock_t          *sock,
  nasd_srpc_sockbuf_opts_t  *opts)
{
  nasd_status_t rc;
  int ret, val;

  NASD_ASSERT(sock->sock.linux_sock != NULL);
  rc = NASD_SUCCESS;  

  sock->sock.nodelay_thresh = opts->nodelay_thresh;

  if (opts->sndbuf >= 0) {
    val = opts->sndbuf;
    ret = sock_setsockopt(sock->sock.linux_sock, SOL_SOCKET, SO_SNDBUF,
      (char *)&val, sizeof(val));
    if (ret) {
      rc = NASD_FAIL;
    }
  }

  if (opts->rcvbuf >= 0) {
    val = opts->rcvbuf;
    ret = sock_setsockopt(sock->sock.linux_sock, SOL_SOCKET, SO_RCVBUF,
      (char *)&val, sizeof(val));
    if (ret) {
      rc = NASD_FAIL;
    }
  }

  if (opts->nodelay >= 0) {
    val = opts->nodelay;
    ret = sock->sock.linux_sock->ops->setsockopt(sock->sock.linux_sock,
      IPPROTO_TCP, TCP_NODELAY, (char *)&val, sizeof(val));
    if (ret) {
      rc = NASD_FAIL;
    }
  }

  if (opts->keepalive >= 0) {
    val = opts->keepalive;
    ret = sock_setsockopt(sock->sock.linux_sock, SOL_SOCKET, SO_KEEPALIVE,
      (char *)&val, sizeof(val));
    if (ret) {
      rc = NASD_FAIL;
    }
  }

  if (opts->reuseaddr >= 0) {
    val = opts->reuseaddr;
    ret = sock_setsockopt(sock->sock.linux_sock, SOL_SOCKET, SO_REUSEADDR,
      (char *)&val, sizeof(val));
    if (ret) {
      rc = NASD_FAIL;
    }
  }

  return(rc);
}

nasd_status_t
nasd_srpc_sys_sock_init(
  nasd_srpc_sock_t  *sock)
{
  /* nothing to do here */
  return(NASD_SUCCESS);
}

void
nasd_srpc_sys_do_sockclose(
  nasd_srpc_sock_t  *sock)
{
  sock_release(sock->sock.linux_sock);
  sock->sock.linux_sock = NULL;
}

nasd_status_t
nasd_srpc_sys_sock_conn(
  nasd_srpc_sock_t  *sock,
  nasd_uint32        ipaddr,
  nasd_uint16        ipport)
{
  struct sockaddr_in server;
  int ret, val, val_len;
  mm_segment_t oldfs;

  server.sin_family = AF_INET;
  server.sin_addr.s_addr = nasd_hton32(ipaddr);
  server.sin_port = nasd_hton16(ipport);

  lock_kernel();
  ret = sock_create(AF_INET, SOCK_STREAM, PF_UNSPEC, &sock->sock.linux_sock);
  unlock_kernel();
  if (ret) {
    sock->sock.linux_sock = NULL;
    return(NASD_FAIL);
  }

  sock->sock.linux_sock->file = NULL;

  lock_kernel();
  ret = sock->sock.linux_sock->ops->connect(sock->sock.linux_sock,
    (struct sockaddr *)&server, sizeof(server), 0);
  unlock_kernel();
  if (ret) {
    nasd_srpc_sys_do_sockclose(sock);
    return(NASD_SRPC_CANNOT_CONNECT);
  }

  sock->sock.peer_addr = server;

  val_len = sizeof(val);
  oldfs = get_fs();
  set_fs(KERNEL_DS);
  ret = sock->sock.linux_sock->ops->getsockopt(sock->sock.linux_sock,
    IPPROTO_TCP, TCP_NODELAY, (char *)&val, &val_len);
  set_fs(oldfs);
  if (ret || (val_len != sizeof(val))) {
    nasd_srpc_sys_do_sockclose(sock);
    return(NASD_SRPC_CANNOT_CONNECT);
  }
  sock->sock.nodelay = val;

  return(NASD_SUCCESS);
}

NASD_INLINE int
nasd_srpc_sys_sock_writev(
  nasd_srpc_sock_t  *sock,
  struct iovec      *iov,
  int                nio,
  int                total_len)
{
  mm_segment_t oldfs;
  struct msghdr msg;
  int out_len;

  msg.msg_name = &sock->sock.peer_addr;
  msg.msg_namelen = sizeof(sock->sock.peer_addr);
  msg.msg_iov = iov;
  msg.msg_iovlen = nio;
  msg.msg_control = 0;
  msg.msg_controllen = 0;
  msg.msg_flags = 0;

  oldfs = get_fs();
  set_fs(KERNEL_DS);
  out_len = sock_sendmsg(sock->sock.linux_sock, &msg, total_len);
  set_fs(oldfs);

  return(out_len);
}

NASD_INLINE int
nasd_srpc_sys_sock_write(
  nasd_srpc_sock_t  *sock,
  void              *buf,
  int                len)
{
  struct iovec iov;
  int ret;

  iov.iov_base = buf;
  iov.iov_len = len;

  ret = nasd_srpc_sys_sock_writev(sock, &iov, 1, len);
  return(ret);
}

NASD_INLINE int
nasd_srpc_sys_sock_read(
  nasd_srpc_sock_t  *sock,
  void              *buf,
  int                len,
  int                blocking)
{
  mm_segment_t oldfs;
  struct msghdr msg;
  struct iovec iov;
  int out_len, n;

  iov.iov_base = buf;
  iov.iov_len = len;

  msg.msg_name = &sock->sock.peer_addr;
  msg.msg_namelen = sizeof(sock->sock.peer_addr);
  msg.msg_iov = &iov;
  msg.msg_iovlen = 1;
  msg.msg_control = 0;
  msg.msg_controllen = 0;
  if (blocking)
    msg.msg_flags = 0;
  else
    msg.msg_flags = MSG_DONTWAIT;

  oldfs = get_fs();
  set_fs(KERNEL_DS);

  n = 0;
  do {
    out_len = sock_recvmsg(sock->sock.linux_sock, &msg, len, msg.msg_flags);
  } while ((out_len == (-EAGAIN)) && ((n++) < MAX_EAGAIN_IN_A_ROW_ALLOWED));

  set_fs(oldfs);

  return(out_len);
}

nasd_status_t
nasd_srpc_sys_sock_send(
  nasd_srpc_sock_t    *sock,
  nasd_srpc_memvec_t  *memvec,
  int                  total_len,
  int                 *bytes_sent)
{
  int f, sent, ret, want, nio, this_sent;
  struct iovec iov[NASD_LS_MAX_IOV];
  nasd_srpc_memvec_t *vec;
  nasd_status_t rc;
  char *tmp;

  sent = 0;
  nio = 0;
  want = 0;

  NASD_ASSERT(sock->sock.linux_sock != NULL);
  NASD_SRPC_INC_COUNTER(&nasd_srpc_stats.sock_send_calls);

  if ((sock->sock.nodelay_thresh >= 0) &&
    ((total_len >= sock->sock.nodelay_thresh) && (sock->sock.nodelay)))
  {
    rc = nasd_srpc_sys_sock_setnodelay(sock, 0);
    if (rc) {
      *bytes_sent = 0;
      return(rc);
    }
  }

  if ((sock->sock.nodelay_thresh >= 0) &&
    ((total_len < sock->sock.nodelay_thresh) && (sock->sock.nodelay == 0)))
  {
    rc = nasd_srpc_sys_sock_setnodelay(sock, 1);
    if (rc) {
      *bytes_sent = 0;
      return(rc);
    }
  }

  if (memvec->next == NULL) {
    /*
     * No scatter-gather optimization
     */
    NASD_ASSERT(memvec->len == total_len);
    tmp = (char *)memvec->buf;
    for(sent=0;sent<memvec->len;) {
      NASD_SRPC_INC_COUNTER(&nasd_srpc_stats.sock_write_calls);
      ret = nasd_srpc_sys_sock_write(sock, &tmp[sent],
        memvec->len-sent);
      if (ret <= 0) {
        *bytes_sent = sent;
        return(NASD_FAIL);
      }
      sent += ret;
    }
    *bytes_sent = sent;
    return(NASD_SUCCESS);
  }

#if 1

  /*
   * Use batches of writev() to go through
   * the memvec list.
   */

  for(vec=memvec;vec;vec=vec->next) {
    if (vec->len == 0)
      continue;
    if (vec->len < 0) {
      *bytes_sent = sent;
      return(NASD_MEM_LIST_ERR);
    }
    iov[nio].iov_base = vec->buf;
    iov[nio].iov_len = vec->len;
    want += vec->len;
    nio++;
    if (nio == NASD_LS_MAX_IOV) {
      f = 0;
      ret = 0;
      this_sent = 0;
      while (this_sent < want) {
        NASD_SRPC_INC_COUNTER(&nasd_srpc_stats.sock_writev_calls);
        ret = nasd_srpc_sys_sock_writev(sock, &iov[f],
          nio-f, want);
        if (ret <= 0) {
          *bytes_sent = sent;
          return(NASD_FAIL);
        }
        sent += ret;
        this_sent += ret;
        if (this_sent < want) {
          /* Great. Figure out where we are. */
          for(;f<nio;f++) {
            if (ret < iov[f].iov_len)
              break;
            ret -= iov[f].iov_len;
          }
          NASD_ASSERT(f < nio);
          if (ret) {
            iov[f].iov_len -= ret;
            tmp = (char *)iov[f].iov_base;
            iov[f].iov_base = &tmp[ret];
          }
        }
      }
      nio = 0;
      want = 0;
    }
    NASD_ASSERT(nio < NASD_LS_MAX_IOV);
  }

  f = 0;
  ret = 0;
  this_sent = 0;
  while (this_sent < want) {
    NASD_SRPC_INC_COUNTER(&nasd_srpc_stats.sock_writev_calls);
    ret = nasd_srpc_sys_sock_writev(sock, &iov[f], nio-f, want);
    if (ret <= 0) {
      *bytes_sent = sent;
      return(NASD_FAIL);
    }
    sent += ret;
    this_sent += ret;
    if (this_sent < want) {
      /* Great. Figure out where we are. */
      for(;f<nio;f++) {
        if (ret < iov[f].iov_len)
          break;
        ret -= iov[f].iov_len;
      }
      NASD_ASSERT(f < nio);
      if (ret) {
        iov[f].iov_len -= ret;
        tmp = (char *)iov[f].iov_base;
        iov[f].iov_base = &tmp[ret];
      }
    }
  }
  *bytes_sent = sent;

#else

  /*
   * This avoids using writev()- it does the obvious
   * iteration of write() through the memvec list.
   */

  *bytes_sent = 0;
  for(vec=memvec;vec;vec=vec->next) {
    if (vec->len == 0)
      continue;
    if (vec->len < 0) {
      *bytes_sent = sent;
      return(NASD_MEM_LIST_ERR);
    }
    tmp = (char *)vec->buf;
    for(sent=0;sent<vec->len;) {
      NASD_SRPC_INC_COUNTER(&nasd_srpc_stats.sock_write_calls);
      ret = nasd_srpc_sys_sock_write(sock, &tmp[sent],
        vec->len-sent);
      if (ret <= 0) {
        return(NASD_FAIL);
      }
      sent += ret;
      *bytes_sent += ret;
    }
  }

#endif

  return(NASD_SUCCESS);
}

nasd_status_t
nasd_srpc_sys_sock_recv(
  nasd_srpc_sock_t  *sock,
  void              *buf,
  int                len,
  int               *bytes_received,
  int                flags)
{
  int ret, got, blocking;
  char *tmp = (char *) buf;
  nasd_status_t rc = NASD_SUCCESS;

  NASD_ASSERT(sock->sock.linux_sock != NULL);
  NASD_SRPC_INC_COUNTER(&nasd_srpc_stats.sock_recv_calls);

  if (flags & NASD_SRPC_RECV_NOWAIT) { blocking = 0; }
  else                               { blocking = 1; }

  if (flags & NASD_SRPC_RECV_FILL) {
    for (got = 0; got < len; got += ret) {
      NASD_SRPC_INC_COUNTER(&nasd_srpc_stats.sock_read_calls);
      ret = nasd_srpc_sys_sock_read(sock, &tmp[got], len-got, blocking);
      if (ret <= 0) {
        if ((ret == -EWOULDBLOCK) || (ret == -EAGAIN)) {
          NASD_ASSERT(blocking == 0);
          if (got) { blocking = 1; continue; }
          rc = NASD_SUCCESS;
        } else if (ret == -ERESTARTSYS) { rc = NASD_AGAIN;
        } else                          { rc = NASD_FAIL;  }
        ret = got; goto out;
      }
    }

    NASD_ASSERT(got == len);
    ret = got;
  } else {
    NASD_SRPC_INC_COUNTER(&nasd_srpc_stats.sock_read_calls);
    ret = nasd_srpc_sys_sock_read(sock, buf, len, blocking);
    if (ret <= 0) {
      if ((ret == -EWOULDBLOCK) || (ret == -EAGAIN)) {
        NASD_ASSERT(blocking == 0);
        rc = NASD_SUCCESS;
      } else if (ret == -ERESTARTSYS) { rc = NASD_AGAIN;
      } else                          { rc = NASD_FAIL;  }
      ret = 0; goto out;
    }
  }

  rc = NASD_SUCCESS;

out:
  *bytes_received = ret;
  return rc;
}

nasd_status_t
nasd_srpc_sys_sock_destroy(
  nasd_srpc_sock_t  *sock)
{
  nasd_srpc_sys_do_sockclose(sock);
  return(NASD_SUCCESS);
}

/* Local Variables:  */
/* indent-tabs-mode: nil */
/* tab-width: 2 */
/* End: */
