/* LICENSE:
  =========================================================================
    CMPack'04 Source Code Release for OPEN-R SDK 1.1.5-r2 for ERS7
    Copyright (C) 2004 Multirobot Lab [Project Head: Manuela Veloso]
    School of Computer Science, Carnegie Mellon University
    All rights reserved.
  ========================================================================= */

#include <math.h>
#include <stdlib.h>
#include <stdio.h>
#include <string.h>
#include <sys/time.h>
#include <unistd.h>

#include <getopt.h>

#include <vector>
using std::vector;
using std::pair;
using std::make_pair;

#include "agent/headers/Util.h"

#include "util.h"
#include "lprintf.h"
#include "image.h"

#include "learn.h"
#include "kernreg.h"
#include "misc_process.h"

bool save_weight_maps = false;
bool use_cache_file   = false;

/*
#define Y_VAR 8
#define U_VAR 8
#define V_VAR 8
#define OVERSAMP 16
*/

int NUM;

double GetTimeSec()
{
  timeval tv;
  gettimeofday(&tv,NULL);
  return((double)tv.tv_sec + tv.tv_usec*(1.0E-6));
}

//#define TEST_MISCLASSIFICATION_ERROR

#ifdef TEST_MISCLASSIFICATION_ERROR
static double random_class_prob=0.05;
#endif

example *make_list(SImage &img,SImage &label,color_info *colors,int num)
{
  int i,c,size;
  example *p,*list;
  rgb r,r2;

  int num_added,num_ignored,num_unknown;

  if(!img.samesize(label)) return(NULL);

  lprintf("  Adding %dx%d image ",img.w,img.h);

  size = img.w * img.h;
  list = NULL;

  num_added = num_ignored = num_unknown = 0;
  c = 0;
  r2 = colors[0].rcolor;

  for(i=0; i<size; i++){
    //{
    //  double eff_x,eff_y;
    //  eff_y = 2*((float)(i / img.w)/img.h - .5);
    //  eff_x = 2*((float)(i % img.w)/img.w - .5);
    //  if(sqrt(sq(eff_x) + sq(eff_y)) <= 1.2)
    //    continue;
    //}

    // find color class
    r = label.img[i];
    if(r2 != r){
      c=0;
      while((c < num) && colors[c].rcolor!=r) c++;
      r2 = r;
    }

    // add if a valid class was found
    if(c < num){
#ifdef TEST_MISCLASSIFICATION_ERROR
      if(drand48() < random_class_prob)
        c = lrand48() % num;
#endif

      r = img.img[i];
      if(r.red!=r.green || r.green!=r.blue || r.red==255){
	p = new example;
	if(p){
	  /*
	  p->y = bound(r.red   + (lrand48() % (Y_VAR*2+1) - Y_VAR),0,255);
	  p->u = bound(r.green + (lrand48() % (U_VAR*2+1) - U_VAR),0,255);
	  p->v = bound(r.blue  + (lrand48() % (V_VAR*2+1) - V_VAR),0,255);
	  */
	  p->y = r.red;
	  p->u = r.green;
	  p->v = r.blue;
	  p->label = c;

	  p->next = list;
	  list = p;
	  num_added++;
	}
      }else{
	// ignore garbage bits at bottom of image
	num_ignored++;
      }
    }else{
      if(r.red==r.green && r.green==r.blue && r.red<=192 && r.red>=64){
	num_ignored++;
      }else{
	num_unknown++;
	if(num_unknown < 100){
	  if(num_unknown == 1) lprintf("  Unknown label(s) at ");
	  lprintf("(%d,%d) ",i%img.w,i/img.w);
	}
      }
    }
  }

  if(num_unknown > 0) lprintf("\n");

  lprintf("  %d added, %d ignored, %d errors\n",num_added,num_ignored,num_unknown);

  return(list);
}

example *make_list(const char *image_file,const char *label_file,color_info *colors,int num)
{
  SImage img,label;

  lprintf("  Image='%s' Labels='%s'\n",image_file,label_file);

  img.load(image_file);
  label.load(label_file);

  return(make_list(img,label,colors,num));
}

void Classify(TMap &tmap,
	      SImage &image,SImage &output,
	      color_info *colors,int num)
{
  int i,c,size;
  rgb r,err_color;
#ifdef USE_AMBIGUOUS
  int sc;
  rgb r2;
#endif

  if(!image.samesize(output)) return;

  size = image.w * image.h;

  err_color.red   = 0;
  err_color.green = 255;
  err_color.blue  = 0;

  for(i=0; i<size; i++){
    r = image.img[i];
#ifdef USE_AMBIGUOUS
    c  = tmap.classify(r.red,r.green,r.blue);
    sc = tmap.classify_subclass(r.red,r.green,r.blue);
    r  = ( c < COLORS)? colors[ c].rcolor : err_color;
    r2 = (sc < COLORS)? colors[sc].rcolor : err_color;
    output.img[i] = blend(r,r2,0.33);
#else
    c = tmap.classify(r.red,r.green,r.blue);
    output.img[i] = (c < COLORS)? colors[c].rcolor : err_color;
#endif
  }
}

void Classify(TMap &tmap,
	      char *image_file,char *output_file,
	      color_info *colors,int num)
{
  SImage image,label,out;

  lprintf("  Test Image='%s' Output='%s'\n",image_file,output_file);

  image.load(image_file);
  out.copy(image);

  Classify(tmap,image,out,colors,num);
  out.save(output_file);
}

#define Y_LEV 16
#define U_LEV 64
#define V_LEV 64

rgb black    = {  0,  0,  0};
rgb darkgray = { 64, 64, 64};

bool find_weight_range(double &min_weight, double &max_weight, int color_idx, int size, double *wmap)
{
  bool found_one=false;

  min_weight=+HUGE_VAL;
  max_weight=0.0;

  for(int i=0; i<size; i++) {
    double weight=wmap[i*COLORS+color_idx];

    if(weight!=0.0)
      found_one=true;

    min_weight = min(min_weight, weight);
    max_weight = max(max_weight, weight);
  }

  return found_one;
}

// scale the weight for display
// returns a value in [0.0 - 1.0]
double scaleWeight(double weight, double min_weight, double max_weight) {
  if(weight < min_weight)
    return 0.0;

  double min_log = log(min_weight);
  double max_log = log(max_weight);

  double weight_log = log(weight);

  return (weight_log - min_log) / (max_log - min_log);
}

static const double black_frac=1e-200;

void dump_weightmaps(thresh &thr,char *out_prefix,double *wmap,int size, int num_colors)
{
  double max_weight=0.0,min_weight=+HUGE_VAL;
  double color_maxw,color_minw;

  for(int color_idx=0; color_idx<COLORS; color_idx++) {
    bool has_data;
    has_data = find_weight_range(color_minw,color_maxw,color_idx,size,wmap);

    if(has_data) {
      min_weight = min(min_weight, color_minw);
      max_weight = max(max_weight, color_maxw);
    }
  }

  min_weight = max(min_weight, black_frac * max_weight);

  for(int color_idx=0; color_idx<num_colors; color_idx++) {
    SImage out,level;
    static char out_file[1024];
    
    sprintf(out_file,"%s%d.ppm",out_prefix,color_idx);

    level.allocate(U_LEV,V_LEV);
    out.allocate(U_LEV*4,V_LEV*Y_LEV/4);

    for(int y_idx=0; y_idx<Y_LEV; y_idx++) {
      for(int u_idx=0; u_idx<U_LEV; u_idx++) {
        for(int v_idx=0; v_idx<V_LEV; v_idx++) {
          int box_idx = thr.box_loc(y_idx,u_idx,v_idx);

          double val;

          val = scaleWeight(wmap[box_idx*COLORS+color_idx],min_weight,max_weight);
          rgb *out_pix=&level.img[u_idx*V_LEV + v_idx];
          out_pix->red = out_pix->green = out_pix->blue = (int)(val*255.999);

          //printf("color %d, idx %d, full idx %d, weight %g, scaled %g, val %d\n",
          //       color_idx,box_idx,box_idx*COLORS+color_idx,
          //       wmap[box_idx*COLORS+color_idx],val,(int)(val*255.999));
        }
      }

      out.blit(level,U_LEV*(y_idx%4),V_LEV*(y_idx/4));
    }
    
    out.save(out_file);
  }
}

void dump_color(thresh &thr,char *out_file,double *wmap,uchar *tmap,int size, int num_colors, color_info *colors)
{
  SImage out,level;
  level.allocate(U_LEV,V_LEV);
  out.allocate(U_LEV*4,V_LEV*Y_LEV/4);

  for(int y_idx=0; y_idx<Y_LEV; y_idx++) {
    for(int u_idx=0; u_idx<U_LEV; u_idx++) {
      for(int v_idx=0; v_idx<V_LEV; v_idx++) {
        int box_idx = thr.box_loc(y_idx,u_idx,v_idx);

        double r,g,b;
        double total_weight;

        r = g = b = 0.0;
        total_weight = 0.0;
        
        for(int color_idx=0; color_idx<num_colors; color_idx++) {
          double val;
          
          val = wmap[box_idx*COLORS+color_idx];
          rgb src_color=colors[color_idx].rcolor;
          r += val * src_color.red;
          g += val * src_color.green;
          b += val * src_color.blue;

          total_weight += val;

          //printf("y %2d u %2d v %2d c %2d, w %10g, tw %10g, rgb'=(%10g,%10g,%10g)\n",
          //       y_idx,u_idx,v_idx,color_idx,val,total_weight,
          //       (total_weight!=0.0) ? (r/total_weight) : 0x80*1.0, 
          //       (total_weight!=0.0) ? (g/total_weight) : 0x00*1.0, 
          //       (total_weight!=0.0) ? (b/total_weight) : 0x80*1.0);
        }

        rgb *out_pix=&level.img[u_idx*V_LEV + v_idx];
        if(total_weight > 0.0) {
          out_pix->red   = (int)rint(r/total_weight);
          out_pix->green = (int)rint(g/total_weight);
          out_pix->blue  = (int)rint(b/total_weight);
        }
        else {
          out_pix->red   = 0x80;
          out_pix->green = 0x00;
          out_pix->blue  = 0x80;
        }
      }
    }

    out.blit(level,U_LEV*(y_idx%4),V_LEV*(y_idx/4));
  }
    
  out.save(out_file);
}

void dump_threshold(char *out_file,uchar *tmap,int size,color_info *colors)
{
  SImage out,level;
  int l,i,c;

  level.allocate(U_LEV,V_LEV);
  out.allocate(U_LEV*4,V_LEV*Y_LEV/4);

  for(l=0; l<Y_LEV; l++){
    for(i=0; i<U_LEV*V_LEV; i++){
      c = tmap[l*U_LEV*V_LEV + i];
      
      level.img[i] = (c == 255)? darkgray : colors[c].rcolor;
    }

    out.blit(level,U_LEV*(l%4),V_LEV*(l/4));
  }

  out.save(out_file);
}

void save_threshold(char *out_file,uchar *tmap,int size_y,int size_u,int size_v)
{
  FILE *out;
  int size,wrote;

  out = fopen(out_file,"w");
  if(!out)
    goto error;

  fprintf(out,"TMAP\nYUV8\n%d %d %d\n",
          size_y,size_u,size_v);
  size = size_y * size_u * size_v;
  wrote = fwrite(tmap,sizeof(uchar),size,out);
  fclose(out);

  if(wrote != size)
    goto error;

  return;

 error:
  fprintf(stderr, "error saving tmap file\n");
}

#define MAXBUF 1024

bool find_label(FILE *in,char *key)
// search forward in file until a line matching <key> is found
{
  char buf[MAXBUF];
  while(fgets(buf,MAXBUF,in)){
    if(buf[0] != '#'){
      // printf("Scan: %s",buf);
      if(strstr(buf,key)) return(true);
    }
  }

  return(false);
}

char *get_quote_string(char *dest,int len,char *src)
{
  char ch;
  int i,l;

  if(!src) return(NULL);

  i = l = 0;

  while((ch=src[i]) && ch!='"') i++;
  i++;

  while((ch=src[i]) && ch!='"' && l<len-1){
    dest[l] = src[i];
    i++;
    l++;
  }
  dest[l] = 0;

  while((ch=src[i]) && ch!='"') i++;
  if(ch=='"') i++;

  return(src + i);
}

color_info test_colors[COLORS] = {
#ifdef USE_AMBIGUOUS
  {"Background", {  0,  0,  0}, 0.5, 0.5, 0.5},
  {"Orange",     {255,128,  0}, 1.0, 0.9, 0.9},
  {"Green",      {  0,128,  0}, 1.0, 0.5, 0.5},
  {"Pink",       {255,  0,128}, 1.0, 0.5, 0.5},
  {"Cyan",       {  0,255,255}, 1.0, 0.5, 0.5},
  {"Yellow",     {255,255,  0}, 1.0, 0.5, 0.5},
  {"DarkBlue",   {  0,  0,255}, 2.0, 0.5, 0.5},
  {"Red",        {255,  0,  0}, 2.0, 0.5, 0.5},
  {"White",      {255,255,255}, 1.0, 0.5, 0.5},
  {"Error",      {255,  0,255}, 1.0, 0.5, 0.5}
#else
  {"Background", {  0,  0,  0}, 0.5, 0.5,},
  {"Orange",     {255,128,  0}, 1.0, 0.9,},
  {"Green",      {  0,128,  0}, 1.0, 0.5,},
  {"Pink",       {255,  0,128}, 1.0, 0.5,},
  {"Cyan",       {  0,255,255}, 1.0, 0.5,},
  {"Yellow",     {255,255,  0}, 1.0, 0.5,},
  {"DarkBlue",   {  0,  0,255}, 2.0, 0.5,},
  {"Red",        {255,  0,  0}, 2.0, 0.5,},
  {"White",      {255,255,255}, 1.0, 0.5,},
  {"Error",      {255,  0,255}, 1.0, 0.5,}
#endif
};

TMap tmap;
thresh tlearn;
KernelRegLearner klearn;
bool UseKernelLearner = false;

void process(char *file)
{
  vector<pair<const char *,const char *> > training_files;
  char buf[MAXBUF];
  char file1[MAXBUF];
  char file2[MAXBUF];
  double time_start,time_end;
  char *str;
  FILE *in;
  example *list,*tmp_list;

  in = fopen(file,"rb");
  if(!in) return;

  training_files.reserve(5);

  lprintf("COLORS:\n");
  find_label(in,"colors");
  int color_idx=0;
  while(fgets(buf,MAXBUF,in) && buf[0]!='}'){
    if(buf[0] != '#'){
      char *loc=buf;

      char name[MAXBUF];
      int r,g,b;
      double weight,conf;
#ifdef USE_AMBIGUOUS
      double conf_hi;
#endif

      loc = get_quote_string(name,MAXBUF,buf);

#ifdef USE_AMBIGUOUS
      if(sscanf(loc," , ( %d , %d , %d ) , %lg , %lg , %lg",&r,&g,&b,&weight,&conf,&conf_hi)!=6) {
        fprintf(stderr,"error reading end of line '%s'\n",buf);
      }
#else
      if(sscanf(loc," , ( %d , %d , %d ) , %lg , %lg",&r,&g,&b,&weight,&conf)!=5) {
        fprintf(stderr,"error reading end of line '%s'\n",buf);
      }
#endif
      
      test_colors[color_idx].name=strdup(name);
      test_colors[color_idx].rcolor.red  =r;
      test_colors[color_idx].rcolor.green=g;
      test_colors[color_idx].rcolor.blue =b;
      test_colors[color_idx].weight=weight;
      test_colors[color_idx].confidence=conf;
#ifdef USE_AMBIGUOUS
      test_colors[color_idx].confidenceHigh=conf_hi;
#endif
      color_idx++;
    }
  }

  NUM=color_idx;

  lprintf("LOADING:\n");
  find_label(in,"train");
  while(fgets(buf,MAXBUF,in) && buf[0]!='}'){
    if(buf[0] != '#'){
      str = buf;
      file1[0] = file2[0] = 0;
      str = get_quote_string(file1,MAXBUF,str);
      str = get_quote_string(file2,MAXBUF,str);

      training_files.push_back(make_pair(strdup(file1),strdup(file2)));

      list = make_list(file1,file2,test_colors,NUM);

      if(UseKernelLearner){
        klearn.addList(list);
      }else{
        tlearn.addList(list);
      }
    }
  }

  lprintf("LEARNING:\n");
  fflush(stdout);

  time_start = GetTimeSec();
  if(UseKernelLearner){
    printf("building BSP..."); fflush(stdout);
    klearn.build();
    printf("done.\n");

    klearn.learn(tmap,test_colors,NUM);
  }else{
    tlearn.learnMap(test_colors,NUM);
    memcpy(tmap.tmap, tlearn.tmap, tmap.size);
  }
  time_end = GetTimeSec();
  lprintf("Elapsed time: %fsec\n",time_end - time_start);

  lprintf("REBUILDING EXAMPLE LIST:\n");
  // rebuild a complete example list
  list = NULL;
  for(int training_idx=0; training_idx<(int)training_files.size(); training_idx++){
    tmp_list = make_list(training_files[training_idx].first,
                         training_files[training_idx].second,
                         test_colors,NUM);
    if(list==NULL){
      list = tmp_list;
    }else if(tmp_list!=NULL){
      example *ex;
      for(ex=tmp_list; ex->next!=NULL; ex=ex->next)
        ;
      ex->next = list;
      list = tmp_list;
    }
  }
  lprintf("CALCULATING CONFUSION MATRIX:\n");
  CalcConfusionMatrix(list,tmap,test_colors,NUM);

  lprintf("CALCULATING AVERAGE COLORS:\n");
  CalcAverageColors(list,tmap,test_colors,NUM);

  lprintf("TESTING:\n");
  find_label(in,"test");
  while(fgets(buf,MAXBUF,in) && buf[0]!='}'){
    if(buf[0] != '#'){
      str = buf;
      file1[0] = file2[0] = 0;
      str = get_quote_string(file1,MAXBUF,str);
      str = get_quote_string(file2,MAXBUF,str);
      if(file1[0] && file2[0]){
        Classify(tmap,file1,file2,test_colors,NUM);
      }
    }
  }

  if(!UseKernelLearner){
    if(save_weight_maps){
      dump_weightmaps(tlearn, "weight", tlearn.wmap, Y_LEV*U_LEV*V_LEV, NUM);
    }
    dump_color(tlearn, "color.png", tlearn.wmap, tlearn.tmap, Y_LEV*U_LEV*V_LEV, NUM, test_colors);
  }

  lprintf("SAVING:\n");
  dump_threshold("tmap.png", tmap.tmap, Y_LEV*U_LEV*V_LEV, test_colors);
  save_threshold("out.tm", tmap.tmap, Y_LEV, U_LEV, V_LEV);
}

void usage()
{
  printf("usage: thresh [-rcwk] colors.txt\n"
         " -r : rerun using cache file if present\n"
         " -c : use alternate input file to colors.txt\n"
         " -w : save per-color weight map images\n"
         " -k : use kernel learner\n"
         " -h : print this usage/help\n");
}

int main(int argc,char **argv)
{
#ifdef USE_AMBIGUOUS
  char *input_file = "colors_ambig.txt";
#else
  char *input_file = "colors.txt";
#endif
  char ch;

  lprintf_open("thresh.log");
  tmap.init(Y_LEV,U_LEV,V_LEV);
  tlearn.init(Y_LEV,U_LEV,V_LEV);

  // process the command line
  while((ch = getopt(argc,argv,"rc:whk")) != EOF){
    switch(ch){
      case 'r':
        use_cache_file = true;
        break;
      case 'c':
        input_file = optarg;
        break;
      case 'w':
        save_weight_maps = true;
        break;
      case 'k':
        UseKernelLearner = true;
        break;
      case 'h':
        usage();
        return(0);
    }
  }

  process(input_file);

  lprintf_close();

  return(0);
}
