Matplotlib: Create 2D bar graphs in a 3D axes plot

This is a first attempt to plot 3D scatter plot and their histograms along the x, y and z directions independently in different axes. The plots obtained using matplotlib module using Python.

For the x and y directions, the histograms work fine but along the z-direction it doesn’t work the way I wanted. However, this is a good attempt to achieve this (at least partially).

Here is the code to achieve the above plots: They are not cleaned at this point – refactoring can improve the quality of the code considerably.

Keywords: “3d axes scatter plot with 2d plots”, “3d scatter plot with 2d histograms”, “matplotlib 3d scatter and 2d histogram”, “matplotlib python 3d axes 2d plots”, “3d and 2d plots together matplotlib”, plotting histograms on 3D axes with Python, matplotlib 3d plot and tied 2d plot as another axis, matplotlib overlay 2d plot in a 3d axis, matplotlib plot bar 2d on 3d axes, matplotlib 2d bar plot on 3d z axes. How to plot a 3D looking barchart using Matplotlib in a 2D environment? 3d scatter plot with histograms in 2D dimensions?

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from scipy.stats import qmc
from dataclasses import dataclass


def auto_close_matplotlib_plots():

    # if not args.noplots:
    # To keep them open for multiple run call the run via &, e.g. `python pythonraytrace/websplitting.py &`
    plt.show(block=False)
    # print("Here")
    plt.pause(0.001)
    input("\nHit [enter] to close matplotlib plots")

    return

def save_snapshots_to_file(
    save_figs: bool = True, filename: str = "snapshotfile"
) -> None:
    """
    Save the snapshots in time to file
    Note that for (16,9) sized figure, dpi=120 gives (16,9)*120 =[1920,1080] pixels png file
    Similarly, dpi=240 gives (16,9)*240 =[3840,2160] pixels png file (4K)
    """

    if save_figs:
        # n = ARGS.npoints - ARGS.npoints%100
        # OutFolder = 'Snapshots_{n}'
        # if not os.path.exists(OutFolder):
        #     os.mkdir(OutFolder)

        # outpngfilename = f"{OutFolder}/{filename}_{ARGS.npoints}_ARGS.markersize}.png"
        # plt.savefig(outpngfilename, facecolor='w', dpi=240)
        plt.savefig(filename, facecolor="w", dpi=240)

    else:
        plt.show()

    return


def plot_x_and_y_place_surfaces(Eox, Eoy, Lz3, ax, alpha=0.1):

    z = np.linspace(0, Lz3, 100)
    y = np.linspace(0, Eoy, 100)
    Z, Y = np.meshgrid(y, z)
    X = np.zeros(Z.shape)
    # ax.plot_surface(X, Y, Z, alpha=0.1, color="green")
    ax.plot_surface(X, Y, Z, alpha=alpha, color="black")

    x = np.linspace(0, Eox, 100)
    y = np.linspace(0, Lz3, 100)
    X, Y = np.meshgrid(x, y)
    Z = np.zeros(X.shape)
    # ax.plot_surface(X, Y, Z, alpha=0.1, color="red")
    ax.plot_surface(X, Y, Z, alpha=alpha, color="black")


def plot_rectangle_border_lines(x0, x1, y0, y1, z0, z1):
    plt.plot([x0, x1], [y0, y0], [z0, z0], "-k", linewidth=0.5)
    plt.plot([x0, x1], [y1, y1], [z0, z0], "-k", linewidth=0.5)
    plt.plot([x0, x1], [y1, y1], [z1, z1], "-k", linewidth=0.5)
    plt.plot([x0, x1], [y0, y0], [z1, z1], "-k", linewidth=0.5)

    plt.plot([x0, x0], [y0, y1], [z0, z0], "-k", linewidth=0.5)
    plt.plot([x0, x0], [y0, y1], [z1, z1], "-k", linewidth=0.5)
    plt.plot([x1, x1], [y0, y1], [z1, z1], "-k", linewidth=0.5)
    plt.plot([x1, x1], [y0, y1], [z0, z0], "-k", linewidth=0.5)

    plt.plot([x0, x0], [y0, y0], [z0, z1], "-k", linewidth=0.5)
    plt.plot([x0, x0], [y1, y1], [z0, z1], "-k", linewidth=0.5)
    plt.plot([x1, x1], [y1, y1], [z0, z1], "-k", linewidth=0.5)
    plt.plot([x1, x1], [y0, y0], [z0, z1], "-k", linewidth=0.5)


def plot_lines_to_histograms(shift, x0, x1, y0, y1, z0,show_hist: bool = False, z_hist: bool = False):

    if not show_hist:
        shift = 0

    plt.plot([x0-shift, x0], [y0, y0], [z0, z0], "--k", linewidth=0.5)
    plt.plot([x0-shift, x0], [y1, y1], [z0, z0], "--k", linewidth=0.5)

    plt.plot([x0, x0], [y1, y1 + shift], [z0, z0], "--k", linewidth=0.5)
    plt.plot([x1, x1], [y1, y1 + shift], [z0, z0], "--k", linewidth=0.5)

    if z_hist:
        plt.plot([x0, x0], [y0, y0], [z0 - shift, z0], "--k", linewidth=0.5)
        plt.plot([x1, x1], [y0, y0], [z0 - shift, z0], "--k", linewidth=0.5)

    # Plot lines under histograms
    plt.plot([x0-shift, x0-shift], [y0, y1], [z0, z0], "--k", linewidth=0.5)
    plt.plot([x0, x1], [y1 + shift, y1 + shift], [z0, z0], "--k", linewidth=0.5)

    if z_hist:
        plt.plot([x0, x1], [y0, y0], [z0 - shift, z0 - shift], "--k", linewidth=0.5)


def generate_random_points(Np: int, option: str = "uniform", mu: float = 0.5, sigma: float = 0.12):

    np.random.seed(seed=0)

    if option.lower() == "uniform":
        # Uniform random
        x = np.random.random(Np)
        y = np.random.random(Np)
        z = np.random.random(Np)

    elif option.lower() == "sobol":
        # Sobol uniform
        sampler = qmc.Sobol(d=3, scramble=True)
        sobol = sampler.random(n=Np)
        x = sobol[:, 0]
        y = sobol[:, 1]
        z = sobol[:, 2]

    elif option.lower() == "normal":
        # Gaussian
        x = np.random.normal(mu, sigma, Np)
        y = np.random.normal(mu, sigma, Np)
        z = np.random.normal(mu, sigma, Np)

    elif option.lower() == "mixed":
        # Gaussian
        x = np.random.normal(mu, sigma, Np)
        y = np.random.normal(mu, sigma, Np)
        z = np.random.random(Np)

    else:
        raise NotImplementedError("Unimplemented Option")

    return x, y, z


def generate_secondary_points(Np: int):

    x = np.random.random(Np)
    y = np.random.random(Np)
    z = np.random.random(Np)

    return x, y, z


@dataclass
class Range:
    x: np.ndarray
    y: np.ndarray
    z: np.ndarray

    def __post_init__(self):

        self.min_x, self.max_x = np.min(self.x), np.max(self.x)
        self.min_y, self.max_y = np.min(self.y), np.max(self.y)
        self.min_z, self.max_z = np.min(self.z), np.max(self.z)
        self.x_range = self.max_x - self.min_x
        self.y_range = self.max_y - self.min_y
        self.z_range = self.max_z - self.min_z


@dataclass
class MinMax:
    xmin: float
    xmax: float
    ymin: float
    ymax: float
    zmin: float
    zmax: float


def arrange_ranges(x, y, z, rg: MinMax): #xmin: float = 0, xmax: float = 1, ymin: float = 0, ymax: float = 1, zmin: float = 0, zmax: float = 1):

    x = x * (rg.xmax - rg.xmin) + rg.xmin
    y = y * (rg.ymax - rg.ymin) + rg.ymin
    z = z * (rg.zmax - rg.zmin) + rg.zmin

    data = Range(x, y, z)

    return data


def plot_histograms(data, Np, nbins, rg, shift, ax, show_hist: bool = False, z_hist: bool = False):

    if show_hist:

        y, x = np.histogram(data.x, bins=np.linspace(data.min_x, data.max_x, nbins))
        x = x[:-1]
        y = y / Np
        dx = data.x_range / (nbins - 1)
        ax.bar(x + dx/2, y, width=dx, edgecolor="black", linewidth=0.7, zdir="y", alpha=0.3, zs=rg.ymax + shift)
        ax.set_xlabel("x")

        y, x = np.histogram(data.y, bins=np.linspace(data.min_y, data.max_y, nbins))
        x = x[:-1]
        y = y / Np
        dx = data.y_range / (nbins - 1)
        ax.bar(x + dx/2, y, width=dx, edgecolor="black", linewidth=0.7, zdir="x", alpha=0.3, zs=rg.xmin - shift)
        ax.set_ylabel("y")
        ax.set_zlabel("z")

        if z_hist:
            # This doesn't work as i wanted
            y, x = np.histogram(data.z, bins=np.linspace(data.min_z, data.max_z, nbins))
            x = x[:-1]
            y = y / Np
            dx = data.z_range / (nbins - 1)
            # print(f'z_range: {z_range} dx: {dx}')
            # ax.bar(x + dx/2, y, width=dx, edgecolor="white", linewidth=0.7, zdir="z", alpha=0.6, zs=ymax)
            ax.bar(x + dx/2, y, width=dx, zs=rg.zmin-shift, edgecolor="white", linewidth=0.7, zdir="z", alpha=0.6, )
            ax.set_xlabel("z")

    # ax.set_xlim([xmin, xmax])
    # ax.set_ylim([ymin, ymax])
    # ax.set_zlim([zmin, zmax]);
    ax.set_xlim([rg.xmin - shift, rg.xmax])
    ax.set_ylim([rg.ymin, rg.ymax + shift])
    if z_hist:
        ax.set_zlim([rg.zmin - shift, rg.zmax])
    else:
        ax.set_zlim([rg.zmin, rg.zmax])

    ax.set_xticklabels([])
    ax.set_yticklabels([])
    ax.set_zticklabels([])

    ax.set_axis_off()

    # ax.view_init(elev=30, azim=45)
    # ax.view_init(elev=30, azim=0)
    # ax.view_init(elev=30, azim=90)
    # ax.view_init(azim=-90, elev=0)
    # ax.view_init(azim=-0, elev=0)
    # ax.view_init(azim=-90, elev=90)

    return


@dataclass
class Settings:
    Np: int
    Np2: int
    shift: float # Shift amount for the histograms
    nbins: int

def main():

    settings = Settings(Np=1000, Np2=10, shift=0.3, nbins=10)
    rg = MinMax(xmin=0, xmax=1, ymin=0, ymax=1, zmin=0, zmax=1)
    z_hist = False  # Show z-histogram
    z_hist = True
    show_hist = True
    show_second_set = False

    # x_inp, y_inp, z_inp = generate_random_points(settings.Np, option="normal")
    # x_inp, y_inp, z_inp = generate_random_points(settings.Np, option="mixed")
    # x_inp, y_inp, z_inp = generate_random_points(settings.Np, option="sobol")
    x_inp, y_inp, z_inp = generate_random_points(settings.Np, option="uniform")

    x_inp2, y_inp2, z_inp2 = generate_secondary_points(settings.Np2)

    data = arrange_ranges(x_inp, y_inp, z_inp, rg) #xmin=xmin, xmax=xmax, ymin=ymin, ymax=ymax, zmin=zmin, zmax=zmax)
    data2 = arrange_ranges(x_inp2, y_inp2, z_inp2, rg) #xmin=xmin, xmax=xmax, ymin=ymin, ymax=ymax, zmin=zmin, zmax=zmax)


    fig= plt.figure(figsize=(12, 8))
    ax = Axes3D(fig)
    ax.scatter(data.x, data.y, data.z)
    if show_second_set:
        ax.scatter(x_inp2, y_inp2, z_inp2, color='r', marker="o", s=60)

    plot_histograms(data, settings.Np, settings.nbins, rg, settings.shift, ax, show_hist=show_hist, z_hist=z_hist)

    plot_x_and_y_place_surfaces(1, 1, 1, ax, alpha=0.05)
    plot_rectangle_border_lines(0,1, 0,1, 0,1)

    plot_lines_to_histograms(settings.shift, 0, 1, 0, 1, 0, show_hist=show_hist, z_hist=z_hist)

    # ax.view_init(elev=30, azim=45)

    save_figs = False
    # save_figs = True

    i = 2
    if save_figs:
        # plot_me(proj_type= "ortho")
        save_snapshots_to_file(save_figs=save_figs, filename=f"3daxes_2d_histogram{i:0>4}.png")
        # plt.show()
    else:
        # plt.show()
        auto_close_matplotlib_plots()


if __name__ == "__main__":
    main()

Leave a comment