# SPDX-License-Identifier: Apache-2.0
# Copyright 2019 Blue Cheetah Analog Design Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This module handles high level simulation routines.
This module defines SimAccess, which provides methods to run simulations
and retrieve results.
"""
from __future__ import annotations
from typing import (
Tuple, Union, Iterable, List, Dict, Any, Optional, TypeVar, Type, Sequence, ItemsView, Mapping
)
import math
from enum import Enum
from dataclasses import dataclass
import numpy as np
from pybag.enum import DesignOutput
from pybag.core import convert_cdba_name_bit
from ..util.immutable import ImmutableList, ImmutableSortedDict
###############################################################################
# Sweep specifications
###############################################################################
[docs]class SweepSpecType(Enum):
[docs]class SweepInfoType(Enum):
@dataclass(eq=True, frozen=True)
[docs]class SweepList:
[docs] values: ImmutableList[float]
[docs] def __len__(self) -> int:
return len(self.values)
@property
[docs] def start(self) -> float:
return self.values[0]
@dataclass(eq=True, frozen=True)
[docs]class SweepLinear:
"""stop is inclusive"""
[docs] def __len__(self) -> int:
return self.num
@property
[docs] def step(self) -> float:
den = self.num - 1 if self.endpoint else self.num
return (self.stop - self.start) / den
@property
[docs] def stop_inc(self) -> float:
return self.stop if self.endpoint else self.start + (self.num - 1) * self.step
@property
[docs] def values(self) -> np.ndarray:
return np.linspace(self.start, self.stop, self.num, self.endpoint)
@dataclass(eq=True, frozen=True)
[docs]class SweepLog:
"""stop is inclusive"""
[docs] def __len__(self) -> int:
return self.num
@property
[docs] def start_log(self) -> float:
return math.log10(self.start)
@property
[docs] def stop_log(self) -> float:
return math.log10(self.stop)
@property
[docs] def step_log(self) -> float:
den = self.num - 1 if self.endpoint else self.num
return (self.stop_log - self.start_log) / den
@property
[docs] def stop_inc(self) -> float:
if self.endpoint:
return self.stop
return 10.0 ** (self.start_log + (self.num - 1) * self.step_log)
[docs]SweepSpec = Union[SweepLinear, SweepLog, SweepList]
[docs]def swp_spec_from_dict(table: Mapping[str, Any]) -> SweepSpec:
swp_type = SweepSpecType[table['type']]
if swp_type is SweepSpecType.LIST:
return SweepList(ImmutableList(table['values']))
elif swp_type is SweepSpecType.LINEAR:
return SweepLinear(table['start'], table['stop'], table['num'], table.get('endpoint', True))
elif swp_type is SweepSpecType.LOG:
return SweepLog(table['start'], table['stop'], table['num'], table.get('endpoint', True))
else:
raise ValueError(f'Unsupported sweep type: {swp_type}')
@dataclass(eq=True, frozen=True)
[docs]class MDSweepInfo:
[docs] params: ImmutableList[Tuple[str, SweepSpec]]
@property
[docs] def ndim(self) -> int:
return len(self.params)
@property
[docs] def stype(self) -> SweepInfoType:
return SweepInfoType.MD
@property
[docs] def shape(self) -> Tuple[int, ...]:
return tuple((len(val[1]) for val in self.params))
[docs] def __contains__(self, item: str) -> bool:
for name, _ in self.params:
if name == item:
return True
return False
[docs] def __iter__(self) -> Iterable[str]:
return (item[0] for item in self.params)
[docs] def default_items(self) -> Iterable[Tuple[str, float]]:
for name, spec in self.params:
yield name, spec.start
@dataclass(eq=True, frozen=True)
[docs]class SetSweepInfo:
[docs] params: ImmutableList[str]
[docs] values: ImmutableList[ImmutableList[float]]
@property
[docs] def stype(self) -> SweepInfoType:
return SweepInfoType.SET
@property
[docs] def shape(self) -> Tuple[int, ...]:
# NOTE: one-element tuple, not typo
return len(self.values),
[docs] def __contains__(self, item: str) -> bool:
return item in self.params
[docs] def __iter__(self) -> Iterable[str]:
return self.params
[docs] def default_items(self) -> Iterable[Tuple[str, float]]:
for idx, name in enumerate(self.params):
yield name, self.values[0][idx]
[docs]SweepInfo = Union[MDSweepInfo, SetSweepInfo]
[docs]def swp_info_from_struct(table: Union[Sequence[Tuple[str, Mapping[str, Any]]], Mapping[str, Any]]
) -> SweepInfo:
if isinstance(table, dict) or isinstance(table, ImmutableSortedDict):
params = ImmutableList(table['params'])
values = []
num_par = len(params)
for combo in table['values']:
if len(combo) != num_par:
raise ValueError('Invalid param set values.')
values.append(ImmutableList(combo))
return SetSweepInfo(params, ImmutableList(values))
else:
par_list = [(par, swp_spec_from_dict(spec)) for par, spec in table]
return MDSweepInfo(ImmutableList(par_list))
###############################################################################
# Analyses
###############################################################################
[docs]class AnalysisType(Enum):
[docs]T = TypeVar('T', bound='AnalysisSweep1D')
@dataclass(eq=True, frozen=True)
[docs]class AnalysisSweep1D:
[docs] sweep: Optional[SweepSpec]
[docs] options: ImmutableSortedDict[str, str]
[docs] save_outputs: ImmutableList[str]
@classmethod
[docs] def from_dict(cls: Type[T], table: Dict[str, Any], def_param: str = '') -> T:
param = table.get('param', def_param)
sweep = table.get('sweep', None)
opt = table.get('options', {})
out = table.get('save_outputs', [])
if not param or sweep is None:
param = ''
swp = None
else:
swp = swp_spec_from_dict(sweep)
return cls(param, swp, ImmutableSortedDict(opt), ImmutableList(out))
@property
[docs] def param_start(self) -> float:
if self.param:
return self.sweep.start
return 0.0
@dataclass(eq=True, frozen=True)
[docs]class AnalysisDC(AnalysisSweep1D):
@property
[docs] def name(self) -> str:
return 'dc'
@dataclass(eq=True, frozen=True)
[docs]class AnalysisAC(AnalysisSweep1D):
@property
[docs] def name(self) -> str:
return 'ac'
@classmethod
[docs] def from_dict(cls: Type[T], table: Dict[str, Any], def_param: str = '') -> T:
base = AnalysisSweep1D.from_dict(table, def_param='freq')
if base.param != 'freq':
freq_val = table['freq']
else:
freq_val = 0.0
return cls(base.param, base.sweep, base.options, base.save_outputs, freq_val)
@dataclass(eq=True, frozen=True)
[docs]class AnalysisSP(AnalysisAC):
[docs] ports: ImmutableList[str]
@property
[docs] def name(self) -> str:
return 'sp'
@dataclass(eq=True, frozen=True)
[docs]class AnalysisNoise(AnalysisAC):
@property
[docs] def name(self) -> str:
return 'noise'
@dataclass(eq=True, frozen=True)
[docs]class AnalysisTran:
[docs] options: ImmutableSortedDict[str, str]
[docs] save_outputs: ImmutableList[str]
@property
[docs] def param(self) -> str:
return ''
@property
[docs] def param_start(self) -> float:
return 0.0
@property
[docs] def name(self) -> str:
return 'tran'
@dataclass(eq=True, frozen=True)
[docs]class AnalysisPSS:
[docs] options: ImmutableSortedDict[str, str]
[docs] save_outputs: ImmutableList[str]
@property
[docs] def param(self) -> str:
return ''
@property
[docs] def name(self) -> str:
return 'pss'
@dataclass(eq=True, frozen=True)
[docs]class AnalysisPAC(AnalysisAC):
@property
[docs] def name(self) -> str:
return 'pac'
@dataclass(eq=True, frozen=True)
[docs]class AnalysisPNoise(AnalysisNoise):
[docs] measurement: Optional[ImmutableList[JitterEvent]] = None
@property
[docs] def name(self) -> str:
return 'pnoise'
@dataclass(eq=True, frozen=True)
[docs]AnalysisInfo = Union[AnalysisDC, AnalysisAC, AnalysisSP, AnalysisNoise, AnalysisTran,
AnalysisPSS, AnalysisPAC, AnalysisPNoise]
[docs]def analysis_from_dict(table: Dict[str, Any]) -> AnalysisInfo:
ana_type = AnalysisType[table['type']]
if ana_type is AnalysisType.DC:
return AnalysisDC.from_dict(table)
elif ana_type is AnalysisType.AC:
return AnalysisAC.from_dict(table)
elif ana_type is AnalysisType.SP:
base = AnalysisAC.from_dict(table)
return AnalysisSP(base.param, base.sweep, base.options, base.save_outputs, base.freq,
ImmutableList(table['ports']), SPType[table['param_type']])
elif ana_type is AnalysisType.NOISE:
base = AnalysisAC.from_dict(table)
return AnalysisNoise(base.param, base.sweep, base.options, base.save_outputs, base.freq,
table.get('p_port', ''), table.get('n_port', ''),
table.get('out_probe', ''), table.get('in_probe', ''))
elif ana_type is AnalysisType.TRAN:
return AnalysisTran(table.get('start', 0.0), table['stop'], table.get('strobe', 0.0),
table.get('out_start', -1.0),
ImmutableSortedDict(table.get('options', {})),
ImmutableList(table.get('save_outputs', [])))
elif ana_type is AnalysisType.PSS:
return AnalysisPSS(table.get('p_port', ''), table.get('n_port', ''),
table.get('period', 0.0), table.get('fund', 0.0),
table.get('autofund', False), table.get('strobe', 0.0),
ImmutableSortedDict(table.get('options', {})),
ImmutableList(table.get('save_outputs', [])))
elif ana_type is AnalysisType.PAC:
base = AnalysisAC.from_dict(table)
return AnalysisPAC(base.param, base.sweep, base.options, base.save_outputs, base.freq)
elif ana_type is AnalysisType.PNOISE:
base = AnalysisAC.from_dict(table)
pnoise_meas = table.get('measurement', None)
if pnoise_meas:
pnoise_meas = ImmutableList([JitterEvent(**_dict) for _dict in pnoise_meas])
return AnalysisPNoise(base.param, base.sweep, base.options, base.save_outputs, base.freq,
table.get('p_port', ''), table.get('n_port', ''),
table.get('out_probe', ''), table.get('in_probe', ''), pnoise_meas)
else:
raise ValueError(f'Unknown analysis type: {ana_type}')
###############################################################################
# Simulation Netlist Info
###############################################################################
@dataclass(eq=True, frozen=True)
[docs]class MonteCarlo:
[docs] options: ImmutableSortedDict[str, Any]
@property
[docs] def name(self) -> str:
return 'mc'
[docs]def monte_carlo_from_dict(mc_dict: Optional[Dict[str, Any]]) -> Optional[MonteCarlo]:
if not mc_dict:
return None
numruns: int = mc_dict['numruns']
seed: int = mc_dict['seed']
options: Dict[str, Any] = mc_dict.get('options', {})
return MonteCarlo(numruns, seed, options=ImmutableSortedDict(options))
@dataclass(eq=True, frozen=True)
[docs]class SimNetlistInfo:
[docs] sim_envs: ImmutableList[str]
[docs] analyses: ImmutableList[AnalysisInfo]
[docs] params: ImmutableSortedDict[str, float]
[docs] env_params: ImmutableSortedDict[str, ImmutableList[float]]
[docs] outputs: ImmutableSortedDict[str, str]
[docs] options: ImmutableSortedDict[str, Any]
[docs] monte_carlo: Optional[MonteCarlo]
[docs] init_voltages: ImmutableSortedDict[str, Union[str, float]]
@property
[docs] def sweep_type(self) -> SweepInfoType:
return self.swp_info.stype
[docs]def netlist_info_from_dict(table: Dict[str, Any]) -> SimNetlistInfo:
sim_envs: List[str] = table['sim_envs']
analyses: List[Dict[str, Any]] = table['analyses']
params: Dict[str, float] = table.get('params', {})
env_params: Dict[str, List[float]] = table.get('env_params', {})
swp_info: Union[List[Tuple[str, Dict[str, Any]]], Dict[str, Any]] = table.get('swp_info', [])
outputs: Dict[str, str] = table.get('outputs', {})
options: Dict[str, Any] = table.get('options', {})
monte_carlo: Optional[Dict[str, Any]] = table.get('monte_carlo', None)
init_voltages: Dict[str, Union[str, float]] = table.get('init_voltages', {})
if not sim_envs:
raise ValueError('simulation environments list is empty')
env_par_dict = {}
num_env = len(sim_envs)
for key, val in env_params.items():
if len(val) != num_env:
raise ValueError("Invalid env_param value.")
env_par_dict[key] = ImmutableList(val)
ana_list = [analysis_from_dict(val) for val in analyses]
return SimNetlistInfo(ImmutableList(sim_envs), ImmutableList(ana_list),
ImmutableSortedDict(params), ImmutableSortedDict(env_par_dict),
swp_info_from_struct(swp_info), ImmutableSortedDict(outputs),
ImmutableSortedDict(options), monte_carlo_from_dict(monte_carlo),
ImmutableSortedDict(init_voltages))
###############################################################################
# Simulation data classes
###############################################################################
[docs]class AnalysisData:
"""A data struct that stores simulation data from a single analysis"""
def __init__(self, sweep_params: Sequence[str], data: Dict[str, np.ndarray],
is_md: bool) -> None:
self._swp_pars = ImmutableList(sweep_params)
self._data = data
self._is_md = is_md
swp_set = set(sweep_params)
self._signals = [key for key in data.keys() if key not in swp_set]
[docs] def __getitem__(self, item: str) -> np.ndarray:
return self._data[item]
[docs] def __contains__(self, item: str) -> bool:
return item in self._data
@property
[docs] def data_shape(self) -> Tuple[int, ...]:
if not self._signals:
return ()
return self._data[self._signals[0]].shape
@property
[docs] def is_md(self) -> bool:
return self._is_md
@property
[docs] def sweep_params(self) -> ImmutableList[str]:
return self._swp_pars
@property
[docs] def signals(self) -> List[str]:
return self._signals
@classmethod
[docs] def combine(cls, data_list: Sequence[AnalysisData], swp_name: str,
swp_vals: Optional[np.ndarray] = None, axis: int = 0) -> AnalysisData:
ndata = len(data_list)
if ndata < 1:
raise ValueError('Must combine at least 1 data.')
if swp_vals is None:
swp_vals = np.arange(ndata)
data0 = data_list[0]
new_data = {}
swp_par_list = list(data0.sweep_params)
# get all signals
max_size = None
for sig in data0.signals:
arr_list = [arr[sig] for arr in data_list]
sizes = [x.shape for x in arr_list]
max_size = np.max(list(zip(*sizes)), -1)
cur_ans = np.full((len(arr_list),) + tuple(max_size), np.nan, dtype=arr_list[0].dtype)
for idx, arr in enumerate(arr_list):
# noinspection PyTypeChecker
select = (idx,) + tuple(slice(0, s) for s in sizes[idx])
cur_ans[select] = arr
new_data[sig] = np.moveaxis(cur_ans, 0, axis)
# get last sweep parameter
last_par = swp_par_list[-1]
last_xvec = data0[last_par]
xvec_list = [data[last_par] for data in data_list]
for xvec in xvec_list:
if not np.array_equal(xvec_list[0], xvec):
# last sweep parameter has to be a multi dimensional array
sizes = [x.shape for x in xvec_list]
cur_ans = np.full((len(xvec_list),) + tuple(max_size), np.nan)
for idx, _xvec in enumerate(xvec_list):
# noinspection PyTypeChecker
select = (idx, ...) + tuple(slice(0, s) for s in sizes[idx])
cur_ans[select] = _xvec
last_xvec = np.moveaxis(cur_ans, 0, axis)
break
new_data[last_par] = last_xvec
# get all other sweep params
for sn in swp_par_list[:-1]:
if sn != 'corner':
new_data[sn] = data0[sn]
swp_par_list.insert(axis, swp_name)
new_data[swp_name] = swp_vals
return AnalysisData(swp_par_list, new_data, data0.is_md)
[docs] def get_param_value(self, name: str) -> np.ndarray:
param_idx = self._swp_pars.index(name)
shape = self.data_shape[:-1]
shape_init = [1] * len(shape)
shape_init[param_idx] = shape[param_idx]
arr = self._data[name].reshape(tuple(shape_init))
return np.broadcast_to(arr, shape)
[docs] def items(self) -> ItemsView[str, np.ndarray]:
return self._data.items()
[docs] def insert(self, name: str, data: np.ndarray) -> None:
self._data[name] = data
if name not in self._signals:
self._signals.append(name)
[docs] def copy(self) -> AnalysisData:
_data = {}
for k, v in self._data.items():
_data[k] = self._data[k].copy()
return AnalysisData(self._swp_pars, _data, self._is_md)
"""Adds combination to simulation results"""
[docs] def add(self, new_data: Dict[str, np.ndarray]):
if self.is_md:
raise AttributeError('Currently only supported in is_md = False mode')
# check that the size of new data is the same as existing data
assert len(self._data.keys()) == len(new_data.keys())
# check that all sweep parameters are provided
for param in self.sweep_params:
if param not in new_data.keys():
raise ValueError('Param %s not provided in data' % param)
ref_length = len(list(new_data.values())[0])
# add data points
for name, arr in new_data.items():
# check that all new data arrays are the correct length
if name in self.sweep_params or name == 'hash':
assert len(arr) == ref_length
else:
assert len(arr[0]) == ref_length
# new sweep point
if name in self.sweep_params:
self._data[name] = np.append(self._data[name], arr)
# sweep data
else:
self._data[name] = np.hstack((self._data[name], arr))
[docs] def remove_sweep(self, name: str, rtol: float = 1e-8, atol: float = 1e-20) -> bool:
new_swp_vars = list(self._swp_pars)
try:
idx = new_swp_vars.index(name)
except ValueError:
return False
if self._is_md:
swp_vals = self._data.pop(name)
if swp_vals.size != 1:
self._data[name] = swp_vals
raise ValueError('Can only remove sweep with 1 value in a MD sweep.')
for sig in self._signals:
self._data[sig] = np.squeeze(self._data[sig], axis=idx)
last_var_name = self._swp_pars[-1]
last_var_arr = self._data[last_var_name]
if len(last_var_arr.shape) != 1:
# also need to squeeze last x axis values
self._data[last_var_name] = np.squeeze(last_var_arr, axis=idx)
del new_swp_vars[idx]
self._swp_pars = ImmutableList(new_swp_vars)
else:
del new_swp_vars[idx]
# remove corners
swp_names = new_swp_vars[1:]
sig_shape = self._data[self._signals[0]].shape
num_env = sig_shape[0]
if len(sig_shape) == 2:
# inner most dimension is part of param sweep
swp_shape, swp_vals = _check_is_md(num_env, [self._data[par] for par in swp_names],
rtol, atol, None)
if swp_shape is not None:
for par, vals in zip(swp_names, swp_vals):
self._data[par] = vals
else:
# inner most dimension is not part of param sweep
last_par = swp_names[-1]
last_dset = self._data[last_par]
swp_names = swp_names[:-1]
swp_shape, swp_vals = _check_is_md(num_env, [self._data[par] for par in swp_names],
rtol, atol, last_dset.shape[-1])
if len(swp_names) == 0: # TODO: this is a hack to fix for 1 variable sweep
swp_shape = list(swp_shape)
swp_shape[-1] *= self._data[name].size
swp_shape = tuple(swp_shape)
if swp_shape is not None:
for par, vals in zip(swp_names, swp_vals):
self._data[par] = vals
if len(last_dset.shape) > 1:
self._data[last_par] = last_dset.reshape(swp_shape)
self._swp_pars = ImmutableList(new_swp_vars)
del self._data[name]
if swp_shape is not None:
# this is multi-D
for sig in self._signals:
self._data[sig] = self._data[sig].reshape(swp_shape)
self._is_md = True
return True
[docs]class SimData:
"""A data structure that stores simulation data as a multi-dimensional array."""
def __init__(self, sim_envs: Sequence[str], data: Dict[str, AnalysisData],
sim_netlist_type: DesignOutput) -> None:
if not data:
raise ValueError('Empty simulation data.')
self._sim_envs = ImmutableList(sim_envs)
self._table = data
self._cur_name = next(iter(self._table.keys()))
self._cur_ana: AnalysisData = self._table[self._cur_name]
self._netlist_type = sim_netlist_type
@property
[docs] def group(self) -> str:
return self._cur_name
@property
[docs] def group_list(self) -> List[str]:
return list(self._table.keys())
@property
[docs] def sim_envs(self) -> ImmutableList[str]:
return self._sim_envs
@property
[docs] def sweep_params(self) -> ImmutableList[str]:
return self._cur_ana.sweep_params
@property
[docs] def signals(self) -> List[str]:
return self._cur_ana.signals
@property
[docs] def is_md(self) -> bool:
return self._cur_ana.is_md
@property
[docs] def data_shape(self) -> Tuple[int, ...]:
return self._cur_ana.data_shape
@property
[docs] def netlist_type(self) -> DesignOutput:
return self._netlist_type
[docs] def __getitem__(self, item: str) -> np.ndarray:
if item.endswith('>'):
item = convert_cdba_name_bit(item, self._netlist_type)
return self._cur_ana[item]
[docs] def __contains__(self, item: str) -> bool:
return item in self._cur_ana
[docs] def items(self) -> ItemsView[str, np.ndarray]:
return self._cur_ana.items()
[docs] def open_group(self, val: str) -> None:
tmp = self._table.get(val, None)
if tmp is None:
raise ValueError(f'Group {val} not found.')
self._cur_name = val
self._cur_ana = tmp
[docs] def open_analysis(self, atype: AnalysisType) -> None:
self.open_group(atype.name.lower())
[docs] def insert(self, name: str, data: np.ndarray) -> None:
self._cur_ana.insert(name, data)
[docs] def add(self, new_data: Dict[str, np.ndarray]):
self._cur_ana.add(new_data)
[docs] def copy(self, rename: Optional[Dict[str, str]] = None) -> SimData:
if rename is None:
rename = {}
_table = {}
for k, v in self._table.items():
key = rename.get(k, k)
_table[key] = self._table[k]
return SimData(self._sim_envs, _table, self.netlist_type)
[docs] def deep_copy(self, rename: Optional[Dict[str, str]] = None) -> SimData:
if rename is None:
rename = {}
_table = {}
for k, v in self._table.items():
key = rename.get(k, k)
_table[key] = self._table[k].copy()
return SimData(self._sim_envs, _table, self.netlist_type)
[docs] def remove_sweep(self, name: str, rtol: float = 1e-8, atol: float = 1e-20) -> bool:
return self._cur_ana.remove_sweep(name, rtol=rtol, atol=atol)
[docs] def get_param_value(self, name: str) -> np.ndarray:
return self._cur_ana.get_param_value(name)
@classmethod
[docs] def combine(cls, data_list: List[SimData], swp_name: str,
swp_vals: Optional[np.ndarray] = None) -> SimData:
ndata = len(data_list)
if ndata < 1:
raise ValueError('Must combine at least 1 data.')
data0 = data_list[0]
sim_envs = data0.sim_envs
new_data = {}
for grp in data0.group_list:
ana_list = [sim_data._table[grp] for sim_data in data_list]
new_data[grp] = AnalysisData.combine(ana_list, swp_name, swp_vals=swp_vals, axis=1)
return SimData(sim_envs, new_data, data0.netlist_type)
[docs]def _check_is_md(num_env: int, swp_vals: List[np.ndarray], rtol: float, atol: float,
last: Optional[int]) -> Tuple[Optional[Tuple[int, ...]], List[np.ndarray]]:
num = len(swp_vals)
shape_list = [num_env] * (num + 1)
new_vals = [np.nan] * num
prev_size = 1
for idx in range(num - 1, -1, -1):
cur_vals = swp_vals[idx]
if prev_size > 1:
rep_prev = cur_vals.size // prev_size
for start_idx in range(0, rep_prev * prev_size, prev_size):
if not np.allclose(cur_vals[start_idx:start_idx + prev_size], cur_vals[start_idx],
rtol=rtol, atol=atol):
# is not MD
return None, []
cur_vals = cur_vals[0::prev_size]
occ_vec = np.nonzero(np.isclose(cur_vals, cur_vals[0], rtol=rtol, atol=atol))[0]
if occ_vec.size < 2:
unique_size = cur_vals.size
else:
unique_size = occ_vec[1]
rep, remain = divmod(cur_vals.size, unique_size)
if remain != 0 or not np.allclose(cur_vals, np.tile(cur_vals[:unique_size], rep),
rtol=rtol, atol=atol):
# is not MD
return None, []
new_vals[idx] = cur_vals[:unique_size]
shape_list[idx + 1] = unique_size
prev_size *= unique_size
if last is not None:
shape_list.append(last)
return tuple(shape_list), new_vals
[docs]def combine_ana_sim_envs(ana_dict: Dict[str, AnalysisData], sim_envs: List[str]) -> AnalysisData:
"""Combine multiple single-corner analysis data to a single multi-corner analysis data.
Parameters
----------
ana_dict : Dict[str, AnalysisData]
dictionary mapping corner to analysis data.
sim_envs: List[str]
list of corners.
Returns
-------
ana_data : AnalysisData
the combined analysis data.
"""
cur_ana_sim_envs = list(ana_dict.keys())
assert sorted(cur_ana_sim_envs) == sorted(sim_envs), f"Expected corners {sim_envs}, got {cur_ana_sim_envs}"
num_sim_envs = len(sim_envs)
if num_sim_envs == 1: # Single corner, nothing to combine
return ana_dict[sim_envs[0]]
ana_list = [ana_dict[sim_env] for sim_env in sim_envs] # Reorder analyses by corner
merged_data = {}
ana0 = ana_list[0]
swp_par_list = ana0.sweep_params
# get all signals
max_size = None
for sig in ana0.signals:
arr_list = [arr[sig] for arr in ana_list]
sizes = [x.shape for x in arr_list]
max_size = np.max(list(zip(*sizes)), -1)
assert max_size[0] == 1
# noinspection PyTypeChecker
cur_ans = np.full((num_sim_envs,) + tuple(max_size[1:]), np.nan, dtype=arr_list[0].dtype)
for idx, arr in enumerate(arr_list):
select = (idx,) + tuple(slice(0, s) for s in sizes[idx][1:])
cur_ans[select] = arr
merged_data[sig] = cur_ans
if len(swp_par_list) > 1:
# get last sweep parameter
last_par = swp_par_list[-1]
last_xvec = ana0[last_par]
xvec_list = [ana[last_par] for ana in ana_list]
xshape_list = [x.shape for x in xvec_list]
for xvec in xvec_list[1:]:
# if the last sweep parameter values are different across corners,
# the last sweep parameter has to be a multi dimensional array
if not np.array_equal(xvec_list[0], xvec):
# noinspection PyTypeChecker
cur_ans = np.full((num_sim_envs,) + tuple(max_size[1:]), np.nan)
if len(xshape_list[0]) == len(cur_ans.shape):
# if last sweep parameter has the same shape as the data,
# then join these together along the first (corner) axis
for idx, (_xvec, _xshape) in enumerate(zip(xvec_list, xshape_list)):
select = (idx, ) + tuple(slice(0, s) for s in _xshape[1:])
cur_ans[select] = _xvec
else:
# if not the same shape as the data, assume corner is missing
# and add it to the merged swept values
for idx, (_xvec, _xshape) in enumerate(zip(xvec_list, xshape_list)):
select = (idx, ...) + tuple(slice(0, s) for s in _xshape)
cur_ans[select] = _xvec
last_xvec = cur_ans
break
merged_data[last_par] = last_xvec
# get all other sweep params
for sn in swp_par_list[:-1]:
if sn != 'corner':
merged_data[sn] = ana0[sn]
return AnalysisData(swp_par_list, merged_data, ana_list[0].is_md)