/*
 ---------------------------------------------------------------------------
 Copyright (c) 2002, Dr Brian Gladman <brg@gladman.me.uk>, Worcester, UK.
 All rights reserved.

 LICENSE TERMS

 The free distribution and use of this software in both source and binary
 form is allowed (with or without changes) provided that:

   1. distributions of this source code include the above copyright
      notice, this list of conditions and the following disclaimer;

   2. distributions in binary form include the above copyright
      notice, this list of conditions and the following disclaimer
      in the documentation and/or other associated materials;

   3. the copyright holder's name is not used to endorse products
      built using this software without specific written permission.

 ALTERNATIVELY, provided that this notice is retained in full, this product
 may be distributed under the terms of the GNU General Public License (GPL),
 in which case the provisions of the GPL apply INSTEAD OF those given above.

 DISCLAIMER

 This software is provided 'as is' with no explicit or implied warranties
 in respect of its properties, including, but not limited to, correctness
 and/or fitness for purpose.
 ---------------------------------------------------------------------------
 Issue Date: 20/02/2003

 This file contains the code for implementing encryption and authentication
 using AES with a Carter-Wegamn hash function  
*/

#include "aescwc.i.h"

#define LITTLE_ENDIAN
#define CWC_BLOCK_SIZE	12

/* increment a 4 byte big-endian counter	*/

#define ui_inc(x,n)	!++((x)[n+3]) && !++((x)[n+2]) && !++((x)[n+1]) &&  ++((x)[n])

/* reverse byte order in 32 bit long word	*/

#ifdef	LITTLE_ENDIAN
#define brot(x,n)	(((ulong)(x) <<  n) | ((ulong)(x) >> (32 - n)))
#define swap_b32(x) ((brot((x),8) & 0x00ff00ff) | (brot((x),24) & 0xff00ff00))
#else
#define swap_b32(x)	(x)
#endif

/* add multiple length unsigned values in big endian form	*/
/* little endian long words in big endian word order		*/

void add_4(ulong l[], ulong r[])
{	ulong	ss, cc;

	ss = l[3] + r[3]; 
	cc = (ss < l[3] ? 1 : 0); 
	l[3] = ss;
	
	ss = l[2] + r[2] + cc; 
	cc = (ss < l[2] ? 1 : ss > l[2] ? 0 : cc); 
	l[2] = ss;
	
	ss = l[1] + r[1] + cc; 
	cc = (ss < l[1] ? 1 : ss > l[1] ? 0 : cc); 
	l[1] = ss;
	
	l[0] += r[0] + cc;
}

/* multiply multiple length unsigned values in big endian form	*/
/* little endian long words in big endian word order			*/

void mlt_4(ulong r[], const ulong a[], const ulong b[])
{	unsigned long long	ch, cl, sm;
    int		i, j, k;

	for(i = 0, cl = 0; i < 8; ++i)
	{
		/* number of terms in sum	*/
		k = (i < 3 ? 0 : i - 3);

		for(j = k, ch = 0; j <= i - k; ++j)
		{
			sm = (unsigned long long)a[3 - j] * b[3 - i + j];
			cl += (ulong)sm;
			ch += (sm >> 32);
		}

		r[7 - i] = (ulong)cl;
		cl = (cl >> 32) + ch;
	}
}

/* Carter-Wegman hash iteration on 12 bytes of data	*/

void do_cwc(ulong in[], aes_cwc_ctx ctx[1])
{	ulong	*pt = ctx->hash + BLOCK_LEN, data[4];

	/* put big endian 32-bit items into little endian order		*/
	data[3] = swap_b32(in[2]);
	data[2] = swap_b32(in[1]);
	data[1] = swap_b32(in[0]);
	data[0] = 0;
	
	/* add current hash value into the current data block	*/
	add_4(data, ctx->hash);

	/* multiply by the hash key in Z						*/
	mlt_4(ctx->hash, data, ctx->zval);

	/* we now want to find the remainder when divided by	*/
	/* (2^127 - 1).  If hash = 2^128 * hi + lo, we can see	*/
	/* that hash = (2^127 - 1) * 2 * hi + 2 * hi + lo, so	*/
	/* we can set the 128 bit remainder as 2 * hi + lo		*/

	add_4(ctx->hash, ctx->hash);/* 2 * hi - if top bit = 1	*/
	if(*pt & 0x80000000)	/* another 2^127-1 has to be	*/
	{						/* subtracted from the result	*/
		*pt &= 0x7fffffff;
		*(pt - 1) += 1; 
	}
	
	add_4(ctx->hash, pt);		/* 2 * hi + lo - adjust the */
	if(*ctx->hash & 0x80000000)	/* result again (as above)	*/
	{
		*ctx->hash &= 0x7fffffff;
		ui_inc((ulong*)ctx->hash, 0);
	}
}

/* initialise AES-CWC encryption mode	*/

void aes_cwc_init(
    const unsigned char key[], unsigned int k_len,		/* the key value to be used         */
    const unsigned char nonce[], unsigned int n_len,	/* the nonce value                  */
    const unsigned char auth[], unsigned int a_len,		/* authenticated data (len < 2^32)	*/
    aes_cwc_ctx ctx[1])									/* the CWC context                  */
{	unsigned int	i, pos;

	/* set all bytes in the context to zero		*/	
	memset(ctx, 0, sizeof(aes_cwc_ctx));
	
	/* set up the initial nonce in the context	*/
	ctx->nonce[0] = 0x80;
	for(i = 0; i < 11; ++i)
		ctx->nonce[i + 1] = nonce[i];
	
	/* set up encryption context				*/
	switch(k_len)
	{
	case 16: aes_encrypt_key128(key, ctx->enc_ctx); break;
	case 24: aes_encrypt_key192(key, ctx->enc_ctx); break;
	case 32: aes_encrypt_key256(key, ctx->enc_ctx); break;
	default:
		return;
	}

	/* initialise cwc z value					*/
	((uchar*)ctx->zval)[0] = 0xc0;
	aes_encrypt((uchar*)ctx->zval, (uchar*)ctx->zval, ctx->enc_ctx);
	((uchar*)ctx->zval)[0] &= 0x7f;

#ifdef LITTLE_ENDIAN
	/* reverse z value into little endian form	*/
	ctx->zval[0] = swap_b32(ctx->zval[0]);
	ctx->zval[1] = swap_b32(ctx->zval[1]);
	ctx->zval[2] = swap_b32(ctx->zval[2]);
	ctx->zval[3] = swap_b32(ctx->zval[3]);
#endif

	/* initialise the xor buffer				*/
	ctx->nonce[15] = 1;
	aes_encrypt(ctx->nonce, ctx->xor_buf, ctx->enc_ctx);

	/* authenticate the header data				*/
	
	i = 0; pos = 0;
	/* if buffers and count are aligned use long operations	*/
	if(!pos && !((ulong)auth & 3) && !((ulong)ctx->cwc_buf & 3))
	{
		while(i + CWC_BLOCK_SIZE < a_len)
		{
			do_cwc((ulong*)(auth + i), ctx);
			i += CWC_BLOCK_SIZE;
		}
	}

	while(i < a_len)
	{
		((uchar*)ctx->cwc_buf)[pos++] = auth[i++];

		if(pos == CWC_BLOCK_SIZE)
		{
			do_cwc(ctx->cwc_buf, ctx); 
			pos = 0;
		}
	}

	if(pos)	/* if bytes remain in the buffer	*/
	{
		while(pos < CWC_BLOCK_SIZE)
			((uchar*)ctx->cwc_buf)[pos++] = 0;

		do_cwc(ctx->cwc_buf, ctx);
	}

	/* set the header length in the context		*/
	ctx->hlen_lo = a_len;
	ctx->hlen_hi = 0;
}

void aes_cwc_encrypt(unsigned char data[], unsigned int len, aes_cwc_ctx ctx[1])
{	unsigned int i, pos;

	i = 0; pos = ctx->mlen_lo & (AES_BLOCK_SIZE - 1);
	/* if buffers and count are aligned use long operations	*/
	if(!pos && !((ulong)data & 3) && !((ulong)ctx->xor_buf & 3))
	{	ulong	*p, *q = (ulong*)(ctx->xor_buf);

		while(i + AES_BLOCK_SIZE < len)
		{	
			p = (ulong*)(data + i);
			p[0] ^= q[0]; p[1] ^= q[1]; p[2] ^= q[2]; p[3] ^= q[3];
			i += AES_BLOCK_SIZE; 
			ui_inc(ctx->nonce, 12);
			aes_encrypt(ctx->nonce, ctx->xor_buf, ctx->enc_ctx);
		}
	}
	
	while(i < len)
	{
		data[i++] ^= ctx->xor_buf[pos++];
		
		if(pos == AES_BLOCK_SIZE)
		{
			ui_inc(ctx->nonce, 12); pos = 0;
			aes_encrypt(ctx->nonce, ctx->xor_buf, ctx->enc_ctx);
		}
	}

	i = 0; pos = ctx->pos;
	/* if buffers and count are aligned use long operations	*/
	if(!pos && !((ulong)data & 3) && !((ulong)ctx->cwc_buf & 3))
	{
		while(i + CWC_BLOCK_SIZE < len)
		{
			do_cwc((ulong*)(data + i), ctx);
			i += CWC_BLOCK_SIZE;
		}
	}

	while(i < len)
	{
		((uchar*)ctx->cwc_buf)[pos++] = data[i++];

		if(pos == CWC_BLOCK_SIZE)
		{
			do_cwc(ctx->cwc_buf, ctx); pos = 0;
		}
	}
	ctx->pos = pos;

	if((ctx->mlen_lo += len) < len)
		(ctx->mlen_hi)++;
}

void aes_cwc_decrypt(unsigned char data[], unsigned int len, aes_cwc_ctx ctx[1])
{	unsigned int i, pos;

	i = 0; pos = ctx->pos;
	/* if buffers and count are aligned use long operations	*/
	if(!pos && !((ulong)data & 3) && !((ulong)ctx->cwc_buf & 3))
	{
		while(i + CWC_BLOCK_SIZE < len)
		{
			do_cwc((ulong*)(data + i), ctx);
			i += CWC_BLOCK_SIZE;
		}
	}
	
	while(i < len)
	{
		((uchar*)ctx->cwc_buf)[pos++] = data[i++];

		if(pos == CWC_BLOCK_SIZE)
		{
			do_cwc(ctx->cwc_buf, ctx); pos = 0;
		}
	}
	ctx->pos = pos;

	i = 0; pos = ctx->mlen_lo & (AES_BLOCK_SIZE - 1);
	/* if buffers and count are aligned use long operations	*/
	if(!pos && !((ulong)data & 3) && !((ulong)ctx->xor_buf & 3))
	{	ulong	*p, *q = (ulong*)(ctx->xor_buf);

		while(i + AES_BLOCK_SIZE < len)
		{	
			p = (ulong*)(data + i);
			p[0] ^= q[0]; p[1] ^= q[1]; p[2] ^= q[2]; p[3] ^= q[3];
			i += AES_BLOCK_SIZE; 
			ui_inc(ctx->nonce, 12);
			aes_encrypt(ctx->nonce, ctx->xor_buf, ctx->enc_ctx);
		}
	}
	
	while(i < len)
	{
		data[i++] ^= ctx->xor_buf[pos++];
		
		if(pos == AES_BLOCK_SIZE)
		{
			ui_inc(ctx->nonce, 12); pos = 0;
			aes_encrypt(ctx->nonce, ctx->xor_buf, ctx->enc_ctx);
		}
	}

	if((ctx->mlen_lo += len) < len)
		(ctx->mlen_hi)++;
}

void aes_cwc_end(unsigned char data[], unsigned int a_len, aes_cwc_ctx ctx[1])
{	unsigned int	pos;

	pos = ctx->pos;
	if(pos)
	{
		while(pos < CWC_BLOCK_SIZE)
			((uchar*)ctx->cwc_buf)[pos++] = 0;

		do_cwc(ctx->cwc_buf, ctx);
	}

	ctx->zval[0] = ctx->hlen_hi;
	ctx->zval[1] = ctx->hlen_lo;
	ctx->zval[2] = ctx->mlen_hi;
	ctx->zval[3] = ctx->mlen_lo;

	add_4(ctx->hash, ctx->zval);

	if(ctx->hash[0] & 0x80000000)
	{
		ctx->hash[0] &= 0x7fffffff;
		ui_inc(ctx->hash, 0);
	}

#ifdef LITTLE_ENDIAN
	ctx->hash[0] = swap_b32(ctx->hash[0]);
	ctx->hash[1] = swap_b32(ctx->hash[1]);
	ctx->hash[2] = swap_b32(ctx->hash[2]);
	ctx->hash[3] = swap_b32(ctx->hash[3]);
#endif

	aes_encrypt((uchar*)ctx->hash, (uchar*)ctx->hash, ctx->enc_ctx);

	ctx->nonce[12] = 0; ctx->nonce[13] = 0; 
	ctx->nonce[14] = 0; ctx->nonce[15] = 0;
	aes_encrypt(ctx->nonce, ctx->xor_buf, ctx->enc_ctx);

	for(pos = 0; pos < a_len; ++pos)
		data[pos] = ((uchar*)ctx->hash)[pos] ^ ctx->xor_buf[pos];
}
