Source code for twpasolver.modes_rwa

"""
``modes_rwa`` module.

Mode Management and Rotating Wave Approximation Analysis
========================================================

Tools for managing complex mode relationships in TWPA analysis, enabling
simulation of extended coupled mode equation (CME) systems with arbitrary
numbers of modes beyond basic pump-signal-idler configuration.

Key Components
--------------

**ModeArray**: Organizes related modes with automatic frequency propagation,
parameter interpolation, and dependency graph management.

**RWAAnalyzer**: Identifies valid 3-wave and 4-wave mixing terms that satisfy
the Rotating Wave Approximation for any given set of mode relationships.

**ModeArrayFactory**: Factory methods for creating standard and custom mode
configurations including pump harmonics, frequency conversion terms, and harmonics.

**ParameterInterpolator**: Interpolation of mode parameters (kappa, gamma, alpha)
from base TWPA circuit data across frequency ranges.

Capabilities
------------

**RWA Term Selection**:
    - Automatic discovery of valid 3WM and 4WM mixing processes
    - Coefficient calculation including factorial corrections for repeated modes
    - Caching of RWA terms for performance

**Mode Relationships**:
    - Support for arbitrary frequency relationships between modes
    - Forward and backward propagating modes
    - Pump harmonics (p, p2, p3, ...)
    - Frequency conversion terms (p+s, p+i, ...)
    - Signal/idler harmonics (s2, i2, ...)

**Symbolic Frequency Propagation**:
    - O(n) frequency updates using pre-computed symbolic expressions
    - Automatic identification of independent vs. dependent modes
    - Topological sorting of mode dependencies
    - Visualization of the mode relations with a graph

**Parameter Management**:
    - Easy interpolation of mode parameters from base circuit data
    - Automatic handling of mode directions for wavenumber calculation

Examples
--------
See Tutorial 4 (:ref:`tutorials`) for complete examples covering:

- Basic 3WM mode array creation and visualization
- Extended mode arrays with pump harmonics and frequency conversion
- Custom mode configurations for specific processes
- RWA term analysis and CME integration

The tutorial demonstrates the progression from simple mode relationships to
complex multi-mode systems used in realistic TWPA simulations.

Performance Notes
-----------------
- Symbolic frequency propagation enables O(n) scaling, useful when iterating over arrays.
- RWA term caching eliminates redundant calculations during CME solving
- Numba-compatible data structures for integration with fast CME solvers
- Efficient interpolation supports faster parameter sweeps
"""

import itertools as it
from collections import Counter, defaultdict, deque
from copy import deepcopy
from dataclasses import dataclass
from math import factorial
from typing import Any, Dict, List, Optional, Tuple, Union

import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
from scipy.interpolate import interp1d
from sympy import symbols

from twpasolver.logger import log


[docs]class ParameterInterpolator: """ Interpolates kappa, gamma, and alpha values based on frequency. Always returns real kappa and alpha, complex gamma. """ def __init__( self, frequencies: np.ndarray, kappas: np.ndarray, gammas: np.ndarray, alphas: np.ndarray, kind: str = "cubic", ): """ Initialize interpolators for kappa, gamma, and alpha. Args: frequencies: Array of frequency points kappas: Array of coupling coefficients (real) gammas: Array of reflection coefficients (complex) alphas: Array of attenuation coefficients (real) kind: Interpolation method ('linear', 'cubic', etc.) """ # Validate input arrays if not (len(frequencies) == len(kappas) == len(gammas) == len(alphas)): raise ValueError( "Frequencies, kappas, gammas, and alphas must have the same length" ) if len(frequencies) < 2: raise ValueError("Need at least 2 points for interpolation") self.orig_freqs = frequencies self.orig_kappas = kappas self.orig_gammas = gammas self.orig_alphas = alphas self.freq_min = np.min(frequencies) self.freq_max = np.max(frequencies) # For kappa interpolation (real only) self.kappa_interp = interp1d( frequencies, np.real(kappas), kind=kind, bounds_error=False, fill_value=(np.real(kappas[0]), np.real(kappas[-1])), ) # For gamma interpolation (complex) if np.iscomplexobj(gammas): self.gamma_real = interp1d( frequencies, np.real(gammas), kind=kind, bounds_error=False, fill_value=(gammas[0].real, gammas[-1].real), ) self.gamma_imag = interp1d( frequencies, np.imag(gammas), kind=kind, bounds_error=False, fill_value=(gammas[0].imag, gammas[-1].imag), ) else: self.gamma_real = interp1d( frequencies, gammas, kind=kind, bounds_error=False, fill_value=(gammas[0], gammas[-1]), ) self.gamma_imag = None # For alpha interpolation (real only) self.alpha_interp = interp1d( frequencies, np.real(alphas), kind=kind, bounds_error=False, fill_value=(np.real(alphas[0]), np.real(alphas[-1])), )
[docs] def get_parameters( self, frequency: Union[float, np.ndarray] ) -> Tuple[ Union[float, np.ndarray], Union[complex, np.ndarray], Union[float, np.ndarray] ]: """ Get interpolated kappa, gamma, and alpha for a given frequency or array of frequencies. Args: frequency: Frequency point(s) for interpolation, can be a single value or array Returns: Tuple of (kappa, gamma, alpha) at the requested frequency/frequencies kappa and alpha are real, gamma is complex """ # Check if we're processing a single value or array is_array = isinstance(frequency, np.ndarray) # Handle out-of-range warning if is_array: if np.any((frequency < self.freq_min) | (frequency > self.freq_max)): log.warn( f"Warning: Some frequencies are outside the interpolation range " f"[{self.freq_min}, {self.freq_max}]. Using endpoint values." ) else: if frequency < self.freq_min or frequency > self.freq_max: log.warn( f"Warning: Frequency {frequency} is outside the interpolation range " f"[{self.freq_min}, {self.freq_max}]. Using endpoint values." ) # Interpolate kappa (real only) kappa = self.kappa_interp(frequency) # Interpolate gamma (complex) if self.gamma_imag is not None: gamma_real = self.gamma_real(frequency) gamma_imag = self.gamma_imag(frequency) gamma = gamma_real + 1j * gamma_imag else: gamma = self.gamma_real(frequency) # Interpolate alpha (real only) alpha = self.alpha_interp(frequency) return kappa, gamma, alpha
[docs]@dataclass class Mode: """ Represents a single mode with its physical properties. Args: label: Mode identifier (e.g., 'p' for pump) frequency: Mode frequency direction: 1 for forward, -1 for backward gamma: Reflection coefficient k: Wavenumber (calculated from frequency) alpha: Attenuation constant """ label: str direction: int = 1 # 1 for forward, -1 for backward frequency: Optional[float] = None k: Optional[float] = None gamma: Optional[Union[float, complex]] = 0.0 alpha: Optional[float] = 0.0 def __post_init__(self): """Initialize derived quantities and validate inputs.""" # Ensure direction is ±1 if abs(self.direction) != 1: raise ValueError("Direction must be either 1 (forward) or -1 (backward)") if self.k is not None: self.k = self.k * self.direction def __eq__(self, other): """Compare modes based on their physical properties.""" if not isinstance(other, Mode): return False return ( self.frequency == other.frequency and self.direction == other.direction and self.gamma == other.gamma and self.k == other.k and self.alpha == other.alpha ) def __hash__(self): """Get hash based on immutable properties.""" return hash((self.label, str(self.frequency), self.direction)) def __str__(self): """Get string representation showing direction and label.""" direction_str = "→" if self.direction == 1 else "←" return f"{direction_str}{self.label}" def __repr__(self): """Get representation of class.""" return ( f'Mode("{self.label}", freq={self.frequency}, ' f"dir={self.direction}, gamma={self.gamma}, k={self.k}, alpha={self.alpha})" )
[docs]class RWAAnalyzer: """Analyzer class for a set of coupled modes.""" def __init__(self, modes: List[str], relations: List[List[str]]): """ Initialize the RWA analyzer with modes and their relations. Args: modes: List of mode names (e.g., ["p", "s", "i", "c", "d", "u", "c2"]) relations: List of relations [result, expression] where expression can be any combination of terms with + and - (e.g., ["c2", "c+c"]) """ self.modes = modes self.modes_ext_str = modes + [f"-{m}" for m in modes] self.modes_symbolic = [symbols(m) for m in modes] self.modes_extended = self.modes_symbolic + [-m for m in self.modes_symbolic] self.mode_to_idx = {mode: idx for idx, mode in enumerate(modes)} self._relations = relations self.relations_idx = self._convert_relations_to_indices() self.modes_subs = self._compute_substitutions() # Cache for RWA terms self._rwa_terms_cache: Dict[int, List[Tuple[Any, ...]]] = {} @property def relations(self): """Getter for mode relations.""" return self._relations
[docs] def update_relations(self, relations: List[List[str]]): """Update mode relations.""" self._relations = relations self.relations_idx = self._convert_relations_to_indices() self.modes_subs = self._compute_substitutions() # Clear cache when relations are updated self._rwa_terms_cache = {}
def _parse_expression(self, expr: str) -> tuple: """ Parse an expression string into a list of indices with proper signs. Args: expr: String expression like "p+s-i" Returns: Tuple of indices with signs indicating conjugation """ expr = expr.replace(" ", "") if not expr.startswith("-"): expr = "+" + expr indices = [] signs = [] current_term = "" sign = 1 for char in expr: if char in ["+", "-"]: if current_term: idx = self.mode_to_idx[current_term] indices.append(idx) signs.append(sign) current_term = "" sign = 1 if char == "+" else -1 else: current_term += char if current_term: idx = self.mode_to_idx[current_term] indices.append(idx) signs.append(sign) return tuple(indices), tuple(signs) def _convert_relations_to_indices(self) -> List[Tuple[int, tuple, tuple]]: """Convert string relations to index-based relations.""" index_relations = [] for result, expr in self.relations: result_idx = self.mode_to_idx[result] input_indices, input_signs = self._parse_expression(expr) index_relations.append((result_idx, input_indices, input_signs)) return index_relations def _compute_substitutions(self) -> List: """Compute all mode substitutions based on the relations.""" apply = True modes_subs = deepcopy(self.modes_symbolic) while apply: modes_subs_old = deepcopy(modes_subs) for output, input_idxs, input_signs in self.relations_idx: subs_rel = sum( [ input_signs[i] * modes_subs[abs(rel_idx)] for i, rel_idx in enumerate(input_idxs) ] ) modes_subs[output] = modes_subs[output].subs( modes_subs[output], subs_rel ) if all([m == modes_subs_old[i] for i, m in enumerate(modes_subs)]): apply = False return modes_subs def _calculate_coefficient(self, terms: List[str]) -> float: """Calculate coefficient based on term repetitions.""" coeff = 1.0 if len(set(terms)) != len(terms): for repetitions in Counter(terms).values(): coeff = coeff / factorial(repetitions) return coeff
[docs] def find_rwa_terms(self, power: int = 3) -> List[Tuple[Any, ...]]: """ Find all valid RWA terms of given power. Args: power: Order of the interaction (default: 3 for three-wave mixing) Returns: List of tuples (mode_idx, combination, mode_name, rhs_terms, coefficient) """ # Check if cached result exists if power in self._rwa_terms_cache: return self._rwa_terms_cache[power] modes_subs_extended = self.modes_subs + [-m for m in self.modes_subs] rwa_terms = [] for comb in it.combinations_with_replacement( range(len(self.modes_extended)), power ): terms = [modes_subs_extended[i] for i in comb] sum_terms = sum(terms) for j, mode in enumerate(self.modes_subs): if sum_terms == mode: rhs = [self.modes_ext_str[k] for k in comb] coeff = self._calculate_coefficient(rhs) rwa_terms.append((j, comb, self.modes[j], rhs, coeff)) result = sorted(rwa_terms, key=lambda x: x[0]) # Cache the result self._rwa_terms_cache[power] = result return result
[docs] def print_rwa_terms(self, terms: List[Tuple[Any, ...]]) -> None: """Pretty print the RWA terms with their coefficients.""" for term in terms: mode_name = term[2] rhs_terms = term[3] coeff = term[4] print(f"{mode_name} = {rhs_terms} {coeff}") print(f"\nTotal matches: {len(terms)}")
[docs]class ModeArray: """Class representing a list of modes and frequency relations between them.""" def __init__( self, modes: List[Mode], relations: List[List[str]], interpolator: ParameterInterpolator, ): """ Initialize mode array with modes, relations and interpolator. Args: modes: List of Mode objects relations: List of [result, expression] pairs for mode relationships interpolator: ParameterInterpolator instance for getting mode parameters """ self.modes = {mode.label: mode for mode in modes} self.relations = relations self.interpolator = interpolator self.analyzer = RWAAnalyzer(list(self.modes.keys()), relations) # Build symbolic expressions for efficient frequency propagation self._build_symbolic_expressions() # Cache for computed mixing coefficients self._rwa_terms_3wm: Optional[List[Tuple[Any, ...]]] = None self._rwa_terms_4wm: Optional[List[Tuple[Any, ...]]] = None # Validate initial state self._validate_modes() def _parse_expression_for_propagation(self, expr: str) -> List[Tuple[str, int]]: """Parse expression into list of (mode, coefficient) pairs for frequency propagation.""" expr = expr.replace(" ", "") if not expr.startswith(("+", "-")): expr = "+" + expr terms = [] current_term = "" sign = 1 for char in expr: if char in ["+", "-"]: if current_term: terms.append((current_term, sign)) current_term = "" sign = 1 if char == "+" else -1 else: current_term += char if current_term: terms.append((current_term, sign)) return terms def _build_dependency_graph(self): """Build dependency graph for relation analysis.""" self.dependency_graph = defaultdict(set) self.reverse_deps = defaultdict(set) for result, expr in self.relations: terms = self._parse_expression_for_propagation(expr) for mode, _ in terms: self.dependency_graph[mode].add(result) self.reverse_deps[result].add(mode) def _topological_sort(self) -> List[str]: """Perform topological sort of dependencies.""" in_degree = defaultdict(int) # Calculate in-degrees for mode in self.modes.keys(): in_degree[mode] = len(self.reverse_deps[mode]) # Start with modes that have no dependencies queue = deque([mode for mode in self.modes.keys() if in_degree[mode] == 0]) result = [] while queue: current = queue.popleft() result.append(current) # Update in-degrees of dependent modes for dependent in self.dependency_graph[current]: in_degree[dependent] -= 1 if in_degree[dependent] == 0: queue.append(dependent) return result def _build_symbolic_expressions(self): """ Build symbolic expressions for each mode in terms of independent modes. This enables O(n) frequency propagation instead of iterative solving. """ # Build dependency graph self._build_dependency_graph() # Find independent modes (those not defined by relations) self.independent_modes = set(self.modes.keys()) for result, _ in self.relations: self.independent_modes.discard(result) # Initialize symbolic expressions self.symbolic_expressions = {} # Independent modes have simple expressions for mode in self.independent_modes: self.symbolic_expressions[mode] = {mode: 1.0} # Build expressions for dependent modes using topological order topo_order = self._topological_sort() for mode in topo_order: if mode in self.symbolic_expressions: continue # Already processed (independent mode) # Find the relation that defines this mode for result, expr in self.relations: if result == mode: terms = self._parse_expression_for_propagation(expr) self.symbolic_expressions[mode] = {} for dep_mode, coeff in terms: if dep_mode in self.symbolic_expressions: # Add contribution from dependency for base_mode, base_coeff in self.symbolic_expressions[ dep_mode ].items(): if base_mode in self.symbolic_expressions[mode]: self.symbolic_expressions[mode][base_mode] += ( coeff * base_coeff ) else: self.symbolic_expressions[mode][base_mode] = ( coeff * base_coeff ) else: # This should not happen with proper topological sort log.warning( f"Dependency {dep_mode} not found when processing {mode}" ) break log.info( f"Built symbolic expressions for {len(self.symbolic_expressions)} modes" ) log.info(f"Independent modes: {self.independent_modes}") def _validate_modes(self): """Ensure all modes referenced in relations exist.""" for result, expr in self.relations: if result not in self.modes: raise ValueError(f"Mode {result} in relations not found") # Check all modes in expression exist expr_modes = { m.strip("+-") for m in expr.replace(" ", "").replace("(", "").replace(")", "") if m.strip("+-") } for mode in expr_modes: if mode not in self.modes: raise ValueError(f"Mode {mode} in relations not found") def _propagate_frequencies_symbolic( self, updated_freqs: Dict[str, float] ) -> Dict[str, Optional[float]]: """ Symbolic frequency propagation - O(n) complexity. Args: updated_freqs: Dictionary of mode labels to their new frequencies Returns: Dictionary of all mode labels to their computed frequencies """ # Start with current frequencies frequencies: Dict[str, Optional[float]] = { label: mode.frequency for label, mode in self.modes.items() } frequencies.update(updated_freqs) # Evaluate each mode using its symbolic expression for mode, expression in self.symbolic_expressions.items(): if mode in updated_freqs: continue # Already set by user # Check if all base modes are available all_available = True new_freq = 0.0 for base_mode, coeff in expression.items(): if frequencies.get(base_mode) is None: all_available = False break new_freq += coeff * frequencies[base_mode] if all_available: frequencies[mode] = new_freq return frequencies
[docs] def update_frequencies(self, new_frequencies: Dict[str, float]) -> None: """ Update mode frequencies using optimized symbolic propagation. Args: new_frequencies: Dictionary mapping mode labels to new frequencies """ # Validate input modes exist unknown_modes = set(new_frequencies.keys()) - set(self.modes.keys()) if unknown_modes: raise ValueError(f"Unknown modes in frequency update: {unknown_modes}") # Use symbolic propagation (much faster than iterative) all_frequencies = self._propagate_frequencies_symbolic(new_frequencies) # Update modes with new frequencies and interpolated parameters for label, freq in all_frequencies.items(): if freq is not None: mode = self.modes[label] mode.frequency = freq # Get interpolated parameters (always returns kappa, gamma, alpha) kappa, gamma, alpha = self.interpolator.get_parameters(abs(freq)) mode.alpha = alpha # type: ignore mode.gamma = gamma # type: ignore mode.k = kappa * mode.direction # type: ignore
[docs] def update_base_data(self, interpolator: ParameterInterpolator) -> None: """ Update the base data interpolator and refresh all mode parameters. Args: interpolator: New ParameterInterpolator instance """ self.interpolator = interpolator # Refresh all mode parameters with current frequencies current_freqs = { label: mode.frequency for label, mode in self.modes.items() if mode.frequency is not None } if current_freqs: self.update_frequencies(current_freqs)
[docs] def process_frequency_array( self, mode_label: str, frequencies: np.ndarray ) -> Dict[str, Dict[str, np.ndarray]]: """ Process an array of frequencies for a single mode, propagating to all related modes. Uses symbolic expressions for optimal performance. Args: mode_label: Label of the mode to update with array of frequencies frequencies: Array of frequencies to process Returns: Dictionary with mode parameters for all related modes """ # Validate mode exists if mode_label not in self.modes: raise ValueError(f"Unknown mode: {mode_label}") # Initialize result containers mode_params: Dict[str, Dict[str, np.ndarray]] = {} # For each mode, compute its frequency array using symbolic expressions for mode, expression in self.symbolic_expressions.items(): if mode_label not in expression: indep_mode = self.modes[mode] mode_params[mode] = { "freqs": np.full(len(frequencies), indep_mode.frequency), "k": np.full(len(frequencies), indep_mode.k), "gamma": np.full(len(frequencies), indep_mode.gamma), "alpha": np.full(len(frequencies), indep_mode.alpha), } continue # Compute frequency array for this mode mode_freqs = frequencies * expression[mode_label] # Add contributions from other independent modes if they have values for base_mode, coeff in expression.items(): if ( base_mode != mode_label and self.modes[base_mode].frequency is not None ): mode_freqs += coeff * self.modes[base_mode].frequency # Get direction for this mode direction = self.modes[mode].direction # Get parameters for all frequencies of this mode (always returns kappa, gamma, alpha) kappas, gammas, alphas = self.interpolator.get_parameters( np.abs(mode_freqs) ) # Apply direction to kappas if isinstance(kappas, np.ndarray): kappas = kappas * direction else: kappas = np.array([kappas * direction]) # Store parameters mode_params[mode] = { "freqs": mode_freqs, "k": kappas, "gamma": gammas, # type: ignore "alpha": alphas, # type: ignore } return mode_params
[docs] def plot_mode_relations( self, figsize: Tuple[float, float] = (12, 8), node_size: int = 3000, font_size: int = 12, show_frequencies: bool = False, show_directions: bool = True, layout: str = "hierarchical", ) -> None: """ Plot the relationships between modes as a directed graph. Args: figsize: Figure size (width, height) node_size: Size of mode nodes font_size: Font size for labels show_frequencies: Whether to show current frequencies on nodes show_directions: Whether to show mode propagation directions layout: Graph layout algorithm ('spring', 'circular', 'hierarchical') """ # Create directed graph G = nx.DiGraph() # Add nodes for all modes for mode_label, mode in self.modes.items(): node_attrs = { "frequency": mode.frequency, "direction": mode.direction, "is_independent": mode_label in self.independent_modes, } G.add_node(mode_label, **node_attrs) # Add edges based on relations edge_labels: dict[Tuple[str, str], str] = {} for result, expr in self.relations: terms = self._parse_expression_for_propagation(expr) for dep_mode, coeff in terms: if coeff > 0: edge_color = "blue" edge_style = "-" else: edge_color = "red" edge_style = "--" G.add_edge( dep_mode, result, weight=abs(coeff), color=edge_color, style=edge_style, ) # Create edge label coeff_str = f"+{coeff}" if coeff > 0 else str(coeff) if (dep_mode, result) in edge_labels: edge_labels[(dep_mode, result)] += f", {coeff_str}" else: edge_labels[(dep_mode, result)] = coeff_str # Create figure fig, (ax1, ax2) = plt.subplots( 1, 2, figsize=figsize, gridspec_kw={"width_ratios": [3, 1]} ) # Choose layout if layout == "hierarchical": pos = self._hierarchical_layout(G) elif layout == "circular": pos = nx.circular_layout(G) else: # spring layout pos = nx.spring_layout(G, k=2, iterations=50) # Node colors based on type node_colors = [] for node in G.nodes(): if G.nodes[node]["is_independent"]: node_colors.append("lightgreen") # Independent modes else: node_colors.append("lightblue") # Dependent modes # Draw nodes nx.draw_networkx_nodes( G, pos, node_color=node_colors, node_size=node_size, ax=ax1 ) # Draw edges with different styles edge_colors = [G.edges[edge]["color"] for edge in G.edges()] edge_styles = [G.edges[edge]["style"] for edge in G.edges()] # Draw positive and negative edges separately for different styles pos_edges = [(u, v) for u, v, d in G.edges(data=True) if d["color"] == "blue"] neg_edges = [(u, v) for u, v, d in G.edges(data=True) if d["color"] == "red"] if pos_edges: nx.draw_networkx_edges( G, pos, edgelist=pos_edges, edge_color="blue", style="-", arrows=True, arrowsize=20, ax=ax1, ) if neg_edges: nx.draw_networkx_edges( G, pos, edgelist=neg_edges, edge_color="red", style="--", arrows=True, arrowsize=20, ax=ax1, ) # Node labels if show_frequencies and show_directions: labels = {} for mode_label, mode in self.modes.items(): direction_str = "→" if mode.direction == 1 else "←" freq_str = f"{mode.frequency:.1f}" if mode.frequency else "None" labels[mode_label] = f"{mode_label}\n{direction_str}\n({freq_str})" elif show_frequencies: labels = {} for mode_label, mode in self.modes.items(): freq_str = f"{mode.frequency:.1f}" if mode.frequency else "None" labels[mode_label] = f"{mode_label}\n({freq_str})" elif show_directions: labels = {} for mode_label, mode in self.modes.items(): direction_str = "→" if mode.direction == 1 else "←" labels[mode_label] = f"{mode_label}\n{direction_str}" else: labels = {mode: mode for mode in G.nodes()} nx.draw_networkx_labels(G, pos, labels, font_size=font_size, ax=ax1) # Edge labels (coefficients) nx.draw_networkx_edge_labels( G, pos, edge_labels, font_size=font_size - 2, ax=ax1 ) ax1.set_title("Mode Relationships", fontsize=font_size + 2, fontweight="bold") ax1.axis("off") # Legend legend_elements = [ mpatches.Patch(color="lightgreen", label="Independent modes"), mpatches.Patch(color="lightblue", label="Dependent modes"), plt.Line2D([0], [0], color="blue", lw=2, label="Positive contribution"), plt.Line2D( [0], [0], color="red", lw=2, linestyle="--", label="Negative contribution", ), ] ax1.legend(handles=legend_elements, loc="upper right") # Relations text in second subplot ax2.axis("off") ax2.set_title("Relations", fontsize=font_size + 1, fontweight="bold") relations_text = "Symbolic Expressions:\n\n" for mode, expression in self.symbolic_expressions.items(): expr_parts = [] for base_mode, coeff in expression.items(): if coeff == 1: expr_parts.append(base_mode) elif coeff == -1: expr_parts.append(f"-{base_mode}") else: expr_parts.append(f"{coeff}*{base_mode}") expr_str = " + ".join(expr_parts).replace(" + -", " - ") relations_text += f"{mode} = {expr_str}\n" relations_text += f"\nIndependent: {', '.join(sorted(self.independent_modes))}" ax2.text( 0.05, 0.95, relations_text, transform=ax2.transAxes, fontsize=font_size - 1, verticalalignment="top", bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8), ) plt.tight_layout() plt.show()
def _hierarchical_layout(self, G: nx.DiGraph) -> Dict[str, Tuple[float, float]]: """Create hierarchical layout with independent modes at bottom.""" levels = {} # Independent modes at level 0 for mode in self.independent_modes: levels[mode] = 0 # Assign levels based on dependencies max_level = 0 for mode in nx.topological_sort(G): if mode not in levels: # Find maximum level of predecessors pred_levels = [levels.get(pred, 0) for pred in G.predecessors(mode)] levels[mode] = max(pred_levels, default=0) + 1 max_level = max(max_level, levels[mode]) # Create positions level_counts: dict[int, int] = defaultdict(int) level_positions: dict[int, int] = defaultdict(int) # Count nodes per level for level in levels.values(): level_counts[level] += 1 pos: dict[str, tuple[float, float]] = {} for mode, level in levels.items(): x = level_positions[level] - (level_counts[level] - 1) / 2 y = max_level - level # Flip so independent modes are at bottom pos[mode] = (x, y) level_positions[level] += 1 return pos
[docs] def get_mode(self, label: str) -> Mode: """Get mode by label.""" return self.modes[label]
[docs] def get_rwa_terms(self, power: int = 3) -> List[Tuple[Any, ...]]: """ Get RWA terms for the specified mixing order with caching. Args: power: Order of the interaction (2 for 3WM, 3 for 4WM) Returns: List of RWA terms """ if power == 2: if self._rwa_terms_3wm is None: self._rwa_terms_3wm = self.analyzer.find_rwa_terms(power) return self._rwa_terms_3wm elif power == 3: if self._rwa_terms_4wm is None: self._rwa_terms_4wm = self.analyzer.find_rwa_terms(power) return self._rwa_terms_4wm else: # For other powers, go directly to the analyzer without caching return self.analyzer.find_rwa_terms(power)
[docs] def print_modes(self): """Print current state of all modes.""" for label, mode in self.modes.items(): independence = " (independent)" if label in self.independent_modes else "" print(f"{label}{independence}:") print(f" Frequency: {mode.frequency}") print(f" Direction: {mode.direction}") print(f" k: {mode.k}") print(f" gamma: {mode.gamma}") print(f" alpha: {mode.alpha}") print()
[docs] def print_symbolic_expressions(self): """Print the symbolic expressions for all modes.""" print("Symbolic Expressions:") print("=" * 50) for mode, expression in self.symbolic_expressions.items(): expr_parts = [] for base_mode, coeff in expression.items(): if coeff == 1: expr_parts.append(base_mode) elif coeff == -1: expr_parts.append(f"-{base_mode}") else: expr_parts.append(f"{coeff}*{base_mode}") expr_str = " + ".join(expr_parts).replace(" + -", " - ") independence = " (independent)" if mode in self.independent_modes else "" print(f"{mode}{independence} = {expr_str}")
[docs]class ModeArrayFactory: """Factory for creating standard ModeArray configurations."""
[docs] @staticmethod def create_basic_3wm( base_data: Dict[str, Any], forward_modes: bool = True, ) -> ModeArray: """ Create a basic 3WM ModeArray with pump, signal, and idler modes. Args: base_data: Dictionary containing 'freqs', 'k', 'gammas', and 'alpha' arrays forward_modes: Whether to create forward (True) or backward (False) propagating modes Returns: ModeArray: Configured for basic 3WM operation """ # Extract required arrays from base_data freqs = base_data["freqs"] kappas = base_data["k"] gammas = base_data["gammas"] alphas = base_data["alpha"] # Create interpolator interpolator = ParameterInterpolator(freqs, kappas, gammas, alphas) direction = 1 if forward_modes else -1 modes = [ Mode(label="p", direction=direction), Mode(label="s", direction=direction), Mode(label="i", direction=direction), ] relations = [["i", "p-s"]] # Idler is pump minus signal return ModeArray(modes, relations, interpolator)
[docs] @staticmethod def create_extended_3wm( base_data: Dict[str, Any], n_pump_harmonics: int = 1, n_frequency_conversion: int = 1, n_signal_harmonics: int = 1, n_sidebands: int = 1, forward_modes: bool = True, ) -> ModeArray: """ Create an extended 3WM ModeArray with pump harmonics and conversion terms. Args: base_data: Dictionary containing 'freqs', 'k', 'gammas', and 'alpha' arrays n_pump_harmonics: Number of pump harmonics to include n_frequency_conversion: Number of frequency conversion terms n_signal_harmonics: Number of signal and idler harmonics forward_modes: Whether to create forward (True) or backward (False) propagating modes Returns: ModeArray: Configured for extended 3WM operation with harmonics """ # Extract required arrays from base_data freqs = base_data["freqs"] kappas = base_data["k"] gammas = base_data["gammas"] alphas = base_data["alpha"] # Create interpolator interpolator = ParameterInterpolator(freqs, kappas, gammas, alphas) direction = 1 if forward_modes else -1 # Create basic modes modes = [ Mode(label="p", direction=direction), Mode(label="s", direction=direction), Mode(label="i", direction=direction), ] # Basic relation relations = [["i", "p-s"]] # Idler is pump minus signal for n in range(1, n_frequency_conversion + 1): if n == 1: modes.append(Mode(label="ps", direction=direction)) # p+s modes.append(Mode(label="pi", direction=direction)) # p+i relations.append(["ps", "p+s"]) relations.append(["pi", "p+i"]) else: modes.append(Mode(label=f"p{n}s", direction=direction)) # p+s modes.append(Mode(label=f"p{n}i", direction=direction)) # p+i relations.append([f"p{n}s", "p+" * (n - 1) + "p+s"]) relations.append([f"p{n}i", "p+" * (n - 1) + "p+i"]) for n in range(2, n_pump_harmonics + 2): modes.append(Mode(label=f"p{n}", direction=direction)) relations.append([f"p{n}", "p+" * (n - 1) + "p"]) for n in range(2, n_signal_harmonics + 2): modes.append(Mode(label=f"s{n}", direction=direction)) relations.append([f"s{n}", "s+" * (n - 1) + "s"]) modes.append(Mode(label=f"i{n}", direction=direction)) relations.append([f"i{n}", "i+" * (n - 1) + "i"]) # for n in range(2, n_sidebands + 2): # modes.append(Mode(label=f"s{n}p", direction=direction)) # modes.append(Mode(label=f"i{n}p", direction=direction)) # relations.append([f"s{n}p", "s+s-p"]) # relations.append([f"i{n}p", "i+i-p"]) return ModeArray(modes, relations, interpolator)
[docs] @staticmethod def create_custom( base_data: Dict[str, Any], mode_labels: List[str], mode_directions: List[int], relations: List[List[str]], ) -> ModeArray: """ Create a custom ModeArray with user-defined modes and relations. Args: base_data: Dictionary containing 'freqs', 'k', 'gammas', and 'alpha' arrays mode_labels: List of mode labels mode_directions: List of mode directions (1 for forward, -1 for backward) relations: List of relations between modes Returns: ModeArray: Custom configured ModeArray """ if len(mode_labels) != len(mode_directions): raise ValueError( "mode_labels and mode_directions must have the same length" ) # Extract required arrays from base_data freqs = base_data["freqs"] kappas = base_data["k"] gammas = base_data["gammas"] alphas = base_data["alpha"] # Create interpolator interpolator = ParameterInterpolator(freqs, kappas, gammas, alphas) modes = [ Mode(label=label, direction=direction) for label, direction in zip(mode_labels, mode_directions) ] return ModeArray(modes, relations, interpolator)