/* LICENSE: */

/* Takes a PPM image of a correction mask and applies it
   to a PPM image from the robot's camera. Produces
   applied.ppm.
*/

#include <stdio.h>
#include <stdlib.h>
#include <string.h>

#include "../../agent/headers/Util.h"

#include "PPMImage.h"
#include "RGBTriplet.h"

typedef unsigned char uchar;

int main(int argc, char **argv) {
  
  if(argc<3) {
    printf("Usage: program <correction mask> <image_to_correct> [<image_to_correct> ...]\n");
  }
  
  uchar *mask=NULL;

  // Shift <pixel value> * <mask> right this many bits to convert back
  // to an integer value after approximating a floating point multiply
  // with integers. 2^mask_shift = 1.0 in the mask.
  const int mask_shift = 7;

  double avg_dev_y,avg_dev_u,avg_dev_v;
  double cor_avg_dev_y,cor_avg_dev_u,cor_avg_dev_v;

  avg_dev_y = avg_dev_u = avg_dev_v = 0.0;
  cor_avg_dev_y = cor_avg_dev_u = cor_avg_dev_v = 0.0;

  for(int i=2; i<argc; i++){
    PPMImage *src  = new PPMImage(argv[i]);
  
    int width  = src->getWidth();
    int height = src->getHeight();
    
    if(mask==NULL){
      mask = new uchar[width*height*3*2];

      FILE *mask_file = fopen(argv[1], "rb");
      if(mask_file){
        if(fread(mask, sizeof(uchar), 2*3*width*height, mask_file)!=2*3*width*height){
          printf("Error reading correction mask file '%s'\n",argv[1]);
          exit(2);
        }
        fclose(mask_file);
      }else{
        printf("Couldn't open correction mask file '%s'\n",argv[1]);
        exit(2);
      }  
    }

    PPMImage *dst  = new PPMImage(width, height);

    double sum_y,sum_u,sum_v;
    double sum_yy,sum_uu,sum_vv;
    double cor_sum_y,cor_sum_u,cor_sum_v;
    double cor_sum_yy,cor_sum_uu,cor_sum_vv;

    sum_y = sum_u = sum_v = 0.0;
    sum_yy = sum_uu = sum_vv = 0.0;
    cor_sum_y = cor_sum_u = cor_sum_v = 0.0;
    cor_sum_yy = cor_sum_uu = cor_sum_vv = 0.0;

    for(int y=0; y<height; y++) {
      for(int x=0; x<width; x++) {
        int py, pu, pv;
      
        py = src->get(x,y).red;
        pu = src->get(x,y).green;
        pv = src->get(x,y).blue;

        sum_y += py;
        sum_u += pu;
        sum_v += pv;
      
        sum_yy += py*py;
        sum_uu += pu*pu;
        sum_vv += pv*pv;

        //if(x==5 && y==5) printf("(%d,%d) y=%d u=%d v=%d\n",x,y,py,pu,pv);

        uchar y_a, u_a, v_a;
        y_a = mask[(y*width*3+width*0+x)*2+0];
        u_a = mask[(y*width*3+width*1+x)*2+0];
        v_a = mask[(y*width*3+width*2+x)*2+0];
        
        py *= y_a;
        pu *= u_a;
        pv *= v_a;
        
        // Divide by 128 to get back to the appropriate result after
        // the multiply.
        py = py >> mask_shift;
        pu = pu >> mask_shift;
        pv = pv >> mask_shift;
        
        //if(x==5 && y==5) printf("after mult (%d,%d) y=%d u=%d v=%d\n",x,y,py,pu,pv);

        // Now we add in our additive bit. These values are also
        // encoded. Y and U are signed chars, so we need to make
        // sure we get their sign back from the mask. V is
        // unsigned, but it's been negated - it should range from
        // -255 to 0.
        char y_b = (char)mask[(y*width*3+width*0+x)*2+1];
        char u_b = (char)mask[(y*width*3+width*1+x)*2+1];
        int  v_b =      -mask[(y*width*3+width*2+x)*2+1];

        //if(x==5 && y==5){
        //  printf("(%d,%d) y:a=%u b=%d\n",x,y,y_a,y_b);
        //  printf("(%d,%d) u:a=%u b=%d\n",x,y,u_a,u_b);
        //  printf("(%d,%d) v:a=%u b=%d\n",x,y,v_a,v_b);
        //}

        py += y_b;
        pu += u_b;
        pv += v_b;

        //if(x==5 && y==5) printf("after shift (%d,%d) y=%d u=%d v=%d\n",x,y,py,pu,pv);

        // We need to clamp our values back in the normal range for
        // a byte. Please, Mr. Compiler, please use a conditional
        // move instead of a jump.

        if(py < 0) py = 0;
        if(py > 255) py = 255;

        if(pu < 0) pu = 0;
        if(pu > 255) pu = 255;

        if(pv < 0) pv = 0;
        if(pv > 255) pv = 255;

        cor_sum_y += py;
        cor_sum_u += pu;
        cor_sum_v += pv;
      
        cor_sum_yy += py*py;
        cor_sum_uu += pu*pu;
        cor_sum_vv += pv*pv;

        dst->set(x, y, RGBTriplet((unsigned char)py,
                                  (unsigned char)pu,
                                  (unsigned char)pv));
      }
    }

    double dev_y,dev_u,dev_v;
    double cor_dev_y,cor_dev_u,cor_dev_v;
    int count;

    count = width*height;

    dev_y = sqrt(sum_yy/count - sq(sum_y/count));
    dev_u = sqrt(sum_uu/count - sq(sum_u/count));
    dev_v = sqrt(sum_vv/count - sq(sum_v/count));

    cor_dev_y = sqrt(cor_sum_yy/count - sq(cor_sum_y/count));
    cor_dev_u = sqrt(cor_sum_uu/count - sq(cor_sum_u/count));
    cor_dev_v = sqrt(cor_sum_vv/count - sq(cor_sum_v/count));

    printf("before dev:y=%8.4f u=%8.4f v=%8.4f for image '%s'\n",dev_y,dev_u,dev_v,argv[i]);
    printf("after  dev:y=%8.4f u=%8.4f v=%8.4f for image '%s'\n",cor_dev_y,cor_dev_u,cor_dev_v,argv[i]);

    avg_dev_y += dev_y;
    avg_dev_u += dev_u;
    avg_dev_v += dev_v;

    cor_avg_dev_y += cor_dev_y;
    cor_avg_dev_u += cor_dev_u;
    cor_avg_dev_v += cor_dev_v;

    char buf[1024];
    strcpy(buf,"app_");
    strcat(buf,argv[i]);
    dst->save(buf);

    delete dst;
    delete src;
  }

  avg_dev_y /= argc-2;
  avg_dev_u /= argc-2;
  avg_dev_v /= argc-2;

  cor_avg_dev_y /= argc-2;
  cor_avg_dev_u /= argc-2;
  cor_avg_dev_v /= argc-2;

  printf("before avg dev:y=%8.4f u=%8.4f v=%8.4f\n",avg_dev_y,avg_dev_u,avg_dev_v);
  printf("after  avg dev:y=%8.4f u=%8.4f v=%8.4f\n",cor_avg_dev_y,cor_avg_dev_u,cor_avg_dev_v);

  delete[] mask;

  return 0;
}
