#include <stdlib.h>
#include <stdio.h>
#include <string.h>
#include <limits.h>

#include "viterbi.h"
#include "misc.h"

/* ---------------------------------------------------------------------- */

struct viterbi *viterbi_init(int k, int poly1, int poly2)
{
	struct viterbi *v;
	int i;

	if ((v = calloc(1, sizeof(struct viterbi))) == NULL) {
		perror("init_viterbi: calloc");
		return NULL;
	}

	v->nstates = 1 << (k - 1);

	if ((v->output = calloc(1 << k, sizeof(int))) == NULL) {
		perror("init_viterbi: calloc");
		free(v);
		return NULL;
	}

	for (i = 0; i < (1 << k); i++)
		v->output[i] = parity(poly1 & i) | (parity(poly2 & i) << 1);

	for (i = 0; i < PATHMEM; i++) {
		v->metrics[i] = calloc(v->nstates, sizeof(int));
		v->history[i] = calloc(v->nstates, sizeof(int));
	}

	for (i = 0; i < 256; i++) {
		v->mettab[0][i] = 128 - i;
		v->mettab[1][i] = i - 128;
	}

	v->ptr = 0;

	return v;
}

void viterbi_reset(struct viterbi *v)
{
	int i;

	for (i = 0; i < PATHMEM; i++) {
		memset(v->metrics[i], 0, v->nstates * sizeof(int));
		memset(v->history[i], 0, v->nstates * sizeof(int));
	}

	v->ptr = 0;
}

void viterbi_free(struct viterbi *v)
{
	int i;

	if (v) {
		free(v->output);

		for (i = 0; i < v->nstates; i++) {
			free(v->metrics[i]);
			free(v->history[i]);
		}

		free(v);
	}
}

static int traceback(struct viterbi *v, int *metric);

int viterbi_decode(struct viterbi *v, unsigned char *sym, int *metric)
{
	unsigned int currptr, prevptr;
	int i, j, met[4], n;

	currptr = v->ptr;
	prevptr = (currptr - 1) % PATHMEM;

#if 0
	sym[0] = (sym[0] < 128) ? 0 : 255;
	sym[1] = (sym[1] < 128) ? 0 : 255;
#endif

	met[0] = v->mettab[0][sym[1]] + v->mettab[0][sym[0]];
	met[1] = v->mettab[0][sym[1]] + v->mettab[1][sym[0]];
	met[2] = v->mettab[1][sym[1]] + v->mettab[0][sym[0]];
	met[3] = v->mettab[1][sym[1]] + v->mettab[1][sym[0]];

	for (n = 0; n < v->nstates; n++) {
		int p0, p1, s0, s1, m0, m1;

		s0 = n;
		s1 = n + v->nstates;

		p0 = s0 >> 1;
		p1 = s1 >> 1;

		m0 = v->metrics[prevptr][p0] + met[v->output[s0]];
		m1 = v->metrics[prevptr][p1] + met[v->output[s1]];

		if (m0 > m1) {
			v->metrics[currptr][n] = m0;
			v->history[currptr][n] = p0;
		} else {
			v->metrics[currptr][n] = m1;
			v->history[currptr][n] = p1;
		}
	}

	v->ptr = (v->ptr + 1) % PATHMEM;

	if ((v->ptr % 8) == 0)
		return traceback(v, metric);

	if (v->metrics[currptr][0] > INT_MAX / 2) {
		for (i = 0; i < PATHMEM; i++)
			for (j = 0; j < v->nstates; j++)
				v->metrics[i][j] -= INT_MAX / 2;
	}
	if (v->metrics[currptr][0] < INT_MIN / 2) {
		for (i = 0; i < PATHMEM; i++)
			for (j = 0; j < v->nstates; j++)
				v->metrics[i][j] += INT_MIN / 2;
	}

	return -1;
}

static int traceback(struct viterbi *v, int *metric)
{
	int i, bestmetric, beststate;
	unsigned int p, c;

	p = (v->ptr - 1) % PATHMEM;

	/*
	 * Find the state with the best metric
	 */
	bestmetric = INT_MIN;
	beststate = 0;

	for (i = 0; i < v->nstates; i++) {
		if (v->metrics[p][i] > bestmetric) {
			bestmetric = v->metrics[p][i];
			beststate = i;
		}
	}

	/*
	 * Trace back PATHMEM - 1 steps
	 */
	v->sequence[p] = beststate;

	for (i = 0; i < PATHMEM - 1; i++) {
		unsigned int prev = (p - 1) % PATHMEM;

		v->sequence[prev] = v->history[p][v->sequence[p]];
		p = prev;
	}

	if (metric)
		*metric = v->metrics[p][v->sequence[p]];

	/*
	 * Decode 8 bits
	 */
	for (c = i = 0; i < 8; i++) {
		/*
		 * low bit of state is the previous input bit
		 */
		c = (c << 1) | (v->sequence[p] & 1);
		p = (p + 1) % PATHMEM;
	}

	if (metric)
		*metric = v->metrics[p][v->sequence[p]] - *metric;

	return c;
}

/* ---------------------------------------------------------------------- */

struct encoder *encoder_init(int k, int poly1, int poly2)
{
	struct encoder *e;
	int i, size;

	if ((e = calloc(1, sizeof(struct encoder))) == NULL) {
		perror("init_encoder: calloc");
		return NULL;
	}

	size = 1 << k;	/* size of the output table */

	if ((e->output = calloc(size, sizeof(int))) == NULL) {
		perror("init_viterbi: calloc");
		free(e);
		return NULL;
	}

	for (i = 0; i < size; i++)
		e->output[i] = parity(poly1 & i) | (parity(poly2 & i) << 1);

	e->shreg = 0;
	e->shregmask = size - 1;

	return e;
}

void encoder_free(struct encoder *e)
{
	if (e) {
		free(e->output);
		free(e);
	}
}

int encoder_encode(struct encoder *e, int bit)
{
	e->shreg = (e->shreg << 1) | !!bit;

	return e->output[e->shreg & e->shregmask];
}

/* ---------------------------------------------------------------------- */
