/*******************************************************************************
 * Copyright (C) 2018 OTK Software
 * 
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 * 
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 * 
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
 ******************************************************************************/
package com.otk.application.image.filter;

import java.awt.image.BufferedImage;
import java.util.Arrays;

import com.otk.application.util.ImageUtils;

public class Retinex extends AbstractFilter {

	private static String[] POSSIBLE_LEVELS = new String[] { "Uniform", "Low", "High" };

	public String level = "Uniform";
	public int scale = 240;
	public int scaleDivision = 3;
	public int dynamic = 50;
	public int amount = 100;

	private RetinexAlgorithm algo = new RetinexAlgorithm();

	@Override
	public FilteringContext doExecute(FilteringContext context) {
		algo.scales_mode = Arrays.asList(POSSIBLE_LEVELS).indexOf(level);
		algo.scale = scale;
		algo.nscales = scaleDivision;
		algo.cvar = (float) (4.0 * dynamic / 100.0);
		BufferedImage resultImage = algo.execute(context.getImage());
		if (amount != 100) {
			resultImage = ImageUtils.blendImages(resultImage, amount, context.getImage(), 100 - amount);
		}
		return context.withImage(resultImage);
	}

	/*
	 * Based on Jimenez-Hernandez Francisco <jimenezf@fi.uaemex.mx> Retinex
	 * implementation Using ImageJ.
	 * 
	 * Based on: MSRCR Retinex (Multi-Scale Retinex with Color Restoration) 2003
	 * Fabien Pelisson <Fabien.Pelisson@inrialpes.fr> Retinex GIMP plug-in
	 * 
	 */

	private class RetinexAlgorithm {
		private static final int RETINEX_UNIFORM = 0, RETINEX_LOW = 1, RETINEX_HIGH = 2;
		private static final int MAX_RETINEX_SCALES = 8;

		/* Global vars */
		public int alpha = 128, offset = 0;
		public int scale = 0, nscales = 0, scales_mode = 0;
		public float cvar = 0f, gain = 1f;

		private float[] RetinexScales = new float[MAX_RETINEX_SCALES];
		private float mean = 0f, var = 0f;

		private BoxFilter gaussian = new BoxFilter();

		public BufferedImage execute(BufferedImage input) {
			int i, i3, offset, rgb;

			int width = input.getWidth();
			int height = input.getHeight();

			final int[] msrcrInput = new int[3 * height * width];

			int[] inputPixels = FastestRGBAccess.get(input);
			for (int col = 0; col < width; col++) {
				for (int row = 0; row < height; row++) {
					int pixelIndex = row * width + col;
					rgb = inputPixels[pixelIndex];
					offset = row * width;
					i = (offset + col) * 3;
					msrcrInput[i] = (int) (rgb & 0x0000ff);// B
					msrcrInput[i + 1] = (int) ((rgb & 0x00ff00) >> 8);// G
					msrcrInput[i + 2] = (int) ((rgb & 0xff0000) >> 16);// R
				}
			}

			int[] msrcrOutput = MSRCR(msrcrInput, width, height, 3);

			// Building an array for showing the output
			int[] outputPixels = new int[height * width];
			for (int row = 0; row < height; row++) {
				offset = row * width;
				for (int col = 0; col < width; col++) {
					i = offset + col;
					i3 = i * 3;
					outputPixels[i] = (inputPixels[i] & 0xff000000) | ((msrcrOutput[i3 + 2] << 16) & 0x00ff0000)
							| ((msrcrOutput[i3 + 1] << 8) & 0x0000ff00) | ((msrcrOutput[i3]) & 0x000000ff);
				}
			}
			BufferedImage output = new BufferedImage(width, height, ImageUtils.getAdaptedBufferedImageType());
			FastestRGBAccess.set(outputPixels, output);
			return output;
		}

		private float clip(float val, int minv, int maxv) {
			return ((val = (val < minv ? minv : val)) > maxv ? maxv : val);
		}

		/*
		 * MSRCR = MultiScale Retinex with Color Restoration
		 */
		private int[] MSRCR(int[] src, int width, int height, int bytes) {
			int scale;
			int i, j;
			int size;
			int pos;
			int channel;
			float[] dst;
			float[] pdst;
			float[] psrc;
			int channelsize;
			float weight;
			float mini, range, maxi;
			/* Allocating all the memory needed for the algorithm */
			size = width * height * bytes;
			dst = new float[size];

			channelsize = width * height;

			float[][] BGR = new float[3][channelsize];
			/*
			 * Calculate the scales of filtering according to the number of filter and their
			 * distribution.
			 */
			RetinexScales = retinex_scales_distribution(RetinexScales, this.nscales, this.scales_mode, this.scale);

			/*
			 * Filtering according to the various scales. Summarize the results of the
			 * various filters according to a specific weight(here equivalent for all).
			 */
			weight = 1f / (float) nscales;
			/*
			 * Here we changed the recursive filtering algorithm for ImageJ's gaussian
			 * blurring as the recursive filter results drift to the right of the image,
			 * possibly a bug
			 */
			pos = 0;
			for (channel = 0; channel < 3; channel++) {
				for (i = 0, pos = channel; i < channelsize; i++, pos += bytes) {
					BGR[channel][i] = clip(src[pos] + 1.0f, 1, 255);
				}

				for (scale = 0; scale < nscales; scale++) {

					int radius = (int) (RetinexScales[scale] * 2.5f);
					BGR[channel] = gaussian.filter(BGR[channel], width, height, radius);

					/*
					 * Summarize the filtered values.In fact one calculates a ratio between the
					 * original values and the filtered values.
					 */
					for (i = 0, pos = channel; i < channelsize; i++, pos += bytes)
						dst[pos] += weight * (float) (Math.log(src[pos] + 1f) - Math.log(BGR[channel][i]));
				}
			}

			/*
			 * Final calculation with original value and cumulated filter values. The
			 * parameters gain, alpha and offset are constants.
			 */
			/* Ci(x,y)=log[a Ii(x,y)]-log[ Ei=1-s Ii(x,y)] */
			// alpha = 128f;
			// gain =1f;
			// offset = 0f;

			psrc = new float[size];
			pdst = new float[size];
			for (i = 0; i < size; i++) {
				psrc[i] = (float) src[i];
				pdst[i] = (float) dst[i];
			}

			for (i = 0; i < size; i += bytes) {
				float logl;
				logl = (float) Math.log((float) psrc[i] + (float) psrc[i + 1] + (float) psrc[i + 2] + 3f);

				pdst[i] = gain * ((float) (Math.log(alpha * (psrc[i] + 1.0f)) - logl) * pdst[i]) + offset;
				pdst[i + 1] = gain * ((float) (Math.log(alpha * (psrc[i + 1] + 1.0f)) - logl) * pdst[i + 1]) + offset;
				pdst[i + 2] = gain * ((float) (Math.log(alpha * (psrc[i + 2] + 1.0f)) - logl) * pdst[i + 2]) + offset;
			}

			compute_mean_var(pdst, size, bytes);
			mini = mean - cvar * var;
			maxi = mean + cvar * var;
			range = maxi - mini;

			if (range == 0)
				range = 1f;
			int[] result = new int[size];
			for (i = 0; i < size; i += bytes) {
				for (j = 0; j < 3; j++) {
					float c = 255f * (pdst[i + j] - mini) / range;
					psrc[i + j] = clip(c, 0, 255);
					result[i + j] = (int) psrc[i + j];
				}
			}
			return result;
		}

		/*
		 * Calculate the mean and variance.
		 */
		private void compute_mean_var(float[] src, int size, int bytes) {
			float vsquared = 0f;
			int i, j;

			mean = 0f;
			for (i = 0; i < size; i += bytes) {
				for (j = 0; j < 3; j++) {
					mean += src[i + j];
					vsquared += src[i + j] * src[i + j];
				}
			}
			mean /= (float) size;
			vsquared /= (float) size;
			var = (vsquared - (mean * mean));
			var = (float) Math.sqrt(var);
		}

		/*
		 * Calculate scale values for desired distribution.
		 */
		private float[] retinex_scales_distribution(float[] scales, int nscales, int mode, int s) {
			if (nscales == 1)
				scales[0] = (float) s / 2f;
			else if (nscales == 2) {
				scales[0] = (float) s / 2f;
				scales[1] = (float) s;
			} else {
				float size_step = (float) s / (float) nscales;
				int i;
				switch (mode) {
				case RETINEX_UNIFORM:
					for (i = 0; i < nscales; ++i)
						scales[i] = 2f + (float) i * size_step;
					break;

				case RETINEX_LOW:
					size_step = (float) Math.log(s - 2f) / (float) nscales;
					for (i = 0; i < nscales; ++i)
						scales[i] = 2f + (float) Math.pow(10, (i * size_step) / Math.log(10));
					break;

				case RETINEX_HIGH:
					size_step = (float) Math.log(s - 2f) / (float) nscales;
					for (i = 0; i < nscales; ++i)
						scales[i] = s - (float) Math.pow(10, (i * size_step) / Math.log(10));
					break;

				default:
					break;
				}
			}
			return scales;
		}

		private class BoxFilter {

			private int[] precomputedOutputValues;
			private int precomputedRadius = -1;

			private boolean isPrecomputingRequired(int radius) {
				return precomputedRadius != radius;
			}

			private void precompute(int radius) {
				int slidingRangeLenth = radius + radius + 1;
				precomputedOutputValues = new int[256 * slidingRangeLenth];
				for (int sum = 0; sum < 256 * slidingRangeLenth; sum++) {
					precomputedOutputValues[sum] = getPrecomputedOutputValue(sum, slidingRangeLenth);
				}
				precomputedRadius = radius;
			}

			private int getPrecomputedOutputValue(int sum, int count) {
				return Math.round(((float) sum) / ((float) count));
			}

			private float[] filter(float[] pixels, int w, int h, int radius) {

				if (isPrecomputingRequired(radius)) {
					precompute(radius);
				}

				float[] newPixels = new float[pixels.length];
				System.arraycopy(pixels, 0, newPixels, 0, pixels.length);
				pixels = newPixels;

				int xMax = w - 1;
				int yMax = h - 1;
				int pixelCount = w * h;
				int componentCount = 1;
				int components[][] = new int[componentCount][pixelCount];
				int x, y, firstRangePixelIndex, lastRangePixelIndex, i, yp, yi, yw;
				float pixel, firstRangePixelValue, lastRangePixelValue;
				float componentSums[] = new float[componentCount];
				int vmin[] = new int[Math.max(w, h)];
				int vmax[] = new int[Math.max(w, h)];

				yw = yi = 0;

				for (y = 0; y < h; y++) {
					for (int c = 0; c < componentCount; c++) {
						componentSums[c] = 0;
					}
					for (i = -radius; i <= radius; i++) {
						pixel = pixels[yi + Math.min(xMax, Math.max(i, 0))];
						for (int c = 0; c < componentCount; c++) {
							componentSums[c] += getComponentValueFromPixel(pixel, c);
						}
					}
					for (x = 0; x < w; x++) {

						for (int c = 0; c < componentCount; c++) {
							components[c][yi] += getPrecomputedOutputValue(componentSums[c]);
						}

						if (y == 0) {
							vmin[x] = Math.min(x + radius + 1, xMax);
							vmax[x] = Math.max(x - radius, 0);
						}
						firstRangePixelValue = pixels[yw + vmin[x]];
						lastRangePixelValue = pixels[yw + vmax[x]];

						for (int c = 0; c < componentCount; c++) {
							componentSums[c] += getComponentValueFromPixel(firstRangePixelValue, c);
							componentSums[c] -= getComponentValueFromPixel(lastRangePixelValue, c);
						}
						yi++;
					}
					yw += w;
				}

				for (x = 0; x < w; x++) {
					for (int c = 0; c < componentCount; c++) {
						componentSums[c] = 0;
					}
					yp = -radius * w;
					for (i = -radius; i <= radius; i++) {
						yi = Math.max(0, yp) + x;
						for (int c = 0; c < componentCount; c++) {
							componentSums[c] += components[c][yi];
						}
						yp += w;
					}
					yi = x;
					for (y = 0; y < h; y++) {
						for (int c = 0; c < componentCount; c++) {
							pixels[yi] = getPixelWithNewComponentValue(pixels[yi], c,
									getPrecomputedOutputValue(componentSums[c]));
						}
						if (x == 0) {
							vmin[y] = Math.min(y + radius + 1, yMax) * w;
							vmax[y] = Math.max(y - radius, 0) * w;
						}
						firstRangePixelIndex = x + vmin[y];
						lastRangePixelIndex = x + vmax[y];

						for (int c = 0; c < componentCount; c++) {
							componentSums[c] += components[c][firstRangePixelIndex];
							componentSums[c] -= components[c][lastRangePixelIndex];
						}

						yi += w;
					}
				}

				return pixels;
			}

			private float getPrecomputedOutputValue(float key) {
				int keyFloor = (int) Math.floor(key);
				int keyCeil = (int) Math.ceil(key);
				return (precomputedOutputValues[keyFloor] + precomputedOutputValues[keyCeil]) / 2.0f;
			}

			private float getComponentValueFromPixel(float pixel, int c) {
				return pixel;
			}

			private float getPixelWithNewComponentValue(float pixel, int c, float value) {
				return value;
			}
		}

	}

}
