"""
Author: Ryan Friedman (@rfriedman22)
Email: ryan.friedman@wustl.edu
"""

from datetime import datetime
import matplotlib as mpl
import matplotlib.lines as mlines
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy import stats
from sklearn.metrics import auc

from PIL import Image
from io import BytesIO


def set_presentation_params():
    """Set the matplotlib rcParams to values for presentation-size figures. (Misnomer because I don't use this.)
    
    """
    mpl.rcParams["axes.titlesize"] = 90
    mpl.rcParams["axes.labelsize"] = 80
    mpl.rcParams["xtick.labelsize"] = 60
    mpl.rcParams["ytick.labelsize"] = 60
    mpl.rcParams["legend.fontsize"] = 60
    mpl.rcParams["figure.figsize"] = (25, 25)
    mpl.rcParams["image.cmap"] = "viridis"
    mpl.rcParams["lines.markersize"] = 14
    mpl.rcParams["lines.linewidth"] = 15
    mpl.rcParams["font.size"] = 60
    mpl.rcParams["xtick.major.size"] = 10
    mpl.rcParams["xtick.major.width"] = 3
    mpl.rcParams["ytick.major.size"] = 10
    mpl.rcParams["ytick.major.width"] = 3
    
    
def set_print_params():
    """Set the matplotlib rcParams to values for print-size figures. (Misnomer because I use this in slides.)
    
    """
    mpl.rcParams["axes.titlesize"] = 25
    mpl.rcParams["axes.labelsize"] = 20
    mpl.rcParams["xtick.labelsize"] = 15
    mpl.rcParams["ytick.labelsize"] = 15
    mpl.rcParams["legend.fontsize"] = 15
    mpl.rcParams["figure.figsize"] = (8, 8)
    mpl.rcParams["image.cmap"] = "viridis"
    mpl.rcParams["lines.markersize"] = 3
    mpl.rcParams["lines.linewidth"] = 3
    mpl.rcParams["font.size"] = 15


def set_manuscript_params():
    """Set the matplotlib rcParams to values for manuscript-size figures.

    """
    mpl.rcParams["figure.figsize"] = (4, 4)
    mpl.rcParams["axes.titlesize"] = 15
    mpl.rcParams["axes.labelsize"] = 12
    mpl.rcParams["xtick.labelsize"] = 12
    mpl.rcParams["ytick.labelsize"] = 12
    mpl.rcParams["legend.fontsize"] = 12
    mpl.rcParams["image.cmap"] = "viridis"
    mpl.rcParams["lines.markersize"] = 1.25
    mpl.rcParams["lines.linewidth"] = 2
    mpl.rcParams["font.size"] = 12
    mpl.rcParams["savefig.dpi"] = 300


def add_letter(ax, x, y, letter):
    """Add a letter to label an axes as a panel of a larger figure.

    Parameters
    ----------
    ax : Axes object
        The panel to add the letter to.
    x : int
        x coordinate of the right side of the letter, in ax.transAxes coordinates
    y : int
        y coordinate of the top side of the letter, in ax.transAxes coordinates
    letter : str
        The letter to add

    Returns
    -------
    Text
        The created Text instance
    """
    return ax.text(x, y, letter, fontsize=mpl.rcParams["axes.labelsize"], fontweight="bold", ha="right", va="top",
                   transform=ax.transAxes)


def rotate_ticks(ticks, rotation=90):
    """Rotate tick labels from an Axes object after the ticks were already generated.

    Parameters
    ----------
    ticks : list[Text]
        The tick labels to rotate
    rotation : int or float
        The angle to set for the tick labels

    Returns
    -------
    None
    """
    for tick in ticks:
        tick.set_rotation(rotation)


def set_color(values):
    """A wrapper for converting numbers into colors. Given a number between 0 and 1, convert it to the corresponding color in the color scheme.
    
    """
    my_cmap = mpl.cm.get_cmap()
    return my_cmap(values)


def save_fig(fig, prefix, tight_layout=True, timestamp=True, tight_pad=1.08):
    """Save a figure as a PNG and an SVG.
    
    """
    if tight_layout:
        fig.tight_layout(pad=tight_pad)
    if timestamp:
        now = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        fig.text(0, 0, now, transform=fig.transFigure)
    fig.savefig(f"{prefix}.svg", bbox_inches="tight")
    fig.savefig(f"{prefix}.png", bbox_inches="tight")
    # Trick to save a TIFF file https://stackoverflow.com/questions/37945495/save-matplotlib-figure-as-tiff
    png1 = BytesIO()
    fig.savefig(png1, format="png", bbox_inches="tight")
    png2 = Image.open(png1)
    png2.save(f"{prefix}.tiff")
    png1.close()
    
    
def setup_multiplot(n_plots, n_cols=2, sharex=True, sharey=True, big_dimensions=True):
    """Setup a multiplot and hide any superfluous axes that may result.

    Parameters
    ----------
    n_plots : int
        Number of subplots to make
    n_cols : int
        Number of columns in the multiplot. Number of rows is inferred.
    sharex : bool
        Indicate if the x-axis should be shared.
    sharey : bool
        Indicate if the y-axis should be shared.
    big_dimensions : bool
        If True, then the size of the multiplot is the default figure size multiplied by the number of rows/columns.
        If False, then the entire figure is the default figure size.

    Returns
    -------
    fig : figure handle
    ax_list : list-like
        The list returned by plt.subplots(), but any superfluous axes are removed and replaced by None
    """
    n_rows = int(np.ceil(n_plots / n_cols))
    row_size, col_size = mpl.rcParams["figure.figsize"]

    if big_dimensions:
        # A bit counter-intuitive...the SIZE of the row is the width, which depends on the number of columns
        row_size *= n_cols
        col_size *= n_rows

    fig, ax_list = plt.subplots(nrows=n_rows, ncols=n_cols, figsize=(row_size, col_size), sharex=sharex, sharey=sharey)

    # The index corresponding to n_plots is the first subplot to be hidden
    for i in range(ax_list.size):
        coords = np.unravel_index(i, ax_list.shape)
        ax = ax_list[coords]
        if i >= n_plots:
            ax.remove()
            ax_list[coords] = None

    return fig, ax_list


def volcano_plot(df, x_col, y_col, colors, alpha=1, xaxis_label=None, yaxis_label=None, title=None, figname=None,
                 xline=None, yline=None, xticks=None, vmin=None, vmax=None, cmap=None, colorbar=False, figax=None):
    """Make a volcano plot, without transforming the x-axis but taking -log10 of the y-axis. Assign different points
    different colors to highlight different classes.

    Parameters
    ----------
    df : pd.DataFrame
    x_col : str
        Column of the df to plot on x
    y_col : str
        Column of the df to plot on y. Take -log10 of this column before plotting
    colors : list-like
        Indicates color to use for each row of df.
    alpha : float
        Opacity of the points.
    xaxis_label : str
        If specified, the label for the x-axis. Otherwise use x_col.
    yaxis_label : str
        If specified, the label for the y-axis. Otherwise use y_col.
    title : str
        If specified, make a title for the plot.
    figname : str
        If specified, save the figure with this name.
    xline : int or float or list
        If specified, plot a dashed vertical line at x = xline
    yline : int or float or list
        If specified, plot a dashed horizontal line at y = yline
    xticks : list
        If specified, set the x ticks to these values.
    vmin : int or float
        If specified, minimum value for the colormap.
    vmax : int or float
        If specified, maximum value for the colormap.
    cmap : str
        If specified, use this colormap. Otherwise, use the default.
    colorbar : bool
        If True, display a colorbar to the right.
    figax : (figure, axes) or None
        If specified, make the plot in the provided axes. Otherwise, generate a new axes.

    Returns
    -------
    fig : Figure handle
    """
    if figax:
        fig, ax = figax
    else:
        fig, ax = plt.subplots()

    # Prepare the data
    x = df[x_col]
    y = -np.log10(df[y_col])
    scatter_kwargs = {"c": colors, "alpha": alpha}
    if vmin:
        scatter_kwargs["vmin"] = vmin
    if vmax:
        scatter_kwargs["vmax"] = vmax
    if cmap:
        scatter_kwargs["cmap"] = cmap

    scatterplot = ax.scatter(x, y, **scatter_kwargs)

    # Default axis labels if none specified
    if not xaxis_label:
        xaxis_label = x_col
    if not yaxis_label:
        yaxis_label = f"-log10 {y_col}"

    # Add dotted lines if specified
    line_kwargs = {"linestyle": "--", "color": "black"}
    if xline is not None:
        if type(xline) is list:
            for xl in xline:
                ax.axhline(xl, **line_kwargs)
        else:
            ax.axhline(xline, **line_kwargs)
    if yline is not None:
        if type(yline) is list:
            for yl in yline:
                ax.axvline(yl, **line_kwargs)
        else:
            ax.axvline(yline, **line_kwargs)

    # Axis labels, ticks, colorbar, title if specified
    ax.set_xlabel(xaxis_label)
    ax.set_ylabel(yaxis_label)

    if xticks is not None:
        ax.set_xticks(xticks)

    if colorbar:
        fig.colorbar(scatterplot, orientation="vertical")

    if title:
        ax.set_title(title)

    if figname:
        save_fig(fig, figname)

    return fig


def scatter_with_corr(x, y, xlabel, ylabel, colors="black", xticks=None, yticks=None, loc=None, figname=None,
                      alpha=1.0, figax=None):
    """Make a scatter plot and display the correlation coefficients in a specified location.

    Parameters
    ----------
    x : list-like
        Data to plot on the x axis.
    y : list-like
        Data to plot on the y axis.
    xlabel : str
        Label for the x axis.
    ylabel : str
        Label for the y axis.
    colors : "density", str or list-like
        If "density", color points based on point density in 2D space. If another str, make every point the same
        color. If list-like, specifies the color for each point.
    xticks : list-like
        If specified, set the x axis ticks to these values.
    yticks: list-like
        If specified, set the y axis ticks to these values.
    loc : str, must be one of "upper left", "upper right", "lower left", or "lower right"
        The location of the plot to display the correlations. If None, just print to the screen. If some other
        string, assume "lower right".
    figname : str
        If specified, save the figure with this name.
    alpha : float
        Alpha (opacity) of the points.
    figax : (figure, axes) or None
        If specified, make the plot in the provided axes. Otherwise, generate a new axes.

    Returns
    -------
    fig : Figure handle
    ax : Axes handle
    """
    # Correlations
    pcc, _ = stats.pearsonr(x, y)
    scc, _ = stats.spearmanr(x, y)
    n = len(x)
    text = f"PCC = {pcc:.3f}\nSCC = {scc:.3f}\nn = {n}"

    # Calculate the density to display on the scatter plot, if specified
    if type(colors) is str and colors == "density":
        xy = np.vstack([x, y])
        colors = stats.gaussian_kde(xy)(xy)
        order = colors.argsort()
        x, y, colors = x[order], y[order], colors[order]
        colors = set_color(colors / colors.max())

    if figax:
        fig, ax = figax
    else:
        fig, ax = plt.subplots()

    ax.scatter(x, y, color=colors, alpha=alpha)
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)

    if xticks is not None:
        ax.set_xticks(xticks)

    if yticks is not None:
        ax.set_yticks(yticks)

    # If no location is specified for the correlations, print to screen.
    if loc is None:
        print(text)
    # Parse info on location
    else:
        yloc, xloc = loc.split()
        if yloc == "upper":
            yloc = 0.98
            va = "top"
        else:
            if yloc != "lower":
                print("Warning, did not recognize yloc, assuming lower")
            yloc = 0.02
            va = "bottom"

        if xloc == "left":
            xloc = 0.02
            ha = "left"
        else:
            if xloc != "right":
                print("Warning, did not recognize xloc, assuming right")
            xloc = 0.98
            ha = "right"

        ax.text(xloc, yloc, text, ha=ha, va=va, transform=ax.transAxes)

    if figname:
        save_fig(fig, figname)

    return fig, ax


# LEGACY FUNCTION
def violin_plot_series(ser, class_masks, class_names, yname, class_colors=None, alpha=1.0, transformation_function=None,
                       pseudocount=0, figname=None, vert=True, yticks=None, figax=None, **kwargs):
    """ Make a violin plot from a series, with len(class_masks) violins.

    Parameters
    ----------
    ser : pd.Series
        Series containing the data
    class_masks : list of pd.Series
        Each value of the list is a boolean mask corresponding to different subsets of the Series.
    class_names : list[str]
        Names for each class
    yname : str
        Name for the y axis
    class_colors : list
        Optional colors for each group.
    alpha : float
        Opacity of the violins.
    transformation_function : function handle
        Optional transformation to apply to the data.
    pseudocount : int or float
        Optional pseudocount for the data.
    figname : str
        If specified, save the figure to a file with this name.
    vert : bool
        If True, violins are vertical. Otherwise, violins are horizontal.
    yticks : list
        If specified, indicates the ticks for the y axis.
    figax : (figure, axes) or None
        If specified, make the plot in the provided axes. Otherwise, generate a new axes.
    kwargs : dict
        Arguments for saving the figure

    Returns
    -------
    fig : figure handle
    """
    data = [ser[i] for i in class_masks]
    fig = _make_violin_plot(data, class_names, yname, colors=class_colors, alpha=alpha,
                            transformation_function=transformation_function, pseudocount=pseudocount,
                            figname=figname, vert=vert, yticks=yticks, figax=figax, **kwargs)
    return fig


def violin_plot_groupby(grouper, yname, class_names=None, class_colors=None, alpha=1.0, transformation_function=None,
                        pseudocount=0, figname=None, vert=True, yticks=None, figax=None, **kwargs):
    """Make a violin plot from a groupby object.

    Parameters
    ----------
    grouper : pd.DataFrameGroupBy or pd.SeriesGroupBy
        Group by object where each group is data for a different violin.
    yname : str
        Name for the y axis
    class_names : list
        Optional names for each group. If not specified, use the names from the grouper
    class_colors : list
        Optional colors for each group.
    alpha : float
        Opacity of the violins.
    transformation_function : function handle
        Optional transformation to apply to the data.
    pseudocount : int or float
        Optional pseudocount for the data.
    figname : str
        If specified, save the figure to a file with this name.
    vert : bool
        If True, violins are vertical. Otherwise, violins are horizontal.
    yticks : list
        If specified, indicates the ticks for the y axis.
    figax : (figure, axes) or None
        If specified, make the plot in the provided axes. Otherwise, generate a new axes.
    kwargs : dict
        Arguments for saving the figure

    Returns
    -------
    fig : figure handle
    """
    names, data = zip(*[(i, j) for i, j in grouper if len(j) > 0])
    if class_names:
        names = class_names

    fig = _make_violin_plot(data, names, yname, colors=class_colors, alpha=alpha,
                            transformation_function=transformation_function, pseudocount=pseudocount,
                            figname=figname, vert=vert, yticks=yticks, figax=figax, **kwargs)
    return fig


# LEGACY FUNCTION
def violin_plot_by_column(df, y_label, column_colors=None, alpha=1.0, transformation_function=None, pseudocount=0,
                          figname=None, vert=True, xnames=None, yticks=None, figax=None):
    """Make a violin plot for each column of a dataframe"""
    data_values = [df[i] for i in df]
    if xnames is None:
        xnames = df.columns

    fig = _make_violin_plot(data_values, xnames, y_label, colors=column_colors, alpha=alpha,
                            transformation_function=transformation_function, pseudocount=pseudocount, figname=figname,
                            vert=vert, yticks=yticks, figax=figax)
    return fig


# LEGACY FUNCTION
def violin_plot(df, class_masks, class_names, column_name, class_colors=None, alpha=1.0, transformation_function=None,
                pseudocount=0, y_label=None, figname=None, vert=True, yticks=None, figax=None, **kwargs):
    """Make a violin plot with len(class_masks) violins for the specified column from a DataFrame.
    
    """
    data_values = [df.loc[i, column_name].values for i in class_masks]
    if not y_label:
        y_label = column_name
    
    fig = _make_violin_plot(data_values, class_names, y_label, colors=class_colors, alpha=alpha,
                            transformation_function=transformation_function, pseudocount=pseudocount,
                            figname=figname, vert=vert, yticks=yticks, figax=figax, **kwargs)
    return fig


def _make_violin_plot(data_values, x_labels, y_label, colors=None, alpha=1.0, transformation_function=None,
                      pseudocount=0, figname=None, vert=True, yticks=None, whisker=1.5, figax=None, **kwargs):
    """Helper function to make violin plots"""
    # Transform the data (e.g. take the log10) if necessary
    if transformation_function:
        data_values = [transformation_function(i + pseudocount) for i in data_values]
    xaxis = np.arange(len(x_labels)) + 1

    # Set the color to grey for everything if colors aren't specified
    if colors is None:
        colors = ["grey"] * len(x_labels)

    # Separate outliers from the rest
    class_quartiles = np.array([np.percentile(i, [25, 50, 75]) for i in data_values])
    class_iqrs = class_quartiles[:, 2] - class_quartiles[:, 0]
    class_whisker = class_iqrs * whisker
    outlier_masks = [(group_data > quartiles[2] + whisk) | (group_data < quartiles[0] - whisk)
            for group_data, quartiles, whisk in zip(data_values, class_quartiles, class_whisker)]
    outlier_data = [group_data[group_mask] for group_data, group_mask in zip(data_values, outlier_masks)]
    main_data = [group_data[~group_mask] for group_data, group_mask in zip(data_values, outlier_masks)]

    # Plot the data and color the violins accordingly.
    if figax:
        fig, ax = figax
    else:
        fig, ax = plt.subplots()

    parts = ax.violinplot(main_data, vert=vert)
    for pc, color in zip(parts["bodies"], colors):
        pc.set_facecolor(color)
        pc.set_edgecolor("black")
        pc.set_alpha(alpha)

    # Clean up the plot
    parts["cmins"].remove()
    parts["cmaxes"].remove()
    parts["cbars"].remove()

    # Add lines for median
    if vert:
        ax.hlines(class_quartiles[:, 1], xaxis - 0.2, xaxis + 0.2, colors="black", zorder=3,
                  lw=mpl.rcParams["lines.linewidth"] * 2)
        # Old code to show a box for the IQR and a dot for the median
        # ax.scatter(xaxis, class_quartiles[:, 1], marker="o", color="white", zorder=3,
        #            s=mpl.rcParams["lines.markersize"] * 10)
        #ax.vlines(xaxis, class_quartiles[:, 0], class_quartiles[:, 2],
        #          color="black", zorder=1)

        # Plot outliers
        for x, outliers in zip(xaxis, outlier_data):
            ax.scatter([x] * len(outliers), outliers, color="k")

        ax.set_ylabel(y_label)
        ax.set_xticks(xaxis)
        ax.set_xticklabels(x_labels)
        if yticks is not None:
            ax.set_yticks(yticks)
    else:
        ax.vlines(class_quartiles[:, 1], xaxis - 0.2, xaxis + 0.2, colors="black", zorder=3,
                  lw=mpl.rcParams["lines.linewidth"] * 2)
        # ax.scatter(class_quartiles[:, 1], xaxis, marker="o", color="white", zorder=3, s=mpl.rcParams[
        #     "lines.markersize"] * 10)
        #ax.hlines(xaxis, class_quartiles[:, 0], class_quartiles[:, 2],
        #          color="black", zorder=1)

        # Plot outliers
        for x, outliers in zip(xaxis, outlier_data):
            ax.scatter(outliers, [x] * len(outliers), color="k")

        ax.set_xlabel(y_label)
        ax.set_yticks(xaxis)
        ax.set_yticklabels(x_labels)
        if yticks is not None:
            ax.set_xticks(yticks)

    fig.tight_layout()
    if figname:
        save_fig(fig, figname, **kwargs)

    return fig


def multi_hist(df, column_list, xlabel, ylabel, n_cols=2, transform=None, sharex=True, sharey=True, bins=10,
               pseudocount=0, figname=None, big_dimensions=True):
    """Make a figure with multiple subplots, each subplot containing a histogram for a different column of the
    dataframe. Optionally add a pseudocount and transform the data before plotting.

    Parameters
    ----------
    df : pd.DataFrame
        The data to plot
    column_list : list-like
        Column names to plot. Each column is plotted on a separate histogram.
    xlabel : str
        Label for the x-axis of the plots
    ylabel : str
        Label for the y-axis of the plots
    n_cols : int
        Number of columns in the multiplot
    transform : function handle
        If specified, add a pseudocount to the data and then apply the transformation function.
    sharex : bool
        Indicates if the x-axis should be shared across subplots.
    sharey : bool
        Same as sharex for y-axis.
    bins : int
        Number of bins for the histogram.
    pseudocount : int or float
        Add a pseudocount to the data if a transformation function is specified.
    figname : str
        If specified, save the figure with this name.
    big_dimensions : bool
        If True, then the size of the multiplot is the default figure size multiplied by the number of rows/columns.
        If False, then the entire figure is the default figure size.

    Returns
    -------
    fig : Figure handle
    """
    n_plots = len(column_list)
    fig, ax_list = setup_multiplot(n_plots, n_cols=n_cols, sharex=sharex, sharey=sharey, big_dimensions=big_dimensions)
    if len(ax_list.shape) == 1:
       ax_list = np.reshape(ax_list, (len(ax_list), 1)) 
    
    n_rows, _ = ax_list.shape # Used for the x axis display

    for i in range(n_plots):
        row, col = np.unravel_index(i, ax_list.shape)
        ax = ax_list[row, col]
        label = column_list[i]

        # Get rid of any NaN in the data since this is different from a zero
        data = df[label]
        data = data[data.notna()]

        if transform:
            data = transform(data + pseudocount)

        ax.hist(data, bins)
        ax.set_title(label)

        # Add axis labels if the axis is not shared or the axis is shared and on the appropriate axis.
        if not sharex or row == n_rows - 1:
            ax.set_xlabel(xlabel)
        if not sharey or col == 0:
            ax.set_ylabel(ylabel)

    if figname:
        save_fig(fig, figname, tight_layout=True)

    return fig


def roc_pr_curves(xaxis, tpr_list, precision_list, model_names, model_colors=None, prc_chance=None,
                  prc_upper_ylim=None, figname=None, legend=True, figax=None, **kwargs):
    """Make a ROC and PR curve for each model, optionally with a SD. Compute an AUC score for each curve.

    Parameters
    ----------
    xaxis : list-like
        The FPR and Recall, i.e. the x-axis for both plots. All TPR and Precision lists should be
        interpolated/computed to reflect the values at each point on xaxis.
    tpr_list : list of lists, shape = [n_models, len(xaxis)]
        tpr_list[i] corresponds to the TPR values for model i along xaxis. If tpr_list[i] is a list, then do not plot a
        standard deviation of the TPR. If tpr_list[i] is a list of lists, then it represents the TPR of each fold
        from cross-validation, in which case it is used to compute the mean and std of the TPR.
    precision_list : list of lists, shape = [n_models, len(xaxis)]
        precision_list[i] corresponds to the precision values for model i along xaxis. If precision_list[i] is a list,
        then do not plot a standard deviation of the precision. If precision_list[i] is a list of lists,
        then it represents the precision of each fold from cross-validation, in which case it is used to compute the
        mean and std of the precision.
    model_names : list-like
        The name of each model.
    model_colors : list-like or None
        If not none, the color to use for each model.
    prc_chance : float or None
        If not none, plot a chance line for the PR curve at this value.
    prc_upper_ylim : float or None
        If specified, the upper ylim for the PR curve. Otherwise, use the uper ylim of the ROC curve.
    figname : str or None
        If specified, save the figure with prefix figname.
    legend : bool
        If specified, display a legend.
    figax : ([figure, figure], [axes, axes]) or None
        If specified, make the plot in the two provided axes. Otherwise, generate a new axes.
    kwargs : dict
        Additional parameters for saving a figure.

    Returns
    -------
    fig_list : The handle to both figures (one for the ROC and one for the PR).
    auroc_list : AUROC scores for each model
    auroc_std_list : 1SD of AUROC scores for each model, or None if not computed.
    aupr_list : AUPR scores for each model
    aupr_std_list : 1SD of AUPR scores for each model, or None if not computed.

    """
    if figax:
        fig_list, ax_list = figax
    else:
        fig_roc, ax_roc = plt.subplots()
        fig_pr, ax_pr = plt.subplots()
        fig_list = [fig_roc, fig_pr]
        ax_list = [ax_roc, ax_pr]

    # If no colors specified, evenly sample the colormap to color each model
    if model_colors is None:
        model_colors = np.linspace(0, 0.99, len(model_names))
        model_colors = set_color(model_colors)

    # ROC curves
    ax = ax_list[0]
    auroc_list, auroc_std_list = _plot_each_model(ax, xaxis, tpr_list, model_colors, model_names)

    # Chance line
    ax.plot(xaxis, xaxis, color="black", linestyle="--", zorder=1)
    ax.set_xlabel("False Positive Rate")
    ax.set_ylabel("True Positive Rate")
    ax.set_aspect("equal")
    if legend:
        ax.legend(loc="lower right", frameon=False)

    # ylim of ROC curve will help format PR curve
    lower_ylim, upper_ylim = ax.get_ylim()

    # PR curves
    ax = ax_list[1]
    aupr_list, aupr_std_list = _plot_each_model(ax, xaxis, precision_list, model_colors, model_names)

    # Optional chance line and formatting
    if prc_chance:
        ax.axhline(prc_chance, color="black", linestyle="--", zorder=1)
    if not prc_upper_ylim:
        prc_upper_ylim = upper_ylim
    ax.set_ylim(bottom=lower_ylim, top=prc_upper_ylim)
    ax.set_xlabel("Recall")
    ax.set_ylabel("Precision")
    ax.set_aspect("equal")
    if legend:
        ax.legend(frameon=False)

    if figname:
        save_fig(fig_list[0], figname + "Roc", **kwargs)
        save_fig(fig_list[1], figname + "Pr", **kwargs)

    return fig_list, auroc_list, auroc_std_list, aupr_list, aupr_std_list


def _plot_each_model(ax, xaxis, y_list, model_colors, model_names):
    """Helper function for roc_pr_curves to plot each model on an Axes object.

    """
    area_list = []
    area_std_list = []
    for y, color, name in zip(y_list, model_colors, model_names):
        y = np.array(y)
        area_std = None

        # If y is a list of lists (i.e. a matrix), then compute the std of the curve and AUC
        if len(y.shape) == 2:
            y_std = np.std(y, axis=0)
            # Compute std of AUC and format as a string
            area_std = np.std([auc(xaxis, i) for i in y])

            # Now compute the mean curve
            y = y.mean(axis=0)

            # The std can't go above 1 or below 0
            y_std_upper = np.min([y + y_std, np.ones(y.size)], axis=0)
            y_std_lower = np.max([y - y_std, np.zeros(y.size)], axis=0)

            # Plot the std of the curve
            ax.fill_between(xaxis, y_std_lower, y_std_upper, alpha=0.2, zorder=2, color=color)

        # Plot the curve and compute AUC
        area = auc(xaxis, y)
        ax.plot(xaxis, y, label=name, zorder=3, color=color)
        area_list.append(area)
        area_std_list.append(area_std)

    return area_list, area_std_list


def stacked_bar_plots(df, ax_name, group_names, value_colors, legend_upper_left=None, legend_title=None,
                      legend_cols=1, vert=False, plot_title=None, figname=None, figax=None, **kwargs):
    """Make stacked bar plots, one bar per row of the provided DataFame, and optionally show a legend.

    Parameters
    ----------
    df : pd.DataFrame
        Data to plot, rows are bar groups, columns are different values/colors
    ax_name : str
        Name of the axis for the plot
    group_names : list[str]
        Names of each group to display as ticks
    value_colors : list-like, length = len(df.columns)
        Color for each value of the df
    legend_upper_left : tuple(float, float)
        If specified, make a legend, with the upper left corner of the bounding box at these axes coordinates.
    legend_title : str
        If specified, title for the legend.
    legend_cols : int
        If specified, number of columns for the legend. Default is 1.
    vert : bool
        If False (default), make a horizontal bar plot. If True, make a vertical bar plot.
    plot_title : str
        If specified, title for the plot.
    figname : str
        If specified, save the figure to this filename.
    figax : (figure, axes) or None
        If specified, make the plot in the provided axes. Otherwise, generate a new axes.
    kwargs : for save_fig

    Returns
    -------
    fig : Figure handle
    """
    tick_values = np.arange(len(group_names))
    margin_edge = np.zeros(len(tick_values))
    if figax:
        fig, ax = figax
    else:
        fig, ax = plt.subplots()

    for (label, values), color in zip(df.items(), value_colors):
        if vert:
            ax.bar(tick_values, values, color=color, label=label, bottom=margin_edge, tick_label=group_names)
        else:
            ax.barh(tick_values, values, color=color, label=label, left=margin_edge, tick_label=group_names)

        # Advance the margin
        margin_edge += values

    # Set the max of axis
    if vert:
        ax.set_ylim(top=margin_edge.max())
    else:
        ax.set_xlim(right=margin_edge.max())

    # Add axis label
    if vert:
        ax.set_ylabel(ax_name)
    else:
        ax.set_xlabel(ax_name)

    # Add legend if specified
    if legend_upper_left:
        legend_args = {"ncol": legend_cols, "bbox_to_anchor": legend_upper_left, "loc": "upper left"}
        if legend_title:
            legend_args["title"] = legend_title
        ax.legend(**legend_args)

    if plot_title:
        ax.set_title(plot_title)

    if figname:
        save_fig(fig, figname, **kwargs)

    return fig


def annotate_heatmap(ax, df, thresh, adjust_lower_triangle=False):
    """Display numbers on top of a heatmap to make it easier to view for a reader. If adjust_lower_triangle is True,
    then the lower triangle of the heatmap will display values in parentheses. This should only happen if the heatmap
    is symmetric. Assumes that low values are displayed as a light color and high values are a dark color.

    Parameters
    ----------
    ax : Axes object
        The plot containing the heatmap on which annotations should be made
    df : pd.DataFrame
        The data underlying the heatmap.
    thresh : float
        Cutoff for switching from dark to light colors. Values above the threshold will be displayed as white text,
        those below as black text.
    adjust_lower_triangle : bool
        If True, the lower triangle values will be shown in parentheses.

    Returns
    -------
    None
    """
    for row in range(df.shape[0]):
        for col in range(df.shape[1]):
            value = df.iloc[row, col]
            if value > thresh:
                color = "white"
            else:
                color = "black"

            # Format the value as text
            value = f"{value:.2f}"
            # Add parentheses if desired and in the lower triangle and the heatmap is square
            if adjust_lower_triangle and row < col and df.shape[0] == df.shape[1]:
                value = "(" + value + ")"

            ax.text(row, col, value, ha="center", va="center", color=color)


# LEGACY FUNCTION
def gkmsvm_best_kmers(kmer_scores, positives, negatives, num_kmers=500, center_width=0):
    """Generate a plot to visualize the frequency and location of the best scoring k-mers. All positive and negative
    sequences are searched for the num_kmers most positively weighted k-mers and the num_kmers most negatively
    weighted k-mers. Then, count the number of times each k-mer occurs at each position of the sequences. Make a plot
    with the k-mer position on the x axis, k-mer weight on the y axis, the number of sequences with that k-mer at
    that position indicated by the size of the circle, and the classes of sequences specified by the color of the
    circle. Note that this function assumes all sequences are the same length.

    Parameters
    ----------
    kmer_scores : pd.Series
        The scores assigned to every k-mer
    positives : pd.Series
        All sequences belonging to the positives. If positives.name is specified, it is used in creating the figure
        legend.
    negatives : pd.Series
        All sequences belonging to the negatives. If negatives.name is specified, it is used in creating the figure
        legend.
    num_kmers : int
        The number of top scoring k-mers to analyze. Both the num_kmers most positive and num_kmers most negatives
        k-mers are analyzed, i.e. 2*num_kmers are analyzed.
    center_width : int
        If greater than 1, draw a vertical grey rectangle along the center of the x axis.

    Returns
    -------
    fig : The figure handle.

    """
    # Easy way to get two slices of the Series
    kmer_scores = kmer_scores.sort_values()
    best_kmers = kmer_scores[np.r_[:num_kmers, -num_kmers:0]]

    # Assumes all sequences are the same length. We want to plot the position of the k-mer *center*, not the start of
    #  the k-mer.
    seq_len = len(positives.iloc[0])
    kmer_len = len(kmer_scores.index[0])
    shift_factor = int((seq_len - kmer_len) / 2)
    positive_positions_df = _count_top_kmers(best_kmers, positives, shift_factor)
    negative_positions_df = _count_top_kmers(best_kmers, negatives, shift_factor)

    row_size, col_size = mpl.rcParams["figure.figsize"]
    marker_size = mpl.rcParams["lines.markersize"]
    if positives.name:
        positive_label = positives.name
    else:
        positive_label = "Positives"
    if negatives.name:
        negative_label = negatives.name
    else:
        negative_label = "Negatives"

    # Get the range for the positive and negative k-mers so that the subplots are sized appropriate
    smallest_pos = best_kmers[best_kmers > 0].min()
    smallest_neg = best_kmers[best_kmers < 0].max()
    pos_range = best_kmers[best_kmers > 0].max() - smallest_pos
    neg_range = smallest_neg - best_kmers[best_kmers < 0].min()

    fig, (ax_pos, ax_neg) = plt.subplots(nrows=2, ncols=1, figsize=(row_size * 1.5, col_size), gridspec_kw={
        "height_ratios": [pos_range, neg_range]})

    # First plot everything on both axes. Then, we will resize the axes so ax_pos only shows the positive k-mers and
    # ax_neg only shows the negative k-mers.
    points_neg = ax_neg.scatter(negative_positions_df["Position"], negative_positions_df["Weight"],
                                s=marker_size * negative_positions_df["Count"], alpha=0.25,
                                label=negative_label, color="blue", zorder=2)
    points_pos = ax_neg.scatter(positive_positions_df["Position"], positive_positions_df["Weight"],
                                s=marker_size * positive_positions_df["Count"], alpha=0.25,
                                label=positive_label, color="red", zorder=2)
    points_neg = ax_pos.scatter(negative_positions_df["Position"], negative_positions_df["Weight"],
                                s=marker_size * negative_positions_df["Count"], alpha=0.25,
                                label=negative_label, color="blue", zorder=2)
    points_pos = ax_pos.scatter(positive_positions_df["Position"], positive_positions_df["Weight"],
                                s=marker_size * positive_positions_df["Count"], alpha=0.25,
                                label=positive_label, color="red", zorder=2)

    # Make the rectangle if desired
    bottom, top = ax_pos.get_ylim()
    if center_width > 1:
        rect_start = int(-center_width / 2)
        rect = mpatches.Rectangle((rect_start, bottom), center_width, top - bottom, color="grey", alpha=0.35, zorder=1)
        ax_pos.add_patch(rect)
        rect = mpatches.Rectangle((rect_start, bottom), center_width, top - bottom, color="grey", alpha=0.35, zorder=1)
        ax_neg.add_patch(rect)

    # Trim the y axes
    extra_whitespace = 0.98
    ax_pos.set_ylim(smallest_pos * extra_whitespace, top)
    ax_neg.set_ylim(bottom, smallest_neg * extra_whitespace)

    # Hide the spines between the axes
    ax_neg.spines["top"].set_visible(False)
    ax_pos.spines["bottom"].set_visible(False)
    ax_pos.tick_params(axis="x", bottom=False, labelbottom=False)

    # Get unique values of marker sizes and select 5 to display in the legend
    uniq_point_sizes = np.unique(np.concatenate((points_pos.get_sizes(), points_neg.get_sizes())), axis=None)
    uniq_point_sizes = uniq_point_sizes[np.linspace(0, uniq_point_sizes.size - 1, 5).round().astype(int)]
    class_handles = [points_pos, points_neg]
    point_handles = []
    for i in uniq_point_sizes:
        point_size = np.sqrt(i)
        label = int(i / marker_size)
        handle = mlines.Line2D([], [], color="black", marker="o", linestyle="None", markersize=point_size,
                               label=f"{label}")
        point_handles.append(handle)

    # Axis labels
    ax_neg.set_xlabel("Center of k-mer Relative\nto Center of Sequence")
    fig.text(0.05, 0.5, "k-mer Weight", ha="center", va="center", rotation="vertical",
             fontsize=mpl.rcParams["axes.labelsize"])

    # Legends
    legend_font = mpl.rcParams["legend.fontsize"]
    legend = ax_pos.legend(loc="upper left", bbox_to_anchor=(1.0, 1.0), handles=class_handles, title="Sequence Class")
    plt.setp(legend.get_title(), fontsize=legend_font)
    legend = ax_neg.legend(loc="upper left", bbox_to_anchor=(1.0, 1.0), handles=point_handles, title="Number of "
                                                                                                     "Sequences")
    plt.setp(legend.get_title(), fontsize=legend_font)

    # Add hatch marks
    diag_size = 0.015
    diag_kwargs = dict(transform=ax_pos.transAxes, color="black", clip_on=False)
    # Top left
    ax_pos.plot((-diag_size, diag_size), (-diag_size, diag_size), **diag_kwargs)
    # Top right
    ax_pos.plot((1 - diag_size, 1 + diag_size), (-diag_size, diag_size), **diag_kwargs)

    diag_kwargs.update(transform=ax_neg.transAxes)
    diag_y_scaler = pos_range / neg_range
    # Bottom left
    ax_neg.plot((-diag_size, diag_size), (1 - diag_size * diag_y_scaler, 1 + diag_size * diag_y_scaler), **diag_kwargs)
    # Bottom right
    ax_neg.plot((1 - diag_size, 1 + diag_size), (1 - diag_size * diag_y_scaler, 1 + diag_size * diag_y_scaler), **diag_kwargs)

    fig.tight_layout(rect=(0.1, 0, 1, 1))
    return fig


def _count_top_kmers(best_kmers, sequences, shift_factor):
    """Helper function for gkmsvm_best_kmers to identify the best k-mers in a set of sequences and determine the
    number of times the k-mer occurs at each position of the sequences.

    Parameters
    ----------
    best_kmers : pd.Series
        The k-mers to analyze, index is the k-mer and value is the weight.
    sequences : pd.Series
        The sequences to use to search for the k-mers.
    shift_factor : int
        Value to subtract from k-mer position to get the position of the center of the k-mer relative to the center
        of the sequences.

    Returns
    -------
    position_counts_df : pd.DataFrame
        Each row contains the sequence position, k-mer weight, and number of sequences with that k-mer at that position

    """
    position_counts = []
    for kmer, weight in best_kmers.iteritems():
        kmer_pos = sequences.str.find(kmer)
        # Anything that has a position of -1 is missing, fill it with an nan.
        kmer_pos[kmer_pos == -1] = np.nan
        kmer_pos -= shift_factor
        kmer_pos = kmer_pos.value_counts()
        for position, counts in kmer_pos.iteritems():
            position_counts.append([position, weight, counts])

    position_counts_df = pd.DataFrame(position_counts, columns=["Position", "Weight", "Count"])
    return position_counts_df