import numpy as np
import matplotlib.pyplot as plt
from pynlo.interactions.FourWaveMixing import SSFM
from pynlo.media.fibers import fiber
from pynlo.light.DerivedPulses import SechPulse

########## Parameters:
Pulse1   = .200  # pulse duration (ps)
Pulse2   = .200
Pulse3   = .200 

pulseWL1 = 1035.0         # pulse central wavelength (nm)\
pulseWL2 = pulseWL1/2.0   
pulseWL3 = pulseWL1/0.74  

EPP1    = 0.100e-6  # Pulse energy (J)
EPP2    = 0.100e-6
EPP3    = 0.010e-6  
  
Length  = 10.00 # mm

GDD1     = 0*0.040   # Group delay dispersion (ps^2)
GDD2     = 0*0.040
GDD3     = 0*0.020

TOD      = 0.0        # Third order dispersion (ps^3), currently the same for all pulses

Window  = 6.0   # simulation window (ps)
Steps   = 20 # simulation steps - note that the integrator typically takes many more steps than this 
Points  = 2**16 # simulation points
error   = 0.001  # the error desired for the integration. 0.001 is usually good. larger allows for faster calculations

Alpha   = 0.0   # attentuation coefficient (dB/cm, pretty sure)

Aeff = 2*np.pi*10e-6**2   # effective area in m^2
n2_per_torr = 1.1e-22/760 # nonlinearity in m^2/(W torr)

Gamma_per_torr = 2*np.pi*n2_per_torr/(pulseWL1*1e-9 * Aeff)

Raman = False # enable Raman? (calculates fused silica Raman response)
Steep = True  # enable self-steepening?

fund_THz = 3e8/(pulseWL1*1e-9) * 1e-12
loss_lo = fund_THz * 5.5
loss_hi = fund_THz * 6.5
loss  = 0 # Alpha in the special attentuation region, in dB/cm

max_pressure = 1350 # pressure in mBar
#######################


fibWL = pulseWL1   # Center WL (nm)
alpha = np.log((10**(Alpha * 0.1))) * 100  # convert from dB/cm to 1/m
loss  = np.log((10**(loss  * 0.1))) * 100

def dB(num):
    return 10 * np.log10(np.abs(num)**2)

def pressure_at_z(z):
    # code can be added here to change pressure as a function of position
    return max_pressure

def myDispersion(z):
    # pressure = pressure_at_z(z*1e3/Length)
    pressure = pressure_at_z(z*1e3/Length)
    
    waves = np.linspace(0.05, 1.2, 1000)
    RIs = 0*waves + 1 # just set the index to 1 everywhere
    return waves*1e3, RIs

def myGamma(z):
    pressure = pressure_at_z(z*1e3/Length)
    pressure_torr = pressure*760.0/1000.0
    return pressure_torr * Gamma_per_torr
    


fig = plt.figure(figsize=(8,8))
ax0 = plt.subplot2grid((3,2), (0, 0), rowspan=1)
ax1 = plt.subplot2grid((3,2), (0, 1), rowspan=1)
ax2 = plt.subplot2grid((3,2), (1, 0), rowspan=2, sharex=ax0)
ax3 = plt.subplot2grid((3,2), (1, 1), rowspan=2, sharex=ax1)

axs = (ax0, ax1, ax2, ax3)

fig.suptitle('Pulse 1: %.0f nm, %.0f fs, %.0f nJ \nPulse 2: %.0f nm, %.0f fs, %.0f nJ\nPulse 3: %.0f nm, %.0f fs, %.0f nJ'%(pulseWL1, Pulse1*1e3, EPP1*1e9, pulseWL2, Pulse2*1e3, EPP2*1e9, pulseWL3, Pulse3*1e3, EPP3*1e9))


# set up the pulse parameters
pulse1 = SechPulse(1, Pulse1/1.76, pulseWL1, time_window_ps=Window, NPTS=Points, frep_MHz=100, power_is_avg=False)
pulse2 = SechPulse(1, Pulse2/1.76, pulseWL2, time_window_ps=Window, NPTS=Points, frep_MHz=100, power_is_avg=False)
pulse3 = SechPulse(1, Pulse3/1.76, pulseWL3, time_window_ps=Window, NPTS=Points, frep_MHz=100, power_is_avg=False)

pulse1.set_epp(EPP1)
pulse2.set_epp(EPP2)
pulse3.set_epp(EPP3)

pulse1.chirp_pulse_W(GDD1)
pulse2.chirp_pulse_W(GDD2)
pulse3.chirp_pulse_W(GDD3)

# pulse2.rotate_spectrum_to_new_center_wl(pulseWL2)
pulse1.set_AW(pulse1.AW + pulse2.interpolate_to_new_center_wl(pulseWL1).AW + pulse3.interpolate_to_new_center_wl(pulseWL1).AW)

fiber1 = fiber.FiberInstance()
fiber1.generate_fiber(Length * 1e-3, center_wl_nm=fibWL, betas=(0,0,0),
                              gamma_W_m=0, gain=-alpha)

fiber1.set_dispersion_function(myDispersion, dispersion_format='n')
fiber1.set_gamma_function(myGamma)

F = pulse1.F_THz     # Frequency grid of pulse (THz)

# include loss:
gain = -alpha * np.zeros(F.size)
# gain[((F>loss_lo)&(F<loss_hi))] = -loss
gain[((np.abs(F)>loss_lo)&(np.abs(F)<loss_hi))] = -loss
fiber1.set_gain(gain)

# Propagation
evol = SSFM.SSFM(local_error=error, USE_SIMPLE_RAMAN=True,
                 disable_Raman=np.logical_not(Raman),
                 disable_self_steepening=np.logical_not(Steep))


y, AW, AT, pulse_out = evol.propagate(pulse_in=pulse1, fiber=fiber1, n_steps=Steps, reload_fiber_each_step=True)


########## That's it! Physics complete. Just plotting commands from here! ################


def dB(num):
    return 10 * np.log10(np.abs(num)**2)

zW = dB( np.transpose(AW)[:, (F > 0)] )
zT = dB( np.transpose(AT) )

y_mm = y * 1e3 # convert distance to mm

ax0.plot(pulse_out.F_THz,    dB(pulse_out.AW),  color = 'r')
ax1.plot(pulse_out.T_ps,     dB(pulse_out.AT),  color = 'r')

ax0.plot(pulse1.F_THz,    dB(pulse1.AW),  color = 'b')
ax1.plot(pulse1.T_ps,     dB(pulse1.AT),  color = 'b')

print 'timestep:', pulse1.T_ps[1]-pulse1.T_ps[0]

extent = (np.min(F[F > 0]), np.max(F[F > 0]), 0, Length)

ax2.imshow(zW, extent=extent, 
           vmin=np.max(zW) - 200.0, vmax=np.max(zW), 
           aspect='auto', origin='lower')

extent = (np.min(pulse1.T_ps), np.max(pulse1.T_ps), np.min(y_mm), Length)
ax3.imshow(zT, extent=extent, 
           vmin=np.max(zT) - 100.0, vmax=np.max(zT), 
           aspect='auto', origin='lower')

ax0.set_ylabel('Intensity (dB)')
ax0.set_ylim( - 250,  100)
ax1.set_ylim( - 200, 100)

ax2.set_ylabel('Propagation distance (mm)')
ax2.set_xlabel('Frequency (THz)')
ax2.set_xlim(0,4000)

ax3.set_xlabel('Time (ps)')

for ax in axs:
    ax.grid(alpha=0.1)

plt.subplots_adjust(left=0.10, bottom=0.07, right=0.98, top=0.89, wspace=0.26, hspace=0.20)
plt.savefig('Three-pulses.png', dpi=200)





plt.show()