"""
Author: Ryan Friedman (@rfriedman22)
Email: ryan.friedman@wustl.edu
"""

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy import stats

import modeling
import plot_utils


def read_bc_count_files(files, labels):
    """Read in barcode count files and store the result in a DataFrame, where rows are barcodes and columns are
    different replicates/experiments/conditions.

    Parameters
    ----------
    files : list-like
        File names, each of which contains barcode counts from different replicates/experiments/conditions.
    labels : list-like
        Name to assign to each file, becomes the columns of the DataFrame.

    Returns
    -------
    bc_counts_df : pd.DataFrame
        Raw barcode counts across multiple conditions.

    """
    # Use the first file to get the barcodes and the sequence label ID
    bc_counts_df = pd.read_csv(files[0], sep="\t", index_col=0, usecols=["barcode", "label"])

    # Then for each file, read in the barcodes and counts, then join with the label ID
    for file, name in zip(files, labels):
        counts = pd.read_csv(file, sep="\t", index_col=0, usecols=["barcode", "count"], squeeze=True)
        bc_counts_df[name] = counts

    return bc_counts_df


def per_million(df):
    return df.sum() / 10**6


def assess_balance_coverage(df, sample_labels):
    """Compute the total number of barcode counts, plot the coverage per barcode in each sample, and plot the
    percentage of barcode counts represented by each sample.

    Parameters
    ----------
    df : pd.DataFrame
        Barcode counts in multiple conditions/samples. Rows are barcodes, columns are the samples.
    sample_labels : list-lie
        Column names to use in assessing balance and coverage

    Returns
    -------
    Nothing

    """
    total_counts_df = df[sample_labels].sum()
    n_barcodes = df.index.size
    print(f"There are a total of {per_million(total_counts_df): .3f} million barcode counts.")

    # Plot the coverage
    xaxis = np.arange(len(sample_labels))
    fig, ax = plt.subplots()
    ax.bar(xaxis, total_counts_df / n_barcodes)
    ax.set_ylabel("Coverage per barcode")
    ax.set_xticks(xaxis)
    ax.set_xticklabels(sample_labels)

    # Plot the balance
    fig, ax = plt.subplots()
    ax.bar(xaxis, total_counts_df / total_counts_df.sum() * 100)
    ax.set_ylabel("Pct. total counts")
    ax.set_xticks(xaxis)
    ax.set_xticklabels(sample_labels)


def filter_low_counts(df, sample_labels, cutoff_values, dna_labels=None, bc_per_seq=3, n_rna_samples_missing=0,
                      n_dna_samples_missing=0, cpm_normalize=True):
    """Filter out barcodes that are less than the provided thresholds for each sample. If the barcode is missing from
    DNA, then replace it with an NaN in all samples. If it is only missing from RNA, then replace it with a zero in
    all samples. Then, normalize by counts per million for each sample.

    Parameters
    ----------
    df : pd.DataFrame
        Raw barcode counts.
    sample_labels : list-like
        Labels of the df indiciating which columns contain barcode counts.
    cutoff_values : list-like
        Cutoff threshold for each sample. Barcodes with counts less than the threshold in any sample are removed.
    dna_labels : list-like
        If specified, then print status reports of how many barcodes and sequences are lost in DNA samples.
    bc_per_seq : int
        The number of barcodes per sequence, used to determine how many sequences are lost due to DNA drop-out.
    n_rna_samples_missing : int
        Number of samples that can be missing a barcode before it is replaced with an NaN in all samples. E.g. if
        n_samples_missing is 0, the barcode must pass the cutoff in every replicate to be maintained. If it is 1,
        then the barcode must pass the cutoff in all but one sample to be maintained. If n_samples_missing > 0,
        then barcodes are only replaced with NaN in the replicates where they are below the cutoff, but if that
        barcode is missing from at least n_samples_missing samples, it is replaced with an NaN in all samples.
    n_dna_samples_missing : int
        Same as n_rna_samples_missing but for DNA.
    cpm_normalize : bool
        If True, normalize each sample to CPM after thresholding. Otherwise, simply return the thresholded counts.

    Returns
    -------
    filtered_df : pd.DataFrame
        The filtered barcode counts, normalized to counts per million.
    """
    # Store the cutoffs as a series, labels as index and cutoffs as values
    df = df.copy()
    cutoffs = pd.Series({i: j for i, j in zip(sample_labels, cutoff_values)})

    # If there are DNA samples, first filter any barcodes missing from the DNA.
    if dna_labels is not None:
        dna_mask = df.apply(
            lambda x: pd.Series({i: x[i] < j for i, j in cutoffs[dna_labels].iteritems()}),
            axis=1
        )
        print("Barcodes missing in DNA:")
        for label in dna_labels:
            print(f"Sample {label}: {dna_mask[label].sum()} barcodes")
            # Replace barcodes below their cutoff with an NaN
            df.loc[dna_mask[label], label] = np.nan

        # Drop any barcodes missing from too many DNA samples
        dna_missing_mask = dna_mask.sum(axis=1) > n_dna_samples_missing
        print(f"{dna_missing_mask.sum()} barcodes are missing from more than {n_dna_samples_missing} DNA samples.")
        df.loc[dna_missing_mask, dna_labels] = np.nan

        rna_labels = sample_labels[~pd.Series(sample_labels).isin(dna_labels)]

    else:
        rna_labels = sample_labels

    # Filter barcodes missing from RNA
    rna_mask = df.apply(
        # For each row, determine if the column is above the cutoff
        lambda x: pd.Series({i: x[i] < j for i, j in cutoffs[rna_labels].iteritems()}),
        axis=1
    )
    print("Barcodes off in RNA:")
    for label in rna_labels:
        print(f"Sample {label}: {rna_mask[label].sum()} barcodes")

    # Drop barcodes missing from too many RNA samples
    rna_missing_mask = rna_mask.sum(axis=1) > n_rna_samples_missing
    print(f"{rna_missing_mask.sum()} barcodes are off in more than {n_rna_samples_missing} RNA samples.")
    df.loc[rna_missing_mask, rna_labels] = 0

    # Normalize the data to CPM if desired
    if cpm_normalize:
        assess_balance_coverage(df, sample_labels)
        df[sample_labels] /= per_million(df[sample_labels])

    return df


def reproducibility_plots(df, labels, unit, big_dimensions=False):
    """Make a multiplot of scatters comparing pairs of replicates to assess reproducibility.

    Parameters
    ----------
    df : pd.DataFrame
        Rows are samples, columns are different replicates.
    labels : list-like
        Column names representing replicates.
    unit : str
        Unit of the counts (raw, cpm, RNA/DNA, etc.)
    big_dimensions : bool
        Indicates whether or not the reproducibility plot should use big dimensions

    Returns
    -------
    fig : figure handle
    """
    text_size = mpl.rcParams["axes.labelsize"]

    n_samples = len(labels)
    fig, ax_list = plot_utils.setup_multiplot(n_samples**2, n_cols=n_samples, sharex=True, sharey=True,
                                              big_dimensions=big_dimensions)
    for i in range(ax_list.size):
        row, col = np.unravel_index(i, ax_list.shape)
        ax = ax_list[row, col]

        # If on the diagonal, display the name of the replicate
        if row == col:
            ax.text(0.5, 0.5, f"{labels[row]} {unit}", fontsize=text_size,
                    transform=ax.transAxes, ha="center", va="center")
        else:
            x = df[labels[col]]
            y = df[labels[row]]

            # Display the scatter plots above the diagonal
            if row < col:
                ax.scatter(x, y, c="black")
            # Display the R**2 below the diagonal
            else:
                pcc = x.corr(y)
                ax.text(0.5, 0.5, f"$R^2 = ${pcc**2: 0.3f}", fontsize=text_size,
                        transform=ax.transAxes, ha="center", va="center")

    return fig


def normalize_rna_by_dna(df, rna_labels, dna_labels, average_dna=False):
    """Normalize RNA counts by DNA counts and return the result.

    Parameters
    ----------
    df : pd.DataFrame
        Contains columns for both RNA and DNA replicates, and potentially other information too.
    rna_labels : list-like
        Column names for RNA labels.
    dna_labels : list-like
        Column names for DNA labels. If there are multiple, assumes that ordering is matched with RNA labels.
    average_dna : bool
        If True, take the mean barcode count across DNA replicates before normalizing RNA counts. If False,
        then normalize RNA counts by its paired DNA sample.

    Returns
    -------
    norm_df : pd.DataFrame
        The original df, but RNA columns normalized to their DNA counterparts.
    """
    norm_df = df.copy()
    dna_labels = _check_rna_dna_labels(rna_labels, dna_labels)

    if average_dna:
        dna_counts = norm_df[dna_labels].mean(axis=1)
        for rna in rna_labels:
            norm_df[rna] /= dna_counts
    else:
        for dna, rna in zip(dna_labels, rna_labels):
            norm_df[rna] /= norm_df[dna]

    return norm_df


def _check_rna_dna_labels(rna, dna):
    """
    Make sure there is only one DNA label or the same number of DNA labels as RNA labels. Both arguments must be
    list-like. Returns an ndarray of DNA labels for normalization.
    """
    if len(dna) == 1:
        dna = np.full(len(rna), dna[0])

    if len(dna) != len(rna):
        raise Exception("The number of DNA labels is neither 1 nor the number of RNA labels. Cannot normalize data.")

    return dna


def bc_counts_vs_dna_plot(df, rna_labels, dna_labels, n_cols=2, big_dimensions=False, y_label="RNA/DNA Counts"):
    """Plot RNA/DNA counts vs. DNA counts to make sure RNA/DNA indicates activity and is not an artifact of DNA counts.

    Parameters
    ----------
    df : pd.DataFrame
        Contains columns for both RNA and DNA replicates, and potentially other information too.
    rna_labels : list-like
        Column names for RNA labels.
    dna_labels : list-like
        Column names for DNA labels. If there are multiple, assumes that ordering is matched with RNA labels.
    n_cols : int
        Number of columns for the multiplot.
    big_dimensions : bool
        Indicates whether or not the reproducibility plot should use big dimensions
    y_label : str
        Label for the y axis

    Returns
    -------
    fig : figure handle
    """
    dna_labels = _check_rna_dna_labels(rna_labels, dna_labels)
    fig, ax_list = plot_utils.setup_multiplot(len(rna_labels), n_cols=n_cols, big_dimensions=big_dimensions)
    text_size = mpl.rcParams["axes.labelsize"]

    # For each RNA sample, plot RNA/DNA on the y and DNA on the x
    for i, (dna, rna) in enumerate(zip(dna_labels, rna_labels)):
        ax = ax_list[np.unravel_index(i, ax_list.shape)]
        x = df[dna]
        y = df[rna]
        pcc = x.corr(y)**2

        ax.scatter(x, y, color="k")
        ax.text(0.95, 0.95, f"$R^2 = ${pcc: 0.3f}", transform=ax.transAxes, fontsize=text_size, ha="right", va="center")
        ax.set_xlabel("DNA CPM")
        ax.set_ylabel(y_label)
        ax.set_title(rna)

    return fig


def _make_basal_mask(df, key):
    return df["label"].str.contains(key, case=False)


def average_barcodes(df, sequence_label="label", out_prefix=None):
    """Average RNA/DNA barcode counts for each sequence within replicates.

    Parameters
    ----------
    df : pd.DataFrame
        Index is the barcode, one column must have the key sequence_label, the rest are assumed to be RNA/DNA counts
        for each replicate.
    sequence_label : str
        Name of the column in df containing the sequence IDs
    out_prefix : str
        If specified, save the df to file with this prefix.

    Returns
    -------
    expression_df : pd.DataFrame
        Average RNA/DNA counts for each sequence in each replicate. Index is now the sequence label, column is the
        replicate.
    """
    expression_df = df.groupby(sequence_label).mean()
    if out_prefix:
        expression_df.to_csv(f"{out_prefix}AverageExpressionPerReplicate.txt", sep="\t", na_rep="NaN")

    return expression_df


def basal_normalize(df, basal_key):
    """Normalize activity levels to basal in the replicate, then take averages and stddev across replicates.

    Parameters
    ----------
    df : pd.DataFrame
        Index is sequence ID, columns are average RNA/DNA barcode counts for each replicate.
    basal_key : str
        Index value for basal.

    Returns
    -------
    activity_df : pd.DataFrame
        Index is sequence ID, columns are the basal-normalized mean, stddev, and number of replicates the sequence was
        measured in.
    """
    activity_df = df / df.loc[basal_key]
    # Drop basal
    activity_df = activity_df.drop(index=basal_key)

    # Convert basal-normalized levels to summary statistics
    activity_df = activity_df.apply(lambda x: pd.Series({"expression": x.mean(), "expression_std": x.std(),
                                                         "expression_reps": x.count()}),
                                    axis=1)
    return activity_df


def log_ttest_vs_basal(df, basal_key):
    """Do t-tests in log space to see if sequences has the same activity as basal.

    Parameters
    ----------
    df : pd.DataFrame
        Index is sequence ID, columns are average RNA/DNA barcode counts for each replicate.
    basal_key : str
        Index value for basal.

    Returns
    -------
    pvals : pd.Series
        p-value for t-test of the null hypothesis that the log activity of a sequence is the same as that of basal.
        Does not include a p-value for basal.
    """
    log_params = df.apply(_get_lognormal_params, axis=1)

    # Pull out basal params
    basal_mean, basal_std, basal_n = log_params.loc[basal_key]

    # Drop basal from the df
    log_params = log_params.drop(index=basal_key)

    # Do t-tests on each row
    pvals = log_params.apply(lambda x: stats.ttest_ind_from_stats(basal_mean, basal_std, basal_n,
                                                                  x["mean"], x["std"], x["n"],
                                                                  equal_var=False)[1],
                             axis=1)
    return pvals


def _get_lognormal_params(row):
    """Helper function to get parameters of lognormal distribution from linear data.

    Parameters
    ----------
    row : pd.Series
        Row of a df corresponding to barcode averages in each replicate.

    Returns
    -------
    params : pd.Series
        mu and sigma for the lognormal distribution, and the number of replicates the sequence was measured in.
    """
    mean = row.mean()
    std = row.std()
    cov = std / mean

    # Rely on the fact that the mean is exp(mu + 1/2 sigma**2) and the variance is mean**2 * (exp(sigma**2) - 1)
    log_mean = np.log(mean / np.sqrt(cov**2 + 1))
    log_std = np.sqrt(np.log(cov**2 + 1))
    params = pd.Series({
        "mean": log_mean,
        "std": log_std,
        "n": row.count()
    })

    return params


def compare_to_normal(data, xlabel, bins=10, show_mean=True, show_median=True, figname=None):
    """Given some data, Z-score it and perform a KS test against the normal distribution. Then plot a histogram of
    the data.

    Parameters
    ----------
    data : pd.Series or pd.DataFrame with one column
        Values of the data of interest.
    xlabel : str
        Label for the x-axis.
    bins : int
        Number of bins for the histogram.
    show_mean : bool
        If True, plot the mean on the histogram.
    show_median : bool
        If True, plot the median on the histogram.
    figname : str
        If specified, save the figure using this name.

    Returns
    -------
    fig : Figure handle
    pval : float
        The p-value from a KS test of the Z-scored data, where the null hypothesis is that the data is normally
        distributed.
    """
    zscores = stats.zscore(data)

    # Are the sequences normally distributed?
    _, pval = stats.kstest(zscores, cdf="norm")

    # Make the histogram
    fig, ax = plt.subplots()
    ax.hist(data, bins=bins, density=True, label=None)
    ax.set_xlabel(xlabel)
    ax.set_ylabel("Frequency")

    # Put the p-value in the title
    ax.set_title(f"KS test to normal p={pval : 1.2e}")

    if show_mean:
        mean = data.mean()
        ax.axvline(mean, color="k", label=f"mean={mean: .2f}")

    if show_median:
        median = data.median()
        ax.axvline(median, color="k", linestyle="--", label=f"median={median: .2f}")

    ax.legend()

    if figname:
        plot_utils.save_fig(fig, figname)

    return fig, pval


def compare_two_sets(data1, data2, xlabel, label1, label2, bins=10, show_means=True, show_medians=True,
                     figname=None):
    """Perform a 2-sample KS test with the null hypothesis that the two datasets come from the same distribution.
    Then, plot a histogram with the two datasets.

    Parameters
    ----------
    data1 : pd.Series or pd.DataFrame with one column
        Values for the first dataset.
    data2 : pd.Series or pd.DataFrame with one column
        Values for the second dataset.
    xlabel : str
        Label for the x-axis of the histogram.
    label1 : str
        Name of the first dataset.
    label2 : str
        Name of the second dataset.
    bins : int
        Number of bins for the histogram.
    show_means : bool
        If True, put the means of the two datasets on the histogram.
    show_medians : bool
        If True, put the medians of the two datasets on the histogram.
    figname : str
        If specified, save the figure with this name.

    Returns
    -------
    fig : Figure handle
    pval : float
        The p-value from a KS test of the Z-scored data, where the null hypothesis is that the data is normally
        distributed.
    """
    # Perform the 2-sample KS test
    _, pval = stats.ks_2samp(data1, data2)

    fig, ax = plt.subplots()
    ax.hist(data1, bins=bins, label=label1, alpha=0.5, density=True, color=plot_utils.set_color(0.25))
    ax.hist(data2, bins=bins, label=label2, alpha=0.5, density=True, color=plot_utils.set_color(0.75))
    ax.set_xlabel(xlabel)
    ax.set_ylabel("Frequency")
    ax.set_title(f"2-sample KS test p={pval: 1.2e}")

    # Display the means if desired
    if show_means:
        ax.axvline(data1.mean(), color="k", label=label1 + " mean")
        ax.axvline(data2.mean(), color="k", linestyle="--", label=label2 + " mean")

    # Display medians if desired
    if show_medians:
        ax.axvline(data1.median(), color="k", linestyle=":", label=label1 + " median")
        ax.axvline(data2.median(), color="k", linestyle="-.", label=label2 + " median")

    ax.legend()
    if figname:
        plot_utils.save_fig(fig, figname)

    return fig, pval