import numpy as np
import matplotlib.pyplot as plt
import matplotlib.lines as mlines

def newline(ax, p1, p2, color="#1f77b4", label="", alpha=1.):
    xmin, xmax = ax.get_xbound()

    if(p2[0] == p1[0]):
        xmin = xmax = p1[0]
        ymin, ymax = ax.get_ybound()
    else:
        ymax = p1[1]+(p2[1]-p1[1])/(p2[0]-p1[0])*(xmax-p1[0])
        ymin = p1[1]+(p2[1]-p1[1])/(p2[0]-p1[0])*(xmin-p1[0])

    l = mlines.Line2D([xmin,xmax], [ymin,ymax], color=color, label=label, alpha=alpha)
    ax.add_line(l)
    return l

def plot_rbf_svm(rbf_svm, X, Y):
    Y_pred = rbf_svm.predict(X)

    fig, (ax1, ax2) = plt.subplots(1,2, figsize=(16,8))

    ax1.scatter(X[:,0][Y == -1], X[:,1][Y == -1], marker="x", label="label: -1")
    ax1.scatter(X[:,0][Y ==  1], X[:,1][Y ==  1], marker="o", label="label: 1")
    ax1.scatter(X[:,0][Y != Y_pred], X[:,1][Y != Y_pred], color="black", marker="X", label="incorrect")

    delta = 0.5
    xrange = np.arange(-15., 15., delta)
    yrange = np.arange(-15., 15., delta)
    X1, X2 = np.meshgrid(xrange,yrange)

    G = X1.flatten()
    for i, (x, y) in enumerate(zip(X1.flatten(), X2.flatten())):
        G[i] = rbf_svm.intercept_[0]
        for (coef, vec) in zip(rbf_svm.dual_coef_[0], rbf_svm.support_vectors_):
            v = vec - [x, y]
            G[i] += coef*np.exp(-rbf_svm.gamma*v.dot(v))
    G = G.reshape(X1.shape)

    ax1.contour(X1, X2, G, [0])

    ax1.set_aspect('equal', 'datalim')


    xmin = -15.
    xmax = 15.
    ymin = -15.
    ymax= 15.

    extent = [xmin, xmax, ymin, ymax]
    im = ax2.imshow(G, extent=extent, origin="lower")
    plt.colorbar(im, ax=ax2)
    plt.tight_layout()

    print(f"Number of misclassification: {np.sum(Y != Y_pred)} / {len(Y)}")

def plot_poly_svm(poly_svm, X, Y):
    Y_pred = poly_svm.predict(X)
    fig, (ax1, ax2) = plt.subplots(1,2, figsize=(16,8))
    
    ax1.scatter(X[:,0][Y == -1], X[:,1][Y == -1], marker="x", label="label: -1")
    ax1.scatter(X[:,0][Y ==  1], X[:,1][Y ==  1], marker="o", label="label: 1")
    ax1.scatter(X[:,0][Y != Y_pred], X[:,1][Y != Y_pred], color="black", marker="X", label="incorrect")
    
    delta = 0.025
    xrange = np.arange(-2.0, 2.0, delta)
    yrange = np.arange(-2.0, 2.0, delta)
    X1, X2 = np.meshgrid(xrange,yrange)

    # F is one side of the equation, G is the other
    F = X1*X1+X2*X2-1
    #G = w[0]*X1*X1+w[1]*X2*X2+w[2]*np.sqrt(2)*X1*X2 + poly_svm.intercept_[0]
    G = X1.flatten()
    for i, (x, y) in enumerate(zip(X1.flatten(), X2.flatten())):
        G[i] = poly_svm.intercept_[0]
        for (coef, vec) in zip(poly_svm.dual_coef_[0], poly_svm.support_vectors_):
            G[i] += coef*(poly_svm.gamma*vec.dot([x, y]))**2
    G = G.reshape(X1.shape)

    ax1.contour(X1, X2, F, [0], alpha=0.3)
    ax1.contour(X1, X2, G, [0])
    ax1.set_aspect('equal', 'datalim')
    ax1.set_xlabel("x1")
    ax1.set_ylabel("x2")
    ax1.legend()
    
    xmin = -2.
    xmax = 2.
    ymin = -2.
    ymax= 2.
    
    extent = [xmin, xmax, ymin, ymax]
    im = ax2.imshow(G, extent=extent, origin="lower")
    plt.colorbar(im, ax=ax2)
    plt.tight_layout()
    
def plot_linear_svm(svm_linear, X, Y):
    Y_pred = svm_linear.predict(X)
    fig, (ax1, ax2) = plt.subplots(1,2, figsize=(16,8))
    
    ax1.scatter(X[:,0][Y == -1], X[:,1][Y == -1], marker="x", label="label: -1")
    ax1.scatter(X[:,0][Y ==  1], X[:,1][Y ==  1], marker="o", label="label: 1")
    ax1.scatter(X[:,0][Y != Y_pred], X[:,1][Y != Y_pred], color="black", marker="X", label="incorrect")

    newline(ax1, [0, 0], [1, -1], label="class boundary", color="black", alpha=0.5)

    w = svm_linear.coef_.reshape(-1)
    b = svm_linear.intercept_[0]
    p1 = [0, -b/w[1]]
    p2 = [-b/w[0], 0]
    newline(ax1, p1, p2, label="SVM decision boundary", color="red")

    # KKT Matrix for nearest point on line
    A = np.array([[1,    0,   -w[0]],
                  [0,    1,   -w[1]],
                  [w[0], w[1], 0]])


    for s in svm_linear.support_vectors_:
        c = np.array([s[0], s[1], -b]).T
        x1, x2,_ = np.linalg.solve(A,c)
        l = mlines.Line2D([s[0], x1], [s[1], x2])
        ax1.add_line(l)

    ax1.set_aspect('equal', 'datalim')

    ax1.set_xlabel("x1")
    ax1.set_ylabel("x2")
    ax1.legend()
    
    delta = 0.025
    xmin = -2.
    xmax = 2.
    ymin = -2.
    ymax= 2.
    xrange = np.arange(xmin, xmax, delta)
    yrange = np.arange(ymin, ymax, delta)
    X1, X2 = np.meshgrid(xrange,yrange)

    dims = X1.shape
    w = svm_linear.coef_.reshape(-1)
    b = svm_linear.intercept_[0]
    flat_grid = np.vstack((X1.flatten(), X2.flatten())).T
    F = flat_grid.dot(w)+b
    F = F.reshape(dims)
    
    extent = [xmin, xmax, ymin, ymax]
    im = ax2.imshow(F, extent=extent, origin="lower")
    plt.colorbar(im, ax=ax2)
    plt.tight_layout()
    print(f"Number of misclassification: {np.sum(Y != Y_pred)} / {len(Y)}")