/* LICENSE:
  =========================================================================
    CMPack'04 Source Code Release for OPEN-R SDK 1.1.5-r2 for ERS7
    Copyright (C) 2004 Multirobot Lab [Project Head: Manuela Veloso]
    School of Computer Science, Carnegie Mellon University
    All rights reserved.
  ========================================================================= */

#include "agent/Vision/colors.h"

#include "lprintf.h"
#include "learn.h"
#include "kernreg.h"
#include "misc_process.h"

void CalcConfusionMatrix(example *exs,TMap &tmap,color_info *colors,int num_colors)
{
  double *confusion_matrix;
  long confusion_matrix_size;
  unsigned long num_examples;
  unsigned long *num_examples_color;

  confusion_matrix_size = num_colors*num_colors;
  confusion_matrix = new double[confusion_matrix_size];
  memset(confusion_matrix,0,sizeof(double)*confusion_matrix_size);
  num_examples = 0;
  num_examples_color = new unsigned long[num_colors];
  memset(num_examples_color,0,sizeof(unsigned long)*num_colors);

  for(example *ex=exs; ex!=NULL; ex=ex->next){
    uchar true_color;
    uchar label_color;

    true_color = ex->label;
    label_color = tmap.classify(ex->y,ex->u,ex->v);

    confusion_matrix[true_color*num_colors + label_color] += 1.0;
    num_examples++;
    num_examples_color[true_color]++;
  }

  lprintf("color: examples  name\n");
  for(int color_idx=0; color_idx<num_colors; color_idx++){
    lprintf(" %02d  : %8ld  '%s'\n",color_idx,num_examples_color[color_idx],
            colors[color_idx].name);
  }

  //for(int true_color_idx=0; true_color_idx<num_colors; true_color_idx++){
  //  for(int label_color_idx=0; label_color_idx<num_colors; label_color_idx++){
  //    confusion_matrix[true_color_idx*num_colors+label_color_idx] /= (double)num_examples;//num_examples_color[true_color_idx];
  //  }
  //}

  if(true){
    lprintf("      labeled as\n");
    lprintf("true: ");
    for(int label_color_idx=0; label_color_idx<num_colors; label_color_idx++){
      lprintf("%4d%4s ",label_color_idx,"");
    }
    lprintf("\n");
    for(int true_color_idx=0; true_color_idx<num_colors; true_color_idx++){
      lprintf(" %02d : ",true_color_idx);
      for(int label_color_idx=0; label_color_idx<num_colors; label_color_idx++){
        double pct_correct;
        pct_correct = confusion_matrix[true_color_idx*num_colors + label_color_idx] / (double)num_examples;
        lprintf("%8.4f ",pct_correct);
      }
      lprintf("\n");
    }
  }

  lprintf("\n  NORMALIZED\n");
  lprintf("      labeled as\n");
  lprintf("true: ");
  for(int label_color_idx=0; label_color_idx<num_colors; label_color_idx++){
    lprintf("%4d%4s ",label_color_idx,"");
  }
  lprintf("\n");
  for(int true_color_idx=0; true_color_idx<num_colors; true_color_idx++){
    lprintf(" %02d : ",true_color_idx);
    for(int label_color_idx=0; label_color_idx<num_colors; label_color_idx++){
      double pct_correct;
      pct_correct = confusion_matrix[true_color_idx*num_colors + label_color_idx] / (double)num_examples_color[true_color_idx];
      lprintf("%8.4f ",pct_correct);
    }
    lprintf("\n");
  }

  delete[] num_examples_color;
  delete[] confusion_matrix;
}

void CalcAverageColors(example *exs,TMap &tmap,color_info *colors,int num_colors)
{
#define SUM_TYPE unsigned long long

  yuv_generic<SUM_TYPE> *sum_colors;
  yuv_generic<uchar> *avg_colors;
  ulong *counts;
  ulong num_examples;

  sum_colors = new yuv_generic<SUM_TYPE>[num_colors*num_colors];
  memset(sum_colors,0,sizeof(yuv_generic<SUM_TYPE>)*num_colors*num_colors);
  avg_colors = new yuv_generic<uchar>[num_colors*num_colors];
  memset(avg_colors,0,sizeof(yuv_generic<uchar>)*num_colors*num_colors);
  counts = new ulong[num_colors*num_colors];
  num_examples = 0;

  for(int true_color_idx=0; true_color_idx<num_colors; true_color_idx++){
    for(int label_color_idx=0; label_color_idx<num_colors; label_color_idx++){
      int idx = true_color_idx*num_colors + label_color_idx;
      counts[idx] = 0;
      sum_colors[idx].y = 0;
      sum_colors[idx].u = 0;
      sum_colors[idx].v = 0;
    }
  }

  for(example *ex=exs; ex!=NULL; ex=ex->next){
    uchar true_color;
    uchar label_color;

    true_color = ex->label;
    label_color = tmap.classify(ex->y,ex->u,ex->v);

    sum_colors[true_color*num_colors + label_color].y += ex->y;
    sum_colors[true_color*num_colors + label_color].u += ex->u;
    sum_colors[true_color*num_colors + label_color].v += ex->v;
    num_examples++;
    counts[true_color*num_colors + label_color]++;
  }

  for(int true_color_idx=0; true_color_idx<num_colors; true_color_idx++){
    for(int label_color_idx=0; label_color_idx<num_colors; label_color_idx++){
      int idx = true_color_idx*num_colors + label_color_idx;
      int count = counts[idx];
      if(count > 0){
        avg_colors[idx].y = (uchar)((double)sum_colors[idx].y / count);
        avg_colors[idx].u = (uchar)((double)sum_colors[idx].u / count);
        avg_colors[idx].v = (uchar)((double)sum_colors[idx].v / count);
      }else{
        idx = true_color_idx*num_colors + true_color_idx;
        count = counts[idx];
        if(count > 0){
          avg_colors[idx].y = (uchar)((double)sum_colors[idx].y / count);
          avg_colors[idx].u = (uchar)((double)sum_colors[idx].u / count);
          avg_colors[idx].v = (uchar)((double)sum_colors[idx].v / count);
        }else{
          avg_colors[idx].y = 128;
          avg_colors[idx].u = 128;
          avg_colors[idx].v = 128;
        }
      }
    }
  }

  lprintf("      labeled as\n");
  lprintf("true: ");
  for(int label_color_idx=0; label_color_idx<num_colors; label_color_idx++){
    lprintf("%4d%7s ",label_color_idx,"");
  }
  lprintf("\n");
  for(int true_color_idx=0; true_color_idx<num_colors; true_color_idx++){
    yuv_generic<float> avg_avg_color;
    ulong count;

    avg_avg_color.y = 0;
    avg_avg_color.u = 0;
    avg_avg_color.v = 0;
    count = 0;
    lprintf(" %02d : ",true_color_idx);
    for(int label_color_idx=0; label_color_idx<num_colors; label_color_idx++){
      yuv_generic<uchar> *avg_color;
      int idx = true_color_idx*num_colors + label_color_idx;
      avg_color = &avg_colors[idx];

      if(counts[idx] > 0){
        lprintf("%03d,%03d,%03d ",avg_color->y,avg_color->u,avg_color->v);
      }else{
        lprintf("---,---,--- ");
      }

      avg_avg_color.y += (double)avg_color->y * counts[idx];
      avg_avg_color.u += (double)avg_color->u * counts[idx];
      avg_avg_color.v += (double)avg_color->v * counts[idx];
      count += counts[idx];
    }
    avg_avg_color.y /= count;
    avg_avg_color.u /= count;
    avg_avg_color.v /= count;
    lprintf("avg=%03d,%03d,%03d",(uchar)avg_avg_color.y,(uchar)avg_avg_color.u,(uchar)avg_avg_color.v);
    lprintf("\n");
  }

  FILE *avg_red_file;
  yuv_generic<uchar> *avg_red_orange_color;
  int red_idx=-1;
  int orange_idx=-1;
  for(int i=0; i<num_colors; i++){
    if(strcasecmp(colors[i].name,"red")==0){
      red_idx = i;
    }
    if(strcasecmp(colors[i].name,"orange")==0){
      orange_idx = i;
    }
  }
  if(orange_idx==-1 || counts[orange_idx]==0){
    lprintf("missing orange color or examples, skipping red v orange calibration\n");
    lprintf("orange_idx=%d count=%lu\n",orange_idx,counts[orange_idx]);
  }else if(red_idx==-1 || counts[red_idx]==0){
    lprintf("missing Red color or examples, skipping red v orange calibration\n");
    lprintf("red_idx=%d count=%lu\n",red_idx,counts[red_idx]);
  }else{
    avg_red_orange_color = &avg_colors[red_idx*num_colors+orange_idx];
    avg_red_file = fopen("red.prm","wb");
    if(avg_red_file==NULL){
      lprintf("couldn't open red v orange calibration output file, skipping\n");
    }else{
      fwrite(&avg_red_orange_color->y,sizeof(avg_red_orange_color->y),1,avg_red_file);
      fwrite(&avg_red_orange_color->u,sizeof(avg_red_orange_color->u),1,avg_red_file);
      fwrite(&avg_red_orange_color->v,sizeof(avg_red_orange_color->v),1,avg_red_file);
      fclose(avg_red_file);
    }
  }

  delete[] counts;
  delete[] avg_colors;
  delete[] sum_colors;

#undef SUM_TYPE
}
