Skip to content

plot

Matplotlib-Based Plotting Routines

This submodule contains methods related to visualize some of the standard Markov state plots.

plot_ck_test(ck, states=None, frames_per_unit=1, unit='frames', grid=(3, 3))

Plot CK-Test results.

This routine is a basic helper function to visualize the results of msmhelper.msm.chapman_kolmogorov_test.

Parameters:

  • ck (dict) –

    Dictionary holding for each lagtime the CK equation and with 'md' the reference.

  • states (ndarray, default: None ) –

    List containing all states to plot the CK-test.

  • frames_per_unit (float, default: 1 ) –

    Number of frames per given unit. This is used to scale the axis accordingly.

  • unit ([frames, fs, ps, ns, us], default: 'frames' ) –

    Unit to use for label.

  • grid ((int, int), default: (3, 3) ) –

    The number of (n_rows, n_cols) to use for the grid layout.

Returns:

  • fig ( Figure ) –

    Figure holding plots.

Source code in src/msmhelper/plot/_ck_test.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
def plot_ck_test(
    ck,
    states=None,
    frames_per_unit=1,
    unit='frames',
    grid=(3, 3),
):
    """Plot CK-Test results.

    This routine is a basic helper function to visualize the results of
    [msmhelper.msm.chapman_kolmogorov_test][].

    Parameters
    ----------
    ck : dict
        Dictionary holding for each lagtime the CK equation and with 'md' the
        reference.
    states : ndarray, optional
        List containing all states to plot the CK-test.
    frames_per_unit : float, optional
        Number of frames per given unit. This is used to scale the axis
        accordingly.
    unit : ['frames', 'fs', 'ps', 'ns', 'us'], optional
        Unit to use for label.
    grid : (int, int), optional
        The number of `(n_rows, n_cols)` to use for the grid layout.

    Returns
    -------
    fig : matplotlib.Figure
        Figure holding plots.

    """
    # load colors
    pplt.load_cmaps()
    pplt.load_colors()

    lagtimes = np.array([key for key in ck.keys() if key != 'md'])
    if states is None:
        states = np.array(
            list(ck['md']['ck'].keys())
        )

    nrows, ncols = grid
    needed_rows = int(np.ceil(len(states) / ncols))

    fig, axs = plt.subplots(
        needed_rows,
        ncols,
        sharex=True,
        sharey='row',
        gridspec_kw={'wspace': 0, 'hspace': 0},
    )
    axs = np.atleast_2d(axs)

    max_time = np.max(ck['md']['time'])
    for irow, states_row in enumerate(_split_array(states, ncols)):
        for icol, state in enumerate(states_row):
            ax = axs[irow, icol]

            pplt.plot(
                ck['md']['time'] / frames_per_unit,
                ck['md']['ck'][state],
                '--',
                ax=ax,
                color='pplt:gray',
                label='MD',
            )
            for lagtime in lagtimes:
                pplt.plot(
                    ck[lagtime]['time'] / frames_per_unit,
                    ck[lagtime]['ck'][state],
                    ax=ax,
                    label=lagtime / frames_per_unit,
                )
            pplt.text(
                0.5,
                0.9,
                'S{0}'.format(state),
                contour=True,
                va='top',
                transform=ax.transAxes,
                ax=ax,
            )

            # set scale
            ax.set_xscale('log')
            ax.set_xlim([
                lagtimes[0] / frames_per_unit,
                max_time / frames_per_unit,
            ])
            ax.set_ylim([0, 1])
            if irow < len(axs) - 1:
                ax.set_yticks([0.5, 1])
            else:
                ax.set_yticks([0, 0.5, 1])

            ax.grid(True, which='major', linestyle='--')
            ax.grid(True, which='minor', linestyle='dotted')
            ax.set_axisbelow(True)

    # set legend
    legend_kw = {
        'outside': 'right',
        'bbox_to_anchor': (2.0, (1 - nrows), 0.2, nrows),
    } if ncols in {1, 2} else {
        'outside': 'top',
        'bbox_to_anchor': (0.0, 1.0, ncols, 0.01),
    }
    if ncols == 3:
        legend_kw['ncol'] = 3
    pplt.legend(
        ax=axs[0, 0],
        **legend_kw,
        title=fr'$\tau_\mathrm{{lag}}$ [{unit}]',
        frameon=False,
    )

    ylabel = (
        r'self-transition probability $P_{i\to i}$'
    ) if nrows >= 3 else (
        r'$P_{i\to i}$'
    )

    pplt.hide_empty_axes()
    pplt.label_outer()
    pplt.subplot_labels(
        ylabel=ylabel,
        xlabel=r'time $t$ [{unit}]'.format(unit=unit),
    )
    return fig

plot_wtd(wtd, frames_per_unit=1, unit='frames', ax=None, show_md=True, show_fliers=False)

Plot waiting time distribution.

This is a wrapper function to plot the return value of msmhelper.msm.estimate_waiting_time_dist.

Parameters:

  • wtd (dict) –

    Dictionary returned from msmhelper.msm.estimate_wtd, holding stats of waiting time distributions.

  • frames_per_unit (float, default: 1 ) –

    Number of frames per given unit. This is used to scale the axis accordingly.

  • unit ([frames, fs, ps, ns, us], default: 'frames' ) –

    Unit to use for label.

  • ax (Axes, default: None ) –

    Axes to plot figure in. With None the current axes is used.

  • show_md (bool, default: True ) –

    Include boxplot of MD data.

  • show_fliers (bool, default: False ) –

    Show fliers (outliers) in MD and MSM prediction.

Returns:

  • ax ( Axes ) –

    Return axes holding the plot.

Source code in src/msmhelper/plot/_wtd.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
def plot_wtd(
    wtd,
    frames_per_unit=1,
    unit='frames',
    ax=None,
    show_md=True,
    show_fliers=False,
):
    """Plot waiting time distribution.

    This is a wrapper function to plot the return value of
    [msmhelper.msm.estimate_waiting_time_dist][].

    Parameters
    ----------
    wtd : dict
        Dictionary returned from `msmhelper.msm.estimate_wtd`, holding stats of
        waiting time distributions.
    frames_per_unit : float, optional
        Number of frames per given unit. This is used to scale the axis
        accordingly.
    unit : ['frames', 'fs', 'ps', 'ns', 'us'], optional
        Unit to use for label.
    ax : matplotlib.Axes, optional
        Axes to plot figure in. With `None` the current axes is used.
    show_md : bool, optional
        Include boxplot of MD data.
    show_fliers : bool, optional
        Show fliers (outliers) in MD and MSM prediction.

    Returns
    -------
    ax : matplotlib.Axes
        Return axes holding the plot.

    """
    if ax is None:
        ax = plt.gca()

    lagtimes = np.array(
        [time for time in wtd.keys() if time != 'MD'], dtype=int,
    )
    max_lagtime = lagtimes.max()

    # convert stats to array
    LB, UB, Q1, Q2, Q3 = np.array([
        np.array([
            wtd[lagtime][key] for lagtime in lagtimes
        ]) / frames_per_unit
        for key in ['whislo', 'whishi', 'q1', 'med', 'q3']
    ])
    FL = np.array([
        min(
            np.min(wtd[lagtime]['fliers']),
            wtd[lagtime]['whislo'],
        ) for lagtime in lagtimes
    ]) / frames_per_unit
    FU = np.array([
        max(
            np.max(wtd[lagtime]['fliers']),
            wtd[lagtime]['whishi'],
        ) for lagtime in lagtimes
    ]) / frames_per_unit

    # plot results
    colors = pplt.categorical_color(4, 'C0')
    lagtimes = lagtimes / frames_per_unit
    if show_fliers:
        ax.fill_between(
            lagtimes, FL, FU, color=colors[3], label=r'fliers',
        )
    ax.fill_between(
        lagtimes,
        LB,
        UB,
        color=colors[2],
        label=r'$Q_{1/3}\pm1.5\mathrm{IQR}$',
    )
    ax.fill_between(lagtimes, Q1, Q3, color=colors[1], label='IQR')
    ax.plot(lagtimes, Q2, color=colors[0], label='$Q_2$')

    max_lagtime_unit = max_lagtime / frames_per_unit
    if show_md:
        bxp = ax.bxp(
            [{
                key: time / frames_per_unit
                for key, time in wtd['MD'][0].items()
            }],
            positions=[max_lagtime_unit * 1.125],
            widths=max_lagtime_unit * 0.075,
            showfliers=show_fliers,
        )
        for median in bxp['medians']:
            median.set_color('k')

        ax.axvline(
            max_lagtime_unit,
            0,
            1,
            lw=plt.rcParams['axes.linewidth'],
            color='pplt:axes',
        )

    if show_md:
        ax.set_xlim([0, max_lagtime_unit * 1.25])
        xticks = np.array([
            *np.linspace(0, max_lagtime_unit, 4).astype(int),
            max_lagtime_unit * 1.125,
        ])
        xticklabels = [
            f'{xtick:.0f}' if idx + 1 < len(xticks) else 'MD'
            for idx, xtick in enumerate(xticks)
        ]
        ax.set_xticks(xticks)
        ax.set_xticklabels(xticklabels)
    else:
        ax.set_xlim([0, max_lagtime_unit])

    # use scientific notation for large values
    ax.ticklabel_format(
        axis='y', style='scientific', scilimits=[0, 2], useMathText=True,
    )
    ax.get_yaxis().get_offset_text().set_ha('right')

    # set legend and labels
    pplt.legend(ax=ax, outside='top', frameon=False)
    ax.set_ylabel(f'time $t$ [{unit}]')
    ax.set_xlabel(fr'$\tau_\mathrm{{lag}}$ [{unit}]')

    return ax