Skip to content

Commit

Permalink
fock state transmission table generation
Browse files Browse the repository at this point in the history
  • Loading branch information
daquintero committed Nov 12, 2024
1 parent f62cf28 commit 6fe2cfe
Show file tree
Hide file tree
Showing 18 changed files with 253 additions and 93 deletions.
2 changes: 1 addition & 1 deletion docs/examples/03b_optical_function_verification.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@
"target_mode_output": None,
},
]
target_optical_state_transition_mzi_2x2 = piel.types.OpticalStateTransitions(
target_optical_state_transition_mzi_2x2 = piel.types.OpticalStateTransitionCollection(
transmission_data=target_output_transition_mzi_2x2,
mode_amount=2,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
# We begin by importing a parametric circuit from `gdsfactory`:
import hdl21 as h
import numpy as np
import jax.numpy as jnp
import pandas as pd
from gdsfactory.generic_tech import get_generic_pdk
import piel
Expand All @@ -54,6 +55,7 @@
component_lattice_generic,
straight_heater_metal_simple,
)
import os

# First, let's set up the filesystem in the directory in which all our files will be generated and stored. This is really an extension of a full mixed-signal design compatible with the tools supported by `piel`.

Expand Down Expand Up @@ -147,8 +149,27 @@ def create_switch_fabric():
target_mode_index=2,
)

chain_fock_state_transitions.transmission_data[0]

chain_fock_state_transitions.transition_dataframe

transition_dataframe_latex = piel.visual.table.electro_optic.compose_optical_state_transition_dataframe_latex_table(
chain_fock_state_transitions.transition_dataframe
)
target_output_dataframe_latex = piel.visual.table.electro_optic.compose_optical_state_transition_dataframe_latex_table(
chain_fock_state_transitions.target_output_dataframe
)
piel.write_file(
directory_path=os.getenv("TAT"),
file_text=transition_dataframe_latex,
file_name="chain_3_transition_dataframe.tex",
)
piel.write_file(
directory_path=os.getenv("TAT"),
file_text=target_output_dataframe_latex,
file_name="chain_3_target_dataframe.tex",
)

# We can plot this to show the electronic-photonic behaviour we want to see:

chain_fock_state_transitions.transmission_data[0].keys()
Expand All @@ -162,7 +183,7 @@ def create_switch_fabric():
# Now, each of these electronic phases applied correspond to a given digital value that we want to implement on the electronic logic.

basic_ideal_phase_map = piel.models.logic.electro_optic.linear_bit_phase_map(
bits_amount=5, final_phase_rad=np.pi, initial_phase_rad=0
bits_amount=5, final_phase_rad=jnp.pi, initial_phase_rad=0
)
basic_ideal_phase_map.dataframe

Expand Down
Original file line number Diff line number Diff line change
@@ -1,52 +1,52 @@
/* Generated by Yosys 0.38 (git sha1 543faed9c8c, clang++ 17.0.6 -fPIC -Os) */
/* Generated by Amaranth Yosys 0.40 (PyPI ver 0.40.0.0.post100, git sha1 a1bb0255d) */

(* top = 1 *)
(* generator = "Amaranth" *)
module top(bit_phase_0, bit_phase_1, input_fock_state_str);
reg \$auto$verilog_backend.cc:2334:dump_module$1 = 0;
(* src = "/nix/store/gi3w8lbvfd4x4gf9m2shpybb58mdkzrr-python3.11-piel-0.1.0/lib/python3.11/site-packages/piel/tools/amaranth/construct.py:86" *)
reg \$auto$verilog_backend.cc:2352:dump_module$1 = 0;
(* src = "/home/daquintero/phd/piel/piel/tools/amaranth/construct.py:86" *)
output [4:0] bit_phase_0;
reg [4:0] bit_phase_0;
(* src = "/nix/store/gi3w8lbvfd4x4gf9m2shpybb58mdkzrr-python3.11-piel-0.1.0/lib/python3.11/site-packages/piel/tools/amaranth/construct.py:86" *)
(* src = "/home/daquintero/phd/piel/piel/tools/amaranth/construct.py:86" *)
output [4:0] bit_phase_1;
reg [4:0] bit_phase_1;
(* src = "/nix/store/gi3w8lbvfd4x4gf9m2shpybb58mdkzrr-python3.11-piel-0.1.0/lib/python3.11/site-packages/piel/tools/amaranth/construct.py:82" *)
(* src = "/home/daquintero/phd/piel/piel/tools/amaranth/construct.py:82" *)
input [2:0] input_fock_state_str;
wire [2:0] input_fock_state_str;
always @* begin
if (\$auto$verilog_backend.cc:2334:dump_module$1 ) begin end
if (\$auto$verilog_backend.cc:2352:dump_module$1 ) begin end
(* full_case = 32'd1 *)
(* src = "/nix/store/gi3w8lbvfd4x4gf9m2shpybb58mdkzrr-python3.11-piel-0.1.0/lib/python3.11/site-packages/piel/tools/amaranth/construct.py:112" *)
(* src = "/home/daquintero/phd/piel/piel/tools/amaranth/construct.py:112" *)
casez (input_fock_state_str)
/* src = "/nix/store/gi3w8lbvfd4x4gf9m2shpybb58mdkzrr-python3.11-piel-0.1.0/lib/python3.11/site-packages/piel/tools/amaranth/construct.py:115" */
/* src = "/home/daquintero/phd/piel/piel/tools/amaranth/construct.py:115" */
3'h4:
bit_phase_0 = 5'h00;
/* src = "/nix/store/gi3w8lbvfd4x4gf9m2shpybb58mdkzrr-python3.11-piel-0.1.0/lib/python3.11/site-packages/piel/tools/amaranth/construct.py:115" */
/* src = "/home/daquintero/phd/piel/piel/tools/amaranth/construct.py:115" */
3'h1:
bit_phase_0 = 5'h00;
/* src = "/nix/store/gi3w8lbvfd4x4gf9m2shpybb58mdkzrr-python3.11-piel-0.1.0/lib/python3.11/site-packages/piel/tools/amaranth/construct.py:115" */
/* src = "/home/daquintero/phd/piel/piel/tools/amaranth/construct.py:115" */
3'h2:
bit_phase_0 = 5'h1f;
/* src = "/nix/store/gi3w8lbvfd4x4gf9m2shpybb58mdkzrr-python3.11-piel-0.1.0/lib/python3.11/site-packages/piel/tools/amaranth/construct.py:124" */
/* src = "/home/daquintero/phd/piel/piel/tools/amaranth/construct.py:124" */
default:
bit_phase_0 = 5'h00;
endcase
end
always @* begin
if (\$auto$verilog_backend.cc:2334:dump_module$1 ) begin end
if (\$auto$verilog_backend.cc:2352:dump_module$1 ) begin end
(* full_case = 32'd1 *)
(* src = "/nix/store/gi3w8lbvfd4x4gf9m2shpybb58mdkzrr-python3.11-piel-0.1.0/lib/python3.11/site-packages/piel/tools/amaranth/construct.py:112" *)
(* src = "/home/daquintero/phd/piel/piel/tools/amaranth/construct.py:112" *)
casez (input_fock_state_str)
/* src = "/nix/store/gi3w8lbvfd4x4gf9m2shpybb58mdkzrr-python3.11-piel-0.1.0/lib/python3.11/site-packages/piel/tools/amaranth/construct.py:115" */
/* src = "/home/daquintero/phd/piel/piel/tools/amaranth/construct.py:115" */
3'h4:
bit_phase_1 = 5'h00;
/* src = "/nix/store/gi3w8lbvfd4x4gf9m2shpybb58mdkzrr-python3.11-piel-0.1.0/lib/python3.11/site-packages/piel/tools/amaranth/construct.py:115" */
/* src = "/home/daquintero/phd/piel/piel/tools/amaranth/construct.py:115" */
3'h1:
bit_phase_1 = 5'h1f;
/* src = "/nix/store/gi3w8lbvfd4x4gf9m2shpybb58mdkzrr-python3.11-piel-0.1.0/lib/python3.11/site-packages/piel/tools/amaranth/construct.py:115" */
/* src = "/home/daquintero/phd/piel/piel/tools/amaranth/construct.py:115" */
3'h2:
bit_phase_1 = 5'h00;
/* src = "/nix/store/gi3w8lbvfd4x4gf9m2shpybb58mdkzrr-python3.11-piel-0.1.0/lib/python3.11/site-packages/piel/tools/amaranth/construct.py:124" */
/* src = "/home/daquintero/phd/piel/piel/tools/amaranth/construct.py:124" */
default:
bit_phase_1 = 5'h00;
endcase
Expand Down
1 change: 1 addition & 0 deletions piel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from piel.connectivity import * # NOQA: F403

import piel.analysis as analysis # NOQA: F401
import piel.conversion as conversion # NOQA: F401
import piel.base as base # NOQA: F401
import piel.experimental as experimental # NOQA: F401
import piel.file_system as file_system # NOQA: F401
Expand Down
4 changes: 2 additions & 2 deletions piel/types/type_conversion.py → piel/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
import jax.numpy as jnp
import numpy as np
import pandas as pd
from .core import ArrayTypes, PackageArrayType, TupleIntType
from .digital import AbstractBitsType, BitsType, LogicSignalsList
from piel.types.core import ArrayTypes, PackageArrayType, TupleIntType
from piel.types.digital import AbstractBitsType, BitsType, LogicSignalsList


def convert_array_type(array: ArrayTypes, output_type: PackageArrayType):
Expand Down
14 changes: 10 additions & 4 deletions piel/flows/digital_electro_optic.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
from collections import OrderedDict
import numpy as np
from typing import Optional, Callable
from ..types import (
from piel.types import (
BitPhaseMap,
BitsType,
PhaseMapType,
OpticalStateTransitions,
OpticalStateTransitionCollection,
TruthTable,
TruthTableLogicType,
)
from piel.conversion import (
convert_tuple_to_string,
convert_to_bits,
)
import logging

logger = logging.getLogger(__name__)


def add_truth_table_bit_to_phase_data(
Expand Down Expand Up @@ -124,7 +129,7 @@ def add_truth_table_phase_to_bit_data(


def convert_optical_transitions_to_truth_table(
optical_state_transitions: OpticalStateTransitions,
optical_state_transitions: OpticalStateTransitionCollection,
bit_phase_map=BitPhaseMap,
logic: TruthTableLogicType = "implementation",
) -> TruthTable:
Expand All @@ -137,7 +142,8 @@ def convert_optical_transitions_to_truth_table(
else:
raise ValueError(f"Invalid logic type: {logic}")

phase_bit_array_length = len(transitions_dataframe["phase"][0])
logger.debug(transitions_dataframe["phase"])
phase_bit_array_length = len(transitions_dataframe["phase"].iloc[0])
truth_table_raw = dict()

# Check if all input and output connection are in the dataframe
Expand Down
2 changes: 1 addition & 1 deletion piel/flows/digital_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
LogicSignalsList,
PathTypes,
TruthTable,
convert_dataframe_to_bits,
)
from piel.conversion import convert_dataframe_to_bits
from ..tools.cocotb import (
configure_cocotb_simulation,
run_cocotb_simulation,
Expand Down
49 changes: 31 additions & 18 deletions piel/flows/electro_optic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,21 @@
import logging
from itertools import product
from typing import Optional, Callable, Any
from ..types import (
absolute_to_threshold,
convert_array_type,
from piel.types import (
ArrayTypes,
PhotonicCircuitComponent,
FockStatePhaseTransitionType,
FockStatePhaseTransition,
NumericalTypes,
PhaseTransitionTypes,
OpticalTransmissionCircuit,
OpticalStateTransitions,
OpticalStateTransitionCollection,
SParameterCollection,
TupleIntType,
)
from piel.conversion import (
absolute_to_threshold,
convert_array_type,
)
from ..tools.sax.netlist import (
address_value_dictionary_to_function_parameter_dictionary,
get_matched_model_recursive_netlist_instances,
Expand Down Expand Up @@ -113,7 +115,7 @@ def calculate_all_transition_probability_amplitudes(
unitary_matrix: ArrayTypes,
input_fock_states: list[ArrayTypes],
output_fock_states: list[ArrayTypes],
) -> dict[int, FockStatePhaseTransitionType]:
) -> dict[int, FockStatePhaseTransition]:
"""
This tells us the transition probabilities between our photon states for a particular implemented unitary.
Expand All @@ -123,7 +125,7 @@ def calculate_all_transition_probability_amplitudes(
output_fock_states (list): The list of output Fock states.
Returns:
dict[int, FockStatePhaseTransitionType]: The dictionary of the Fock state phase transition type.
dict[int, FockStatePhaseTransition]: The dictionary of the Fock state phase transition type.
"""
i = 0
circuit_transition_probability_data_i = dict()
Expand Down Expand Up @@ -174,10 +176,11 @@ def calculate_classical_transition_probability_amplitudes(
for i, input_fock_state in enumerate(input_fock_states):
mode_transformation = jnp.dot(unitary_matrix, input_fock_state)
classical_transition_mode_probability = jnp.abs(
mode_transformation
mode_transformation,
) # Assuming probabilities are the squares of the amplitudes TODO recheck

if target_mode_index is not None:
logger.debug(classical_transition_mode_probability[target_mode_index])
if (
isinstance(
classical_transition_mode_probability[target_mode_index],
Expand Down Expand Up @@ -207,6 +210,8 @@ def calculate_classical_transition_probability_amplitudes(
)
pass

logger.debug(classical_transition_target_mode_probability)

data = {
"input_fock_state": input_fock_state,
"mode_transformation": mode_transformation,
Expand All @@ -224,7 +229,7 @@ def construct_unitary_transition_probability_performance(
unitary_phase_implementations_dictionary: dict,
input_fock_states: list,
output_fock_states: list,
) -> dict[int, dict[int, FockStatePhaseTransitionType]]:
) -> dict[int, dict[int, FockStatePhaseTransition]]:
"""
This function determines the Fock state probability performance for a given implemented unitary. This means we
iterate over each circuit, then each implemented unitary, and we determine the probability transformation
Expand Down Expand Up @@ -356,14 +361,14 @@ def compose_network_matrix_from_models(


def extract_phase_from_fock_state_transitions(
optical_state_transitions: OpticalStateTransitions,
optical_state_transitions: OpticalStateTransitionCollection,
transition_type: PhaseTransitionTypes = "cross",
):
"""
Extracts the phase corresponding to the specified transition type.
Parameters:
optical_state_transitions (OpticalStateTransitions): Optical state transitions.
optical_state_transitions (OpticalStateTransitionCollection): Optical state transitions.
transition_type (str): Type of transition to extract phase for ('cross' or 'bar').
Returns:
Expand Down Expand Up @@ -414,9 +419,9 @@ def format_electro_optic_fock_transition(
input_fock_state_array: ArrayTypes,
raw_output_state: ArrayTypes,
**kwargs,
) -> FockStatePhaseTransitionType:
) -> FockStatePhaseTransition:
"""
Formats the electro-optic state into a standard FockStatePhaseTransitionType format. This is useful for the
Formats the electro-optic state into a standard FockStatePhaseTransition format. This is useful for the
electro-optic model to ensure that the output state is in the correct format. The output state is a dictionary
that contains the phase, input fock state, and output fock state. The idea is that this will allow us to
standardise and compare the output states of the electro-optic model across multiple formats.
Expand All @@ -428,7 +433,7 @@ def format_electro_optic_fock_transition(
**kwargs: Additional keyword arguments.
Returns:
electro_optic_state(FockStatePhaseTransitionType): Electro-optic state.
electro_optic_state(FockStatePhaseTransition): Electro-optic state.
"""
electro_optic_state = {
"phase": convert_array_type(switch_state_array, "tuple"),
Expand All @@ -438,7 +443,7 @@ def format_electro_optic_fock_transition(
),
**kwargs,
}
# assert type(electro_optic_state) == FockStatePhaseTransitionType # TODO fix this
# assert isinstance(electro_optic_state, FockStatePhaseTransition) # TODO FIX ME
return electro_optic_state


Expand Down Expand Up @@ -524,7 +529,7 @@ def get_state_phase_transitions(
netlist_function: Optional[Callable] = None,
target_mode_index: Optional[int] = None,
**kwargs,
) -> OpticalStateTransitions:
) -> OpticalStateTransitionCollection:
"""
The goal of this function is to extract the corresponding phase required to implement a state transition.
Expand Down Expand Up @@ -590,6 +595,12 @@ def get_state_phase_transitions(
)

for id_i_i, _ in data_i.items():
logger.debug(data_i[id_i_i]["classical_transition_target_mode_probability"])
logger.debug(
jnp.round(
data_i[id_i_i]["classical_transition_target_mode_probability"]
)
)
output_state_i = format_electro_optic_fock_transition(
switch_state_array=extract_phase_tuple_from_phase_address_state(
circuit_phase_address_state[id_i]
Expand All @@ -599,7 +610,9 @@ def get_state_phase_transitions(
"classical_transition_mode_probability"
],
target_mode_output=int(
data_i[id_i_i]["classical_transition_target_mode_probability"]
jnp.round(
data_i[id_i_i]["classical_transition_target_mode_probability"]
)
)
if data_i[id_i_i]["classical_transition_target_mode_probability"]
is not None
Expand All @@ -612,7 +625,7 @@ def get_state_phase_transitions(
output_states.append(output_state_i)
id_i += 1

output_optical_state_transitions = OpticalStateTransitions(
output_optical_state_transitions = OpticalStateTransitionCollection(
mode_amount=mode_amount,
target_mode_index=target_mode_index,
transmission_data=output_states,
Expand Down
2 changes: 1 addition & 1 deletion piel/tools/qutip/fock.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
import jax.numpy as jnp
from typing import Optional, Literal, Any
from piel.types.type_conversion import convert_array_type
from piel.conversion import convert_array_type


def all_fock_states_from_photon_number(
Expand Down
Loading

0 comments on commit 6fe2cfe

Please sign in to comment.