@@ -446,6 +446,10 @@ def __init__(self, ckt: Model, **kwargs) -> None:
446446 raise ValueError ("Must specify 'wl' (wavelengths to simulate)." )
447447 super ().__init__ (ckt , kwargs ["wl" ])
448448
449+ # get the unitary s-parameters of the circuit
450+ self .s_params = dict_to_matrix (self .ckt ())
451+ self .unitary = self .to_unitary (self .s_params )
452+
449453 def add_qstate (self , qstate : QuantumState ) -> None :
450454 """Add a quantum state to the simulation.
451455
@@ -494,13 +498,10 @@ def run(self) -> QuantumResult:
494498 """Run the simulation."""
495499 ports = get_ports (self .ckt ())
496500 n_ports = len (ports )
497- # get the unitary s-parameters of the circuit
498- s_params = dict_to_matrix (self .ckt ())
499- unitary = self .to_unitary (s_params )
500501 # get an array of the indices of the input ports
501502 input_indices = [ports .index (port ) for port in self .input .ports ]
502503 # create vacuum ports for each extra mode in the unitary matrix
503- n_modes = unitary .shape [1 ]
504+ n_modes = self . unitary .shape [1 ]
504505 n_vacuum = n_modes - len (input_indices )
505506 self .input ._add_vacuums (n_vacuum )
506507 input_indices += [i for i in range (n_modes ) if i not in input_indices ]
@@ -511,7 +512,7 @@ def run(self) -> QuantumResult:
511512 means = []
512513 covs = []
513514 for wl_ind in range (len (self .wl )):
514- s_wl = unitary [wl_ind ]
515+ s_wl = self . unitary [wl_ind ]
515516 transform = jnp .zeros ((n_modes * 2 , n_modes * 2 ))
516517 n = n_modes
517518
@@ -533,7 +534,7 @@ def run(self) -> QuantumResult:
533534 covs .append (output_cov )
534535
535536 return QuantumResult (
536- s_params = s_params ,
537+ s_params = self . s_params ,
537538 input_means = input_means ,
538539 input_cov = input_cov ,
539540 transforms = jnp .stack (transforms ),
0 commit comments