'''record origin of repository changes

The auditlog extension records the time and responsible party for
repository changes. For local changes this is the user name according
to the `USER` environment variable. For remote changes via hgweb or
ssh, it is based on `REMOTE_USER` variable from the WSGI request or
the ``USER`` environment variable for SSH. If the
`auditlog.record-public-key` option is enabled, attempts to record TLS
client certificates or the SSH key used for the connection are made.
To record the SSH key, sshd must be running with `ExposeAuthInfo yes`.
'''

# TODO:
# localpeer exchange
# smarter exchange protocol

import hashlib
import time

from mercurial.i18n import _
from mercurial.utils import cborutil
from mercurial.revlogutils import (
    constants as revlog_constants,
)
from mercurial import (
    cmdutil,
    encoding,
    error,
    exchange,
    extensions,
    localrepo,
    logcmdutil,
    phases,
    pycompat,
    registrar,
    revlog,
    util,
    wireprototypes,
    wireprotov1server,
)

testedwith = b'ships-with-hg-core'

configtable = {}
configitem = registrar.configitem(configtable)

configitem(
    b'auditlog',
    b'record-public-key',
    default=False,
)

configitem(
    b'auditlog',
    b'client-mode',
    default=False,
)

configitem(
    b'auditlog',
    b'server-mode',
    default=False,
)

default_template = (
    b'changeset {changeset} user {user} date {date|isodate} phase {phase}'
    b'{if(oldphase, " (old: {oldphase})")} source {source}'
    b'{if(pki, " pki {pki}")}\n'
)

cmdtable = {}
command = registrar.command(cmdtable)


def capabilities(orig, repo, proto):
    caps = orig(repo, proto)
    if repo.ui.configbool(b'auditlog', b'client-mode'):
        caps.append(b'auditlog-client')
    if repo.ui.configbool(b'auditlog', b'server-mode'):
        caps.append(b'auditlog-server')
    return caps


@wireprotov1server.wireprotocommand(b'queryauditlog', permission=b'pull')
def queryauditlogwireproto(repo, proto):
    my_size, my_digest = repo.auditlog.status()
    return wireprototypes.bytesresponse(
        b'%d %s' % (my_size, pycompat.bytestr(my_digest))
    )


@wireprotov1server.wireprotocommand(b'pushauditlog', permission=b'push')
def pushauditlogwireproto(repo, proto):
    # XXX stream decoding?
    raw_records = b''.join(proto.getpayload())
    ui = repo.ui

    try:
        records = cborutil.decodeall(raw_records)
    except cborutil.CBORDecodeError:
        ui.warn(
            _(b'WARNING: incomplete or broken records obtained from peer\n')
        )
        return wireprototypes.bytesresponse(b'0')
    ui.debug(b'auditlog: received %d records\n' % len(records))

    unfi = repo.unfiltered()
    auditlog = repo.auditlog.auditlogfile()

    with repo.lock(), repo.transaction(b'pushauditlog') as tr:
        for record in records:
            if not verifyparsedrecord(ui, record):
                return wireprototypes.bytesresponse(b'0')  # XXX rollback?

            if record[b'node'] not in unfi:
                repo.ui.debug(
                    b'auditlog: skipping record for unknown node %s\n'
                    % record[b'node']
                )
                continue

            auditlog.addrevision(
                b''.join(cborutil.streamencode(record)),
                tr,
                unfi[record[b'node']].rev(),
                repo.nullid,
                repo.nullid,
            )
    return wireprototypes.bytesresponse(b'1')


@wireprotov1server.wireprotocommand(
    b'pullauditlog', b'their_size their_digest', permission=b'pull'
)
def pullauditlogwireproto(repo, proto, their_size, their_digest):
    """Return auditlog data.

    The response is a series of .

    The first line starts with "0" or "1". This indicates success. If the line
    starts with "0", an error message should follow. e.g.
    ``0 exception when accessing sqlite``.

    All subsequent lines describe individual pushes. The format of these lines
    is:

      <push ID> <author> <when> <node0> [<node1> [<node2> ...]]

    That is:

      * An integer ID for the push.
      * A string of the user (SSH username) who performed the push
      * Integer seconds since UNIX epoch when this push was performed
      * A list of 40 byte hex changesets included in the push, in revset order
    """

    their_size = int(their_size)
    auditlog = repo.auditlog.auditlogfile()
    if len(auditlog) >= their_size:
        _, my_digest = repo.auditlog.status(their_size)
        if pycompat.bytestr(my_digest) != their_digest:
            repo.ui.debug(
                b'auditlog: remote has unknown records, sending all\n'
            )
            their_size = 0
    else:
        repo.ui.debug(b'auditlog: remote has more records, sending all\n')
        their_size = 0
    return wireprototypes.bytesresponse(
        b''.join(
            [auditlog.rawdata(rev) for rev in range(their_size, len(auditlog))]
        )
    )


def exchangepushauditlog(orig, pushop):
    res = orig(pushop)
    if not pushop.repo.ui.configbool(b'auditlog', b'server-mode'):
        return res

    if b'auditlog' in pushop.stepsdone or not pushop.remote.capable(
        b'auditlog-client'
    ):
        return res
    pushop.stepsdone.add(b'auditlog')

    repo = pushop.repo
    repo.ui.debug(b'auditlog: checking remote auditlog\n')
    auditlog = repo.auditlog.auditlogfile()
    their_size, their_digest = pushop.remote._call(b'queryauditlog').split(b' ')
    their_size = int(their_size)
    my_size = len(auditlog)
    if my_size >= their_size:
        _, my_digest = repo.auditlog.status(their_size)
        if pycompat.bytestr(my_digest) != their_digest:
            repo.ui.debug(
                b'auditlog: remote has unknown records, sending all\n'
            )
            their_size = 0
    else:
        repo.ui.debug(b'auditlog: remote has more records, sending all\n')
        their_size = 0

    if my_size == their_size:
        repo.ui.debug(b'auditlog: already up-to-date\n')
    else:
        data = b''.join(
            [auditlog.rawdata(rev) for rev in range(their_size, len(auditlog))]
        )
        repo.ui.debug(
            b'auditlog: sending %d records (%d bytes)\n'
            % (len(auditlog) - their_size, len(data))
        )
        fp = util.stringio(data)
        stream = pushop.remote._calltwowaystream(b'pushauditlog', fp)
        result = stream.read()
        if result == b'0':
            pushop.ui.warn(_('WARNING: auditlog not processed by remote\n'))
        elif result != b'1':
            pushop.ui.warn(_('bad remote reply\n'))
    return res


def exchangepullauditlog(orig, pullop):
    """Request auditlog data from peer.

    The auditlog is replicated fully as long as the revisions are known
    locally as well. Partial replication is currently not supported.
    """
    res = orig(pullop)
    if not pullop.repo.ui.configbool(b'auditlog', b'client-mode'):
        return res

    # Future versions might do the exchange directly as part of bundle2,
    # so check if the step is already done.
    # No further operations are necessary if the peer doesn't server auditlogs.
    if b'auditlog' in pullop.stepsdone or not pullop.remote.capable(
        b'auditlog-server'
    ):
        return res

    pullop.stepsdone.add(b'pushlog')
    repo = pullop.repo
    ui = repo.ui
    tr = pullop.trmanager.transaction()
    unfi = repo.unfiltered()
    auditlog = repo.auditlog.auditlogfile()
    my_size, my_digest = repo.auditlog.status()

    raw_records = pullop.remote._call(
        b'pullauditlog', their_size=b'%d' % my_size, their_digest=my_digest
    )
    try:
        records = cborutil.decodeall(raw_records)
    except cborutil.CBORDecodeError:
        ui.warn(
            _(b'WARNING: incomplete or broken records obtained from server\n')
        )
        return res
    ui.debug(
        b'auditlog: received %d records (%d bytes)\n'
        % (len(records), len(raw_records))
    )
    for record in records:
        if not verifyparsedrecord(ui, record):
            return res

        if record[b'node'] not in unfi:
            ui.debug(
                b'auditlog: skipping record for unknown node %s\n'
                % record[b'node']
            )
            continue

        auditlog.addrevision(
            b''.join(cborutil.streamencode(record)),
            tr,
            unfi[record[b'node']].rev(),
            repo.nullid,
            repo.nullid,
        )

    return res


class auditlog(object):
    '''An interface to auditlog data.'''

    def __init__(self, repo):
        '''Create an instance bound to a sqlite database in a path.'''
        self.repo = repo
        self._auditlogrevlog = None
        self._baserecord = None

    def status(self, limit=None):
        auditlogfile = self.auditlogfile()
        size = len(auditlogfile)
        m = hashlib.sha256()
        if limit is None:
            limit = size
        for node in sorted(auditlogfile.node(i) for i in range(limit)):
            m.update(node)
        return size, m.hexdigest()

    def auditlogfile(self):
        if self._auditlogrevlog is None:
            self._auditlogrevlog = revlog.revlog(
                self.repo.store.opener,
                target=(revlog_constants.KIND_OTHER, None),
                radix=b'auditlog',
                checkambig=False,
                mmaplargeindex=True,
            )
        return self._auditlogrevlog

    def repair(self):
        auditlog = self.auditlogfile()
        unfi = self.repo.unfiltered()
        kill_revs = set()
        fix_revs = {}
        for auditrev in auditlog:
            record = cborutil.decodeall(auditlog.rawdata(auditrev))[0]
            clrev = auditlog.linkrev(auditrev)
            if record[b'node'] not in unfi:
                kill_revs.add(auditrev)
            elif clrev >= len(unfi) or unfi[clrev].hex() != record[b'node']:
                fix_revs[auditrev] = unfi[record[b'node']].rev()
        if not kill_revs and not fix_revs:
            return
        if auditlog._docket_file:
            raise RuntimeError(
                b'Stripping of auditlog with docket file is not supported'
            )
        tr = self.repo.currenttransaction()
        newauditlog = revlog.revlog(
            self.repo.store.opener,
            target=(revlog_constants.KIND_OTHER, None),
            radix=b'auditlog.strip',
            checkambig=False,
        )
        for auditrev in auditlog:
            if auditrev in kill_revs:
                continue
            entry = auditlog.index[auditrev]
            with newauditlog._writing(tr):
                newauditlog.addrevision(
                    auditlog.rawdata(auditrev),
                    tr,
                    entry[4]
                    if auditrev not in fix_revs
                    else fix_revs[auditrev],
                    auditlog.node(entry[5]),
                    auditlog.node(entry[6]),
                )
        oldindex = auditlog._indexfile
        newindex = newauditlog._indexfile
        olddata = auditlog._datafile
        newdata = newauditlog._datafile
        self._auditlogrevlog = None
        del auditlog
        vfs = self.repo.store.vfs
        if vfs.exists(newindex):
            vfs.rename(newindex, oldindex)
        elif vfs.exists(oldindex):
            vfs.unlink(oldindex)
        if vfs.exists(newdata):
            vfs.rename(newdata, olddata)
        elif vfs.exists(olddata):
            vfs.unlink(olddata)

    def prepare_record(self, ui, repo, source):
        if self._baserecord is not None:
            return self._baserecord
        baserecord = {
            b'date': int(time.time() * 1000000000),
        }
        if source == b'serve':
            # REMOTE_USER is set for wsgi requests when authentication is enabled.
            remoteuser = ui.environ.get(b'REMOTE_USER')
            # USER is set by ssh, use it as fallback.
            user = encoding.environ.get(b'USER')
            if not remoteuser and not user:
                ui.error(
                    _(
                        b'authenticated user not found; refusing to write auditlog\n'
                    )
                )
                return None
            baserecord.update(
                {b'user': remoteuser or user, b'source': b'remote'}
            )
            if ui.configbool(b'auditlog', b'record-public-key'):
                sslcert = ui.environ.get(b'SSL_CLIENT_CERT')
                if sslcert:
                    baserecord[b'public-key'] = sslcert
                sshuserauthfile = encoding.environ.get(b'SSH_USER_AUTH')
                if sshuserauthfile:
                    try:
                        sshuserauth = util.readfile(sshuserauthfile)
                    except (FileNotFoundError, PermissionError):
                        ui.error(
                            _(
                                b'unable to read SSH user info; refusing to write auditlog\n'
                            )
                        )
                        return None
                    if sshuserauth.startswith(b'publickey '):
                        baserecord[b'public-key'] = sshuserauth.split(b' ', 1)[
                            1
                        ].strip()
        elif source in (b'push', b'pull', b'unbundle', b'local'):
            user = encoding.environ.get(b'USER')
            if not user:
                ui.error(
                    _(
                        b'authenticated user not found; refusing to write auditlog\n'
                    )
                )
                return None
            baserecord.update(
                {
                    b'user': user,
                    b'source': b'local',
                }
            )
        else:
            raise error.ProgrammingError(b'unknown source: %s' % source)

        self._baserecord = baserecord
        return baserecord


@command(b'auditlog', cmdutil.formatteropts, b'dump the auditlog')
def printauditlog(ui, repo, *revs, **opts):
    opts = pycompat.byteskwargs(opts)
    if not opts.get(b'template'):
        opts[b'template'] = default_template
    fm = ui.formatter(b'auditlog', opts)

    def printrecord(record):
        fm.startitem()
        fm.data(
            changeset=record[b'node'],
            user=record[b'user'],
            date=fm.formatdate((record[b'date'] / 1000000000, 0)),
            phase=record[b'phase'],
            oldphase=record.get(b'old-phase'),
            source=record[b'source'],
            pki=record.get(b'public-key'),
        )

    auditlog = repo.auditlog.auditlogfile()
    if revs:
        revs = set(logcmdutil.revrange(repo, revs))

    for rev in auditlog:
        if revs and auditlog.linkrev(rev) not in revs:
            continue
        record = verifyrecord(ui, auditlog.rawdata(rev))
        if record is None:
            continue  # skip bad records for now
        printrecord(record)

    fm.end()


def verifyparsedrecord(ui, record):
    schema = {
        b'node': (True, lambda v: isinstance(v, bytes)),
        b'user': (True, lambda v: isinstance(v, bytes)),
        b'date': (True, lambda v: isinstance(v, int)),
        b'phase': (True, lambda v: isinstance(v, bytes)),
        b'old-phase': (False, lambda v: isinstance(v, bytes)),
        b'source': (True, lambda v: isinstance(v, bytes)),
        b'public-key': (False, lambda v: isinstance(v, bytes)),
    }
    if not isinstance(record, dict):
        ui.debug(b'auditlog: found bad record: dictionary expected\n')
        return False
    for key, (mandatory, typecheck) in schema.items():
        if key in record:
            if not typecheck(record[key]):
                ui.debug(
                    b'auditlog: record entry has the wrong type: %s\n' % key
                )
                return False
        elif mandatory:
            ui.debug(b'auditlog: mandatory key missing from record: %s\n' % key)
            return False
    return True


def verifyrecord(ui, raw_record):
    try:
        records = cborutil.decodeall(raw_record)
    except cborutil.CBORDecodeError:
        ui.debug(b'auditlog: found unparsable record\n')
        return
    if not isinstance(records, list):
        ui.debug(b'auditlog: found bad record: list expected\n')
        return
    if len(records) != 1:
        ui.debug(b'auditlog: found bad record: list of one element expected\n')
        return
    record = records[0]
    if not verifyparsedrecord(ui, record):
        return
    return record


def pretxnclosehook(ui, repo, **kwargs):
    if kwargs.get('source') == b'push-response':
        return 0
    if kwargs['txnname'] == b'strip':
        repo.auditlog.repair()
        return 0

    baserecord = repo.auditlog.prepare_record(
        ui, repo, kwargs.get('source', b'local')
    )
    if baserecord is None:
        return 1
    auditlog = repo.auditlog.auditlogfile()
    tr = repo.currenttransaction()
    unfi = repo.unfiltered()
    if 'node' in kwargs and 'node_last' in kwargs:
        revs = range(
            repo[kwargs['node']].rev(), repo[kwargs['node_last']].rev() + 1
        )
    else:
        revs = range(kwargs['changes'][b'origrepolen'], len(repo))
    for rev in revs:
        ctx = unfi[rev]
        record = {
            b'node': ctx.hex(),
            b'phase': phases.phasenames[ctx.phase()],
            b'transaction': kwargs['txnid'],
            **baserecord,
        }
        auditlog.addrevision(
            b''.join(cborutil.streamencode(record)),
            tr,
            rev,
            repo.nullid,
            repo.nullid,
        )

    for revs, (oldphase, newphase) in tr.changes[b'phases']:
        if oldphase is None:
            continue
        for rev in revs:
            ctx = unfi[rev]
            assert ctx.phase() == newphase
            record = {
                b'node': ctx.hex(),
                b'old-phase': phases.phasenames[oldphase],
                b'phase': phases.phasenames[newphase],
                b'transaction': kwargs['txnid'],
                **baserecord,
            }
            auditlog.addrevision(
                b''.join(cborutil.streamencode(record)),
                tr,
                rev,
                repo.nullid,
                repo.nullid,
            )

    ui.note(_(b'recorded push in auditlog\n'))
    return 0


def extsetup(ui):
    extensions.wrapfunction(wireprotov1server, '_capabilities', capabilities)
    extensions.wrapfunction(exchange, '_pullobsolete', exchangepullauditlog)
    extensions.wrapfunction(exchange, '_pushobsolete', exchangepushauditlog)


def reposetup(ui, repo):
    if not repo.local():
        return

    if repo.ui.configbool(b'auditlog', b'server-mode') and repo.ui.configbool(
        b'auditlog', b'client-mode'
    ):
        ui.error(
            _(
                b'auditlog misconfigured: client-mode and server-mode are mutally exclusive\n'
            )
        )

    if repo.ui.configbool(b'auditlog', b'server-mode'):
        ui.setconfig(b'hooks', b'pretxnclose', pretxnclosehook, b'auditlog')

    class auditlogrepo(repo.__class__):
        @localrepo.storecache(b'auditlog.i')
        def auditlog(self):
            return auditlog(self)

    repo.__class__ = auditlogrepo
