/*****************************************************************************
 * ass2shirai.cpp
 ***************************************************************************** 
 *
 * Description: Assignment 2 (COMPSCI775)
 *
 * Topic: Correspondence Analysis for Binocular Stereo - Shirai's Algorithm
 *
 * (c) 2002 Christian Graf, Uli Schroeder, YongTao Zou (Group 7)
 *
 *****************************************************************************
 *
 * $Revision: 1.0 $
 *
 */

#include <math.h>
#include <iostream.h>
#include "HalconCpp.h"
#include "String.h"

class SHIRAI {

  private:

    double MEAN(int **leftImage, int k, int n, int xLeft, int y) {
		// k      : influences window size (1<=k<=5)
		// n      : window size (n=2*k+1)
		// xLeft  : x of p
		// y      : y of p
		double result = 0;
		for (int i = -k; i<=k; i++) {
			for (int j = -k; j<=k; j++) {
				result += leftImage[xLeft+i][y+j];
			}
		}
		result /= (n*n);
		return result;
	}

	double VARIANCE(int **leftImage, int k, int n, int xLeft, int y) {
		// k      : influences window size (1<=k<=5)
		// n      : window size (n=2*k+1)
		// xLeft  : x of p
		// y      : y of p
		double result = 0;
		double mean = 0;
		mean = MEAN(leftImage, k, n, xLeft, y);
		for (int i=-k; i<=k; i++) {
			for (int j=-k; j<=k; j++) {
				result += (leftImage[xLeft+i][y+j]*leftImage[xLeft+i][y+j]) - (mean*mean);
			}
		}
		result /= (n*n);
		return result;
	}

	double SE(int **leftImage, int **rightImage, int k, int xLeft, int y, int xRight) {
		// k      : influences window size (1<=k<=5)
		// xLeft  : x of p
		// xRight : x of q
		// y      : y of p
		int result = 0;
		for (int i=-k; i<=k; i++) {
			for (int j=-k; j<=k; j++) {
				result += (int)((leftImage[xLeft+i][y+j] - rightImage[xRight+i][y+j])*(leftImage[xLeft+i][y+j] - rightImage[xRight+i][y+j]));
			}
		}
		return result;
	}

/*
	double MSE(int **leftImage, int **rightImage, int k, int n, int xLeft, int y, int xRight) {
		// k      : influences window size (1<=k<=5)
		// n      : window size (n=2*k+1)
		// xLeft  : x of p
		// xRight : x of q
		// y      : y of p
		double result = 0;
		for (int i=-k; i<=k; i++) {
			for (int j=-k; j<=k; j++) {
				result += (leftImage[xLeft+i][y+j] - rightImage[xRight+i][y+j])*(leftImage[xLeft+i][y+j] - rightImage[xRight+i][y+j]);
			}
		}
		result /= (n*n);
		return result;
	}
*/

  public:

    SHIRAI(HByteImage *leftImage, HByteImage *rightImage, HByteImage *edgeImage, HByteImage *outputImage, double d1, double d2, double d3, double maxInterval) {
		double *similarityValues = NULL;
		int n, k, leftBorder, rightBorder;
		int belowD1Count;
		int aboveD2Count;
		int minPos;
		int maximumDisparity = 0;
		int imageWidth = (*leftImage).Width();
		int imageHeight = (*leftImage).Height();
		bool continueLoop;
		double windowVariance;
		// Allocate memory for outBuffer
		int **outBuffer = NULL;
		outBuffer = (int **) new int*[imageHeight];
		for (int i = 0; i < imageHeight; i++) {
			outBuffer[i] = (int *) new int[imageWidth];
		}
		// Create two-dimensional array and store left image in it because of access speed
		int **leftInBuffer = NULL;
		leftInBuffer = (int **) new int*[imageHeight];
		for (int i = 0; i < imageHeight; i++) {
			leftInBuffer[i] = (int *) new int[imageWidth];
		}
		for (int yPos = 0; yPos < imageHeight; yPos++) {
			for (int xPos = 0; xPos < imageWidth; xPos++) {
				leftInBuffer[xPos][yPos] = (*leftImage)(xPos, yPos);
			}
		}		
		// Create two-dimensional array and store right image in it because of access speed
		int **rightInBuffer = NULL;
		rightInBuffer = (int **) new int*[imageHeight];
		for (int i = 0; i < imageHeight; i++) {
			rightInBuffer[i] = (int *) new int[imageWidth];
		}
		for (int yPos = 0; yPos < imageHeight; yPos++) {
			for (int xPos = 0; xPos < imageWidth; xPos++) {
				rightInBuffer[xPos][yPos] = (*rightImage)(xPos, yPos);
			}
		}		
		// Shirai algorithm
		for (int yPos = 1; yPos < imageHeight-1; yPos++) {
			for (int xPos = 1; xPos < imageWidth-1; xPos++) {
				outBuffer[xPos][yPos] = 0;										// make sure standard colour for pixel is black
				if ((*edgeImage)(xPos, yPos) == 0) {							// begin: check if pixel is an edge pixel
					k = 1;														// initialise k
					rightBorder = xPos;											// initialise right interval border
					leftBorder = rightBorder-(int)(imageWidth*maxInterval);		// initialise left interval border
					if (leftBorder < 1) {										// begin: check if left border is valid
						leftBorder = 1;											// correct leftBorder
					}															// end: check if left border is valid
					continueLoop = true;										// enable loop to start
					while (continueLoop) {										// begin: loop
						n = 2*k+1;												// set window size
						belowD1Count = 0;										// will hold number of values below d1
						aboveD2Count = 0;										// will hold number of values above d2
						minPos = -1;											// will hold x value at which we habe minimum disparity (unique if belowD1Count = 1)
						similarityValues = (double *) new double[xPos+1];		// allocate memory to store similarity values for current search interval
						windowVariance = VARIANCE(leftInBuffer, k, n, xPos, yPos); // compute VARIANCE for this window size
						if (windowVariance != 0) {								// begin: check if windowVariance is valid
							for (int i = rightBorder; i >= leftBorder; i--) {	// begin: search the interval from rightBorder to leftBorder
																				// measure similarity of leftImage(xPos, yPos) with rightImage(i, yPos)
								similarityValues[i] = SE(leftInBuffer, rightInBuffer, k, xPos, yPos, i) / windowVariance;
//								similarityValues[i] = MSE(leftImage, rightImage, k, n, xPos, yPos, i) / windowVariance;
																				// begin: check if similarity is below d1?
								if ((similarityValues[i] < d1) && (similarityValues[i] >= 0)) {
									belowD1Count++;								// increase belowD1Count
									minPos = i;									// store x position of minimum
								}												// end: check if similarity is below d1?
																				// begin: check if similarity is above d2?
								if ((similarityValues[i] > d2) || (similarityValues[i] < 0)) {
									aboveD2Count++;								// increase aboveD2Count
								}												// end: check if similarity is above d2?
							}													// end: search the interval from rightBorder to leftBorder
						}														// end: check if windowVariance is valid
						if (belowD1Count == 1) {								// check if there exists a unique minimum smaller than d1
							outBuffer[xPos][yPos] = xPos - minPos;				// set disparity for point p
							if (outBuffer[xPos][yPos] > maximumDisparity) {		// begin: check for maximum Disparity
								maximumDisparity = outBuffer[xPos][yPos];		// save maximumDisparity for later scaling to 0-255
							}													// end: check for maximum disparity
							continueLoop = false;								// stop the loop from running again
						} else {
							if (aboveD2Count == rightBorder-leftBorder+1) {		// check if all similarity values are greater than d2
								outBuffer[xPos][yPos] = 0;						// calculation of disparity impossible => 0
								continueLoop = false;							// stop the loop from running again
							} else {
																				// window has maximum size or would be out of bounds in the next step
								if (((int)floor(n/2)+2 >= xPos) || (xPos+(int)ceil(n/2)+2 >= imageWidth) ||
					      		    ((int)floor(n/2)+2 >= yPos) || (yPos+(int)ceil(n/2)+2 >= imageHeight) ||
					      		    (k==5)) {
									outBuffer[xPos][yPos] = 0;					// calculation of disparity impossible => 0
									continueLoop = false;						// stop the loop from running again
								} else {										// preprocessing for next try
									k++;										// increase k
																				// begin: check if leftBorder is valid and if interval can become smaller
									while (((similarityValues[leftBorder] > d3) || ((int)floor(n/2) >= leftBorder)) &&
										   (leftBorder < rightBorder)) {
										leftBorder++;							// move left border
									}											// end: check if leftBorder is valid and if interval can become smaller
																				// begin: check if rightBorder is valid and if interval can become smaller
									while ((similarityValues[rightBorder] > d3) && (leftBorder < rightBorder)) {
										rightBorder--;							// move right border
									}											// end: check if rightBorder is valid and if interval can become smaller
									if ((int)floor(n/2) >= leftBorder) {		// begin: check if leftBorder is still valid
										outBuffer[xPos][yPos] = 0;				// calculation of disparity impossible => 0
										continueLoop = false;					// stop the loop from running again
									}											// end: check if leftBorder is still valid
								}
							}
						}
						delete [] similarityValues;								// free memory allocated for similarity array
					}															// end: loop
				} 																// end: check if pixel is an edge pixel
			}
		}
		if (maximumDisparity != 0) {
			// Scale values to the range 0-255
			double factor = 255/maximumDisparity;
			for (int y=0; y<imageHeight; y++) {
				for (int x=0; x<imageWidth; x++) {
					outBuffer[x][y]=(int)floor((outBuffer[x][y]*factor)+0.5);
				}
			}
			// Copy picture buffer to output picture
			for (int y=1; y<(imageHeight-1); y++) {
				for (int x=1; x<(imageWidth-1); x++) {
					(*outputImage)(x, y) = outBuffer[x][y];
				}
			}
		}
		// Free buffers reserved during processing
		for (int i = 0; i<imageHeight; i++) {
			delete [] outBuffer[i];
		}
		delete [] outBuffer;
		for (int i = 0; i<imageHeight; i++) {
			delete [] leftInBuffer[i];
		}
		delete [] leftInBuffer;
		for (int i = 0; i<imageHeight; i++) {
			delete [] rightInBuffer[i];
		}
		delete [] rightInBuffer;
	}
	
};

main(int argc, char* argv[]) {
	cout << "----------------------------------------------" << endl;
	cout << " Assignment 2 (COMPSCI775):                   " << endl;
	cout << " Correspondence Analysis for Binocular Stereo " << endl;
	cout << " (Shirai's Algorithm)                         " << endl;
	cout << "----------------------------------------------" << endl;
	cout << " Group 7:                                     " << endl;
	cout << " Christian Graf                               " << endl;
	cout << " Uli Schroeder                                " << endl;
	cout << " YongTao Zou                                  " << endl;
	cout << "----------------------------------------------" << endl;
	if (argc != 8) {
		cout << " Wrong number of arguments!" << endl << endl;
		cout << " ./ass2shirai [channels] [image] [outimage] [d1] [d2] [d3] [maxinterval]" << endl << endl;
		cout << "    [channels]    number of channels per picture" << endl << endl;
		cout << "    [image]       path and filename of the image (_l, _r_g, etc is added automatically!)" << endl;
		cout << "    [outimage]    path and filename of the output image (including .jpg)" << endl;
		cout << "    [d1]          shirai threshold d1" << endl;
		cout << "    [d2]          shirai threshold d2" << endl;
		cout << "    [d3]          shirai threshold d3" << endl;
		cout << "    [maxinterval] percentage of image width (0 <= maxinterval <= 1)" << endl << endl;
		cout << " Example:" << endl;
		cout << "   ./ass2shirai 1 pyrtasse result.jpg 10 30 30 0.3" << endl << endl;
		return(-1);
	}
	int myThreshold = 55;
	// Read input images
	String ext = ".bmp";
	HByteImage leftImage, rightImage, edgeImage, result;
	cout << " Processing images " << argv[2] << "..." << endl;
	switch (atoi(argv[1])) {
		case 1 : {
			leftImage = HByteImage(((String)argv[2])+"_l"+ext); 
			rightImage = HByteImage(((String)argv[2])+"_r"+ext); 
			result = HByteImage(leftImage.Width(),leftImage.Height());
			// Apply Median and Sobel
			edgeImage = leftImage.MedianImage("circle", 3, -1);
			edgeImage = edgeImage.SobelAmp("sum_abs", 3);
			// Convert edge image to black and white
			for (int y = 0; y < edgeImage.Height(); y++) {
				for (int x = 0; x < edgeImage.Width(); x++) {
					if (edgeImage(x, y) > myThreshold) {
						edgeImage(x, y) = 0;
					} else {
						edgeImage(x, y) = 255;
					}
				}
			}
			// Apply Shirai
			new SHIRAI(&leftImage, &rightImage, &edgeImage, &result, atof(argv[4]), atof(argv[5]), atof(argv[6]), atof(argv[7]));
			break;
		}
		case 3 : {
			HByteImage temp;
//Red channel
			cout << "   Red channel..." << endl;
			leftImage = HByteImage(((String)argv[2])+"_l"+"_r"+ext); 
			rightImage = HByteImage(((String)argv[2])+"_r"+"_r"+ext);
			result = HByteImage(leftImage.Width(),leftImage.Height());
			for (int y = 0; y < result.Height(); y++) {
				for (int x = 0; x < result.Width(); x++) {
					result(x,y) = 0;
				}
			}
			temp = HByteImage(leftImage.Width(),leftImage.Height());
			// Apply Median and Sobel
			edgeImage = leftImage.MedianImage("circle", 3, -1);
			edgeImage = edgeImage.SobelAmp("sum_abs", 3);
			// Convert edge image to black and white
			for (int y = 0; y < edgeImage.Height(); y++) {
				for (int x = 0; x < edgeImage.Width(); x++) {
					if (edgeImage(x, y) > myThreshold) {
						edgeImage(x, y) = 0;
					} else {
						edgeImage(x, y) = 255;
					}
				}
			}
			// Apply Shirai
			new SHIRAI(&leftImage, &rightImage, &edgeImage, &temp, atof(argv[4]), atof(argv[5]), atof(argv[6]), atof(argv[7]));
			// Maybe disparities are only found in one channel
			for (int y = 0; y < temp.Height(); y++) {
				for (int x = 0; x < temp.Width(); x++) {
					if (result(x, y) == 0) {
						 result(x, y) = temp(x, y);
					} else {
						 result(x, y) = (int)ceil((result(x, y) + temp(x, y))/2);
					}
				}
			}
//Green channel
			cout << "   Green channel..." << endl;
			leftImage=HByteImage(((String)argv[2])+"_l"+"_g"+ext); 
			rightImage=HByteImage(((String)argv[2])+"_r"+"_g"+ext);
			result = HByteImage(leftImage.Width(),leftImage.Height());
			for (int y = 0; y < result.Height(); y++) {
				for (int x = 0; x < result.Width(); x++) {
					result(x,y) = 0;
				}
			}
			temp = HByteImage(leftImage.Width(),leftImage.Height());
			// Apply Median and Sobel
			edgeImage = leftImage.MedianImage("circle", 3, -1);
			edgeImage = edgeImage.SobelAmp("sum_abs", 3);
			// Convert edge image to black and white
			for (int y = 0; y < edgeImage.Height(); y++) {
				for (int x = 0; x < edgeImage.Width(); x++) {
					if (edgeImage(x, y) > myThreshold) {
						edgeImage(x, y) = 0;
					} else {
						edgeImage(x, y) = 255;
					}
				}
			}
			// Apply Shirai
			new SHIRAI(&leftImage, &rightImage, &edgeImage, &temp, atof(argv[4]), atof(argv[5]), atof(argv[6]), atof(argv[7]));
			// Maybe disparities are only found in one channel
			for (int y = 0; y < temp.Height(); y++) {
				for (int x = 0; x < temp.Width(); x++) {
					if (result(x, y) == 0) {
						 result(x, y) = temp(x, y);
					} else {
						 result(x, y) = (int)ceil((result(x, y) + temp(x, y))/2);
					}
				}
			}
//Blue channel
			cout << "   Blue channel..." << endl;
			leftImage=HByteImage(((String)argv[2])+"_l"+"_b"+ext); 
			rightImage=HByteImage(((String)argv[2])+"_r"+"_b"+ext);
			result = HByteImage(leftImage.Width(),leftImage.Height());
			for (int y = 0; y < result.Height(); y++) {
				for (int x = 0; x < result.Width(); x++) {
					result(x,y) = 0;
				}
			}
			temp = HByteImage(leftImage.Width(),leftImage.Height());
			// Apply Median and Sobel
			edgeImage = leftImage.MedianImage("circle", 3, -1);
			edgeImage = edgeImage.SobelAmp("sum_abs", 3);
			// Convert edge image to black and white
			for (int y = 0; y < edgeImage.Height(); y++) {
				for (int x = 0; x < edgeImage.Width(); x++) {
					if (edgeImage(x, y) > myThreshold) {
						edgeImage(x, y) = 0;
					} else {
						edgeImage(x, y) = 255;
					}
				}
			}
			// Apply Shirai
			new SHIRAI(&leftImage, &rightImage, &edgeImage, &temp, atof(argv[4]), atof(argv[5]), atof(argv[6]), atof(argv[7]));
			// Maybe disparities are only found in one channel
			for (int y = 0; y < temp.Height(); y++) {
				for (int x = 0; x < temp.Width(); x++) {
					if (result(x, y) == 0) {
						 result(x, y) = temp(x, y);
					} else {
						 result(x, y) = (int)ceil((result(x, y) + temp(x, y))/2);
					}
				}
			}
			break;
		}
	}
	cout << " Processing finished..." << endl;
	// Write result to file
	result.WriteImage("jpeg 50", 0xffffff, argv[3]);
/*
	// Display result
	HWindow img1;
	result.Display(img1);
	img1.Click();
*/
	cout << " The End!" << endl << endl << endl << endl;
	return(0);
}