Skip to content

Instantly share code, notes, and snippets.

@m3hrdadfi
Last active November 10, 2023 12:46
Show Gist options
  • Save m3hrdadfi/6a52e1581290931624e61e8a57bea37d to your computer and use it in GitHub Desktop.
Save m3hrdadfi/6a52e1581290931624e61e8a57bea37d to your computer and use it in GitHub Desktop.
heatmap using plotly
import plotly.graph_objects as go
def heat_map(x, y, z):
"""
Generates an interactive heat map using Plotly.
This function creates a heat map visualization with the provided x and y axis labels and z axis values. The colors of the map are set to the 'Viridis' scale. The layout of the plot is configured with titles and axes properties, such as the number of ticks, tick text, and tick font properties. The size of the heat map is automatically adjusted based on the input data. Finally, the heat map is displayed in the output.
Args:
x (list of str): A list of strings representing the labels on the x-axis (n,).
y (list of str): A list of strings representing the labels on the y-axis (m,).
z (2D list of float): A 2D list of floats representing the data points for the heat map. (nxm)
Note:
The function assumes the existence of global variables `layer`, `n_layers`, `head`, `n_heads`, `width`, and `height` for layout configuration. The `attn` variable is also used but not passed as an argument, which should be a 2D numpy array or nested list containing the attention weights.
Raises:
This function does not explicitly handle any exceptions, and it is expected that the caller handles exceptions related to data type errors or value errors.
Returns:
None: This function does not return any value. The heat map is displayed directly using `fig.show()`.
Examples:
>>> heat_map(['Token1', 'Token2'], ['TokenA', 'TokenB'], [[0.1, 0.2], [0.3, 0.4]])
# This will create and display a heat map with 2x2 data points.
"""
x_lim = len(x)
y_lim = len(y)
# Create the heatmap
fig = go.Figure(
data=go.Heatmap(
z=z[:y_lim, :x_lim],
x=x,
y=y,
colorscale="Viridis",
hoverongaps=False,
)
)
# Update the layout
fig.update_layout(
title=f"Heatmap",
xaxis_title="Input",
yaxis_title="Output",
)
fig.update_layout(
xaxis=dict(
nticks=len(x),
ticktext=x,
dtick=0,
# type="category",
tickangle=-90,
tickfont=dict(family="Helvetica", size=11, color="black"),
),
yaxis=dict(
nticks=len(y),
ticktext=y,
dtick=0,
# type="category",
# tickangle=-90,
tickfont=dict(family="Helvetica", size=11, color="black"),
),
)
fig.update_layout(
autosize=True,
width=width,
height=height,
)
fig.show()
print()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment