from plotly.offline import download_plotlyjs, init_notebook_mode, plot, iplot
import plotly.graph_objs as go
import numpy as np

def hidden_axis(ax, r):
    ax.showgrid = False
    ax.zeroline = False
    ax.showline = False
    ax.ticks = ''
    ax.showticklabels = False
    # ax.showaxeslabels = False
    ax.range = r
    ax.title = ""


def plot_point_cloud(clouds, colors,marker_size=15,colorscale="Greys",opacity=0.5,zyx_range=None,filename='tmp.html'):
    z,y,x = clouds
    trace1 = go.Scatter3d(
    x=x,
    y=y,
    z=z,
    mode='markers',
    marker=dict(
        size=marker_size,
        color=c,                # set color to an array/list of desired values
        colorscale=colorscale,   # choose a colorscale
        opacity=opacity
    ))
    
    data = [trace1]
    layout = go.Layout(
        margin=dict(
            l=None,
            r=None,
            b=None,
            t=None
        ),
    )
    fig = go.Figure(data=data, layout=layout)
    
    if zyx_range is not None:
        hidden_axis(fig.layout.scene.zaxis, zyx_range[0])
        hidden_axis(fig.layout.scene.yaxis, zyx_range[1])
        hidden_axis(fig.layout.scene.xaxis, zyx_range[2])
        
    plot(fig, filename=filename)
    
def plot_voxel_as_point_clouds(voxel, segmentation,
                               marker_size=10,colorscale="Greys",opacity=0.5,
                               zyx_range=None,filename='tmp.html'):
    assert voxel.shape==segmentation.shape
    if zyx_range is None:
        d,h,w = voxel.shape
        zyx_range = ((0,d),(0,h),(0,w))
    
    clouds = np.where(segmentation)
    colors = voxel[clouds]
    
    plot_point_cloud(clouds, colors, marker_size, colorscale, opacity, zyx_range, filename)
    
if __name__=="__main__":
  v = np.random.random((32,32,32))
  seg = v>0.99
  plot_voxel_as_point_clouds(v, seg)