| 
          import jax.profiler | 
        
        
           | 
          jax.profiler.start_server(9999) | 
        
        
           | 
          import numpy as onp | 
        
        
           | 
          import jax.numpy as jnp | 
        
        
           | 
          from functools import partial | 
        
        
           | 
          from jax import random | 
        
        
           | 
          from jax.nn.initializers import (xavier_normal, xavier_uniform, glorot_normal, glorot_uniform, uniform,  | 
        
        
           | 
                                           normal, lecun_uniform, lecun_normal,kaiming_uniform,kaiming_normal) | 
        
        
           | 
          
 | 
        
        
           | 
          from jax.nn import (softplus, selu,gelu,glu,swish,relu,relu6,elu,sigmoid, swish) | 
        
        
           | 
          from jax import vmap, grad, partial, pmap, value_and_grad, jit | 
        
        
           | 
          
 | 
        
        
           | 
          from jax.experimental.ode import odeint | 
        
        
           | 
          
 | 
        
        
           | 
          coupling_matrix_ = onp.load('./coupling_matrix.npy') | 
        
        
           | 
          epi_array_ = onp.load('./epi_array.npy') | 
        
        
           | 
          mobilitypopulation_array_scaled_ = onp.load('./mobilitypopulation_array_scaled.npy') | 
        
        
           | 
          coupling_matrix = jnp.asarray(coupling_matrix_) | 
        
        
           | 
          epi_array = jnp.asarray(epi_array_) | 
        
        
           | 
          mobilitypopulation_array_scaled = jnp.asarray(mobilitypopulation_array_scaled_) | 
        
        
           | 
          
 | 
        
        
           | 
          def inv_softplus(x): | 
        
        
           | 
              return x+jnp.log(-jnp.expm1(-x)) | 
        
        
           | 
          
 | 
        
        
           | 
          key = random.PRNGKey(0) | 
        
        
           | 
          layers = [7, 14, 14, 7, 1] | 
        
        
           | 
          activations = [swish, swish, swish, softplus] | 
        
        
           | 
          weight_initializer = kaiming_uniform | 
        
        
           | 
          bias_initializer = normal | 
        
        
           | 
          
 | 
        
        
           | 
          def init_layers(nn_layers,nn_weight_initializer_, | 
        
        
           | 
                          nn_bias_initializer_): | 
        
        
           | 
              init_w = weight_initializer() | 
        
        
           | 
              init_b = bias_initializer() | 
        
        
           | 
              params = [] | 
        
        
           | 
              for in_, out_ in zip(layers[:-1],layers[1:]): | 
        
        
           | 
                  key = random.PRNGKey(in_) | 
        
        
           | 
                  weights = init_w(key,(in_,out_)).reshape((in_*out_,)) | 
        
        
           | 
                  biases = init_b(key,(out_,)) | 
        
        
           | 
                  params_ = jnp.concatenate((weights,biases)) | 
        
        
           | 
                  params.append(params_) | 
        
        
           | 
              return jnp.concatenate(params) | 
        
        
           | 
          
 | 
        
        
           | 
          def nnet(nn_layers, nn_activations, nn_params, x): | 
        
        
           | 
              n_s = 0 | 
        
        
           | 
              x_in = jnp.expand_dims(x,axis=1) # | 
        
        
           | 
              #x_in = x.reshape(len(x),1) | 
        
        
           | 
              for in_,out_, act_ in zip(nn_layers[:-1],nn_layers[1:],nn_activations): | 
        
        
           | 
                  n_w = in_*out_ | 
        
        
           | 
                  n_b = out_ | 
        
        
           | 
                  n_t = n_w+n_b | 
        
        
           | 
                  weights = nn_params[n_s:n_s+n_w].reshape((out_,in_)) | 
        
        
           | 
                  biases = jnp.expand_dims(nn_params[n_s+n_w:n_s+n_t],axis=1) | 
        
        
           | 
                  x_in = act_(jnp.matmul(weights,x_in)+biases) | 
        
        
           | 
                  n_s += n_t | 
        
        
           | 
          
 | 
        
        
           | 
              return x_in | 
        
        
           | 
          
 | 
        
        
           | 
          nn = jit(partial(nnet, layers,activations)) | 
        
        
           | 
          nn_batch = vmap(partial(nnet,layers,activations), (None,0),0) | 
        
        
           | 
          #nn_batch=partial(nnet, layers,activations) | 
        
        
           | 
          
 | 
        
        
           | 
          p_net = init_layers(layers,weight_initializer,bias_initializer) | 
        
        
           | 
          
 | 
        
        
           | 
          # county-wise learnable scaling factors | 
        
        
           | 
          n_counties = coupling_matrix.shape[0] | 
        
        
           | 
          init_b = bias_initializer() | 
        
        
           | 
          p_scaling = softplus(200*init_b(key,(n_counties,))) | 
        
        
           | 
          
 | 
        
        
           | 
          def SEIRD_mobility_coupled(u, t, p_, mobility_, coupling_matrix_): | 
        
        
           | 
              s, e, id1, id2, id3, id4, id5, id6, id7, d, ir1, ir2, ir3, ir4, ir5, r = u | 
        
        
           | 
              κ, α, γ = softplus(p_[:3]) | 
        
        
           | 
              # κ*α and γ*η are not independent. The probablibility of transition from e to Ir and Id has to add up to 1 | 
        
        
           | 
              η = - jnp.log(-jnp.expm1(-κ*α))/(γ+1.0e-8)  | 
        
        
           | 
              ind = jnp.rint(t.astype(jnp.float32)) | 
        
        
           | 
              n_c = coupling_matrix_.shape[0] | 
        
        
           | 
              scaler_ = softplus(p_[3:3+n_c]) | 
        
        
           | 
              cm_ = jnp.expand_dims(scaler_,(1))*coupling_matrix_[...,ind.astype(jnp.int32)] | 
        
        
           | 
              β = nn_batch(p_[3+n_c:], mobility_[...,ind.astype(jnp.int32)])[:,0,0] | 
        
        
           | 
              i = id1+id2+id3+ir1+ir2+ir3+ir4+ir5 | 
        
        
           | 
               | 
        
        
           | 
              a = β*s*i+β*s*(jnp.matmul(i,cm_.T)+jnp.matmul(cm_,i)) | 
        
        
           | 
              ds = -a | 
        
        
           | 
              de = a - κ*α*e - γ*η*e | 
        
        
           | 
               | 
        
        
           | 
              d_id1 = κ*(α*e-id1) | 
        
        
           | 
              d_id2 = κ*(id1-id2) | 
        
        
           | 
              d_id3 = κ*(id2-id3) | 
        
        
           | 
              d_id4 = κ*(id3-id4) | 
        
        
           | 
              d_id5 = κ*(id4-id5) | 
        
        
           | 
              d_id6 = κ*(id5-id6) | 
        
        
           | 
              d_id7 = κ*(id6-id7) | 
        
        
           | 
              d_d = κ*id7 | 
        
        
           | 
               | 
        
        
           | 
              d_ir1 = γ*(η*e-ir1) | 
        
        
           | 
              d_ir2 = γ*(ir1-ir2) | 
        
        
           | 
              d_ir3 = γ*(ir2-ir3) | 
        
        
           | 
              d_ir4 = γ*(ir3-ir4) | 
        
        
           | 
              d_ir5 = γ*(ir4-ir5) | 
        
        
           | 
              d_r = γ*ir5 | 
        
        
           | 
               | 
        
        
           | 
              return jnp.stack([ds, | 
        
        
           | 
                                de, | 
        
        
           | 
                                d_id1, d_id2, d_id3, d_id4, d_id5, d_id6, d_id7, d_d, | 
        
        
           | 
                                d_ir1 ,d_ir2, d_ir3, d_ir4, d_ir5, d_r]) | 
        
        
           | 
          
 | 
        
        
           | 
          # Initial conditions | 
        
        
           | 
          ifr = 0.007 | 
        
        
           | 
          n_counties = epi_array.shape[2]  | 
        
        
           | 
          n = jnp.tile(1.0,(n_counties,)) | 
        
        
           | 
          ic0 = epi_array[0,0,:] | 
        
        
           | 
          d0 = epi_array[0,1,:] | 
        
        
           | 
          r0 = d0/ifr | 
        
        
           | 
          s0 = n-ic0-r0-d0 | 
        
        
           | 
          e0 = ic0 | 
        
        
           | 
          id10=id20=id30=id40=id50=id60=id70=ic0*ifr/7.0 | 
        
        
           | 
          ir10=ir20=ir30=ir40=ir50=ic0*(1.0-ifr)/5.0 | 
        
        
           | 
          u0 = jnp.array([s0,  | 
        
        
           | 
                          e0, | 
        
        
           | 
                         id10,id20,id30,id40,id50, id60, id70, d0, | 
        
        
           | 
                         ir10,ir20,ir30,ir40,ir50,r0]) | 
        
        
           | 
          
 | 
        
        
           | 
          # ODE Parameters | 
        
        
           | 
          κ0_ = 0.97 | 
        
        
           | 
          α0_ = 0.00185 | 
        
        
           | 
          β0_ = 0.5 | 
        
        
           | 
          tb_ = 15 | 
        
        
           | 
          β1_ = 0.4 | 
        
        
           | 
          γ0_ = 0.24 | 
        
        
           | 
          
 | 
        
        
           | 
          
 | 
        
        
           | 
          
 | 
        
        
           | 
          p_ode = inv_softplus(jnp.array([κ0_, α0_, γ0_])) | 
        
        
           | 
          
 | 
        
        
           | 
          # Initial model parameters | 
        
        
           | 
          p_init = jnp.concatenate((p_ode,p_scaling,p_net)) | 
        
        
           | 
          
 | 
        
        
           | 
          t0 = jnp.linspace(0, float(epi_array.shape[0]), int(epi_array.shape[0])+1) | 
        
        
           | 
          
 | 
        
        
           | 
          # LOSS Function | 
        
        
           | 
          def diff(sol_,data_): | 
        
        
           | 
              l1 = jnp.square(jnp.ediff1d((1-sol_[:,0])) - data_[:,0]) | 
        
        
           | 
              l2 = jnp.square(jnp.ediff1d(sol_[:,9]) - data_[:,1]) | 
        
        
           | 
              return l1+20000*l2 | 
        
        
           | 
          diff_v = vmap(diff,(2,2)) | 
        
        
           | 
          
 | 
        
        
           | 
          def loss(data_,m_array_, coupling_matrix_, params_): | 
        
        
           | 
              sol_ = odeint(SEIRD_mobility_coupled, u0, t0, params_, m_array_,coupling_matrix_,  | 
        
        
           | 
                            rtol=1e-4, atol=1e-8) | 
        
        
           | 
              return jnp.sum(diff_v(sol_,data_))  | 
        
        
           | 
          
 | 
        
        
           | 
          loss_ = partial(loss, epi_array,mobilitypopulation_array_scaled,coupling_matrix) | 
        
        
           | 
          
 | 
        
        
           | 
          grad_jit = jit(grad(loss_)) | 
        
        
           | 
          loss_jit = jit(loss_) | 
        
        
           | 
          
 | 
        
        
           | 
          %timeit loss_jit(p_init).block_until_ready() | 
        
        
           | 
          # 91.1 ms ± 4.13 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) | 
        
        
           | 
          
 | 
        
        
           | 
          %timeit grad_jit(p_init).block_until_ready() | 
        
        
           | 
          # 24.6 s ± 1.13 s per loop (mean ± std. dev. of 7 runs, 1 loop each) |