# -*- coding: utf-8 -*-
from matplotlib.pyplot import subplots
import mpl_toolkits.mplot3d
from numpy import array
[docs]def init_fig(fig=None, ax=None, shape="default", is_3d=False):
    """Get all the handle and legend of a figure or initialize them
    (for matplotlib)
    Parameters
    ----------
    fig : Matplotlib.figure.Figure
        The figure to get the handle from (can be None)
    ax : Matplotlib.axes.Axes object
        Axis on which to plot the data
    shape : str
        Shape of the figure: "default", "square" or "rectangle" for 20x10 figure
    is_3d : bool
        3D or 2D figure
    Returns
    -------
    (fig,axes,patch_leg,label_leg): Matplotlib.figure.Figure, matplotlib.axes._subplots.AxesSubplot, patch, string
        Figure handle, Axes Handle, List of legend patches, List of legend label
    """
    if fig is None:
        # Create a new figure with empty legend
        if shape == "rectangle":
            if is_3d:
                fig, axes = subplots(
                    tight_layout=True,
                    figsize=(8, 4),
                    subplot_kw=dict(projection="3d"),
                )
            else:
                fig, axes = subplots(tight_layout=True, figsize=(8, 4))
        elif shape == "square":
            if is_3d:
                fig, axes = subplots(
                    tight_layout=True,
                    figsize=(8, 8),
                    subplot_kw=dict(projection="3d"),
                )
            else:
                fig, axes = subplots(tight_layout=True, figsize=(8, 8))
        else:
            if is_3d:
                fig, axes = subplots(
                    tight_layout=True, subplot_kw=dict(projection="3d")
                )
            else:
                fig, axes = subplots(tight_layout=True)
        patch_leg, label_leg = [], []
    else:
        if ax is None:
            axes = fig.axes[0]
        else:
            axes = ax
        if axes.legend_ is None:
            # Empty legend
            patch_leg, label_leg = [], []
        else:
            # Get the symbol and label of all legend entry
            patch_leg = axes.legend_.get_patches()
            label_leg = [t.get_text() for t in axes.legend_.get_texts()]
    return (fig, axes, patch_leg, label_leg) 
[docs]def init_subplot(fig=None, subplot_index=None, is_3d=False):
    """Initialize subplot (given position or automatic stacking)
    Parameters
    ----------
    fig : Matplotlib.figure.Figure
        The figure to get the handle from (can be None)
    subplot_index : int
        Index of the subplot, or None
    is_3d : bool
        3D or 2D figure
    Returns
    -------
    (fig,ax): Matplotlib.figure.Figure, matplotlib.axes._subplots.AxesSubplot
        Figure handle, Axes Handle
    """
    is_newfig = False
    if fig is None:
        is_newfig = True
    is_autostack = False
    if subplot_index is None:
        is_autostack = True
    (fig, axes, patch_leg, label_leg) = init_fig(fig, shape="rectangle", is_3d=is_3d)
    if not is_newfig and is_autostack:
        n = len(fig.axes)
        for i in range(n):
            fig.axes[i].change_geometry(n, 1, i)
        if is_3d:
            ax = fig.add_subplot(n, 1, n, projection="3d")
        else:
            ax = fig.add_subplot(n, 1, n)
    else:
        if subplot_index is None:
            subplot_index = 0
        ax = fig.axes[subplot_index]
    return fig, ax