/*_____________________________________________________________________
Learning decision tree based agent. 
(c) Attila Szalkai, 2006.06.18 White Stone Studio

The decision tree generation is not recursive. 
It accepts samples runtime.

Interface description:
----------------------

Agent a;
a.init( nofInputs, nofActs, maxBranch, maxSamples, maxValues );
a.setSamples( ptr, nofSamples );
a.addSample( ptr );
a.genTree( act );
a.genTrees();
a.decide( sample, act );

_______________________________________________________________________

Example:
--------
Agent a;
short ti[6][7] = {
    0,1,1,0,1,  0, 1, 
    0,1,0,1,0,  1, 1,
    1,1,0,1,0,  0, 0,
};
short maxvs[] = { 2,2,2,2,2,  2, 2 };

void main() {
    a.init(5,2,200,100, maxvs);
    a.setSamples( ts, 1);            

    a.addSample( ti[0] );    
    a.genTrees();
    int ret = a.decide( &a.samples[0], 1);

    a.addSample( ti[1] );    
    a.genTrees();
    int ret = a.decide( &a.samples[1], 0);        
}
_______________________________________________________________________*/

int dlog = 0;
#define MAXCASE 8
#define LDTREE_TEST

#include <stdio.h>
#include <conio.h>
#include <stdlib.h>
#include <stdio.h>

//------------------------------------
// Sample: properties, actions
//------------------------------------
struct Sample {
    short *is;
    short *as;
};


//------------------------------------
// Node in the decision tree
//------------------------------------
struct Branch {
    short prop;
    short cases[MAXCASE];
};


//----------------------------------------------------------------------
// A property by its N possible values divides the samples to N groups
//----------------------------------------------------------------------
struct Division {
    short total;
    short samplen[MAXCASE];    
    short numAct[MAXCASE];

    void  init();
    short mostAct(short maxv);
    short count(int maxv);
};


//----------------------------------------------------------------------
// Iterative instead of recursive. This is the stack
//----------------------------------------------------------------------
struct Subst {
    short prop;     // current property of this level
    short val;      // yes or no branch is processed at this level
    short node;     // its node in tree. stack[sp-1].node mindig a parent    
};


//----------------------------------------------------------------------
// The decision tree is packed within an agent
//----------------------------------------------------------------------
struct Agent {
    Sample *samples;
    int    samplen;
    int    maxSample;
    int    addSample( short *d );
    void   setSamples( short *d, int n );
    bool   filterSample(Sample *s, int sp);
    int    sampleDivision(int pi, int act, int sp, Division *d);

    int    inpn;    
    int    actn;
    int   *acts;
    short *maxvs; // max value of input or action
    void   init(int maxInp, int maxAct, int maxBranch, int maxSamp, short *mvs);

    Branch *tree;
    int    branchn;
    int    genBranch(int &sp, int act);
    int    genTree(int act);
    void   genTrees();
    void   logTree(int n, int e);
    int    decide(Sample *s, int n);    
};



//========================= DIVISION =========================

//--------------------------------------------------
// Init division. 
// Prepare to find most selected value for this act 
//--------------------------------------------------
void Division::init()
{    
    total=0;    
    for(int i=0; i<MAXCASE; i++)
        samplen[i]=0;
    for(i=0; i<MAXCASE; i++)
        numAct[i]=0;
}

//---------------------------------------------------------
// The act value of the most sample
//---------------------------------------------------------
short Division::mostAct(short maxv)
{
    int bestVal=0;
    int bestAct=-1;
    for(int i=0; i<maxv; i++) {
        if (bestAct<0 || bestVal<numAct[i]) {
            bestVal = numAct[i];
            bestAct = i;
        }
    }
    return bestAct;
}


//---------------------------------------------------------
// What big is the biggest sample set in current division
//---------------------------------------------------------
short Division::count(int maxv)
{
    int dif = 0;
    for (int i=0; i<maxv; i++)
        if (dif < samplen[i])
            dif = samplen[i];
    return dif;
}



//================================== AGENT =========================================

// Work variables of tree generation
static Subst stack[100];    

//--------------------------------------------------
// Use the tree to decide action for a situation.
// The situation is represented as a sample
//--------------------------------------------------
int Agent::decide(Sample *s, int ai)
{
    int bi = acts[ai];
    if (dlog) printf("%d %d %d %d %d   ", s->is[0], s->is[1], s->is[2], s->is[3], s->is[4] );
    while(1) {                
        // Got further on sample selected case
        Branch *b = &tree[ bi ];
        short pi = b->prop;
        short val= s->is[ pi ];
        if (dlog) printf("%d)%d=%d ",bi, pi, val);
        bi = b->cases[ val ];        

        // Leaf returns the dup val.
        if (bi<=0) {
            if (dlog) printf("%d. ", -bi);
            return -bi;
        }                        
    }

    // There was no leaf on decision tree for this sample
    return -1;
}


//----------------------------------------------------------------------
// Add a new sample on the fly. 
// To take effect, the tree must be regererated by genTrees()
//----------------------------------------------------------------------
int Agent::addSample( short *d )
{   
    if (samplen>=maxSample-1)
        return -1;
    samples[samplen].is = d; d += inpn;        
    samples[samplen].as = d; d += actn;       
    samplen++;    
    return 0;
}


//----------------------------------------------------------------------
// Set the initial samples if there is
//----------------------------------------------------------------------
void Agent::setSamples( short *d, int n )
{    
    for(int i=0; i<n; i++) {
        samples[i].is = d; d += inpn;        
        samples[i].as = d; d += actn;
    }    
    samplen = n;    
}

//----------------------------------------------------------------------
// Filer out samples of not this branch
//----------------------------------------------------------------------
bool Agent::filterSample(Sample *s, int sp)
{        
    int i;
    for(i=0; i<sp && s->is[ stack[i].prop ] == stack[i].val; i++);
    return i>=sp;
}

//-----------------------------------------------------------------------
// RETURNS:
// total:      How many samples depends on this property
// samplen[N]: How many samples need this propery to be the value n
//-----------------------------------------------------------------------
int Agent::sampleDivision(int pi, int act, int sp, Division *d)
{           
    // Looks through all the samples regarding to this property    
    d->init();
    Sample *s = samples;
    for(int si=0; si<samplen; si++, s++) if (s->as[act]<MAXCASE)
    {       
        // Filer out samples of not this branch
        if (!filterSample(s,sp)) 
            continue;

        // Not regarded and counts
        d->total++;
        d->samplen[ s->is[pi] ]++;

        // Most value of act
        d->numAct[ s->as[act] ]++;
    }
    if  (!d->total) return -1;

    // The value of the act really depends on this property.
    return d->count( maxvs[pi] );
}


//---------------------------------------------------------------------------------------
// Extend a subtree
// The samples are divided to MAXCASE amount of groups by value of a property.
// The stack stores the current values of already regarded properties leading 
// to this node in the decision tree.
// The order of the props can be different on different branches of the tree.
//---------------------------------------------------------------------------------------
int Agent::genBranch(int &sp, int act)
{    
    // Mark already considered props
    static int done[1000];
    for(int i=0; i<inpn; i++) done[i]=0;    
    for(i=0; i<sp; i++) done[ stack[i].prop ]=1;

    // Has just arrived to this node to discover. sp is not yet processed.
    static int kkk=0;
    if (dlog) {
        printf( "%d Subtree sp=%d ", kkk++, sp);
        for(int i=0; i<sp; i++) 
            printf( "[%d]=%d ", stack[i].prop, stack[i].val);        
        for(i=0; i<inpn; i++) printf( "%d:%d ", i, done[i]);
        printf( "\n");
    }
        
    // Find the most balancing property
    static Division div,bestDiv;
    int prop;        
    int bestProp=-1;
    int bestDif =0;                    
    for(prop=0; prop<inpn; prop++) if (!done[prop]) 
    {                   
        int dif = sampleDivision(prop, act, sp, &div);
        if (dif>=0 && (bestProp<0 || dif<bestDif)) {
            bestProp = prop;
            bestDif  = dif;    
            bestDiv  = div;                 
        }
	    if (dlog) printf( "  chk prop %d  total=%d s#0=%d s#1=%d\n", prop, div.total, div.samplen[0], div.samplen[1] );
    }        
    
    // If all samples would go to the same brach, there is no branch, the result does not depend on the property.
    // This is the worst prop. Special case, if not all the value of the act is the same, return the most popular.
    if (bestDif == bestDiv.total || bestProp<0)
    {                
        int bestAct = bestDiv.mostAct( maxvs[act] );        
        if (sp) {
            short parent = stack[sp-1].node;                
            short val    = stack[sp-1].val;
            tree[ parent ].cases[ val ] = -bestAct;
            if (dlog) printf( "of %d case %d is\n", parent, val);
        }
        if (dlog) printf( "Leaf act=%d   a0:%d  a1:%d \n", bestAct, bestDiv.numAct[0], bestDiv.numAct[1] );
        if (dlog && 'q'==getch()) { if (dlog) logTree(0,branchn); exit(0); }     
        return 0;
    }

    // Link this branch to parent in tree
    Branch *b = &tree[ branchn++ ];
    if (sp) {
        short parent = stack[sp-1].node;                
        short val    = stack[sp-1].val;
        tree[ parent ].cases[ val ] = branchn-1;
        if (dlog) printf( "branch %d linked to %d case:%d\n", branchn-1, parent, val);
    }

    // The best dividing property. Fill this new branch    
    b->prop = bestProp;
    stack[sp].node = branchn-1;
    stack[sp].prop = bestProp;
    stack[sp].val  = 0;
    sp++;

    // We know which remaining property divides the remaining samples the best way.        
    if (dlog) printf( "Chosen prop=%d\n", bestProp);
    if (dlog && 'q'==getch()) { if (dlog) logTree(0,branchn); exit(0); }     
    return 1;
}


//------------------------------------------------------------
// Beginning with most dividing properties divide samples
// to create one branch of the tree
//------------------------------------------------------------
int Agent::genTree(int act)
{    
    int ti=branchn;
    int sp = 0;
        
    // Loops on properties until there is a sample dividing property
    while(1)
    {
        // Process all cases on the branch. Step forward. It is 1 while it could do so.
        // (In the samples values of properties are the same as stored on stack before sp.)
        while( genBranch(sp,act) );
        
        // Leaf is reached. Loop to step back. 
        while(1)
        {
            // This node ended to leaf. See parent's next branch
            if (dlog) printf( "Step back .. " );                        
            if (--sp<0) 
                return ti;     

            stack[sp].val++;            
            if (dlog) printf( "next sp:%d  prop:%d  val=%d\n", sp, stack[sp].prop, stack[sp].val);
            if ( stack[sp].val < maxvs[ stack[sp].prop ]) {                
                // Go to genBranchre having the current property and value on top of stack.
                sp++;
                break;
            }            
        }
    }
    return -1;
}


//------------------------------------------------------------
// Generate trees of each action of samples
//------------------------------------------------------------
void Agent::genTrees()
{
    int i;
    for(i=0; i<actn; i++) {
        acts[i] = genTree(i);
    }

    if (dlog) {
        for(i=0; i<actn; i++) 
            logTree(acts[i], i<actn-1 ? acts[i+1] : branchn );
    }
}


//------------------------------------------------------------
// Log a part of the tree on console
//------------------------------------------------------------
void Agent::logTree(int n, int e)
{    
    printf("\nTree\n");
    for(int j=n; j<=e; j++) {
        printf("%3d on i[%d] ", j, tree[j].prop  );        
        for(int i=0; i<2; i++) {
            int v = tree[j].cases[i];
            if (v>0)         
                printf("%d=>goto %d ",i, v);
            else
                printf("%d=>ret %d ",i, -v);
        }
        printf("\n");
    }
}


//------------------------------------------------------------
// Allocate required size of static buffers
//------------------------------------------------------------
void Agent::init(int maxInp, int maxAct, int maxBranch, int maxSamp, short *mvs )
{
    maxvs= mvs;
    inpn = maxInp;
    actn = maxAct;
    acts = new int[ actn ];
    branchn=0;  
    tree = new Branch[ maxBranch ];      
    samples = new Sample[ maxSamp ];
    samplen = 0;
    maxSample = maxSamp;
}


//------------------------------------------------------------
// Generate cpp program from the generated tree. Just for fun.
//------------------------------------------------------------
void saveCpp(Agent *a) {
    FILE *fp = fopen("gen.cpp","wt");
    fprintf(fp,"int i[100];\n");
    
    for(int i=0; i<a->actn; i++)
    {
        fprintf(fp,"int D%d() {\n", i);
        int beg = a->acts[i];
        int end = i<a->actn-1 ? a->acts[i+1] : a->branchn;
        for(int j=beg; j<end; j++)
        {           
            fprintf(fp,"    L%d:", j); 
            for(int c=0; c<2; c++) {
                int v = a->tree[j].cases[c];
                if (v>0)         
                    fprintf(fp,"if (i[%d]==%d) goto L%d; ", a->tree[j].prop,c, v);
                else
                    fprintf(fp,"if (i[%d]==%d) return %d; ", a->tree[j].prop,c, -v);
            }
            fprintf(fp,"\n");            
        }
        fprintf(fp,"}\n\n");
    }            
    fclose(fp);
}

//______________________________________________________________________________________

#ifdef LDTREE_TEST
Agent a;
short maxvs[] = { 2,2,2,2,2,  2,2, };

// Sample: I I I A A OK

//--------------------------- Test 1
short ts[] = {
    1,0,1,1,0,  1, 1, 
    1,0,1,1,0,  0, 0,
    1,0,1,1,0,  0, 0,
    0,0,0,1,1,  0, 0,
    0,1,1,0,1,  0, 1,
    0,1,1,1,0,  1, 1,
    1,1,1,0,0,  1, 0,

    0,1,1,0,1,  0, 1,
    0,1,0,1,0,  1, 1,
    1,1,0,1,0,  0, 0,
    0,1,0,1,1,  1, 0,
    0,1,1,1,1,  1, 1,
    0,0,1,1,0,  0, 1,

};

void test() {
    a.init(5,2,200, 100, maxvs );
    a.setSamples( ts, 13);
    a.genTrees();

    //Does it work well with own samples?        
    for(int o=0; o<a.actn; o++)
        for(int i=0; i<13; i++) {
            int ret = a.decide( &a.samples[i], o);
            int exp = a.samples[i].as[o];
            printf("test sample %d: %d expected: %d  %s\n", i, ret, exp, ret==exp ? "OK" : "ERR" );
        }
    saveCpp(&a);
}

//--------------------------- Test 2

short ts2[] = {
    1,0,1,1,0,  1, 1, 
    1,0,1,1,0,  0, 0,
    1,0,1,1,0,  0, 0,
    0,0,0,1,1,  0, 0,
    0,1,1,0,1,  0, 1,
    0,1,1,1,0,  1, 1,
    1,1,1,0,0,  1, 0,
};

short ts2i[6][7] = {
    0,1,1,0,1,  0, 1, 
    0,1,0,1,0,  1, 1,
    1,1,0,1,0,  0, 0,
    0,1,0,1,1,  1, 0,
    0,1,1,1,1,  1, 1,
    0,0,1,1,0,  0, 1,
};

void testInc() {
    a.init(5,2,200,100, maxvs );
    a.setSamples( ts2, 7);            
    a.genTrees();

    int k,i;
    for(k=0; k<6; k++) {
        printf(" - add sample %d. Press any key.\n", k );        
        getch();        
        a.addSample( ts2i[k] );    
        a.genTrees();

        //Does it work well with own samples?        
        for(i=0; i<7+k; i++) {
            printf("test sample %2d ",i);
            for(int o=0; o<a.actn; o++) {
                int ret = a.decide( &a.samples[i], o);
                int exp = a.samples[i].as[o];
                printf("a%d val:%d exp:%d %s  ", o, ret, exp, ret==exp ? "OK" : "ER" );
            }
            printf("\n");
        }    
    }
}


void main() {
    testInc();
}

#endif