Source code for eispac.core.eisfittemplate

__all__ = ['EISFitTemplate']

import sys
import copy
import pathlib
import warnings

import numpy as np
import h5py

if sys.version_info >= (3, 11):
    import tomllib
    import tomli as tomllib

[docs] class EISFitTemplate: """Multigaussian fitting template for use with MPFIT and `~eispac.core.fit_spectra` Parameters ---------- filename : str or `pathlib.Path`, optional Path to the template file. Default is "unknown_file". Note: passing a filename directly to EISFitTemplate will NOT automatically load the file; please use `~eispac.core.read_template` instead. template : dict, optional Dictionary of template parameters. Valid keys include: * ``n_gauss`` (int) - number of Gaussian components * ``n_poly`` (int) - Number of background polynomial terms. Common values are: 0 (no background), 1 (constant), and 2 (linear). * ``line_ids`` (array_like) - Strings giving the line identification for each Gaussian component. For example, "Fe XII 195.119". If not specified, placeholder values of "unknown I {INITAL CENTROID VALUE}" will be used. * ``wmin`` (float) - min wavelength value of data to use for fitting * ``wmax`` (float) - max wavelength value of data to use for fitting parinfo : list or dict, optional Either a list of dicts or a dict of lists giving the initial fitting parameters and constraints. If given a `list`, each entry must be a dict with the correct keys for a single fit parameter. If given a `dict`, each key must contain an array of list for all parameters. The order of parameters is assumed to be sets of [PEAK, CENTROID, WIDTH] for each Gaussian component followed by any coefficients for the background polynomial, starting with the LOWEST (constant) order term first. Valid keys include: * ``value`` (float) - initial parameter guess * ``fixed`` (0 or 1) - If set to "1", will not fit and just use initial value * ``limited`` (two-element array_like) - If set to "1" in the first/second value, will apply and limit to the parameter on the lower/upper side * ``limits`` (two-element array_like) - Values of the limits on the lower/upper side. Both "limited" and "limits" must be give together. * ``tied`` (str) - String defining a fixed relationship between this parameters one or more others. For example "p[0] / 2" would define a parameter to ALWAYS be exactly half the value of the first parameter. Additional keys are available, please see the MPFIT documentation in the EISPAC user's guide for more details. kwargs : value or array_like, optional For convenience, users can also pass in any "template" or "parinfo" value or array as a seperate keyword. Any values defined in this way will take priority and overwrite any previous values for that key. Attributes ---------- template : dict Dictoary of template paramaters parinfo : list of dicts List of parameter constraint dicts """ def __init__(self, filename=None, template=None, parinfo=None, **kwargs): if filename is None: self.filename_temp = 'unknown_file' else: self.filename_temp = str(filename) self.template = dict() if isinstance(template, dict): self.template = copy.deepcopy(template) # Ensure that all keys are lower case self.template = {KEY.lower():VAL for KEY, VAL in self.template.items()} elif template is not None: print('Error: Invalid datatype for "template". Please input a ' +'dictionary. Initializing an empty dict using the other input ' +'parameters.', file=sys.stderr) self.parinfo = dict() if isinstance(parinfo, (list, tuple, dict)): self.parinfo = copy.deepcopy(parinfo) elif parinfo is not None: print('Error: Invalid datatype for "parinfo". Please input a ' +'list of dicts or a dict of lists with the correct keys. ' +'Initializing an empty list of dicts using the other input ' +'parameters.', file=sys.stderr) # Copy over kwargs (if any) to the correct dict or list if len(kwargs) > 0: # Ensure that all keys are lower case kwargs = {KEY.lower():VAL for KEY, VAL in kwargs.items()} for KEY in kwargs.keys(): if KEY in ['component', 'n_gauss', 'n_poly', 'line_ids', 'wmin', 'wmax', 'data_e', 'data_x', 'data_y', 'fit', 'fit_back', 'fit_gauss', 'order']: # .template keys self.template[KEY] = copy.deepcopy(kwargs[KEY]) elif KEY in ['value', 'fixed', 'limited', 'limits', 'tied', 'parname', 'step', 'mpmaxstep', 'mpside', 'mpprint']: # .parinfo keys if isinstance(self.parinfo, dict): self.parinfo[KEY] = copy.deepcopy(kwargs[KEY]) elif isinstance(self.parinfo, list): n_new = len(kwargs[KEY]) n_par = len(self.parinfo) for p in range(min(n_new, n_par)): self.parinfo[p][KEY] = copy.deepcopy(kwargs[KEY][p]) else: print(f'Warning: "{KEY}" is not a valid template or parinfo ' +f'keyword and will be ignored. Please see the docs ' +f'for a list of supported inputs.', file=sys.stderr) # Initialize values of n_gauss and n_poly (if not given) if ('n_gauss' not in self.template.keys()) or ('n_poly' not in self.template.keys()): if len(self.parinfo) > 0: # If parinfo is defined, guess from number of values if isinstance(self.parinfo, (list, tuple)): # Standard list of dicts n_params = len(self.parinfo) elif isinstance(self.parinfo, dict): # Alternative dict of arrays n_params = 1 for KEY in self.parinfo.keys(): n_params = max(n_params, len(self.parinfo[KEY])) # Add the missing value to the template if 'n_gauss' not in self.template.keys(): self.template['n_gauss'] = int(n_params / 3) if 'n_poly' not in self.template.keys(): self.template['n_poly'] = int(n_params % 3) else: # If neither template nor parinfo are defined, use defaults if 'n_gauss' not in self.template.keys(): self.template['n_gauss'] = 1 if 'n_poly' not in self.template.keys(): self.template['n_poly'] = 1 # If parinfo is not given, initialize minimal parinfo list of dicts if len(self.parinfo) == 0: n_gauss = max(1, self.template.get('n_gauss', 1)) n_poly = max(0, self.template.get('n_poly', 1)) par_list = [] for g in range(3*n_gauss): par_list.append({'value': 1.0, 'limited': np.array([1, 0], dtype='int16')}) for p in range(n_poly): par_list.append({'value': 1.0, 'limited': np.array([0, 0], dtype='int16')}) self.parinfo = par_list # Finally, validate the actual values and fix format issues # Note: we need to validate parinfo first, so we can convert the format # and more easily check the template values of n_gauss and n_poly self._validate_parinfo_list() self._validate_template_dict() def _validate_parinfo_list(self): """Helper function for validating the length and keys of .parinfo """ # Default dict for a single parameter default_parinfo = {'fixed': 0, 'limited': np.array([1, 0], dtype='int16'), 'limits': np.zeros(2), 'tied': np.array(' ', dtype='<U16'), 'value': 1.0} # If given a dict of arrays (or lists), convert to a list of dicts if isinstance(self.parinfo, dict): # First, ensure all keys are lower case input_parinfo = {KEY.lower():VAL for KEY, VAL in self.parinfo.items()} if 'value' in input_parinfo.keys(): n_params = len(self.parinfo['value']) else: max_len = 0 for KEY in input_parinfo.keys(): this_len = len(input_parinfo[KEY]) max_len = max(this_len, max_len) n_params = max_len print('Warning: No initial values given in parinfo! Please ' +'define a "value" key for each parameter', file=sys.stderr) # Extract the values and assemble the correct list of dicts par_list = [] for p in range(n_params): info_dict = {} for KEY in input_parinfo.keys(): try: if KEY in ['tied', 'parname']: # strings info_dict[KEY] = np.array(str(input_parinfo[KEY][p]), dtype='<U16') elif KEY in ['limited', 'limits']: # two-element info_dict[KEY] = copy.deepcopy(default_parinfo[KEY]) info_dict[KEY][0] = input_parinfo[KEY][p][0] info_dict[KEY][1] = input_parinfo[KEY][p][1] elif KEY in ['fixed', 'mpside', 'mpprint']: # integers info_dict[KEY] = int(input_parinfo[KEY][p]) elif KEY in ['value', 'step', 'mpmaxstep']: # floats info_dict[KEY] = float(input_parinfo[KEY][p]) else: print(f'Warning: {KEY} is not a valid parinfo key ' +'and will be ignored.', file=sys.stderr) except: # If there is an issue, give a warning and continue print(f'Warning: There was an issue loading the {KEY} ' +f'key for parameter index {p}. Skipping for now; ' +f'please check the format and datatype.', file=sys.stderr) continue par_list.append(info_dict) self.parinfo = par_list # Check for missing keys expected by eispac and fill with default values print_limit_warning = False n_params = len(self.parinfo) for p in range(n_params): # First, ensure all keys are lower case self.parinfo[p] = {KEY.lower():VAL for KEY, VAL in self.parinfo[p].items()} for KEY, VALUE in default_parinfo.items(): # Add missing keys that are required for printing details if KEY not in self.parinfo[p].keys(): self.parinfo[p][KEY] = copy.deepcopy(VALUE) else: # Check shape of two-element keys that DO exist if KEY in ['limited', 'limits']: if len(self.parinfo[p][KEY]) != 2: # Overwrite if given wrong shape self.parinfo[p][KEY] = copy.deepcopy(VALUE) print_limit_warning = True if print_limit_warning: print(f'Warning: incorrect length of "limited" and/or "limits" ' +f'in found parinfo. Both "limited" and "limits" should be ' +f'two-element lists or arrays. Invalid inputs have been ' +f'replaced with default values.', file=sys.stderr) def _validate_template_dict(self): """Helper function for validating the keys and values of .template """ # Ensure that all keys are lower case self.template = {KEY.lower():VAL for KEY, VAL in self.template.items()} # Validate keys actaully used by eispac n_gauss = self.template.get('n_gauss', -1) if isinstance(n_gauss, (int, float)) and int(n_gauss) > 0: self.template['n_gauss'] = int(n_gauss) else: self.template['n_gauss'] = -1 print('Error: Invalid value for n_gauss. ' +'Please input a non-zero integer', file=sys.stderr) n_poly = self.template.get('n_poly', -1) if isinstance(n_poly, (int, float)) and int(n_poly) >= 0: self.template['n_poly'] = int(n_poly) else: self.template['n_poly'] = -1 print('Error: Invalid value for n_poly. ' +'Please input an integer >= 0.', file=sys.stderr) wmin = self.template.get('wmin', 170.0) if isinstance(wmin, (int, float)): self.template['wmin'] = float(wmin) else: self.template['wmin'] = 170.0 print('Error: Invalid datatype for wmin. Please input a float. ' +'Using a default value of 170.', file=sys.stderr) wmax = self.template.get('wmax', 292.0) if isinstance(wmax, (int, float)): self.template['wmax'] = float(wmax) else: self.template['wmax'] = 292.0 print('Error: Invalid datatype for wmax. Please input a float. ' +'Using a default value of 292.', file=sys.stderr) comp_num = self.template.get('component', 1) if isinstance(comp_num, (int, float)): self.template['component'] = int(comp_num) else: self.template['component'] = 1 print('Error: Invalid datatype for component. Please input an ' +'integer. Using a default value of 1.', file=sys.stderr) # Check number of parameters and components n_params = len(self.parinfo) if n_params != (3*self.template['n_gauss'] + self.template['n_poly']): # When mismatched, assume input parinfo is correct # Then, update/overwrite parts of the template as needed n_gauss = int(n_params / 3) n_poly = int(n_params % 3) self.template['n_gauss'] = n_gauss self.template['n_poly'] = n_poly self.template['line_ids'] = np.zeros(n_gauss, dtype='<U32') for g in range(n_gauss): self.template['line_ids'][g] = 'unknown I 001.000' print(f'Warning: the values of n_gauss and n_poly do not match the ' +f'length of parinfo. Defaulting to the best guess values of ' +f'n_gauss = {n_gauss} and n_poly = {n_poly}. Template dict ' +f' values and line_ids have been overwritten to match.', file=sys.stderr) # Check for line_id and convert as needed n_gauss = self.template['n_gauss'] n_poly = self.template['n_poly'] if 'line_ids' not in self.template.keys(): self.template['line_ids'] = np.zeros(n_gauss, dtype='<U32') for g in range(n_gauss): self.template['line_ids'][g] = 'unknown I 001.000' elif isinstance(self.template['line_ids'], (str, int, float)): # Convert bare values to an array (actual values checked later) bare_id = self.template['line_ids'] self.template['line_ids'] = np.array([bare_id], dtype='<U24') else: # Force everything else to be an array of strings current_ids = self.template['line_ids'] self.template['line_ids'] = np.array(current_ids, dtype='<U24') # Trim or expand line_ids array to have the correct length if len(self.template['line_ids']) != n_gauss: old_id_arr = self.template['line_ids'] self.template['line_ids'] = np.zeros(n_gauss, dtype='<U32') for g in range(n_gauss): if g < len(old_id_arr): self.template['line_ids'][g] = old_id_arr[g] else: self.template['line_ids'][g] = 'unknown I 001.000' # Validate format and values of line_id strings for g in range(n_gauss): parinfo_wave_str = f"{self.parinfo[3*g+1]['value']:07.3f}" this_id = self.template['line_ids'][g].lower() if this_id.startswith(('unknown', 'no line')) or len(this_id) < 1: # If a line is unknown, just update the wavelength substring self.template['line_ids'][g] = f"unknown I {parinfo_wave_str}" else: # Parse the ID string and check values warn_bad_id = False input_id = self.template['line_ids'][g] split_id = input_id.split() # Element - can be any string, really elem_id = split_id[0] # Ionization state - only allow roman numerals less than 100 ion_id = 'I' if len(split_id) >= 2: test_ion = split_id[1].upper() num_roman = sum(map(test_ion.count, ['I','V','X','L'])) if len(test_ion) == num_roman: ion_id = test_ion else: warn_bad_id = True else: warn_bad_id = True # Wavelength - needs to have a decimal point wave_id = parinfo_wave_str if len(split_id) >= 3: test_wave = split_id[2] if test_wave.isdigit(): wave_id = test_wave+'.000' elif '.' in test_wave: wave_id = test_wave else: warn_bad_id = True else: warn_bad_id = True # Copy over the clean ID (all other string parts are ignored) clean_id = ' '.join([elem_id, ion_id, wave_id]) self.template['line_ids'][g] = clean_id if warn_bad_id: print(f'Warning: "{input_id}" is not a valid or complete ' +f'line ID string. Using a replacement ID of ' +f'"{clean_id}" instead.', file=sys.stderr) # Extract initial values and update the 'fit' array # Note: this is important for using the scale_guess() function during # the fitting process. self.template['fit'] = np.zeros(3*n_gauss + n_poly) for p in range(3*n_gauss + n_poly): self.template['fit'][p] = self.parinfo[p]['value'] def __repr__(self): rows = [] rows.append('--- EISFitTemplate SUMMARY ---') rows.append(f"filename_temp: {self.filename_temp}") rows.append(f"n_gauss: {self.template['n_gauss']}") rows.append(f"n_poly: {self.template['n_poly']}") rows.append(f"line_ids: {self.template['line_ids']}") rows.append(f"wmin, wmax: {self.template['wmin']}, {self.template['wmax']}") rows.append('') rows.append('--- PARAMETER CONSTRAINTS ---') rows.append(f"{'*':>4} {'Value':>16} {'Fixed':>7} " +f"{'Limited':>11} {'Limits':>19} {'Tied':>19}") for i, p in enumerate(self.parinfo): rows.append(f"{f'p[{i}]':>6} {p['value']:14.4f} {p['fixed']:7d} " +f"{p['limited'][0]:5d} {p['limited'][1]:5d} " +f"{p['limits'][0]:12.4f} {p['limits'][1]:12.4f} {p['tied']:>18}") return '\n'.join(rows) @property def central_wave(self): """Wavelength value in the center of the template wavelength range Note: this is calculated using the current values of 'wmin' and 'wmax' contained in the template dict. """ return self.template['wmin'] + (self.template['wmax'] - self.template['wmin']) * 0.5 @property def funcinfo(self): """List of dicts specifying each subcomponent function used in the template Note: this is generated on-demand using the current template dict and the ``get_funcinfo`` class method. """ return self.get_funcinfo(self.template)
[docs] @staticmethod def get_funcinfo(template): """Generate the ``funcinfo`` dict Returns a list of dictionaries where each entry describes the parameters for one of the fitting basis functions Parameters ---------- template : dict Returns ------- funcinfo : list """ funcinfo = [] for g in range(template['n_gauss']): funcinfo.append({'func': 'Gaussian1D', 'name': template['line_ids'][g], 'n_params': 3}) if template['n_poly'] > 0: funcinfo.append({'func': 'Polynomial1D', 'name': 'Background', 'n_params': template['n_poly']}) return funcinfo
[docs] @classmethod def read_template(cls, filename): """Load an `EISFitTemplate` from an HDF5 or TOML template file Parameters ---------- filename : str or `pathlib.Path` Path to a HDF5 template file provided with eispac or a user-made TOML template file. Returns ------- cls : `EISFitTemplate` class instance Object containing the fit template """ # NOTE: return None here rather than allow h5py to handle # exception so that spectral fitting pipeline can error # more gracefully # FIXME: replace with proper exception handling and logging if not isinstance(filename, (str, pathlib.Path)): warnings.warn('Error: Template filepath must be either a string or pathlib.Path') return None filename = pathlib.Path(filename) if not filename.is_file(): warnings.warn(f'Error: Template filepath {filename} does not exist') return None file_type = filename.suffix if file_type.lower() in ['.h5', '.hdf5']: # Load a standard format HDF5 file (probably packaged with eispac) with h5py.File(filename, 'r') as f_temp: # Template template = {} for key in f_temp['template']: val = f_temp['template/'+key] if key == 'line_ids': val = np.char.decode(val).flatten() # convert bytes to unicode elif len(val) > 1: val = np.array(val) else: val = val[0] template[key] = val # Parinfo nstr = len(f_temp['parinfo/value']) parinfo = [] for istr in range(nstr): parameter = {} for key in f_temp['parinfo']: val = f_temp['parinfo/'+key][istr] if key == 'tied': val = np.char.decode(val) # convert bytes to unicode parameter[key] = val parinfo.append(parameter) # Fix the datatypes of important keys template['n_gauss'] = int(template.get('n_gauss', 1)) template['n_poly'] = int(template.get('n_poly', 1)) template['wmin'] = float(template.get('wmin', 100)) template['wmax'] = float(template.get('wmax', 1000)) template['component'] = int(template.get('component', 1)) template['line_ids'] = template['line_ids'].astype('<U24') elif file_type.lower() == '.toml': # Load a custom template stored in a TOML file with open(filename, 'rb') as f_temp: toml_dict = tomllib.load(f_temp) # Ensure top-level keys are all lower case toml_dict = {KEY.lower(): VALUE for KEY, VALUE in toml_dict.items()} template = toml_dict.get('template', None) parinfo = toml_dict.get('parinfo', None) # note: this parinfo will be an alternative dict of lists return cls(filename, template, parinfo)