Last active
November 10, 2023 12:46
-
-
Save m3hrdadfi/6a52e1581290931624e61e8a57bea37d to your computer and use it in GitHub Desktop.
heatmap using plotly
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import plotly.graph_objects as go | |
def heat_map(x, y, z): | |
""" | |
Generates an interactive heat map using Plotly. | |
This function creates a heat map visualization with the provided x and y axis labels and z axis values. The colors of the map are set to the 'Viridis' scale. The layout of the plot is configured with titles and axes properties, such as the number of ticks, tick text, and tick font properties. The size of the heat map is automatically adjusted based on the input data. Finally, the heat map is displayed in the output. | |
Args: | |
x (list of str): A list of strings representing the labels on the x-axis (n,). | |
y (list of str): A list of strings representing the labels on the y-axis (m,). | |
z (2D list of float): A 2D list of floats representing the data points for the heat map. (nxm) | |
Note: | |
The function assumes the existence of global variables `layer`, `n_layers`, `head`, `n_heads`, `width`, and `height` for layout configuration. The `attn` variable is also used but not passed as an argument, which should be a 2D numpy array or nested list containing the attention weights. | |
Raises: | |
This function does not explicitly handle any exceptions, and it is expected that the caller handles exceptions related to data type errors or value errors. | |
Returns: | |
None: This function does not return any value. The heat map is displayed directly using `fig.show()`. | |
Examples: | |
>>> heat_map(['Token1', 'Token2'], ['TokenA', 'TokenB'], [[0.1, 0.2], [0.3, 0.4]]) | |
# This will create and display a heat map with 2x2 data points. | |
""" | |
x_lim = len(x) | |
y_lim = len(y) | |
# Create the heatmap | |
fig = go.Figure( | |
data=go.Heatmap( | |
z=z[:y_lim, :x_lim], | |
x=x, | |
y=y, | |
colorscale="Viridis", | |
hoverongaps=False, | |
) | |
) | |
# Update the layout | |
fig.update_layout( | |
title=f"Heatmap", | |
xaxis_title="Input", | |
yaxis_title="Output", | |
) | |
fig.update_layout( | |
xaxis=dict( | |
nticks=len(x), | |
ticktext=x, | |
dtick=0, | |
# type="category", | |
tickangle=-90, | |
tickfont=dict(family="Helvetica", size=11, color="black"), | |
), | |
yaxis=dict( | |
nticks=len(y), | |
ticktext=y, | |
dtick=0, | |
# type="category", | |
# tickangle=-90, | |
tickfont=dict(family="Helvetica", size=11, color="black"), | |
), | |
) | |
fig.update_layout( | |
autosize=True, | |
width=width, | |
height=height, | |
) | |
fig.show() | |
print() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment