Last active
July 5, 2023 20:52
-
-
Save JotaRata/87c50a8c6bb9517316feb29a8a241aeb to your computer and use it in GitHub Desktop.
Matplotlib subplots shortcut
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class get_axis(): | |
""" | |
Context manager to generate subplots quickly | |
Parameters: | |
- rows, cols (int): Rows and columns to draw. | |
- xlabel, ylabel (str / list): Labels for the horizontal and vertical axis, If xlabel/ylabel are lists then each corresponding subplot will have its own labels. | |
- title (str / list): Title of the plot, If title is a list then each subplot will have its own title. | |
- figzise (tuple): Size of the figure in inches (same as figure.figsize). | |
- figargs (dict): Additionl arguments passed to plt.subplots | |
Usage: | |
This class should be used along the with statement | |
Example 1: | |
with get_axis(1, 1, title='Plot 1', figsize= (6, 6)) as ax: | |
ax.scatter(data[:, 0], data[:, 1], label= 'Data 1') | |
ax.scatter(data2[:, 0], data2[:, 1], label= 'Data 2') | |
ax.legend() | |
Example 2: | |
with get_axis(1, 2, title= ['Cosine function', 'Squared Cosine function'], xlabel= 'Time', sharey= True) as axs: | |
t = np.linspace(0, 6.28, 20) # variable t cannot be referenced outside this context | |
axs[0].plot(t, np.cos(t)) | |
axs[1].plot(t, np.cos(t) ** 2) | |
""" | |
def __init__(self, rows=1, cols=1, xlabel= None, ylabel=None, title= None, axsize= (10, 6), **figargs): | |
self.r= rows; self.c= cols | |
self.xlabel= np.pad(xlabel, (0, rows*cols - np.size(xlabel)), mode='constant', constant_values=('', '')) if type(xlabel) is tuple or type(xlabel) is list else xlabel | |
self.ylabel= np.pad(ylabel, (0, rows*cols - np.size(ylabel)), mode='constant', constant_values=('', '')) if type(ylabel) is tuple or type(ylabel) is list else ylabel | |
self.title = np.pad(title, (0, rows*cols - np.size(title)), mode='constant', constant_values=('', '')) if type(title) is tuple or type(title) is list else title | |
self.args= figargs | |
if 'figsize' not in figargs.keys(): | |
self.size = (axsize[0] * cols, axsize[1] * rows) | |
else: | |
self.size = figargs['figsize'] | |
del self.args['figsize'] | |
def __enter__(self): | |
_, ax = plt.subplots(self.r, self.c, figsize= self.size, **self.args) | |
if isinstance(ax, plt.Axes): | |
if type(self.xlabel) is str: | |
ax.set_xlabel(self.xlabel) | |
if type(self.ylabel) is str: | |
ax.set_ylabel(self.ylabel) | |
ax.set_title(self.title) | |
elif isinstance(ax, np.ndarray): | |
ax= ax.flatten() | |
_i = range(np.size(ax)) | |
if self.xlabel is not None: | |
if type(self.xlabel) is str: | |
for i in _i: ax[i].set_xlabel(self.xlabel) | |
elif is_list(ax): | |
for i in _i: ax[i].set_xlabel(self.xlabel[i]) | |
if self.ylabel is not None: | |
if type(self.ylabel) is str: | |
for i in _i: ax[i].set_ylabel(self.ylabel) | |
elif is_list(ax): | |
for i in _i: ax[i].set_ylabel(self.ylabel[i]) | |
if self.title is not None: | |
if type(self.title) is str: | |
for i in _i: ax[i].set_title(self.title) | |
elif is_list(ax): | |
for i in _i: ax[i].set_title(self.title[i]) | |
return ax | |
def __exit__(self, *args): pass |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Context manager to generate subplots quickly
Parameters:
Usage:
This class should be used along the with statement
Example 1:
Example 2:
You can also use the returned object to store temporal variables that live inside the current scope.