/* LICENSE:
  =========================================================================
    CMPack'04 Source Code Release for OPEN-R SDK 1.1.5-r2 for ERS7
    Copyright (C) 2004 Multirobot Lab [Project Head: Manuela Veloso]
    School of Computer Science, Carnegie Mellon University
    All rights reserved.
  ========================================================================= */

/* This program takes a series of PPM images from the robot's
   camera. Each image should be of a solid colored surface with
   a matte finish. (i.e. no specular highlights). We determine
   the "true" color of each image by taking the mean of a circle
   of pixels in the center of it. We then calculate a LSQ linear
   fit for each pixel from the series of images to figure out how
   to correct for color distortion.

   Since we can't store floating point numbers in a PPM file, we
   multiply each factor of the slope by 128 - this is equivalent to 1.0. On
   the robot, we'll multiply the YUV values in the image by
   the YUV values in the mask and then right shift by 7 bits
   to get back approximately correct results (and get them
   back way faster than a floating point multiply would return
   them). (Be careful of overflow if you ever implement this -
   255*255 >> 7 is larger than 255 so you end up losing the
   most significant bit when you convert to a uchar)

   Also be careful - we can't represent more than a factor of 2
   (okay, 1.99... really) using this approximation. We cap the
   integer factors at 255 in this utility.

   The offset values for YU are signed chars and the offset value
   for V is a negated unsigned char (-255 to 0 stored as 0-255 in
   the mask file)
*/

#include <stdio.h>
#include <math.h>

#include "PPMImage.h"
#include "RGBTriplet.h"

typedef unsigned char uchar;

PPMImage **buildSolidImages(PPMImage **images, int count)
{
  int width  = images[0]->getWidth();
  int height = images[0]->getHeight();

  PPMImage **retval = new PPMImage*[count];

  for(int i=0; i<count; i++)
    retval[i] = new PPMImage(width, height);

  double r_thresh = height/8;

  // Step through the image and average pixels that are within
  // a given radius of the center. We'll assume that this is
  // the uniform color present in the image.
  for(int i=0; i<count; i++){
    double y_accum,u_accum,v_accum;
    int num_pixels;

    y_accum = 0;
    u_accum = 0;
    v_accum = 0;
    num_pixels = 0;

    for(int y=0; y<height; y++){
      for(int x=0; x<width; x++){
	double dx = x - (width -1)/2.0;
	double dy = y - (height-1)/2.0;

	double r = sqrt(dx*dx + dy*dy);
	if(r <= r_thresh){
	  y_accum += images[i]->get(x, y).red;
	  u_accum += images[i]->get(x, y).green;
	  v_accum += images[i]->get(x, y).blue;
	  num_pixels++;
	}
      }
    }

    y_accum /= num_pixels;
    u_accum /= num_pixels;
    v_accum /= num_pixels;

    // Now we set each pixel of our return image to the average
    // color of the center
    for(int y=0; y<height; y++){
      for(int x=0; x<width; x++){
	retval[i]->set(x,y, RGBTriplet((unsigned char)rint(y_accum),
				       (unsigned char)rint(u_accum),
				       (unsigned char)rint(v_accum)));
      }
    }

    printf("y=%g u=%g v=%g\n",y_accum,u_accum,v_accum);
  }

  return retval;
}

int main(int argc, char **argv)
{
  if(argc==1){
    printf("Usage: pass me image files of blank colored surfaces.\n");
    return 0;
  }

  int width=0, height=0, count = argc - 1;
  PPMImage **images = new PPMImage*[count];

  // Load all of our images from file
  for(int i=1; i<argc; i++){
    images[i - 1] = new PPMImage(argv[i]);
  }

  width  = images[0]->getWidth();
  height = images[0]->getHeight();

  PPMImage **corrected_images = buildSolidImages(images, count);

  // Run least squares for each and every pixel to
  // figure out a good value of A (constant to
  // multiply the actual value by to get the
  // corrected value)
  double *y_a_list = new double[width*height];
  double *u_a_list = new double[width*height];
  double *v_a_list = new double[width*height];
  double *y_b_list = new double[width*height];
  double *u_b_list = new double[width*height];
  double *v_b_list = new double[width*height];
  int offset = 0;
  for(int r=0; r<height; r++){
    for(int c=0; c<width; c++){
      // We need to track some statistics. 
      double Y_sum_x = 0, Y_sum_y = 0, Y_sum_x_x = 0, Y_sum_x_y = 0, Y_sum_y_y = 0;
      double U_sum_x = 0, U_sum_y = 0, U_sum_x_x = 0, U_sum_x_y = 0, U_sum_y_y = 0;
      double V_sum_x = 0, V_sum_y = 0, V_sum_x_x = 0, V_sum_x_y = 0, V_sum_y_y = 0;
      
      for(int i=0; i<count; i++){
	int raw_Y = images[i]->get(c,r).red;
	int raw_U = images[i]->get(c,r).green;
	int raw_V = images[i]->get(c,r).blue;
	int want_Y = corrected_images[i]->get(c,r).red;
	int want_U = corrected_images[i]->get(c,r).green;
	int want_V = corrected_images[i]->get(c,r).blue;

	Y_sum_x += raw_Y;
        U_sum_x += raw_U;
        V_sum_x += raw_V; 

	Y_sum_y += want_Y;
        U_sum_y += want_U;
        V_sum_y += want_V; 

	Y_sum_x_x +=  raw_Y* raw_Y;
	U_sum_x_x +=  raw_U* raw_U;
	V_sum_x_x +=  raw_V* raw_V;

	Y_sum_x_y +=  raw_Y*want_Y;
	U_sum_x_y +=  raw_U*want_U;
	V_sum_x_y +=  raw_V*want_V;

	Y_sum_y_y += want_Y*want_Y;
	U_sum_y_y += want_U*want_U;
	V_sum_y_y += want_V*want_V;
      }

      double num, denom;

      // Calculations for least squares from any stats book.
      // We want a and b such that a*x+b minimizes the squared
      // error at each point. We're calculating a separate
      // a and b for each pixel. (which may be overkill)
      num = Y_sum_x_y - Y_sum_x*Y_sum_y/count;
      denom = Y_sum_x_x - Y_sum_x*Y_sum_x/count;
      if(fabs(denom) < .001){
	y_a_list[offset] = 1;
      }else{
	y_a_list[offset] = num/denom;
      }

      num = U_sum_x_y - U_sum_x*U_sum_y/count;
      denom = U_sum_x_x - U_sum_x*U_sum_x/count;
      if(fabs(denom) < .001){
	u_a_list[offset] = 1;
      }else{
	u_a_list[offset] = num/denom;
      }

      num = V_sum_x_y - V_sum_x*V_sum_y/count;
      denom = V_sum_x_x - V_sum_x*V_sum_x/count;
      if(fabs(denom) < .001){
	//printf("V denom: %lf\n", denom);
	v_a_list[offset] = 1;
      }else{
	v_a_list[offset] = num/denom;
      }

      // Hey, if we're going to clamp these is it better to use
      // the y-intercept corresponding to the clamped or
      // unclamped value?
      if(y_a_list[offset] > 2.0) y_a_list[offset] = 2.0;
      if(y_a_list[offset] < 0.0) y_a_list[offset] = 0.0;

      if(u_a_list[offset] > 2.0) u_a_list[offset] = 2.0;
      if(u_a_list[offset] < 0.0) u_a_list[offset] = 0.0;

      if(v_a_list[offset] > 2.0) v_a_list[offset] = 2.0;
      if(v_a_list[offset] < 0.0) v_a_list[offset] = 0.0;

      num = Y_sum_y - y_a_list[offset]*Y_sum_x;
      denom = count;
      y_b_list[offset] = num/denom;

      num = U_sum_y - u_a_list[offset]*U_sum_x;
      denom = count;
      u_b_list[offset] = num/denom;

      num = V_sum_y - v_a_list[offset]*V_sum_x;
      denom = count;
      v_b_list[offset] = num/denom;

      offset++;
    }
  }

  
  // We need to store our values for A and B in a convenient format.
  // Let's double the width of the image and even numbered columns 
  // will contain A, odd numbered columns will contain B.
  
  // We need to convert from our lists of doubles into single bytes.
  // Our A values for all 3 channels tend to be between 0 and 2.0.
  // A few are higher, but only a very few in the test data.
  // Why don't we just take an unsigned byte and save (double val)*128
  // clamped to the appropriate interval? The 128 is handy 'cuz we can
  // do an integer division on the robots with a right shift of 7 bits.

  // B values for the Y and U channels tend to fall between -60 and 60.
  // we can get away with a signed char (again, clamped). However, for
  // the V channel we need -210 to about 10. (Only a very few samples are
  // greater than 0). I'm thinking unsigned char * (-1) as storage.
  // Again, we'll make with the clamping.


  // --

  uchar *mask_y_a = new uchar[width*height];
  uchar *mask_y_b = new uchar[width*height];
  uchar *mask_u_a = new uchar[width*height];
  uchar *mask_u_b = new uchar[width*height];
  uchar *mask_v_a = new uchar[width*height];
  uchar *mask_v_b = new uchar[width*height];

  offset = 0;
  for(int y=0; y<height; y++){
    for(int x=0; x<width; x++){   

      int y_a = (int)(y_a_list[offset]*128.0);
      if(y_a < 0){
	y_a = 0;
	printf("warning at (%3d,%3d): clamping y_a to    0 for raw value %lf\n",x,y,y_a_list[offset]);
      }else if(y_a > 255){
	y_a = 255;
	printf("warning at (%3d,%3d): clamping y_a to  255 for raw value %lf\n",x,y,y_a_list[offset]);
      }
      
      int u_a = (int)(u_a_list[offset]*128.0);
      if(u_a < 0){
	u_a = 0;
	printf("warning at (%3d,%3d): clamping u_a to    0 for raw value %lf\n",x,y,u_a_list[offset]);
      }else if(u_a > 255){
	u_a = 255;
	printf("warning at (%3d,%3d): clamping u_a to  255 for raw value %lf\n",x,y,u_a_list[offset]);
      }

      int v_a = (int)(v_a_list[offset]*128.0);
      if(v_a < 0){
	v_a = 0;
	printf("warning at (%3d,%3d): clamping v_a to    0 for raw value %lf\n",x,y,v_a_list[offset]);
      }else if(v_a > 255){
	v_a = 255;
	printf("warning at (%3d,%3d): clamping v_a to  255 for raw value %lf\n",x,y,v_a_list[offset]);
      }

      int y_b = (int)y_b_list[offset];
      if(y_b < -127){
	y_b = -127;
	printf("warning at (%3d,%3d): clamping y_b to -127 for raw value %lf\n",x,y,y_b_list[offset]);
      }else if(y_b > 127){
	y_b = 127;
	printf("warning at (%3d,%3d): clamping y_b to  127 for raw value %lf\n",x,y,y_b_list[offset]);
      }

      int u_b = (int)u_b_list[offset];
      if(u_b < -127){
	u_b = -127;
	printf("warning at (%3d,%3d): clamping u_b to -127 for raw value %lf\n",x,y,u_b_list[offset]);
      }else if(u_b > 127){
	u_b = 127;
	printf("warning at (%3d,%3d): clamping u_b to  127 for raw value %lf\n",x,y,u_b_list[offset]);
      }

      int v_b = (int)v_b_list[offset];
      //printf("%d vb\n",v_b);
      if(v_b < -255){
	v_b = -255;
	printf("warning at (%3d,%3d): clamping v_b to -255 for raw value %lf\n",x,y,v_b_list[offset]);
      }else if(v_b > 0){
	v_b = 0;
	printf("warning at (%3d,%3d): clamping v_b to    0 for raw value %lf\n",x,y,v_b_list[offset]);
      }
      v_b = -v_b; // get rid of - sign; we'll add it back when applying the mask.

      if(y==5 && x==5){
        printf("(%d,%d) y:a=%u,b=%d u:a=%u,b=%d v:a=%u,b=%d\n",x,y,y_a,y_b,u_a,u_b,v_a,-v_b);
      }

      // Wow, this is easy to forget.
      offset++;

      mask_y_a[y*width+x] = y_a;
      mask_y_b[y*width+x] = y_b;
      mask_u_a[y*width+x] = u_a;
      mask_u_b[y*width+x] = u_b;
      mask_v_a[y*width+x] = v_a;
      mask_v_b[y*width+x] = v_b;
    }
  }

  // We'll want a nice PPM so the user can see what it looks like.
  // Double the width so we have room for both A and B.
  PPMImage *save_me = new PPMImage(2*width, 2*height);
  for(int y=0; y<height; y++){
    for(int x=0; x<width; x++){   
      int y_a,y_b,u_a,u_b,v_a,v_b;

      y_a = mask_y_a[y*width+x];
      y_b = mask_y_b[y*width+x];
      u_a = mask_u_a[y*width+x];
      u_b = mask_u_b[y*width+x];
      v_a = mask_v_a[y*width+x];
      v_b = mask_v_b[y*width+x];

      save_me->set(2*x,   y, RGBTriplet((unsigned char)y_a,
                                        (unsigned char)u_a,
                                        (unsigned char)v_a));
      save_me->set(2*x+1, y, RGBTriplet((unsigned char)(y_b+128),
                                        (unsigned char)(u_b+128),
                                        (unsigned char)(-v_b+255)));
      
      save_me->set(x,       y+height, RGBTriplet((unsigned char)y_a,
                                                 (unsigned char)u_a,
                                                 (unsigned char)v_a));
      save_me->set(x+width, y+height, RGBTriplet((unsigned char)(y_b+128),
                                                 (unsigned char)(u_b+128),
                                                 (unsigned char)(-v_b+255)));
    }
  }
  save_me->save("output.ppm");

  // Now we need to write an image file that CMVision will
  // understand. We'll use the planar YUV format.
  FILE *outfile = fopen("mask_img.raw", "wb");
  for(int y=0; y<height; y++){
    // We need to write these one channel at a time.
    // *sigh* I hope fwrite buffers or we're gonna
    // be thrashing the harddrive.
    for(int c=0; c<width; c++){
      unsigned char a_buf = mask_y_a[y*width+c];
      unsigned char b_buf = mask_y_b[y*width+c];
      fwrite(&a_buf, 1, sizeof(a_buf), outfile);
      fwrite(&b_buf, 1, sizeof(b_buf), outfile);

      //if(y==5 && c==5) printf("(%d,%d) y:a=%u b=%u\n",c,y,a_buf,b_buf);
    }
    for(int c=0; c<width; c++){
      unsigned char a_buf = mask_u_a[y*width+c];
      unsigned char b_buf = mask_u_b[y*width+c];
      fwrite(&a_buf, 1, sizeof(a_buf), outfile);
      fwrite(&b_buf, 1, sizeof(b_buf), outfile);

      //if(y==5 && c==5) printf("(%d,%d) u:a=%u b=%u\n",c,y,a_buf,b_buf);
    }
    for(int c=0; c<width; c++){
      unsigned char a_buf = mask_v_a[y*width+c];
      unsigned char b_buf = mask_v_b[y*width+c];
      fwrite(&a_buf, 1, sizeof(a_buf), outfile);
      fwrite(&b_buf, 1, sizeof(b_buf), outfile);

      //if(y==5 && c==5) printf("(%d,%d) v:a=%u b=%u\n",c,y,a_buf,b_buf);
    }
  }
  fclose(outfile);

  delete[] mask_y_a;
  delete[] mask_y_b;
  delete[] mask_u_a;
  delete[] mask_u_b;
  delete[] mask_v_a;
  delete[] mask_v_b;

  return 0;
}

