import cornerstone from 'cornerstone-core';
import cornerstoneTools, { importInternal } from 'cornerstone-tools';
import _ from 'lodash';

import bigGaussianBlur from '../filters/bigGaussianBlur';
import { otsu } from '../filters/otsuThreshold';
import dilation from '../operations/dilation';
import findConcaveHull from '../operations/findConcaveHull';
import {
  math,
  transformPointsToPhysicalById,
} from '../../modules/dicom-measurement/src';

// Cornerstone 3rd party dev kit imports
const draw = importInternal('drawing/draw');
const { getDistance, getDistance3D } = math;

const THRESHOLD_TYPE = {
  BINARY: 'binary',
  BINARY_INV: 'binary-inv',
  DEFAUlT: 'default',
};

export function sphereThresholdPainter(
  toolState,
  structureSetSeriesInstanceUid,
  config
) {
  const state = _.cloneDeep(toolState);
  const editModule = cornerstoneTools.getModule('rtstruct-edit');
  let ROINumber = editModule.getters.selectedROINumber();
  const { colorArray } = editModule.getters.ROIContour(
    structureSetSeriesInstanceUid,
    ROINumber
  );
  let initialPoint = null;
  let currentPoint = null;
  if (!_.isNumber(ROINumber)) return;
  const type = config.type || THRESHOLD_TYPE.DEFAUlT;
  let upper,
    lower,
    canvasCenter,
    canvasRadius,
    radius = 0;

  return {
    getState: function() {
      return state;
    },
    commit: function(evt) {
      /** check if ROI contours exist */
      if (state.data.find(x => x.ROINumber === ROINumber)) {
        const confirmed = window.confirm(
          'ROI contours exist. Are you sure to override the current ROI contours?'
        );
        if (!confirmed) return false;
      }

      const { image, element } = evt.detail;
      const { imageId } = image;
      const pixelData = image.getPixelData();

      /** canvas points to pixel points */
      const initialPixel = cornerstone.canvasToPixel(element, initialPoint);
      const currentPixel = cornerstone.canvasToPixel(element, currentPoint);
      canvasCenter = {
        x: (initialPoint.x + currentPoint.x) / 2,
        y: (initialPoint.y + currentPoint.y) / 2,
      };
      canvasRadius = getDistance(initialPoint, currentPoint) / 2;
      radius =
        getDistance3D(
          transformPointsToPhysicalById([initialPixel], imageId)[0],
          transformPointsToPhysicalById([currentPixel], imageId)[0]
        ) / 2;

      /** get otsu threshold on the current image */
      const sample = getSample(
        { ...image, pixelData },
        canvasCenter,
        canvasRadius,
        element
      );
      const threshold = otsu(sample);
      const foregroundAverage = getForegroundAverage(
        { ...image, pixelData },
        canvasCenter,
        canvasRadius,
        element
      );
      if (type === THRESHOLD_TYPE.BINARY) {
        upper = Infinity;
        lower = threshold;
      } else if (type === THRESHOLD_TYPE.BINARY_INV) {
        upper = threshold;
        lower = -Infinity;
      } else {
        upper = foregroundAverage > threshold ? Infinity : threshold;
        lower = foregroundAverage > threshold ? threshold : -Infinity;
      }

      /** binary threshold on the current image */
      runProcess(
        { ...image, pixelData },
        structureSetSeriesInstanceUid,
        ROINumber,
        upper,
        lower,
        canvasCenter,
        canvasRadius,
        element
      );

      return true;
    },
    commit3D: async function(evt) {
      /** binary threshold on the following images */
      const { element } = evt.detail;
      const stack = cornerstoneTools.getToolState(element, 'stack');
      const { currentImageIdIndex, imageIds } = stack.data[0];
      const cipp = cornerstone.metaData.get(
        'ImagePositionPatient',
        evt.detail.image.imageId
      );

      let prev = currentImageIdIndex - 1;
      while (prev >= 0) {
        const image = await cornerstone.loadImage(imageIds[prev]);
        const { imageId } = image;
        const ipp = cornerstone.metaData.get('ImagePositionPatient', imageId);
        const distance = getDistance3D(
          { x: ipp[0], y: ipp[1], z: ipp[2] },
          { x: cipp[0], y: cipp[1], z: cipp[2] }
        );
        if (distance > radius) break;
        const painter = editModule.setters.createPainter(
          'sphere-threshold',
          imageId,
          structureSetSeriesInstanceUid
        );
        const referencedCanvasRadius = Math.sqrt(
          Math.pow(canvasRadius, 2) -
            Math.pow((distance * canvasRadius) / radius, 2)
        );
        runProcess(
          { ...image, pixelData: image.getPixelData() },
          structureSetSeriesInstanceUid,
          ROINumber,
          upper,
          lower,
          canvasCenter,
          referencedCanvasRadius,
          element
        );
        painter.commitCallback();
        prev--;
      }

      let next = currentImageIdIndex + 1;
      while (next <= imageIds.length - 1) {
        const image = await cornerstone.loadImage(imageIds[next]);
        const { imageId } = image;
        const ipp = cornerstone.metaData.get(
          'ImagePositionPatient',
          image.imageId
        );
        const distance = getDistance3D(
          { x: ipp[0], y: ipp[1], z: ipp[2] },
          { x: cipp[0], y: cipp[1], z: cipp[2] }
        );
        if (distance > radius) break;
        const painter = editModule.setters.createPainter(
          'sphere-threshold',
          imageId,
          structureSetSeriesInstanceUid
        );
        const referencedCanvasRadius = Math.sqrt(
          Math.pow(canvasRadius, 2) -
            Math.pow((distance * canvasRadius) / radius, 2)
        );
        runProcess(
          { ...image, pixelData: image.getPixelData() },
          structureSetSeriesInstanceUid,
          ROINumber,
          upper,
          lower,
          canvasCenter,
          referencedCanvasRadius,
          element
        );
        painter.commitCallback();
        next++;
      }

      return true;
    },
    update: function(evt) {
      const eventData = evt.detail;
      const newPoint = eventData.currentPoints.canvas;
      if (!initialPoint) initialPoint = newPoint;
      currentPoint = newPoint;

      return true;
    },
    cursor: function(evt, context, cursorCanvasPosition, isDrawing) {
      if (!isDrawing) return false;
      draw(context, context => {
        const x = (initialPoint.x + currentPoint.x) / 2;
        const y = (initialPoint.y + currentPoint.y) / 2;
        const radius =
          Math.sqrt(
            Math.pow(Math.abs(initialPoint.x - currentPoint.x), 2) +
              Math.pow(Math.abs(initialPoint.y - currentPoint.y), 2)
          ) / 2;
        const startAngle = 0;
        const endAngle = Math.PI * 2;
        context.strokeStyle = `rgba(${colorArray.join(',')}, 1)`;
        context.fillStyle = `rgba(${colorArray.join(',')}, 0.1)`;

        const circle = new Path2D();
        circle.arc(x, y, radius, startAngle, endAngle);
        context.stroke(circle);
        context.fill(circle);
      });
      return true;
    },
  };
}

function runProcess(
  image,
  structureSetSeriesInstanceUid,
  ROINumber,
  upper,
  lower,
  center,
  radius,
  element
) {
  const processed = processImage(image, upper, lower, center, radius, element);
  processData(processed, structureSetSeriesInstanceUid, ROINumber);
}

function processImage(image, upper, lower, center, radius, element) {
  const blurredImage = bigGaussianBlur(image);
  const binaryImage = sphereThreshold(
    blurredImage,
    upper,
    lower,
    center,
    radius,
    element
  );
  return dilation(binaryImage);
}

function processData(image, structureSetSeriesInstanceUid, ROINumber) {
  const { pixelData, width, height, imageId } = image;
  const editModule = cornerstoneTools.getModule('rtstruct-edit');
  const toolState = editModule.getters.toolState(imageId);
  /** clear previous data and add new data */

  toolState.data = toolState.data.filter(x => x.ROINumber !== ROINumber);

  const hull = findConcaveHull({ pixelData, width, height });
  if (hull.length >= 3) {
    const newData = {
      ROINumber,
      handles: { points: hull },
      structureSetSeriesInstanceUid,
    };

    toolState.data = [...toolState.data, newData];
  }
}

function getForegroundAverage(image, center, radius, element) {
  const { pixelData, width } = image;
  const start = cornerstone.canvasToPixel(element, {
    x: center.x - radius,
    y: center.y - radius,
  });
  const end = cornerstone.canvasToPixel(element, {
    x: center.x + radius,
    y: center.y + radius,
  });

  const sample = [];
  for (let i = Math.round(start.x); i < Math.round(end.x); i++) {
    for (let j = Math.round(start.y); j < Math.round(end.y); j++) {
      const point = cornerstone.pixelToCanvas(element, { x: i, y: j });
      const d = getDistance(center, point);
      if (d < radius / 2) {
        const k = i + j * width;
        sample.push(pixelData[k]);
      }
    }
  }
  const sum = sample.reduce((acc, cur) => acc + cur, 0);
  const avg = sum / sample.length;
  return avg;
}

function getSample(image, center, radius, element) {
  const { pixelData, width } = image;
  const start = cornerstone.canvasToPixel(element, {
    x: center.x - radius,
    y: center.y - radius,
  });
  const end = cornerstone.canvasToPixel(element, {
    x: center.x + radius,
    y: center.y + radius,
  });

  const sample = [];
  for (let i = Math.round(start.x); i < Math.round(end.x); i++) {
    for (let j = Math.round(start.y); j < Math.round(end.y); j++) {
      const point = cornerstone.pixelToCanvas(element, { x: i, y: j });
      const d = getDistance(center, point);
      if (d <= radius) {
        const k = i + j * width;
        sample.push(pixelData[k]);
      }
    }
  }
  return sample;
}

export default function sphereThreshold(
  image,
  upper,
  lower,
  center,
  radius,
  element
) {
  const { pixelData, width, height } = image;
  let start = { x: 0, y: 0 };
  let end = { x: width, y: height };
  if (center && radius) {
    start = cornerstone.canvasToPixel(element, {
      x: center.x - radius,
      y: center.y - radius,
    });
    end = cornerstone.canvasToPixel(element, {
      x: center.x + radius,
      y: center.y + radius,
    });
  }

  const thresholded = new Int16Array(pixelData.length);
  for (let i = Math.round(start.x); i < Math.round(end.x); i++) {
    for (let j = Math.round(start.y); j < Math.round(end.y); j++) {
      const k = i + j * width;
      const point = cornerstone.pixelToCanvas(element, { x: i, y: j });
      if (center && radius && getDistance(center, point) > radius) continue;
      if (upper > pixelData[k] && pixelData[k] > lower) {
        thresholded[k] = 1;
      }
    }
  }
  return { ...image, pixelData: thresholded };
}
