import math
import numpy as np
import matplotlib.pyplot as plt
import cv2


def imshow(img, w_max=1200, h_max=700, name="show"):
    h = img.shape[0]
    w = img.shape[1]
    ratio = w / h
    ratio_max = w_max / h_max

    if ratio >= ratio_max:
        w_new = w_max
        h_new = int(h * w_new / w)
    else:
        h_new = h_max
        w_new = int(w * h_new / h)
    img_show = cv2.resize(img, (w_new, h_new))
    cv2.imshow(name, img_show)


def mill(img_mill, img_keep, dpi, dia_mm, stepover=0.8, once=False):
    global count
    img_mill_bin = np.zeros_like(img_mill)
    img_mill_bin[img_mill >= 127] = 255

    img_keep_bin = np.zeros_like(img_keep)
    img_keep_bin[img_keep >= 127] = 255

    dpmm = dpi / 25.4
    dia_px = int(dia_mm * dpmm)

    print(f"Tool: {dia_mm} mm = {dia_px} px")
    if dia_px < 10:
        print(f"Warning: tool is only {dia_px} pixels wide")

    k_step = int(dia_mm * 2 * stepover * dpmm)
    k_tool = int(dia_mm * dpmm)

    kernel_step = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k_step, k_step))
    kernel_tool = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k_tool, k_tool))

    done = False

    paths = []

    # mask = cv2.dilate(img_mill_bin, kernel_tool)
    # mask = cv2.dilate(mask, kernel_tool)
    img_keep_pad = np.pad(img_keep_bin, pad_width=(k_tool, k_tool), mode="constant", constant_values=0)

    mask = np.copy(img_mill_bin)
    # cv2.imwrite(f"mask_init_{count:02d}_pre.png", mask)
    mask = cv2.dilate(mask, kernel_tool)
    # mask = cv2.dilate(mask, kernel_tool)
    mask[img_keep_bin > 0] = 0
    # cv2.imwrite(f"mask_init_{count:02d}.png", mask)
    # mask[img_keep > 0] = 0
    mask = np.pad(mask, pad_width=(k_tool, k_tool), mode="constant", constant_values=0)

    h_p, w_p = mask.shape

    edge_px = k_tool

    mask_diff = np.zeros_like(mask)
    mask_miss = np.zeros_like(mask)
    mask_next = np.zeros_like(mask)

    i = 0
    img_miss = np.zeros_like(img_mill_bin)

    while not done:
        if i == 0:
            dil = cv2.dilate(img_keep_pad, kernel_tool)
            mask_next[(mask > 0) & ~(dil > 0)] = 255
            mask_next_dil = cv2.dilate(mask_next, kernel_tool)
            missed = (mask > 0) & ~(mask_next_dil > 0)
            img_miss[missed[edge_px:-edge_px, edge_px:-edge_px]] = 255
            # ignore pixels that weren't marked to begin with
            img_miss[img_mill_bin == 0] = 0
        else:
            mask_next[:, :] = cv2.erode(mask, kernel_step)
            mask_diff[~(mask > 0)] = 0
            mask_diff[mask > 0] = 255
            mask_diff[mask_next > 0] = 0
            mask_miss[:, :] = cv2.erode(mask_diff, kernel_tool)
            contours, hierarchy = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_TC89_KCOS)
            for c in contours:
                c_float = c.astype(np.float32)
                c_mm = np.zeros_like(c_float)
                c_mm[:, 0, 0] = (c_float[:, 0, 0] - edge_px) / dpmm
                c_mm[:, 0, 1] = ((h_p - c_float[:, 0, 1]) - edge_px) / dpmm
                paths.append(c_mm[:, 0, :])
            if once:
                done = True

        mask[:, :] = mask_next[:, :]
        if i > 0:
            mask[mask_miss > 0] = 255

        i += 1

        if not np.any(mask > 0):
            done = True

        print(i)

    return paths, img_miss


def add_gcode(f, paths, z_down=0.0, z_up=1.0):
    feedrate_cut = 480.0
    feedrate_plunge = 240.0

    for p in paths:
        x = p[0, 0]
        y = p[0, 1]
        f.write(f"G0 X{x:.4f} Y{y:.4f}\n")
        f.write(f"G1 Z{z_down:.4f} F{feedrate_plunge:.4f}\n")

        for c in p:
            x, y = c[:]
            f.write(f"G1 X{x:.4f} Y{y:.4f} F{feedrate_cut:.4f}\n")

        x = p[0, 0]
        y = p[0, 1]
        f.write(f"G1 X{x:.4f} Y{y:.4f} F{feedrate_cut:.4f}\n")
        f.write(f"G0 Z{z_up:.4f}\n")


def main():
    filename_copper = "input/Cu.png"
    filename_edge = "input/Edge.png"
    filename_out = "result.nc"

    dpi = 1000.0

    img_copper = cv2.imread(filename_copper)
    img_edge = cv2.imread(filename_edge)

    if len(img_copper.shape) > 2:
        img_copper = cv2.cvtColor(img_copper, cv2.COLOR_BGR2GRAY)

    if len(img_edge.shape) > 2:
        img_edge = cv2.cvtColor(img_edge, cv2.COLOR_BGR2GRAY)

    img1 = np.zeros_like(img_edge)
    img1[(img_edge < 127) | (img_copper > 127)] = 255
    img1_keep = np.copy(img1)
    img1_keep[img1_keep >= 127] = 255
    img1_keep[img1_keep < 127] = 0
    img1_mill = 255 - img1_keep

    paths1, img_miss1 = mill(img1_mill, img1_keep, 1000, 0.794)

    paths2, img_miss2 = mill(img_miss1, img1_keep, 1000, 0.397)

    img_edge_keep = np.copy(img_edge)
    img_edge_keep[img_edge_keep >= 127] = 255
    img_edge_keep[img_edge_keep < 127] = 0
    img_edge_mill = 255 - img_edge_keep
    paths3, _ = mill(img_edge_mill, img_edge_keep, dpi, 0.794, once=True)

    f = open(filename_out, "w")

    f.write("%\n")
    f.write("G17\n")
    f.write("G21\n")
    f.write("G40\n")
    f.write("G49\n")
    f.write("G54\n")
    f.write("G80\n")
    f.write("G90\n")
    f.write("G94\n")
    f.write("T4 M06\n")
    f.write("S16000\n")
    f.write("G0 Z1\n")
    f.write("M03\n")

    add_gcode(f, paths2, z_down=-0.12)

    f.write("M05\n")
    f.write("T3 M06\n")
    f.write("S16000\n")
    f.write("G0 Z1\n")
    f.write("M03\n")

    add_gcode(f, paths1, z_down=-0.12)

    add_gcode(f, paths3, z_down=-0.6)
    add_gcode(f, paths3, z_down=-1.2)
    add_gcode(f, paths3, z_down=-1.75)

    f.write("M30\n")
    f.write("%\n")
    f.write("M6 T-1\n")
    f.write("M496.1\n")

    f.close()

    dpmm = dpi / 25.4

    h, w = np.shape(img1_mill)
    img_show = np.zeros((h, w, 3), dtype=np.uint8)
    img_show[img1_mill > 0, :] = 150
    img_show[img_miss2 > 0, :] = 0
    img_show[img_miss2 > 0, 0] = 255

    plt.figure()
    plt.imshow(img_show)
    for c in paths1:
        plt.plot(c[:, 0] * dpmm,  h-c[:, 1] * dpmm, "b")
        plt.plot([c[-1, 0] * dpmm, c[0, 0] * dpmm],  [h-c[-1, 1] * dpmm, h-c[0, 1] * dpmm], "b")
    for c in paths2:
        plt.plot(c[:, 0] * dpmm,  h-c[:, 1] * dpmm, "g")
        plt.plot([c[-1, 0] * dpmm, c[0, 0] * dpmm],  [h-c[-1, 1] * dpmm, h-c[0, 1] * dpmm], "g")
    for c in paths3:
        plt.plot(c[:, 0] * dpmm,  h-c[:, 1] * dpmm, "b")
        plt.plot([c[-1, 0] * dpmm, c[0, 0] * dpmm],  [h-c[-1, 1] * dpmm, h-c[0, 1] * dpmm], "b")
    plt.axis("equal")
    plt.show()


if __name__ == "__main__":
    main()