/* LICENSE:
  =========================================================================
    CMPack'04 Source Code Release for OPEN-R SDK 1.1.5-r2 for ERS7
    Copyright (C) 2004 Multirobot Lab [Project Head: Manuela Veloso]
    School of Computer Science, Carnegie Mellon University
    All rights reserved.
  ========================================================================= */


// Ugly parser to encode C4.5 output trees in a more useful
// form. Note the use of the word ugly to describe it. But
// at least it lets us use C4.5 to make decision trees
// without redistributing any of the code that comes with it
// (i.e. we respect the license)

// To use it, you need the [comment free] <stem>.names file
// and the text representation of the pruned tree output
// by C4.5. CUT OUT EVERYTHING BUT THE TREE - the parser is
// crappy and only speaks tree.

#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#include <algobase.h>

struct DTreeNode {

  DTreeNode() {
    type = T_CHOICE;
    test_type = TEST_LTE;
    left = right = NULL;
    test_value = 0.0;
    test_variable = leaf_class = 0;
  }

  enum types { T_CHOICE, T_LEAF };

  // Why doesn't C4.5 just output a frickin' sorted tree?
  // Common usage puts less than down the left branch. But no,
  // they need to introduce bugs into my crappy parser.
  // :.......O( Whaaaaaaaaaaaaaaaaaaa!!!
  enum tests { TEST_LTE, TEST_GT };
  
  int type;

  /* Only valid if this is a choice node. If
     the test attribute is less than or equal to
     the test value, we traverse the left subtree. Otherwise,
     it's right for us. */
  double test_value;
  int test_variable;
  int test_type;

  /* Only valid if this is a leaf node. */
  int leaf_class;

  DTreeNode *left;
  DTreeNode *right;
};

int num_vars = 0;
char **var_names = NULL;

int num_classes = 0;
char **class_names = NULL;

void readVarNames(char *names_file);
long getFileLength(char *filename);
char *slurp(char *filename, long *len_return);
int getVarIndex(char *var_name);
int getClassIndex(char *class_name);

DTreeNode *getTree(char *tree_file);
DTreeNode *recursiveGetTree(int depth);

/* Convert all choice nodes so that <= branch is the
   left subtree.
*/
void normalizeTests(DTreeNode *head);

/* Write this node in the form of:
   int type (0 = choice node, 1 = leaf)

   if(leaf)
   
   int class 0 -> (n - 1) in order defined in .names

   else
  
   int test_variable 0 -> (n - 2) in order defined in names
   double test_value ( features[test_var] <= test_value => left branch
   write left subtree
   write right subtree
*/
void writeTree(DTreeNode *head, char *filename);
void recursiveWriteTree(DTreeNode *head, FILE *outfile);

int main(int argc, char **argv) {

  if(argc!=3) {
    printf("Usage:\nparse_tree <names file> <tree filename>");
  }

  readVarNames(argv[1]);

  DTreeNode *the_tree = getTree(argv[2]);

  // Convert > nodes to <= nodes for consistency.
  normalizeTests(the_tree);

  writeTree(the_tree, "out.bintree");

  return 0;
}

/* Make some attempt at parsing a C4.5 .names file */
void readVarNames(char *names_file) {

  long length = 0;
  char *buf = NULL;
  char *token = NULL;
  char sep[] = " ,:\n";

  buf = slurp(names_file, NULL);

  /* We make a first pass through the names file. Every line
     should end with a period. We do NOT handle comments right
     now. We expect a comma separated list of class names followed
     by period. We expect an unknown number of variables, one per
     line in the form: varname:domain. We only handle continuous
     right now.
  
     In this pass we just count stuff so we know how much storage
     to allocate.
  */

  bool finished_classes = false;
  num_vars = 0;
  num_classes = 0;
  int max_name_len = 0;
  token = strtok(buf, sep);
  while(token!=NULL) {

    if(!finished_classes) {
      if(!strcmp(".", token)) {
	finished_classes = true;
      } else {
	// The above only works if the file contains
	// <last class> <space> . - if the period is not
	// separated from the last class name you get
	// <last class>. with no space. So check for that
	// and handle it.
	if(token[strlen(token) - 1]=='.') {
	  token[strlen(token) - 1] = 0;
	  finished_classes = true;
	}
	// Our parser bites, but at least be verbose for debugging.
	printf("Found class: %s\n", token);
	max_name_len = max(max_name_len, (int)strlen(token));
	num_classes++;
      }
    } else {
      if(strcmp("continuous.", token)!=0) {
	printf("Found variable: %s\n", token);
	max_name_len = max(max_name_len, (int)strlen(token));
	num_vars++;
      }
    }

    // get next token
    token = strtok(NULL, sep);
  }
  
  // Allocate buffers
  var_names = new (char *)[num_vars];
  class_names = new (char *)[num_classes];

  for(int i=0; i<num_vars; i++)
    var_names[i] = new char[max_name_len];

  for(int i=0; i<num_classes; i++)
    class_names[i] = new char[max_name_len];

  // Now we start all over and actually keep the names...
  delete[] buf;
  buf = NULL;
  token = NULL;
  
  buf = slurp(names_file, NULL);

  /* Second pass, same as the first, only a whole lot louder
     and a whole lot worse.

     (also, we actually store the names)
  */

  finished_classes = false;
  num_vars = 0;
  num_classes = 0;
  token = strtok(buf, sep);
  while(token!=NULL) {

    if(!finished_classes) {
      if(!strcmp(".", token)) {
	finished_classes = true;
      } else {
	// The above only works if the file contains
	// <last class> <space> . - if the period is not
	// separated from the last class name you get
	// <last class>. with no space. So check for that
	// and handle it.
	if(token[strlen(token) - 1]=='.') {
	  token[strlen(token) - 1] = 0;
	  finished_classes = true;
	}
	strcpy(class_names[num_classes], token);
	num_classes++;
      }
    } else {
      if(strcmp("continuous.", token)!=0) {
	strcpy(var_names[num_vars], token);
	num_vars++;
      }
    }

    // get next token
    token = strtok(NULL, sep);
  }
}

long getFileLength(char *filename) {
  FILE *infile = fopen(filename, "rb");
  
  if(infile==NULL)
    return -1;
  
  fseek(infile, 0,SEEK_END);
  
  long retval = ftell(infile);
  
  fclose(infile);
  
  return retval;
}

/* Reads an entire file into a char buf. You should
   call delete[] to free the returned value. This
   function is nice and adds a null terminator at
   the end of the string.
*/
char *slurp(char *filename, long *len_return) {
  
  long length = getFileLength(filename);
  
  if(length==-1)
    return NULL;
  
  char *retval = new char[length + 1];
  
  FILE *infile = fopen(filename, "rb");
  
  if(fread(retval, 1, length, infile)!=length)
    printf("Did not read full file.\n");
  
  fclose(infile);
  
  retval[length] = 0; // terminate string

  if(len_return!=NULL)
    *len_return = length; // should really reflect fread's return
  
  return retval;
}

int getVarIndex(char *var_name) {
  
  for(int i=0; i<num_vars; i++)
    if(!strcmp(var_name, var_names[i]))
      return i;

  return -1;
}

int getClassIndex(char *class_name) {
  
  for(int i=0; i<num_classes; i++)
    if(!strcmp(class_name, class_names[i]))
      return i;

  return -1;
}

DTreeNode *getTree(char *tree_file) {

  char sep[] = "\n :";
  char *token = NULL;
  char *buf = slurp(tree_file, NULL);

  DTreeNode *retval = new DTreeNode();

  // We assume the tree has more than 1 node...
  retval->type = DTreeNode::T_CHOICE;

  // Read our first node to setup strtok and
  // then parse tree recursively.
  token = strtok(buf, sep);
  retval->test_variable = getVarIndex(token);
  if(retval->test_variable==-1)
    printf("Unable to lookup variable %s\n", token);
  
  // DEBUG
  // for(int i=0; i<depth; i++)
  // printf(" ");
  printf("%s ", token);
  
  token = strtok(NULL, sep);
  if(strcmp(token, "<=")==0)
    retval->test_type = DTreeNode::TEST_LTE;
  else if(strcmp(token, ">")==0)
    retval->test_type = DTreeNode::TEST_GT;
  else
    printf("Invalid comparison operator %s\n", token);
  
  // DEBUG
  printf("%s ", token);
  
  token = strtok(NULL, sep);
  retval->test_value = atof(token);

  // DEBUG
  printf("%s\n", token);

  retval->left = recursiveGetTree(1);

  // Now do the right subtree
  token = strtok(NULL, sep);
  if(getVarIndex(token)==-1)
    printf("Unable to lookup variable %s\n", token);
    
  if(getVarIndex(token)!=retval->test_variable)
    printf("Left and right subtree vars do not match!\n");

  // DEBUG
  // for(int i=0; i<depth; i++)
  // printf(" ");
  printf("%s ", token);

  token = strtok(NULL, sep);
  if(retval->test_type==DTreeNode::TEST_LTE &&
     strcmp(token, ">")!=0)
    printf("Invalid comparison operator %s\n", token);
  else if(retval->test_type==DTreeNode::TEST_GT &&
	  strcmp(token, "<=")!=0)
    printf("Invalid comparison operator %s\n", token);

  // DEBUG
  printf("%s ", token);

  token = strtok(NULL, sep);
  if(retval->test_value!=atof(token))
    printf("Right and left subtree values do not match!\n");

  // DEBUG
  printf("%s\n", token);
  
  retval->right = recursiveGetTree(1);

  return retval;
}

DTreeNode *recursiveGetTree(int depth) {

  char *token;
  char sep[] = "\n :";
  DTreeNode *retval = new DTreeNode();

  token = strtok(NULL, sep);

  // What type of node are we?
  if(getVarIndex(token)!=-1) {
    // We're a choice node
    retval->type = DTreeNode::T_CHOICE;

    // Parse our left subtree
    retval->test_variable = getVarIndex(token);

    // DEBUG
    for(int i=0; i<depth; i++)
      printf(" ");
    printf("%s ", token);

    token = strtok(NULL, sep);
    if(strcmp(token, "<=")==0)
      retval->test_type = DTreeNode::TEST_LTE;
    else if(strcmp(token, ">")==0)
      retval->test_type = DTreeNode::TEST_GT;
    else
      printf("Invalid comparison operator %s\n", token);

    // DEBUG
    printf("%s ", token);
    
    token = strtok(NULL, sep);
    retval->test_value = atof(token);
    
    // DEBUG
    printf("%s\n", token);

    retval->left = recursiveGetTree(depth + 1);

    // Now process the right subtree
    token = strtok(NULL, sep);
    if(getVarIndex(token)==-1)
      printf("Unable to lookup variable %s\n", token);
    
    if(getVarIndex(token)!=retval->test_variable)
      printf("Left and right subtree vars do not match!\n");

    // DEBUG
    for(int i=0; i<depth; i++)
      printf(" ");
    printf("%s ", token);
    
    token = strtok(NULL, sep);
    if(retval->test_type==DTreeNode::TEST_LTE &&
       strcmp(token, ">")!=0)
      printf("Invalid comparison operator %s\n", token);
    else if(retval->test_type==DTreeNode::TEST_GT &&
	    strcmp(token, "<=")!=0)
      printf("Invalid comparison operator %s\n", token);
    
    // DEBUG
    printf("%s ", token);

    token = strtok(NULL, sep);
    if(retval->test_value!=atof(token))
      printf("Right and left subtree values do not match!\n");
    
    printf("%s\n", token);

    retval->right = recursiveGetTree(depth + 1);

  } else if(getClassIndex(token)!=-1) {
    // We're a leaf node
    retval->type = DTreeNode::T_LEAF;
    retval->leaf_class = getClassIndex(token);
  } else {
    printf("Unable to parse %s at node start\n", token);
  }
  
  return retval;
}

void normalizeTests(DTreeNode *head) {
  if(head->type==DTreeNode::T_LEAF)
    return;

  if(head->test_type==DTreeNode::TEST_GT) {
    DTreeNode *temp = head->left;
    head->left = head->right;
    head->right = temp;

    head->test_type = DTreeNode::TEST_LTE;
  }

  normalizeTests(head->left);
  normalizeTests(head->right);
}

void writeTree(DTreeNode *head, char *filename) {

  FILE *outfile = fopen(filename, "wb");

  if(head->type==DTreeNode::T_LEAF) {
    fwrite(&head->type, 1, sizeof(int), outfile);
    fwrite(&head->leaf_class, 1, sizeof(int), outfile);
  } else {
    fwrite(&head->type, 1, sizeof(int), outfile);
    fwrite(&head->test_variable, 1, sizeof(int), outfile);
    fwrite(&head->test_value, 1, sizeof(double), outfile);

    recursiveWriteTree(head->left, outfile);
    recursiveWriteTree(head->right, outfile);
  }

  fclose(outfile);
}

void recursiveWriteTree(DTreeNode *head, FILE *outfile) {
  
  if(head->type==DTreeNode::T_LEAF) {
    fwrite(&head->type, 1, sizeof(int), outfile);
    fwrite(&head->leaf_class, 1, sizeof(int), outfile);
  } else {
    fwrite(&head->type, 1, sizeof(int), outfile);
    fwrite(&head->test_variable, 1, sizeof(int), outfile);
    fwrite(&head->test_value, 1, sizeof(double), outfile);
    
    recursiveWriteTree(head->left, outfile);
    recursiveWriteTree(head->right, outfile);
  }
}
