#include "config.h"

#include "handshake.h"

#include <cassert>
#include <cstdio>

#include "manager.h"
#include "download/download_main.h"
#include "net/throttle_list.h"
#include "protocol/extensions.h"
#include "protocol/handshake_manager.h"
#include "torrent/download_info.h"
#include "torrent/error.h"
#include "torrent/exceptions.h"
#include "torrent/poll.h"
#include "torrent/throttle.h"
#include "torrent/net/fd.h"
#include "torrent/net/network_config.h"
#include "torrent/tracker/dht_controller.h"
#include "torrent/utils/log.h"
#include "utils/diffie_hellman.h"
#include "utils/sha1.h"

#if DISABLED__USE_EXTRA_DEBUG
#define LT_LOG_EXTRA_DEBUG_SA(sa, log_fmt, ...)                         \
  lt_log_print(LOG_CONNECTION_HANDSHAKE, "handshake->%s: " log_fmt, sap_pretty_str(m_address).c_str(), __VA_ARGS__);
#else
#define LT_LOG_EXTRA_DEBUG_SA(sa, log_fmt, ...)
#endif

namespace {
class handshake_succeeded : public torrent::network_error {
};

class handshake_error : public torrent::network_error {
public:
  handshake_error(int type, int error) : m_type(type), m_error(error) {}

  const char* what() const noexcept override  { return "Handshake error"; }
  int         type() const noexcept           { return m_type; }
  int         error() const noexcept          { return m_error; }

private:
  int     m_type;
  int     m_error;
};

} // namespace

namespace torrent {

Handshake::Handshake(int fd, HandshakeManager* m, int encryptionOptions) :
  m_manager(m),

  // Use global throttles until we know which download it is.
  m_uploadThrottle(manager->upload_throttle()->throttle_list()),
  m_downloadThrottle(manager->download_throttle()->throttle_list()),

  m_encryption(encryptionOptions),
  m_extensions(torrent::HandshakeManager::default_extensions()) {

  set_fd(SocketFd(fd));

  m_readBuffer.reset();
  m_writeBuffer.reset();

  m_task_timeout.slot() = [this, m] { m->receive_timeout(this); };
}

Handshake::~Handshake() {
  assert(!m_task_timeout.is_scheduled());
  assert(!get_fd().is_valid());

  m_encryption.cleanup();
}

void
Handshake::initialize_incoming(const sockaddr* sa) {
  m_incoming = true;
  m_address = sa_copy(sa);

  if (m_encryption.options() & (net::NetworkConfig::encryption_allow_incoming | net::NetworkConfig::encryption_require))
    m_state = READ_ENC_KEY;
  else
    m_state = READ_INFO;

  this_thread::poll()->open(this);
  this_thread::poll()->insert_read(this);
  this_thread::poll()->insert_error(this);

  // Use lower timeout here.
  this_thread::scheduler()->wait_for_ceil_seconds(&m_task_timeout, 60s);
}

void
Handshake::initialize_outgoing(const sockaddr* sa, DownloadMain* d, PeerInfo* peerInfo) {
  m_download = d;

  m_peerInfo = peerInfo;
  m_peerInfo->set_flags(PeerInfo::flag_handshake);

  m_incoming = false;
  m_address = sa_copy(sa);

  std::make_pair(m_uploadThrottle, m_downloadThrottle) = m_download->throttles(m_address.get());

  m_state = CONNECTING;

  this_thread::poll()->open(this);
  this_thread::poll()->insert_write(this);
  this_thread::poll()->insert_error(this);

  this_thread::scheduler()->wait_for_ceil_seconds(&m_task_timeout, 60s);
}

void
Handshake::deactivate_connection() {
  if (!get_fd().is_valid())
    throw internal_error("Handshake::deactivate_connection called but m_fd is not open.");

  m_state = INACTIVE;

  this_thread::scheduler()->erase(&m_task_timeout);
  this_thread::poll()->remove_and_close(this);
}

void
Handshake::release_connection() {
  if (!get_fd().is_valid())
    throw internal_error("Handshake::release_connection called but m_fd is not open.");

  m_peerInfo->unset_flags(PeerInfo::flag_handshake);
  m_peerInfo = NULL;

  get_fd().clear();
}

void
Handshake::destroy_connection() {
  if (!get_fd().is_valid())
    throw internal_error("Handshake::destroy_connection called but m_fd is not open.");

  manager->connection_manager()->dec_socket_count();

  get_fd().close();
  get_fd().clear();

  if (m_peerInfo == NULL)
    return;

  m_download->peer_list()->disconnected(m_peerInfo, 0);

  m_peerInfo->unset_flags(PeerInfo::flag_handshake);
  m_peerInfo = NULL;

  if (!m_extensions->is_default()) {
    m_extensions->cleanup();
    delete m_extensions;
  }
}

int
Handshake::retry_options() {
  uint32_t options = m_encryption.options() & ~net::NetworkConfig::encryption_enable_retry;

  if (m_encryption.retry() == HandshakeEncryption::RETRY_PLAIN)
    options &= ~net::NetworkConfig::encryption_try_outgoing;
  else if (m_encryption.retry() == HandshakeEncryption::RETRY_ENCRYPTED)
    options |= net::NetworkConfig::encryption_try_outgoing;
  else
    throw internal_error("Invalid retry type.");

  return options;
}

inline uint32_t
Handshake::read_unthrottled(void* buf, uint32_t length) {
  return m_downloadThrottle->node_used_unthrottled(read_stream_throws(buf, length));
}

inline uint32_t
Handshake::write_unthrottled(const void* buf, uint32_t length) {
  return m_uploadThrottle->node_used_unthrottled(write_stream_throws(buf, length));
}

// Handshake::read_proxy_connect()
// Entry: 0, 0
// * 0, [0, 508>
bool
Handshake::read_proxy_connect() {
  // Being greedy for now.
  m_readBuffer.move_end(read_unthrottled(m_readBuffer.end(), 512));

  const char* pattern = "\r\n\r\n";
  const unsigned int patternLength = 4;

  if (m_readBuffer.remaining() < patternLength)
    return false;

  auto itr = std::search(m_readBuffer.begin(), m_readBuffer.end(),
                                     reinterpret_cast<const uint8_t*>(pattern), reinterpret_cast<const uint8_t*>(pattern) + patternLength);

  m_readBuffer.set_position_itr(itr != m_readBuffer.end() ? (itr + patternLength) : (itr - patternLength));
  m_readBuffer.move_unused();

  return itr != m_readBuffer.end();
}

// Handshake::read_encryption_key()
// Entry: * 0, [0, 508>
// IU 20, [20, enc_pad_read_size>
// *E 96, [96, enc_pad_read_size>
bool
Handshake::read_encryption_key() {
  if (m_incoming) {
    if (m_readBuffer.remaining() < 20)
      m_readBuffer.move_end(read_unthrottled(m_readBuffer.end(), 20 - m_readBuffer.remaining()));

    if (m_readBuffer.remaining() < 20)
      return false;

    if (m_readBuffer.peek_8() == 19 && std::memcmp(m_readBuffer.position() + 1, m_protocol, 19) == 0) {
      // got unencrypted BT handshake
      if (m_encryption.options() & net::NetworkConfig::encryption_require)
        throw handshake_error(ConnectionManager::handshake_dropped, e_handshake_unencrypted_rejected);

      m_state = READ_INFO;
      return true;
    }
  }

  // Read as much of key, pad and sync string as we can; this is safe
  // because peer can't send anything beyond the initial BT handshake
  // because it doesn't know our encryption choice yet.
  if (m_readBuffer.remaining() < enc_pad_read_size)
    m_readBuffer.move_end(read_unthrottled(m_readBuffer.end(), enc_pad_read_size - m_readBuffer.remaining()));

  // but we need at least the key at this point
  if (m_readBuffer.size_end() < 96)
    return false;

  // If the handshake fails after this, it wasn't because the peer
  // doesn't like encrypted connections, so don't retry unencrypted.
  m_encryption.set_retry(HandshakeEncryption::RETRY_NONE);

  if (m_incoming)
    prepare_key_plus_pad();

  if(!m_encryption.key()->compute_secret(m_readBuffer.position(), 96))
    throw handshake_error(ConnectionManager::handshake_failed, e_handshake_invalid_encryption);
  m_readBuffer.consume(96);

  // Determine the synchronisation string.
  if (m_incoming)
    m_encryption.hash_req1_to_sync();
  else
    m_encryption.encrypt_vc_to_sync(m_download->info()->hash().c_str());

  // also put as much as we can write so far in the buffer
  if (!m_incoming)
    prepare_enc_negotiation();

  m_state = READ_ENC_SYNC;
  return true;
}

// Handshake::read_encryption_sync()
// *E 96, [96, enc_pad_read_size>
bool
Handshake::read_encryption_sync() {
  // Check if we've read the sync string already in the previous
  // state. This is very likely and avoids an unneeded read.
  auto itr = std::search(m_readBuffer.position(), m_readBuffer.end(),
                                     reinterpret_cast<const uint8_t*>(m_encryption.sync()), reinterpret_cast<const uint8_t*>(m_encryption.sync()) + m_encryption.sync_length());

  if (itr == m_readBuffer.end()) {
    // Otherwise read as many bytes as possible until we find the sync
    // string.
    int toRead = enc_pad_size + m_encryption.sync_length() - m_readBuffer.remaining();

    if (toRead <= 0)
      throw handshake_error(ConnectionManager::handshake_failed, e_handshake_encryption_sync_failed);

    m_readBuffer.move_end(read_unthrottled(m_readBuffer.end(), toRead));

    itr = std::search(m_readBuffer.position(), m_readBuffer.end(),
                      reinterpret_cast<const uint8_t*>(m_encryption.sync()), reinterpret_cast<const uint8_t*>(m_encryption.sync()) + m_encryption.sync_length());

    if (itr == m_readBuffer.end())
      return false;
  }

  if (m_incoming) {
    // We've found HASH('req1' + S), skip that and go on reading the
    // SKEY hash.
    m_readBuffer.consume(std::distance(m_readBuffer.position(), itr) + 20);
    m_state = READ_ENC_SKEY;

  } else {
    m_readBuffer.consume(std::distance(m_readBuffer.position(), itr));
    m_state = READ_ENC_NEGOT;
  }

  return true;
}

bool
Handshake::read_encryption_skey() {
  LT_LOG_EXTRA_DEBUG_SA(m_address, "read_encryption_skey", 0)

  if (!fill_read_buffer(20))
    return false;

  m_encryption.deobfuscate_hash(reinterpret_cast<char*>(m_readBuffer.position()));
  m_download = m_manager->download_info_obfuscated(reinterpret_cast<char*>(m_readBuffer.position()));
  m_readBuffer.consume(20);

  validate_download();

  // We don't allow encrypted connections for meta-data downloads.
  if (m_download->info()->is_meta_download())
    throw handshake_error(ConnectionManager::handshake_dropped, e_handshake_invalid_encryption);

  std::make_pair(m_uploadThrottle, m_downloadThrottle) = m_download->throttles(m_address.get());

  m_encryption.initialize_encrypt(m_download->info()->hash().c_str(), m_incoming);
  m_encryption.initialize_decrypt(m_download->info()->hash().c_str(), m_incoming);

  m_encryption.info()->decrypt(m_readBuffer.position(), m_readBuffer.remaining());

  HandshakeEncryption::copy_vc(m_writeBuffer.end());
  m_encryption.info()->encrypt(m_writeBuffer.end(), HandshakeEncryption::vc_length);
  m_writeBuffer.move_end(HandshakeEncryption::vc_length);

  m_state = READ_ENC_NEGOT;
  return true;
}

bool
Handshake::read_encryption_negotiation() {
  LT_LOG_EXTRA_DEBUG_SA(m_address, "read_encryption_negotiation", 0)

  if (!fill_read_buffer(enc_negotiation_size))
    return false;

  if (!m_incoming) {
    // Start decrypting, but don't decrypt beyond the initial
    // encrypted handshake and later the pad because we may have read
    // too much data which may be unencrypted if the peer chose that
    // option.
    m_encryption.initialize_decrypt(m_download->info()->hash().c_str(), m_incoming);
    m_encryption.info()->decrypt(m_readBuffer.position(), enc_negotiation_size);
  }

  if (!HandshakeEncryption::compare_vc(m_readBuffer.position()))
    throw handshake_error(ConnectionManager::handshake_failed, e_handshake_invalid_value);

  m_readBuffer.consume(HandshakeEncryption::vc_length);

  m_encryption.set_crypto(m_readBuffer.read_32());
  m_readPos = m_readBuffer.read_16();       // length of padC/padD

  if (m_readPos > enc_pad_size)
    throw handshake_error(ConnectionManager::handshake_failed, e_handshake_invalid_value);

  // choose one of the offered encryptions, or check the chosen one is valid
  if (m_incoming) {
    if ((m_encryption.options() & net::NetworkConfig::encryption_prefer_plaintext) && m_encryption.has_crypto_plain()) {
      m_encryption.set_crypto(HandshakeEncryption::crypto_plain);

    } else if ((m_encryption.options() & net::NetworkConfig::encryption_require_RC4) && !m_encryption.has_crypto_rc4()) {
      throw handshake_error(ConnectionManager::handshake_dropped, e_handshake_unencrypted_rejected);

    } else if (m_encryption.has_crypto_rc4()) {
      m_encryption.set_crypto(HandshakeEncryption::crypto_rc4);

    } else if (m_encryption.has_crypto_plain()) {
      m_encryption.set_crypto(HandshakeEncryption::crypto_plain);

    } else {
      throw handshake_error(ConnectionManager::handshake_failed, e_handshake_invalid_encryption);
    }

    // at this point we can also write the rest of our negotiation reply
    m_writeBuffer.write_32(m_encryption.crypto());
    m_writeBuffer.write_16(0);
    m_encryption.info()->encrypt(m_writeBuffer.end() - 4 - 2, 4 + 2);

  } else {
    if (m_encryption.crypto() != HandshakeEncryption::crypto_rc4 && m_encryption.crypto() != HandshakeEncryption::crypto_plain)
      throw handshake_error(ConnectionManager::handshake_failed, e_handshake_invalid_encryption);

    if ((m_encryption.options() & net::NetworkConfig::encryption_require_RC4) && (m_encryption.crypto() != HandshakeEncryption::crypto_rc4))
      throw handshake_error(ConnectionManager::handshake_failed, e_handshake_invalid_encryption);
  }

  if (!m_incoming) {
    // decrypt appropriate part of buffer: only pad or all
    if (m_encryption.crypto() == HandshakeEncryption::crypto_plain)
      m_encryption.info()->decrypt(m_readBuffer.position(), std::min<uint32_t>(m_readPos, m_readBuffer.remaining()));
    else
      m_encryption.info()->decrypt(m_readBuffer.position(), m_readBuffer.remaining());
  }

  // next, skip padC/padD
  m_state = READ_ENC_PAD;
  return true;
}

bool
Handshake::read_negotiation_reply() {
  if (!m_incoming) {
    if (m_encryption.crypto() != HandshakeEncryption::crypto_rc4)
      m_encryption.info()->set_obfuscated();

    m_state = READ_INFO;
    return true;
  }

  LT_LOG_EXTRA_DEBUG_SA(m_address, "read_negotiation_reply", 0)

  if (!fill_read_buffer(2))
    return false;

  // The peer may send initial payload that is RC4 encrypted even if
  // we have selected plaintext encryption, so read it ahead of BT
  // handshake.
  m_encryption.set_length_ia(m_readBuffer.read_16());

  if (m_encryption.length_ia() > handshake_size)
    throw handshake_error(ConnectionManager::handshake_failed, e_handshake_invalid_value);

  m_state = READ_ENC_IA;

  return true;
}

bool
Handshake::read_info() {
  LT_LOG_EXTRA_DEBUG_SA(m_address, "read_info", 0)

  fill_read_buffer(handshake_size);

  // Check the first byte as early as possible so we can
  // disconnect non-BT connections if they send less than 20 bytes.
  if ((m_readBuffer.remaining() >= 1 && m_readBuffer.peek_8() != 19) ||
      (m_readBuffer.remaining() >= 20 &&
       (std::memcmp(m_readBuffer.position() + 1, m_protocol, 19) != 0)))
    throw handshake_error(ConnectionManager::handshake_failed, e_handshake_not_bittorrent);

  if (m_readBuffer.remaining() < part1_size)
    return false;

  // If the handshake fails after this, it isn't being rejected because
  // it is unencrypted, so don't retry.
  m_encryption.set_retry(HandshakeEncryption::RETRY_NONE);

  m_readBuffer.consume(20);

  // Should do some option field stuff here, for now just copy.
  m_readBuffer.read_range(m_options, m_options + 8);

  // Check the info hash.
  if (m_incoming) {
    if (m_download != NULL) {
      // Have the download from the encrypted handshake, make sure it
      // matches the BT handshake.
      if (m_download->info()->hash().not_equal_to(reinterpret_cast<char*>(m_readBuffer.position())))
        throw handshake_error(ConnectionManager::handshake_failed, e_handshake_invalid_value);

    } else {
      m_download = m_manager->download_info(reinterpret_cast<char*>(m_readBuffer.position()));
    }

    validate_download();

    std::make_pair(m_uploadThrottle, m_downloadThrottle) = m_download->throttles(m_address.get());

    prepare_handshake();

  } else {
    if (m_download->info()->hash().not_equal_to(reinterpret_cast<char*>(m_readBuffer.position())))
      throw handshake_error(ConnectionManager::handshake_failed, e_handshake_invalid_value);
  }

  m_readBuffer.consume(20);
  m_state = READ_PEER;

  return true;
}

bool
Handshake::read_peer() {
  LT_LOG_EXTRA_DEBUG_SA(m_address, "read_peer", 0)

  if (!fill_read_buffer(20))
    return false;

  prepare_peer_info();

  // Send EXTENSION_PROTOCOL handshake message if peer supports it.
  if (m_peerInfo->supports_extensions())
    write_extension_handshake();

  // Replay HAVE messages we receive after starting to send the bitfield.
  // This avoids replaying HAVEs for pieces received between starting the
  // handshake and now (e.g. when connecting takes longer). Ideally we
  // should make a snapshot of the bitfield here in case it changes while
  // we're sending it (if it can't be sent in one write() call).
  m_initialized_time = this_thread::cached_time();

  // The download is just starting so we're not sending any
  // bitfield. Pretend we wrote it already.
  if (m_download->file_list()->bitfield()->is_all_unset() || m_download->initial_seeding() != NULL) {
    m_writePos = m_download->file_list()->bitfield()->size_bytes();
    m_writeBuffer.write_32(0);

    if (m_encryption.info()->is_encrypted())
      m_encryption.info()->encrypt(m_writeBuffer.end() - 4, 4);

  } else {
    prepare_bitfield();
  }

  m_state = READ_MESSAGE;
  this_thread::poll()->insert_write(this);

  // Give some extra time for reading/writing the bitfield.
  this_thread::scheduler()->update_wait_for_ceil_seconds(&m_task_timeout, 120s);

  return true;
}

bool
Handshake::read_bitfield() {
  LT_LOG_EXTRA_DEBUG_SA(m_address, "read_bitfield: size:%" PRIu32, m_bitfield.size_bytes());

  if (m_readPos < m_bitfield.size_bytes()) {
    uint32_t length = read_unthrottled(m_bitfield.begin() + m_readPos, m_bitfield.size_bytes() - m_readPos);

    if (m_encryption.info()->decrypt_valid())
      m_encryption.info()->decrypt(m_bitfield.begin() + m_readPos, length);

    m_readPos += length;
  }

  return m_readPos == m_bitfield.size_bytes();
}

bool
Handshake::read_extension() {
  if (m_readBuffer.peek_32() > m_readBuffer.reserved())
    throw handshake_error(ConnectionManager::handshake_failed, e_handshake_invalid_value);

  int32_t need = m_readBuffer.peek_32() + 4 - m_readBuffer.remaining();

  // We currently can't handle an extension handshake that doesn't
  // completely fit in the buffer.  However these messages are usually
  // ~100 bytes large and the buffer holds over 1000 bytes so it
  // should be ok. Else maybe figure out how to disable extensions for
  // when peer connects next time.
  //
  // In addition, make sure there's at least 5 bytes available after
  // the PEX message has been read, so that we can fit the preamble of
  // the BITFIELD message.
  if (need + 5 > m_readBuffer.reserved_left()) {
    m_readBuffer.move_unused();

    if (need + 5 > m_readBuffer.reserved_left())
      throw handshake_error(ConnectionManager::handshake_failed, e_handshake_invalid_value);
  }

  LT_LOG_EXTRA_DEBUG_SA(m_address, "read_extension", 0)

  if (!fill_read_buffer(m_readBuffer.peek_32() + 4))
    return false;

  uint32_t length = m_readBuffer.read_32() - 2;
  m_readBuffer.read_8();
  m_extensions->read_start(m_readBuffer.read_8(), length, false);

  std::memcpy(m_extensions->read_position(), m_readBuffer.position(), length);
  m_extensions->read_move(length);

  // Does this check need to check if it is a handshake we read?
  if (!m_extensions->is_complete())
    throw internal_error("Could not read extension handshake even though it should be in the read buffer.");

  m_extensions->read_done();
  m_readBuffer.consume(length);
  return true;
}

bool
Handshake::read_port() {
  if (m_readBuffer.peek_32() > m_readBuffer.reserved())
    throw handshake_error(ConnectionManager::handshake_failed, e_handshake_invalid_value);

  int32_t need = m_readBuffer.peek_32() + 4 - m_readBuffer.remaining();

  if (need + 5 > m_readBuffer.reserved_left()) {
    m_readBuffer.move_unused();

    if (need + 5 > m_readBuffer.reserved_left())
      throw handshake_error(ConnectionManager::handshake_failed, e_handshake_invalid_value);
  }

  LT_LOG_EXTRA_DEBUG_SA(m_address, "read_port", 0)

  if (!fill_read_buffer(m_readBuffer.peek_32() + 4))
    return false;

  uint32_t length = m_readBuffer.read_32() - 1;
  m_readBuffer.read_8();

  if (length == 2)
    manager->dht_controller()->add_node(m_address.get(), m_readBuffer.peek_16());

  m_readBuffer.consume(length);
  return true;
}

void
Handshake::read_done() {
  if (m_readDone != false)
    throw internal_error("Handshake::read_done() m_readDone != false.");

//   if (m_peerInfo->supports_extensions() && m_extensions->is_initial_handshake())
//     throw handshake_error(ConnectionManager::handshake_failed, e_handshake_invalid_order);

  m_readDone = true;
  this_thread::poll()->remove_read(this);

  if (m_bitfield.empty()) {
    m_bitfield.set_size_bits(m_download->file_list()->bitfield()->size_bits());
    m_bitfield.allocate();
    m_bitfield.unset_all();

  } else {
    m_bitfield.update();
  }

  // Should've started to write post handshake data already, but we were
  // still reading the bitfield/extension and postponed it. If we had no
  // bitfield to send, we need to send a keep-alive now.
  if (m_writePos == m_download->file_list()->bitfield()->size_bytes())
    prepare_post_handshake(m_download->file_list()->bitfield()->is_all_unset() || m_download->initial_seeding() != NULL);

  if (m_writeDone)
    throw handshake_succeeded();
}

void
Handshake::event_read() {
  try {

restart:
    switch (m_state) {
    case PROXY_CONNECT:
      if (!read_proxy_connect())
        break;

      m_state = PROXY_DONE;

      this_thread::poll()->insert_write(this);
      return event_write();

    case READ_ENC_KEY:
      if (!read_encryption_key())
        break;

      if (m_state != READ_ENC_SYNC)
        goto restart;

      [[fallthrough]];
    case READ_ENC_SYNC:
      if (!read_encryption_sync())
        break;

      if (m_state != READ_ENC_SKEY)
        goto restart;

      [[fallthrough]];
    case READ_ENC_SKEY:
      if (!read_encryption_skey())
        break;

      [[fallthrough]];
    case READ_ENC_NEGOT:
      if (!read_encryption_negotiation())
        break;

      if (m_state != READ_ENC_PAD)
        goto restart;

      [[fallthrough]];
    case READ_ENC_PAD:
      if (m_readPos) {
        LT_LOG_EXTRA_DEBUG_SA(m_address, "event_read : READ_ENC_PAD : m_readPos:%" PRIu32, m_readPos)

        // Read padC + lenIA or padD; pad length in m_readPos.
        if (!fill_read_buffer(m_readPos + (m_incoming ? 2 : 0)))
          // This can be improved (consume as much as was read)
          break;

        m_readBuffer.consume(m_readPos);
        m_readPos = 0;
      }

      if (!read_negotiation_reply())
        break;

      if (m_state != READ_ENC_IA)
        goto restart;

      [[fallthrough]];
    case READ_ENC_IA:
      LT_LOG_EXTRA_DEBUG_SA(m_address, "event_read : READ_ENC_IA", 0)

      // Just read (and automatically decrypt) the initial payload
      // and leave it in the buffer for READ_INFO later.
      if (m_encryption.length_ia() > 0 && !fill_read_buffer(m_encryption.length_ia()))
        break;

      if (m_readBuffer.remaining() > m_encryption.length_ia())
        throw handshake_error(ConnectionManager::handshake_failed, e_handshake_invalid_value);

      if (m_encryption.crypto() != HandshakeEncryption::crypto_rc4)
        m_encryption.info()->set_obfuscated();

      m_state = READ_INFO;

      [[fallthrough]];
    case READ_INFO:
      if (!read_info())
        break;

      if (m_state != READ_PEER)
        goto restart;

      [[fallthrough]];
    case READ_PEER:
      if (!read_peer())
        break;

      // Is this correct?
      if (m_state != READ_MESSAGE)
        goto restart;

      [[fallthrough]];
    case READ_MESSAGE:
    case POST_HANDSHAKE:
      // For meta-downloads, we aren't interested in the bitfield or
      // extension messages here, PCMetadata handles all that. The
      // bitfield only refers to the single-chunk meta-data, so fake that.
      if (m_download->info()->is_meta_download()) {
        m_bitfield.set_size_bits(1);
        m_bitfield.allocate();
        m_bitfield.set(0);
        read_done();
        break;
      }

      LT_LOG_EXTRA_DEBUG_SA(m_address, "event_read : READ_MESSAGE", 0);

      if (m_readBuffer.reserved_left() < 5)
        m_readBuffer.move_unused();

      fill_read_buffer(5);

      // Received a keep-alive message which means we won't be
      // getting any bitfield.
      if (m_readBuffer.remaining() >= 4 && m_readBuffer.peek_32() == 0) {
        m_readBuffer.read_32();
        read_done();
        break;
      }

      if (m_readBuffer.remaining() < 5)
        break;

      m_readPos = 0;

      // Extension handshake was sent after BT handshake but before
      // bitfield, so handle that. If we've already received a message
      // of this type then we will assume the peer won't be sending a
      // bitfield, as the second extension message will be part of the
      // normal traffic, not the handshake.
      if (m_readBuffer.peek_8_at(4) == protocol_bitfield) {
        const Bitfield* bitfield = m_download->file_list()->bitfield();

        if (!m_bitfield.empty() || m_readBuffer.read_32() != bitfield->size_bytes() + 1)
          throw handshake_error(ConnectionManager::handshake_failed, e_handshake_invalid_value);

        m_readBuffer.read_8();

        m_bitfield.set_size_bits(bitfield->size_bits());
        m_bitfield.allocate();

        m_readPos = std::min<uint32_t>(m_bitfield.size_bytes(), m_readBuffer.remaining());
        std::memcpy(m_bitfield.begin(), m_readBuffer.position(), m_readPos);
        m_readBuffer.consume(m_readPos);

        m_state = READ_BITFIELD;

      } else if (m_readBuffer.peek_8_at(4) == protocol_extension && m_extensions->is_initial_handshake()) {
        m_readPos = 0;
        m_state = READ_EXT;

      } else if (m_readBuffer.peek_8_at(4) == protocol_port) {
        // Some peers seem to send the port message before handshake,
        // so handle it here.
        m_readPos = 0;
        m_state = READ_PORT;

      } else {
        read_done();
        break;
      }

      [[fallthrough]];
    case READ_BITFIELD:
    case READ_EXT:
    case READ_PORT:
      // Gather the different command types into the same case group
      // so that we don't need 'goto restart' above.
      if ((m_state == READ_BITFIELD && !read_bitfield()) ||
          (m_state == READ_EXT && !read_extension()) ||
          (m_state == READ_PORT && !read_port()))
        break;

      m_state = READ_MESSAGE;

      if (!m_bitfield.empty() && (!m_peerInfo->supports_extensions() || !m_extensions->is_initial_handshake())) {
        read_done();
        break;
      }

      goto restart;

    default:
      throw internal_error("Handshake::event_read() called in invalid state.");
    }

    // Call event_write if we have any data to write. Make sure
    // event_write() doesn't get called twice in this function.
    if (m_writeBuffer.remaining() && !this_thread::poll()->in_write(this)) {
      this_thread::poll()->insert_write(this);
      return event_write();
    }

} catch (const handshake_succeeded&) {
  m_manager->receive_succeeded(this);

} catch (const handshake_error& e) {
  m_manager->receive_failed(this, e.type(), e.error());

} catch (const network_error&) {
  m_manager->receive_failed(this, ConnectionManager::handshake_failed, e_handshake_network_read_error);
}
}

bool
Handshake::fill_read_buffer(int size) {
  LT_LOG_EXTRA_DEBUG_SA(m_address, "fill_read_buffer : size:%i remaining:%" PRIu16 " reserved_left:%" PRIu16,
                        size, m_readBuffer.remaining(), m_readBuffer.reserved_left())

  if (m_readBuffer.remaining() < size) {
    if (size - m_readBuffer.remaining() > m_readBuffer.reserved_left())
      throw internal_error("Handshake::fill_read_buffer(...) Buffer overflow.");

    int read = m_readBuffer.move_end(read_unthrottled(m_readBuffer.end(), size - m_readBuffer.remaining()));

    if (m_encryption.info()->decrypt_valid())
      m_encryption.info()->decrypt(m_readBuffer.end() - read, read);
  }

  return m_readBuffer.remaining() >= size;
}

inline void
Handshake::validate_download() {
  if (m_download == NULL)
    throw handshake_error(ConnectionManager::handshake_dropped, e_handshake_unknown_download);
  if (!m_download->info()->is_active())
    throw handshake_error(ConnectionManager::handshake_dropped, e_handshake_inactive_download);
  if (!m_download->info()->is_accepting_new_peers())
    throw handshake_error(ConnectionManager::handshake_dropped, e_handshake_not_accepting_connections);
}

void
Handshake::event_write() {
  int socket_error;

  try {

    switch (m_state) {
    case CONNECTING:
      if (!fd_get_socket_error(m_fileDesc, &socket_error))
        throw internal_error("Handshake::event_write() fd_get_socket_error failed : " + std::string(strerror(errno)));

      if (socket_error != 0)
        throw handshake_error(ConnectionManager::handshake_failed, e_handshake_network_unreachable);

      this_thread::poll()->insert_read(this);

      if (m_encryption.options() & net::NetworkConfig::encryption_use_proxy) {
        prepare_proxy_connect();

        m_state = PROXY_CONNECT;
        break;
      }

      [[fallthrough]];
    case PROXY_DONE:
      // If there's any bytes remaining, it means we got a reply from
      // the other side before our proxy connect command was finished
      // written. This probably means the other side isn't a proxy.
      if (m_writeBuffer.remaining())
        throw handshake_error(ConnectionManager::handshake_failed, e_handshake_not_bittorrent);

      m_writeBuffer.reset();

      if (m_encryption.options() & (net::NetworkConfig::encryption_try_outgoing | net::NetworkConfig::encryption_require)) {
        prepare_key_plus_pad();

        // if connection fails, peer probably closed it because it was encrypted, so retry encrypted if enabled
        if (!(m_encryption.options() & net::NetworkConfig::encryption_require))
          m_encryption.set_retry(HandshakeEncryption::RETRY_PLAIN);

        m_state = READ_ENC_KEY;

      } else {
        // if connection is closed before we read the handshake, it might
        // be rejected because it is unencrypted, in that case retry encrypted
        m_encryption.set_retry(HandshakeEncryption::RETRY_ENCRYPTED);

        prepare_handshake();

        if (m_incoming)
          m_state = READ_PEER;
        else
          m_state = READ_INFO;
      }

      break;

    case READ_MESSAGE:
    case READ_BITFIELD:
    case READ_EXT:
      write_bitfield();
      return;

    default:
      break;
    }

    if (!m_writeBuffer.remaining())
      throw internal_error("event_write called with empty write buffer.");

    if (m_writeBuffer.consume(write_unthrottled(m_writeBuffer.position(), m_writeBuffer.remaining()))) {
      if (m_state == POST_HANDSHAKE)
        write_done();
      else
        this_thread::poll()->remove_write(this);
    }

  } catch (const handshake_succeeded&) {
    m_manager->receive_succeeded(this);

  } catch (const handshake_error& e) {
    m_manager->receive_failed(this, e.type(), e.error());

  } catch (const network_error&) {
    m_manager->receive_failed(this, ConnectionManager::handshake_failed, e_handshake_network_write_error);
  }
}

void
Handshake::prepare_proxy_connect() {
  int advance = snprintf(reinterpret_cast<char*>(m_writeBuffer.position()), m_writeBuffer.reserved_left(),
                         "CONNECT %s:%hu HTTP/1.0\r\n\r\n", sap_addr_str(m_address).c_str(), sap_port(m_address));

  if (advance == -1 || advance > m_writeBuffer.reserved_left())
    throw internal_error("Handshake::prepare_proxy_connect() snprintf failed.");

  m_writeBuffer.move_end(advance);
}

void
Handshake::prepare_key_plus_pad() {
  if (!m_encryption.initialize())
    throw handshake_error(ConnectionManager::handshake_failed, e_handshake_invalid_value);

  m_encryption.key()->store_pub_key(m_writeBuffer.end(), 96);
  m_writeBuffer.move_end(96);

  const int length = random() % enc_pad_size;
  auto pad = std::make_unique<char[]>(length);

  std::generate_n(pad.get(), length, &::random);
  m_writeBuffer.write_len(pad.get(), length);
}

void
Handshake::prepare_enc_negotiation() {
  char hash[20];

  // first piece, HASH('req1' + S)
  sha1_salt("req1", 4, m_encryption.key()->c_str(), m_encryption.key()->size(), m_writeBuffer.end());
  m_writeBuffer.move_end(20);

  // second piece, HASH('req2' + SKEY) ^ HASH('req3' + S)
  m_writeBuffer.write_len(m_download->info()->hash_obfuscated().c_str(), 20);
  sha1_salt("req3", 4, m_encryption.key()->c_str(), m_encryption.key()->size(), hash);

  for (int i = 0; i < 20; i++)
    m_writeBuffer.end()[i - 20] ^= hash[i];

  // last piece, ENCRYPT(VC, crypto_provide, len(PadC), PadC, len(IA))
  m_encryption.initialize_encrypt(m_download->info()->hash().c_str(), m_incoming);

  Buffer::iterator old_end = m_writeBuffer.end();

  HandshakeEncryption::copy_vc(m_writeBuffer.end());
  m_writeBuffer.move_end(HandshakeEncryption::vc_length);

  if (m_encryption.options() & net::NetworkConfig::encryption_require_RC4)
    m_writeBuffer.write_32(HandshakeEncryption::crypto_rc4);
  else
    m_writeBuffer.write_32(HandshakeEncryption::crypto_plain | HandshakeEncryption::crypto_rc4);

  m_writeBuffer.write_16(0);
  m_writeBuffer.write_16(handshake_size);
  m_encryption.info()->encrypt(old_end, m_writeBuffer.end() - old_end);

  // write and encrypt BT handshake as initial payload IA
  prepare_handshake();
}

void
Handshake::prepare_handshake() {
  m_writeBuffer.write_8(19);
  m_writeBuffer.write_range(m_protocol, m_protocol + 19);

  std::memset(m_writeBuffer.end(), 0, 8);
  *(m_writeBuffer.end()+5) |= 0x10;    // support extension protocol
  if (manager->dht_controller()->is_active())
    *(m_writeBuffer.end()+7) |= 0x01;  // DHT support, enable PORT message
  m_writeBuffer.move_end(8);

  m_writeBuffer.write_range(m_download->info()->hash().c_str(), m_download->info()->hash().c_str() + 20);
  m_writeBuffer.write_range(m_download->info()->local_id().c_str(), m_download->info()->local_id().c_str() + 20);

  if (m_encryption.info()->is_encrypted())
    m_encryption.info()->encrypt(m_writeBuffer.end() - handshake_size, handshake_size);
}

void
Handshake::prepare_peer_info() {
  if (std::memcmp(m_readBuffer.position(), m_download->info()->local_id().c_str(), 20) == 0)
    throw handshake_error(ConnectionManager::handshake_failed, e_handshake_is_self);

  // PeerInfo handling for outgoing connections needs to be moved to
  // HandshakeManager.
  if (m_peerInfo == NULL) {
    if (!m_incoming)
      throw internal_error("Handshake::prepare_peer_info() !m_incoming.");

    m_peerInfo = m_download->peer_list()->connected(m_address.get(), PeerList::connect_incoming);

    if (m_peerInfo == NULL)
      throw handshake_error(ConnectionManager::handshake_failed, e_handshake_no_peer_info);

    if (m_peerInfo->failed_counter() > torrent::HandshakeManager::max_failed)
      throw handshake_error(ConnectionManager::handshake_dropped, e_handshake_toomanyfailed);

    m_peerInfo->set_flags(PeerInfo::flag_handshake);
  }

  std::memcpy(m_peerInfo->set_options(), m_options, 8);

  m_peerInfo->mutable_id().assign(reinterpret_cast<const char*>(m_readBuffer.position()));
  m_readBuffer.consume(20);

  hash_string_to_hex(m_peerInfo->id(), m_peerInfo->mutable_id_hex());

  // For meta downloads, we require support of the extension protocol.
  if (m_download->info()->is_meta_download() && !m_peerInfo->supports_extensions())
    throw handshake_error(ConnectionManager::handshake_dropped, e_handshake_unwanted_connection);
}

void
Handshake::prepare_bitfield() {
  m_writeBuffer.write_32(m_download->file_list()->bitfield()->size_bytes() + 1);
  m_writeBuffer.write_8(protocol_bitfield);

  if (m_encryption.info()->is_encrypted())
    m_encryption.info()->encrypt(m_writeBuffer.end() - 5, 5);

  m_writePos = 0;
}

void
Handshake::prepare_post_handshake(bool must_write) {
  if (m_writePos != m_download->file_list()->bitfield()->size_bytes())
    throw internal_error("Handshake::prepare_post_handshake called while bitfield not written completely.");

  m_state = POST_HANDSHAKE;

  Buffer::iterator old_end = m_writeBuffer.end();

  // Send PORT message for DHT if enabled and peer supports it.
  if (m_peerInfo->supports_dht() &&
      manager->dht_controller()->is_active() &&
      manager->dht_controller()->is_receiving_requests()) {

    m_writeBuffer.write_32(3);
    m_writeBuffer.write_8(protocol_port);
    m_writeBuffer.write_16(manager->dht_controller()->port());

    // manager->dht_controller()->port_sent();
  }

  // Send a keep-alive if we still must send something.
  if (must_write && old_end == m_writeBuffer.end())
    m_writeBuffer.write_32(0);

  if (m_encryption.info()->is_encrypted())
    m_encryption.info()->encrypt(old_end, m_writeBuffer.end() - old_end);

  if (!m_writeBuffer.remaining())
    write_done();
}

void
Handshake::write_done() {
  m_writeDone = true;
  this_thread::poll()->remove_write(this);

  // Ok to just check m_readDone as the call in event_read() won't
  // set it before the call.
  if (m_readDone)
    throw handshake_succeeded();
}

void
Handshake::write_extension_handshake() {
  DownloadInfo* info = m_download->info();

  if (m_extensions->is_default()) {
    m_extensions = new ProtocolExtension;
    m_extensions->set_info(m_peerInfo, m_download);
  }

  // PEX may be disabled but still active if disabled since last download tick.
  if (info->is_pex_enabled() && info->is_pex_active() && info->size_pex() < info->max_size_pex())
    m_extensions->set_local_enabled(ProtocolExtension::UT_PEX);

  DataBuffer message = m_extensions->generate_handshake_message();

  m_writeBuffer.write_32(message.length() + 2);
  m_writeBuffer.write_8(protocol_extension);
  m_writeBuffer.write_8(ProtocolExtension::HANDSHAKE);
  m_writeBuffer.write_range(message.data(), message.end());

  if (m_encryption.info()->is_encrypted())
    m_encryption.info()->encrypt(m_writeBuffer.end() - message.length() - 2 - 4, message.length() + 2 + 4);

  message.clear();
}

void
Handshake::write_bitfield() {
  const Bitfield* bitfield = m_download->file_list()->bitfield();

  if (m_writeDone != false)
    throw internal_error("Handshake::event_write() m_writeDone != false.");

  if (m_writeBuffer.remaining())
    if (!m_writeBuffer.consume(write_unthrottled(m_writeBuffer.position(), m_writeBuffer.remaining())))
      return;

  if (m_writePos != bitfield->size_bytes()) {
    if (m_encryption.info()->is_encrypted()) {
      if (m_writePos == 0)
        m_writeBuffer.reset();	// this should be unnecessary now

      uint32_t length = std::min<uint32_t>(bitfield->size_bytes() - m_writePos, m_writeBuffer.reserved()) - m_writeBuffer.size_end();

      if (length > 0) {
        std::memcpy(m_writeBuffer.end(), bitfield->begin() + m_writePos + m_writeBuffer.size_end(), length);
        m_encryption.info()->encrypt(m_writeBuffer.end(), length);
        m_writeBuffer.move_end(length);
      }

      length = write_unthrottled(m_writeBuffer.begin(), m_writeBuffer.size_end());
      m_writePos += length;

      if (length != m_writeBuffer.size_end() && length > 0)
        std::memmove(m_writeBuffer.begin(), m_writeBuffer.begin() + length, m_writeBuffer.size_end() - length);

      m_writeBuffer.move_end(-length);

    } else {
      m_writePos += write_unthrottled(bitfield->begin() + m_writePos,
                                      bitfield->size_bytes() - m_writePos);
    }
  }

  // We can't call prepare_post_handshake until the read code is done reading
  // the bitfield, so if we get here before then, postpone the post handshake
  // data until reading is done. Since we're done writing, remove us from the
  // poll in that case.
  if (m_writePos == bitfield->size_bytes()) {
    if (!m_readDone)
      this_thread::poll()->remove_write(this);
    else
      prepare_post_handshake(false);
  }
}

void
Handshake::event_error() {
  if (m_state == INACTIVE)
    throw internal_error("Handshake::event_error() called on an inactive handshake.");

  m_manager->receive_failed(this, ConnectionManager::handshake_failed, e_handshake_network_socket_error);
}

} // namespace torrent
