Source code for bag3_testbenches.measurement.search

# SPDX-License-Identifier: Apache-2.0
# Copyright 2019 Blue Cheetah Analog Design Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import Any, Tuple, Mapping, Optional, Dict, Set, Union

import abc
import pprint
from enum import Flag, auto

import numpy as np

from bag.util.search import FloatIntervalSearch
from bag.simulation.core import TestbenchManager
from bag.simulation.cache import SimulationDB, SimResults, MeasureResult, DesignInstance
from bag.simulation.measure import MeasurementManager, MeasurementManagerFSM, MeasInfo


[docs]class AcceptMode(Flag):
[docs] POSITIVE = auto()
[docs] NEGATIVE = auto()
[docs] BOTH = POSITIVE | NEGATIVE
[docs]class IntervalSearchMM(MeasurementManagerFSM, abc.ABC): """A Measurement manager that performs binary search for you. Assumes that no parameters/corners are swept. """ def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self._tbm_info: Optional[Tuple[TestbenchManager, Mapping[str, Any]]] = None self._table: Dict[str, Tuple[FloatIntervalSearch, bool]] = {} self._remaining: Set[str] = set() self._result: Dict[str, Dict[str, Any]] = {} self._use_dut = True @property
[docs] def bounds(self) -> Dict[str, Tuple[float, float]]: return {name: (val[0].low, val[0].high) for name, val in self._table.items()}
@abc.abstractmethod
[docs] def process_init(self, cur_info: MeasInfo, sim_results: SimResults ) -> Tuple[Dict[str, Any], bool]: pass
@abc.abstractmethod @abc.abstractmethod
[docs] def process_output_helper(self, cur_info: MeasInfo, sim_results: SimResults, remaining: Set[str]) -> Mapping[str, Tuple[Tuple[float, float], Dict[str, Any], bool]]: pass
[docs] def get_init_result(self, adj_name: str) -> Dict[str, Any]: return {}
[docs] def get_bound(self, adj_name: str) -> Tuple[float, float]: search = self._table[adj_name][0] return search.low, search.high
[docs] def initialize(self, sim_db: SimulationDB, dut: DesignInstance) -> Tuple[bool, MeasInfo]: tmp = self.init_search(sim_db, dut) tbm, tb_params, intv_params, intv_defaults, has_init, use_dut = tmp if tbm.swp_info: self.error('Parameter sweep is not supported.') if tbm.num_sim_envs != 1: self.error('Corner sweep is not supported.') self._table.clear() self._remaining.clear() self._result.clear() self._use_dut = use_dut any_next = False for adj_name, search_params in intv_params.items(): low: float = _get('low', search_params, intv_defaults) high: Optional[float] = _get('high', search_params, intv_defaults) step: float = _get('step', search_params, intv_defaults) tol: float = _get('tol', search_params, intv_defaults) max_err: float = _get('max_err', search_params, intv_defaults) overhead_factor: float = _get('overhead_factor', search_params, intv_defaults) guess: Optional[Union[float, Tuple[float, float]]] = _get( 'guess', search_params, intv_defaults, None) single_first: bool = _get('single_first', search_params, intv_defaults, False) tmp = FloatIntervalSearch(low, high, overhead_factor, tol=tol, guess=guess, search_step=step, max_err=max_err) any_next = any_next or tmp.has_next() self._table[adj_name] = (tmp, single_first) self._result[adj_name] = self.get_init_result(adj_name) self._remaining.add(adj_name) if not any_next: return True, MeasInfo('done', self._result) self._tbm_info = (tbm, tb_params) return False, MeasInfo('init' if has_init else 'bin_0', {})
[docs] def get_sim_info(self, sim_db: SimulationDB, dut: DesignInstance, cur_info: MeasInfo ) -> Tuple[Union[Tuple[TestbenchManager, Mapping[str, Any]], MeasurementManager], bool]: state = cur_info.state if state.startswith('bin'): if len(self._table) == 1: # can do sweeps adj_name = next(iter(self._table.keys())) search, single_first = self._table[adj_name] if state == 'bin_0' and single_first: val = search.get_value() self.log(f'Set {adj_name} to: {val:.4g}') self._tbm_info[0].sim_params[adj_name] = val else: swp_specs = search.get_sweep_specs() self.log(f'{adj_name} sweep: {swp_specs}') self._tbm_info[0].set_swp_info([(adj_name, swp_specs)]) else: # no sweeps update_dict = {adj_name: self._table[adj_name][0].get_value() for adj_name in self._remaining} self.log(f'Setting sim parameters:\n{pprint.pformat(update_dict, width=100)}') self._tbm_info[0].sim_params.update(update_dict) return self._tbm_info, self._use_dut
[docs] def process_output(self, cur_info: MeasInfo, sim_results: Union[SimResults, MeasureResult] ) -> Tuple[bool, MeasInfo]: cur_state = cur_info.state if cur_state.startswith('bin'): info_table = self.process_output_helper(cur_info, sim_results, self._remaining) any_next = False for key, ((low, high), result, accept) in info_table.items(): search = self._table[key][0] search.set_interval(low, high=high) any_next = any_next or search.has_next() if accept: self._result[key] = result if any_next: return False, MeasInfo(f'bin_{int(cur_state[4:]) + 1}', self._result) else: return True, MeasInfo('done', self._result) else: self._result, done = self.process_init(cur_info, sim_results) if done: return True, MeasInfo('done', self._result) else: return False, MeasInfo('bin_0', self._result)
[docs] def get_adj_interval(self, adj_name: str, adj_sign: bool, adj_values: np.ndarray, diff: np.ndarray, accept_mode: AcceptMode = AcceptMode.POSITIVE ) -> Tuple[int, bool, float, float]: num_values = adj_values.size # find sign change idx_arr = np.nonzero(np.diff((diff >= 0).astype(int)))[0] # idx_arr[0] will be the index before sign change (if it exists) if idx_arr.size == 0: # no sign change test_val = diff[0] if test_val == 0: # everything is 0, pick middle point res_idx = num_values // 2 low = high = adj_values[res_idx] accept = True else: accept = (((test_val > 0) and AcceptMode.POSITIVE in accept_mode) or ((test_val < 0) and AcceptMode.NEGATIVE in accept_mode)) cur_bnds = self.get_bound(adj_name) if (test_val > 0) ^ adj_sign: # increase parameter low = adj_values[num_values - 1] high = cur_bnds[1] res_idx = num_values - 1 else: # decrease parameter low = cur_bnds[0] high = adj_values[0] res_idx = 0 else: res_idx = idx_arr[0] accept = True test_val = diff[res_idx + 1] if test_val == 0: res_idx += 1 low = high = adj_values[res_idx] else: low = adj_values[res_idx] high = adj_values[res_idx + 1] if ((accept_mode is AcceptMode.BOTH and np.abs(test_val) < np.abs(diff[res_idx])) or (accept_mode is AcceptMode.POSITIVE and test_val > 0) or (accept_mode is AcceptMode.NEGATIVE and test_val < 0)): res_idx += 1 return res_idx, accept, low, high
[docs] def log_result(self, state: str, new_result: Mapping[str, Any]) -> None: fmt = '{:.5g}' msg_list = [f'state = {state}'] for name, val in self._table.items(): search = val[0] msg_list.append(f'{name} bnd = [{fmt.format(search.low)}, {fmt.format(search.high)})') msg_list.append('result:') msg_list.append(pprint.pformat(new_result, width=100)) self.log('\n'.join(msg_list))
[docs]def _get(key: str, table1: Mapping[str, Any], table2: Mapping[str, Any], *args: Any) -> Any: if key in table1: return table1[key] if key in table2: return table2[key] if not args: raise ValueError(f'Cannot find key: {key}') return args[0]