#include <stream.h>
#include <string.h>
#include <math.h>
#include "error.h"
#include "perf.h"
#include "dataseq.h"
#include "vq.h"
#include "codevec.h"

extern "C" double pow(double,double);
extern "C" double log(double);
//extern "C" long rand();
extern DATAVEC* dvp;
const double Inf = 9.99999e99;	// practically infinity

/*
 *  Allocate space for and construct a VQ from an open input stream.
 */
VQ* new_VQ(istream& is) {

    char vqtype[256];
    is >> vqtype;
    while (vqtype[0]=='#') {			// get rid of comments
	is.get(vqtype,256);
	is >> vqtype;
	}

    if      (strcmp(vqtype,"CW")==0)		return new CODEVEC(is);
    else if (strcmp(vqtype,"MTSVQ")==0)         return new MTSVQ(is);
    else { Error(vqtype," is not a vq type");	return NULL; }
    }

/*
 *  Allocate space for and construct a VQ from a file.
 */
VQ* new_VQ(char* filename) {

    filebuf fb;
    if (fb.open(filename,input)==0)
	Error("can't open vq description file ",filename);
    istream is(&fb);
    return new_VQ(is);
    }

/*
 *  Test a VQ (on a DATASEQ).
 */
PERF VQ::test(const DATASEQ& ds) const {

    PERF perf(0,0);
    for (DATABLK* dbp=ds.head; dbp; dbp=dbp->next)
	perf += test(*dbp);
    return perf;
    }

/*
 *  Allocate space for and construct a VQ from a file.
 */
void VQ::get(char* filename) {

    filebuf fb;
    if (fb.open(filename,input)==0)
	Error("can't open vq description file ",filename);
    istream is(&fb);
    char vqtype[50];
    is >> vqtype; // ignore vqtype, hope it's okay
    get(is);
    }

/*
 *  Write out a VQ to a file.
 */
void VQ::put(char* filename) {

    filebuf fb;
    if (fb.open(filename,output)==0)
	Error("can't open vq output file ",filename);
    ostream os(&fb);
    put(os);
    }

/*
 *  Construct a FullSearchVQ from another FullSearchVQ.
 */
FullSearchVQ::FullSearchVQ(const FullSearchVQ& fullsearchvq) {

    width = fullsearchvq.width;
    size = fullsearchvq.size;
    if (size) {
	cw = new VQ*[size];
	rate = new double[size];
	for (int m=0; m<size; m++) {
	    cw[m] = fullsearchvq.cw[m]->copy();
	    rate[m] = fullsearchvq.rate[m];
	    }
	}
    else {
	cw = 0;
	rate = 0;
	}

    logsize = fullsearchvq.logsize;
    entropy = fullsearchvq.entropy;
    fullcnt = fullsearchvq.fullcnt;
    }

/*
 *  Destruct a FullSearchVQ.
 */
FullSearchVQ::~FullSearchVQ() {

    for (int m=0; m<size; m++)
	delete cw[m]; // recursive destruct
    delete cw;
    delete rate;
    }

/*
 *  Test a FullSearchVQ (on the data).
 */
PERF FullSearchVQ::test(const DATABLK& db) const {

    PERF perf(0,0);
    DATASEQ ds(db,width);
    for (DATABLK* dbp=ds.head; dbp; dbp=dbp->next) {
	int bestm; // not used here
	perf += nearest(*dbp,bestm);
	}
    return perf;
    }

/*
 *  Train a FullSearchVQ (private helper function).
 */
void FullSearchVQ::train(const DATASEQ& ds, double precision,
	int fillEmptyCellsFlag) {

    DATASEQ vqds(ds,width);		// dataseq for all codewords
    DATASEQ* cwds = new DATASEQ[size];	// dataseq for each codeword
    long* pop = new long[size];		// population of each codeword
    double* dist = new double[size];	// total distortion of each cw
    char* split = new char[size];	// split indicator for each cw

    int m;				// codeword index
    PERF oldperf, perf;

    int lastiter = 0;

    for (int i=0; lastiter || i<2 || i<MAXITER && (oldperf-perf)/oldperf>precision; i++) {

	oldperf = perf;
	perf.dist = perf.rate = 0;
	for (m=0; m<size; m++) { pop[m]=0; dist[m]=0; split[m]=0; }
	long totalpop = 0;

	// Optimize encoder for decoder.
	while (vqds.head) {
	    totalpop++;
	    DATABLK* dbp = vqds.pop();
	    PERF tmp_perf = nearest(*dbp,m);
	    pop[m]++;
	    perf += tmp_perf;
	    dist[m] += tmp_perf.dist;
	    cwds[m].push(dbp);
	    }

	// Debug.
	// cerr << "iter=" << i << " ";
	// cerr << "dist=" << perf.dist << " ";
	// cerr << "rate=" << perf.rate << " ";
	// cerr << "perf=" << double(perf) << "\n";

	// Deal with empty cells if necessary.
	if (fillEmptyCellsFlag && !lastiter) {

	    // Count the number of full cells (cells that can be split).
	    fullcnt = 0;
	    for (m=0; m<size; m++)
		if (pop[m]>1) fullcnt++;

	    // Scan for empty cells (as long as there are full cells).
	    for (int mt=0; fullcnt && mt<size; mt++) {
		if (pop[mt]) continue;

		// Found an empty cell.  Now find a full cell with the
		// highest distortion that hasn't already been split.
		double worstdist = -1;
		int worstm = -1;
		for (m=0; m<size; m++)
		    if (pop[m]>1 && !split[m] && dist[m]>worstdist) {
			worstdist = dist[m];
			worstm = m;
			}

		// Found a full cell to split.  Now split off the 
		// first vector which is far away.
		fullcnt--;			// one less cell to split
		split[worstm]++;		// don't split it again
		worstdist /= pop[worstm];	// average distortion
		pop[worstm] = pop[mt] = 0;	// reset populations
		DATASEQ tmp_ds;
		tmp_ds += cwds[worstm];	// transfer samples to tmp_ds
		while (tmp_ds.head) {
		    DATABLK* dbp = tmp_ds.pop();
#ifdef SPLIT_RAND
		    if (rand()&0X1000) {
#else
#ifdef SPLIT_GEQ
		    if (pop[mt]==0 && cw[worstm]->test(*dbp).dist>=worstdist) {
#else
		    if (pop[mt]==0 && cw[worstm]->test(*dbp).dist>worstdist) {
#endif
#endif
			cwds[mt].push(dbp);	// transfer to empty cell
			pop[mt]++;
			}
		    else {			// retain in full cell
			cwds[worstm].push(dbp);	
			pop[worstm]++;
			}
		    }
		}
	    }

	// Optimize decoder for encoder.
	double prec = (i<2)? 1.0 : (oldperf-perf)/oldperf;
	for (m=0; m<size; m++) {
	    if (pop[m]==0) continue;		// ignore still empty cells
	    cw[m]->train(cwds[m],prec);
	    vqds += cwds[m];
	    }

	// Optimize codelengths & set entropy.
	entropy = 0.0;
	fullcnt = 0;
	for (m=0; m<size; m++)
	    if (pop[m]) {
		rate[m] = -log((double)pop[m]/totalpop)/log(2.0);
		entropy += rate[m]*pop[m];
		fullcnt++;
		}
	    else rate[m] = Inf;
	if (totalpop) entropy /= totalpop;
        if (!lastiter && !(i<1 || i<MAXITER-1 && (oldperf-perf)/oldperf>precision)) lastiter = 1; else lastiter=0;
	}

    delete cwds;
    delete pop;
    delete dist;
    delete split;
    }

/*
 *  Input a FullSearchVQ.
 */
void FullSearchVQ::get(istream& is) {

    is >> width;
    is >> size;
    cw = new VQ*[size];
    rate = new double[size];

    // Next M entries specify the codewords.
    int M;
    is >> M;
    if (M<=0 || M>size) Error("bad format for FullSearchVQ");
    for (int m=0; m<M; m++) {
	cw[m] = new_VQ(is);
	is >> rate[m];
	}

    // Remaining codewords are copies of codeword 0.
    for (; m<size; m++) {
	cw[m] = cw[0]->copy();
	rate[m] = rate[0];
	}

    logsize = log(size)/log(2.0);
    entropy = 0.0;
    fullcnt = 0;
    }

/*
 *  Output a FullSearchVQ.
 */
void FullSearchVQ::put(ostream& os) const {

    os << width << " " << size << " " << size << "\n";
    for (int m=0; m<size; m++) {
	os << "# prob = " << pow(2.0,-rate[m]) << "\n";
	cw[m]->put(os);
	os << rate[m] << "\n";
	}
    }

/*
 *  Find nearest STDVQ codeword (private helper function).
 */
PERF STDVQ::nearest(const DATABLK& db, int& bestm) const {

    // Skip over empty cells.
    for (int m=0; rate[m]==Inf && m<size; m++);
    if (m == size) { m=0; rate[0]=0; }

    bestm = m;
    PERF bestperf = cw[bestm]->test(db);
    bestperf.rate += logsize;
    for (m=bestm+1; m<size; m++) {
	if (rate[m]==Inf) continue;
	PERF perf = cw[m]->test(db);
	perf.rate += logsize;
	if (perf<bestperf) { bestm=m; bestperf=perf; }
	}
    return bestperf;
    }

/*
 *  Train an STDVQ (on the data).
 */
void STDVQ::train(const DATASEQ& ds, double precision) {

    // Train, always filling empty cells.
    FullSearchVQ::train(ds,precision,1);
    }

/*
 *  Copy an STDVQ.
 */
VQ* STDVQ::copy() const {

    return new STDVQ(*this);
    }

/*
 *  Output an STDVQ.
 */
void STDVQ::put(ostream& os) const {

    if (entropy) os << "# entropy = " << entropy << "\n";
    if (fullcnt) os << "# fullcnt = " << fullcnt << "\n";
    os << "STDVQ ";
    FullSearchVQ::put(os);
    }

/*
 *  Allocate space for and construct a TREENODE from an open input stream.
 */
TREENODE* new_TREENODE(istream& is) {

    char vqtype[256];
    is >> vqtype;
    while (vqtype[0]=='#') {			// get rid of comments
	is.get(vqtype,256);
	is >> vqtype;
	}
    if (strcmp(vqtype,"NTN")==0)           	return new NTN(is);
    else { Error(vqtype," Unknown TREENODE");	return NULL; }
    }

/*
 *  Construct a TREENODE from another TREENODE.
 */
TREENODE::TREENODE(const TREENODE& treenode) : STDVQ(treenode) {

    nodecw = treenode.nodecw->copy();
    nodepop = treenode.nodepop;
    nodeperf = treenode.nodeperf;
    deltaperf = treenode.deltaperf;
    lambda = treenode.lambda;
    lambdaopt = treenode.lambdaopt;
    subtree = treenode.subtree;
    split = treenode.split;
    }

/*
 *  Destruct a TREENODE.
 */
TREENODE::~TREENODE() {

    delete nodecw;
    }

void TREENODE::splitData() {

    // Split data among children.
    long totalpop = 0;
    for (int m=0; m<size; m++) {
	((TREENODE*)cw[m])->nodepop = 0;
	((TREENODE*)cw[m])->nodeperf = 0;
	((TREENODE*)cw[m])->subtree = 0;
	}
    while (vqds.head) {
	totalpop++;
	DATABLK* dbp = vqds.pop();
	PERF tmp_perf = nearest(*dbp,m);
	tmp_perf.rate -= logsize;	// don't include parent's rate
	((TREENODE*)cw[m])->nodepop++;
	((TREENODE*)cw[m])->nodeperf += tmp_perf;
	((TREENODE*)cw[m])->vqds.push(dbp);
	}
    nodepop = totalpop;

    // Compute rate and entropy.
    entropy = 0.0;
    fullcnt = 0;
    for (m=0; m<size; m++)
	if (((TREENODE*)cw[m])->nodepop) {
	    rate[m] = -log((double)((TREENODE*)cw[m])->nodepop/totalpop)/log(2.0);
	    entropy += rate[m]*((TREENODE*)cw[m])->nodepop;
	    fullcnt++;
	    }
	else rate[m] = Inf;
    if (totalpop) entropy /= totalpop;

    // Compute deltaperf.
    deltaperf = -nodeperf;
    for (m=0; m<size; m++)
	deltaperf += ((TREENODE*)cw[m])->nodeperf;
    deltaperf.rate += deltaRate() * totalpop;
}

void TREENODE::splitNode(double precision, int default_size) {

    // Allocate children if necessary.
    if (size == 0) {
	cw = new VQ*[default_size];
	rate = new double[default_size];
	logsize = log(default_size)/log(2.0);
	for (int m=0; m<default_size; m++) {
	     cw[m] = (TREENODE*) copy();
	     rate[m] = logsize;
	     }
	size = default_size;
	}

    // Design children.
    subtree = 1;
    if (vqds.head) train(vqds,precision);

    // Partition data among children and compute delta performance.
    splitData();
    if (deltaperf.rate > 0) 
	lambda = lambdaopt = -deltaperf.dist / deltaperf.rate;
    else lambda = lambdaopt = 0;
    }

PERF TREENODE::doBestSplit(double precision) {

    if (split) {
        if (lambdaopt == 0) return 0;   // when is this used?
        for (int m=0; lambdaopt!=((TREENODE*)cw[m])->lambdaopt; m++);
        PERF delta = ((TREENODE*)cw[m])->doBestSplit(precision);
        for (m=0; m<size; m++)
            if ((m==0) || ((TREENODE*)cw[m])->lambdaopt > lambdaopt)
                lambdaopt = ((TREENODE*)cw[m])->lambdaopt;
        return delta;
        }
    else {
        split = 1;
        for (int m=0; m<size; m++) {
            ((TREENODE*)cw[m])->split = 0;
            ((TREENODE*)cw[m])->splitNode(precision,size);
            if ((m==0) || ((TREENODE*)cw[m])->lambdaopt > lambdaopt)
                lambdaopt = ((TREENODE*)cw[m])->lambdaopt;
            }
        return deltaperf;
        }
    }

/*
 *  Remove leaves from unsplit nodes.
 */
void TREENODE::delete_unsplit() {

    if (split) 
	for (int m=0; m<size; m++) 
	    ((TREENODE*)cw[m])->delete_unsplit();
    else {
	if (size>0) {
	    for (int m=0; m<size; m++)
		delete cw[m]; // recursive destruct
	    delete cw;
	    delete rate;
	    cw = 0;
	    rate = 0;
	    size = 0;
	    }
	}
    }
 
/*
 *  Remove training sequence from nodes.
 */
void TREENODE::delete_vqds() {

    // Empty vqds.
    while (vqds.head) delete vqds.pop();

    // Repeat for each subtree.
    for (int m=0; m<size; m++)
	((TREENODE*)cw[m])->delete_vqds();
    }

/*
 * Count the number of codewords below a given node.
 */
long TREENODE::countCWs() {

    long count = 0;
    for (int m=0; m<size; m++)  {
	count += ((TREENODE*)cw[m])->countCWs();
	}
    count +=size;
    return(count);
    }

/*
 *  Train a TREENODE (on the data).
 */
void TREENODE::train(const DATASEQ& ds, double precision) {

    if (size==0 || subtree==0)
	nodecw->train(DATASEQ(ds,width),precision);
    else {
	for (int m=0; m<size; m++) ((TREENODE*)cw[m])->subtree = 0;
	STDVQ::train(ds,precision);
	}
    }


/*
 *  Input a TREENODE.
 */
void TREENODE::get(istream& is) {

    is >> width;
    is >> size;		// size == 0 for leaf

    if (size > 0) {
	cw = new VQ*[size];
	rate = new double[size];

	// Next M entries specify the codewords.
	int M;
	is >> M;
	if (M<=0 || M>size) Error("bad format for TREENODE");
	for (int m=0; m<M; m++) {
	    cw[m] = new_TREENODE(is);
	    is >> rate[m];
	    }

	// Remaining codewords are copies of codeword 0.
	for (; m<size; m++) {
	    cw[m] = cw[0]->copy();
	    rate[m] = rate[0];
	    }

	logsize = log(size)/log(2.0);
	}
    else {
	cw = 0;
	rate = 0;
	}

    entropy = 0.0;
    fullcnt = 0;

    nodecw = new_VQ(is);
    }

/*
 *  Output a TREENODE.
 */
void TREENODE::put(ostream& os) const {

    if (size > 0) {
	os << " " << size << "\n";
	for (int m=0; m<size; m++) {
	    os << "# prob = " << pow(2.0,-rate[m]) << "\n";
	    cw[m]->put(os);
	    os << rate[m] << "\n";
	    }
	}
    else os << "\n";

    nodecw->put(os);
    }

/*
 *  Return the change in rate associated with splitting a node.
 */
double NTN::deltaRate() const {

    return 1.0/nodepop;
    }

/*
 *  Test the child nodes.
 */
PERF NTN::test(const DATABLK& db) const {

    PERF perf(0,0);
    DATASEQ ds(db,width);
    if (size==0 || !subtree) {
        // Treat this node as a leaf.
        perf = nodecw->test(ds);
        }
    else {
        // Treat this node as the root of a subtree.
        for (DATABLK* dbp=ds.head; dbp; dbp=dbp->next) {
            int bestm;
            for (int m=0; m<size; m++) ((TREENODE*)cw[m])->subtree = 0;
            nearest(*dbp,bestm);
            ((TREENODE*) cw[bestm])->subtree = 1;
            perf += cw[bestm]->test(*dbp);
            perf.rate += logsize;
            }
        }
    return perf;
    }

/*
 *  Copy an NTN.
 */
VQ* NTN::copy() const {

    return new NTN(*this);
    }

/*
 *  Output an NTN.
 */
void NTN::put(ostream& os) const {

    // if (entropy) os << "# entropy = " << entropy << "\n";
    // if (fullcnt) os << "# fullcnt = " << fullcnt << "\n";
    os << "NTN " << width << " " << size;
    TREENODE::put(os);
    }

/*
 *  Construct a TreeStructVQ from another TreeStructVQ.
 */
TreeStructVQ::TreeStructVQ(const TreeStructVQ& treestructvq) {

    constraint = treestructvq.constraint;
    root = (TREENODE*)(treestructvq.root->copy());
    }

/*
 *  Destruct a TreeStructVQ.
 */
TreeStructVQ::~TreeStructVQ() {

    delete root; // recursive destruct
    }

/*
 *  Test a TreeStructVQ (on the data).
 */
PERF TreeStructVQ::test(const DATABLK& db) const {

    root->subtree = 1;
    return root->test(db);
    }

/*
 *  Input a TreeStructVQ.
 */
void TreeStructVQ::get(istream& is) {

    is >> constraint;		// depth for TSVQ
    root = new_TREENODE(is);
    }

/*
 *  Output a TreeStructVQ.
 */
void TreeStructVQ::put(ostream& os) const {

    os << constraint << "\n";
    os << "Total_codewords " << root->countCWs() + 1<< " codewords.\n";
                                // 1 is the root cw
    root->put(os);
    }

/*
 *  Train a MTSVQ (on the data).
 */
void MTSVQ::train(const DATASEQ& ds, double precision) {

    root->vqds += DATASEQ(ds,root->width);

    root->nodecw->train(root->vqds,precision);
    root->nodeperf = root->nodecw->test(root->vqds);
    root->splitNode(precision);
    root->split = 0;

    PERF perf = root->nodeperf;
    double currentlambda = Inf;
    double lastrate = -1.0;

    int done = 0;
    while (!done &&
          (constraint>=0) && ((((root->countCWs()+2)/2)+1)/2 < constraint)) {
        if (perf.rate/N-lastrate >= .1) {
            // cerr << "Growing Distortion: " << perf.dist/N;
            // cerr << " Rate: " << perf.rate/N;
            // cerr << " Lambda: " << currentlambda << "\n";
            lastrate = perf.rate/N;
            }

        PERF deltaperf = root->doBestSplit(precision);
        // cerr << " NUmber: " << (((root->countCWs()+2)/2)+1)/2 << "\n";
        if (deltaperf.rate > 0) {
            perf += deltaperf;
            currentlambda = -deltaperf.dist / deltaperf.rate;
            }
        else done = 1;  // happens when all leaves are pure
        }

    // cerr << "Growing Distortion: " << perf.dist/N;
    // cerr << " Rate: " << perf.rate/N;
    // cerr << " Lambda: " << currentlambda << "\n";

    root->delete_unsplit();
    root->delete_vqds();
    }

/*
 *  Copy a MTSVQ.
 */
VQ* MTSVQ::copy() const {

    return new MTSVQ(*this);
}

/*
 *  Output a MTSVQ.
 */
void MTSVQ::put(ostream& os) const {

    os << "MTSVQ ";
    TreeStructVQ::put(os);
}



