Source code for platon.combined_retriever

from __future__ import print_function

import os

import numpy as np
import matplotlib.pyplot as plt
import scipy.interpolate
import emcee
import nestle
import copy

from .transit_depth_calculator import TransitDepthCalculator
from .eclipse_depth_calculator import EclipseDepthCalculator
from .fit_info import FitInfo
from .constants import METRES_TO_UM, M_jup, R_jup, R_sun
from ._params import _UniformParam
from .errors import AtmosphereError
from ._output_writer import write_param_estimates_file
from .TP_profile import Profile

[docs]class CombinedRetriever:
[docs] def pretty_print(self, fit_info): line = "ln_prob={:.2e}\t".format(self.last_lnprob) for i, name in enumerate(fit_info.fit_param_names): value = self.last_params[i] unit = "" if name == "Rs": value /= R_sun unit = "R_sun" if name == "Mp": value /= M_jup unit = "M_jup" if name == "Rp": value /= R_jup unit = "R_jup" if name == "T": unit = "K" if name == "T": format_str = "{:4.0f}" elif abs(value) < 1e4: format_str = "{:.2f}" else: format_str = "{:.2e}" format_str = "{}=" + format_str + " " + unit + "\t" line += format_str.format(name, value) return line
def _validate_params(self, fit_info, calculator): # This assumes that the valid parameter space is rectangular, so that # the bounds for each parameter can be treated separately. Unfortunately # there is no good way to validate Gaussian parameters, which have # infinite range. fit_info = copy.deepcopy(fit_info) if fit_info.all_params["ri"].best_guess is None: # Not using Mie scattering if fit_info.all_params["log_number_density"].best_guess != -np.inf: raise ValueError("log number density must be -inf if not using Mie scattering") else: if fit_info.all_params["log_scatt_factor"].best_guess != 0: raise ValueError("log scattering factor must be 0 if using Mie scattering") for name in fit_info.fit_param_names: this_param = fit_info.all_params[name] if not isinstance(this_param, _UniformParam): continue if this_param.best_guess < this_param.low_lim \ or this_param.best_guess > this_param.high_lim: raise ValueError( "Value {} for {} not between low and high limits".format( this_param.best_guess, name)) if this_param.low_lim >= this_param.high_lim: raise ValueError( "low_lim for {} is higher than high_lim".format(name)) for lim in [this_param.low_lim, this_param.high_lim]: this_param.best_guess = lim calculator._validate_params( fit_info._get("T"), None, fit_info._get("logZ"), fit_info._get("CO_ratio"), 10**fit_info._get("log_cloudtop_P")) def _ln_prob(self, params, transit_calc, eclipse_calc, fit_info, measured_transit_depths, measured_transit_errors, measured_eclipse_depths, measured_eclipse_errors, plot=False): if not fit_info._within_limits(params): return -np.inf params_dict = fit_info._interpret_param_array(params) Rp = params_dict["Rp"] T = params_dict["T"] logZ = params_dict["logZ"] CO_ratio = params_dict["CO_ratio"] scatt_factor = 10.0**params_dict["log_scatt_factor"] scatt_slope = params_dict["scatt_slope"] cloudtop_P = 10.0**params_dict["log_cloudtop_P"] error_multiple = params_dict["error_multiple"] Rs = params_dict["Rs"] Mp = params_dict["Mp"] T_star = params_dict["T_star"] T_spot = params_dict["T_spot"] spot_cov_frac = params_dict["spot_cov_frac"] frac_scale_height = params_dict["frac_scale_height"] number_density = 10.0**params_dict["log_number_density"] part_size = 10.0**params_dict["log_part_size"] ri = params_dict["ri"] if Rs <= 0 or Mp <= 0: return -np.inf ln_likelihood = 0 try: if measured_transit_depths is not None: if T is None: raise ValueError("Must fit for T if using transit depths") transit_wavelengths, calculated_transit_depths, info_dict = transit_calc.compute_depths( Rs, Mp, Rp, T, logZ, CO_ratio, scattering_factor=scatt_factor, scattering_slope=scatt_slope, cloudtop_pressure=cloudtop_P, T_star=T_star, T_spot=T_spot, spot_cov_frac=spot_cov_frac, frac_scale_height=frac_scale_height, number_density=number_density, part_size=part_size, ri=ri, full_output=True) residuals = calculated_transit_depths - measured_transit_depths scaled_errors = error_multiple * measured_transit_errors ln_likelihood += -0.5 * np.sum(residuals**2 / scaled_errors**2 + np.log(2 * np.pi * scaled_errors**2)) if plot: plt.figure(1) plt.plot(METRES_TO_UM * info_dict["unbinned_wavelengths"], info_dict["unbinned_depths"], color='b', label="Calculated (unbinned)") plt.errorbar(METRES_TO_UM * transit_wavelengths, measured_transit_depths, yerr = measured_transit_errors, fmt='.', color='k', label="Observed") plt.scatter(METRES_TO_UM * transit_wavelengths, calculated_transit_depths, color='r', label="Calculated (binned)") plt.xlabel("Wavelength ($\mu m$)") plt.ylabel("Transit depth") plt.xscale('log') plt.tight_layout() if measured_eclipse_depths is not None: t_p_profile = Profile() t_p_profile.set_from_params_dict(params_dict["profile_type"], params_dict) if np.any(np.isnan(t_p_profile.temperatures)): raise AtmosphereError("Invalid T/P profile") eclipse_wavelengths, calculated_eclipse_depths, info_dict = eclipse_calc.compute_depths( t_p_profile, Rs, Mp, Rp, T_star, logZ, CO_ratio, scattering_factor=scatt_factor, scattering_slope=scatt_slope, cloudtop_pressure=cloudtop_P, T_spot=T_spot, spot_cov_frac=spot_cov_frac, frac_scale_height=frac_scale_height, number_density=number_density, part_size = part_size, ri = ri, full_output=True) residuals = calculated_eclipse_depths - measured_eclipse_depths scaled_errors = error_multiple * measured_eclipse_errors ln_likelihood += -0.5 * np.sum(residuals**2 / scaled_errors**2 + np.log(2 * np.pi * scaled_errors**2)) if plot: plt.figure(2) plt.plot(METRES_TO_UM * info_dict["unbinned_wavelengths"], info_dict["unbinned_eclipse_depths"], color='b', label="Calculated (unbinned)") plt.errorbar(METRES_TO_UM * eclipse_wavelengths, measured_eclipse_depths, yerr = measured_eclipse_errors, fmt='.', color='k', label="Observed") plt.scatter(METRES_TO_UM * eclipse_wavelengths, calculated_eclipse_depths, color='r', label="Calculated (binned)") plt.legend() plt.xlabel("Wavelength ($\mu m$)") plt.ylabel("Eclipse depth") plt.xscale('log') plt.tight_layout() except AtmosphereError as e: print(e) return -np.inf lnprob = fit_info._ln_prior(params) + ln_likelihood self.last_params = params self.last_lnprob = lnprob return lnprob
[docs] def run_emcee(self, transit_bins, transit_depths, transit_errors, eclipse_bins, eclipse_depths, eclipse_errors, fit_info, nwalkers=50, nsteps=1000, include_condensation=True, plot_best=False): '''Runs affine-invariant MCMC to retrieve atmospheric parameters. Parameters ---------- transit_bins : array_like, shape (N,2) Wavelength bins, where wavelength_bins[i][0] is the start wavelength and wavelength_bins[i][1] is the end wavelength for bin i. transit_depths : array_like, length N Measured transit depths for the specified wavelength bins transit_errors : array_like, length N Errors on the aforementioned transit depths eclipse_bins : array_like, shape (N,2) Wavelength bins, where wavelength_bins[i][0] is the start wavelength and wavelength_bins[i][1] is the end wavelength for bin i. eclipse_depths : array_like, length N Measured eclipse depths for the specified wavelength bins eclipse_errors : array_like, length N Errors on the aforementioned eclipse depths fit_info : :class:`.FitInfo` object Tells the method what parameters to freely vary, and in what range those parameters can vary. Also sets default values for the fixed parameters. nwalkers : int, optional Number of walkers to use nsteps : int, optional Number of steps that the walkers should walk for include_condensation : bool, optional When determining atmospheric abundances, whether to include condensation. plot_best : bool, optional If True, plots the best fit model with the data Returns ------- result : EnsembleSampler object This returns emcee's EnsembleSampler object. The most useful attributes in this item are result.chain, which is a (W x S X P) array where W is the number of walkers, S is the number of steps, and P is the number of parameters; and result.lnprobability, a (W x S) array of log probabilities. For your convenience, this object also contains result.flatchain, which is a (WS x P) array where WS = W x S is the number of samples; and result.flatlnprobability, an array of length WS ''' initial_positions = fit_info._generate_rand_param_arrays(nwalkers) transit_calc = TransitDepthCalculator( include_condensation=include_condensation) transit_calc.change_wavelength_bins(transit_bins) eclipse_calc = EclipseDepthCalculator() eclipse_calc.change_wavelength_bins(eclipse_bins) self._validate_params(fit_info, transit_calc) sampler = emcee.EnsembleSampler( nwalkers, fit_info._get_num_fit_params(), self._ln_prob, args=(transit_calc, eclipse_calc, fit_info, transit_depths, transit_errors, eclipse_depths, eclipse_errors)) for i, result in enumerate(sampler.sample( initial_positions, iterations=nsteps)): if (i + 1) % 10 == 0: print("Step {}: {}".format(i + 1, self.pretty_print(fit_info))) best_params_arr = sampler.flatchain[np.argmax( sampler.flatlnprobability)] write_param_estimates_file( sampler.flatchain, best_params_arr, np.max(sampler.flatlnprobability), fit_info.fit_param_names) if plot_best: self._ln_prob(best_params_arr, transit_calc, eclipse_calc, fit_info, transit_depths, transit_errors, eclipse_depths, eclipse_errors, plot=True) return sampler
[docs] def run_multinest(self, transit_bins, transit_depths, transit_errors, eclipse_bins, eclipse_depths, eclipse_errors, fit_info, include_condensation=True, plot_best=False, **nestle_kwargs): '''Runs nested sampling to retrieve atmospheric parameters. Parameters ---------- transit_bins : array_like, shape (N,2) Wavelength bins, where wavelength_bins[i][0] is the start wavelength and wavelength_bins[i][1] is the end wavelength for bin i. transit_depths : array_like, length N Measured transit depths for the specified wavelength bins transit_errors : array_like, length N Errors on the aforementioned transit depths eclipse_bins : array_like, shape (N,2) Wavelength bins, where wavelength_bins[i][0] is the start wavelength and wavelength_bins[i][1] is the end wavelength for bin i. eclipse_depths : array_like, length N Measured eclipse depths for the specified wavelength bins eclipse_errors : array_like, length N Errors on the aforementioned eclipse depths fit_info : :class:`.FitInfo` object Tells us what parameters to freely vary, and in what range those parameters can vary. Also sets default values for the fixed parameters. include_condensation : bool, optional When determining atmospheric abundances, whether to include condensation. plot_best : bool, optional If True, plots the best fit model with the data **nestle_kwargs : keyword arguments to pass to nestle's sample method Returns ------- result : Result object This returns the object returned by nestle.sample The object is dictionary-like and has many useful items. For example, result.samples (or alternatively, result["samples"]) are the parameter values of each sample, result.weights contains the weights, and result.logl contains the log likelihoods. result.logz is the natural logarithm of the evidence. ''' transit_calc = TransitDepthCalculator( include_condensation=include_condensation) transit_calc.change_wavelength_bins(transit_bins) eclipse_calc = EclipseDepthCalculator() eclipse_calc.change_wavelength_bins(eclipse_bins) self._validate_params(fit_info, transit_calc) def transform_prior(cube): new_cube = np.zeros(len(cube)) for i in range(len(cube)): new_cube[i] = fit_info._from_unit_interval(i, cube[i]) return new_cube def multinest_ln_prob(cube): return self._ln_prob(cube, transit_calc, eclipse_calc, fit_info, transit_depths, transit_errors, eclipse_depths, eclipse_errors) def callback(callback_info): print("Iteration {}: {}".format( callback_info["it"], self.pretty_print(fit_info))) result = nestle.sample( multinest_ln_prob, transform_prior, fit_info._get_num_fit_params(), callback=callback, method='multi', **nestle_kwargs) best_params_arr = result.samples[np.argmax(result.logl)] write_param_estimates_file( nestle.resample_equal(result.samples, result.weights), best_params_arr, np.max(result.logl), fit_info.fit_param_names) if plot_best: self._ln_prob(best_params_arr, transit_calc, eclipse_calc, fit_info, transit_depths, transit_errors, eclipse_depths, eclipse_errors, plot=True) return result
[docs] @staticmethod def get_default_fit_info(Rs, Mp, Rp, T=None, logZ=0, CO_ratio=0.53, log_cloudtop_P=np.inf, log_scatt_factor=0, scatt_slope=4, error_multiple=1, T_star=None, T_spot=None, spot_cov_frac=None,frac_scale_height=1, log_number_density=-np.inf, log_part_size =-6, ri = None, profile_type = 'isothermal', **profile_kwargs): '''Get a :class:`.FitInfo` object filled with best guess values. A few parameters are required, but others can be set to default values if you do not want to specify them. All parameters are in SI. Parameters ---------- Rs : float Stellar radius Mp : float Planetary mass Rp : float Planetary radius T : float Temperature of the isothermal planetary atmosphere logZ : float Base-10 logarithm of the metallicity, in solar units CO_ratio : float, optional C/O atomic ratio in the atmosphere. The solar value is 0.53. log_cloudtop_P : float, optional Base-10 log of the pressure level (in Pa) below which light cannot penetrate. Use np.inf for a cloudless atmosphere. log_scatt_factor : float, optional Base-10 logarithm of scattering factoring, which make scattering that many times as strong. If `scatt_slope` is 4, corresponding to Rayleigh scattering, the absorption coefficients are simply multiplied by `scattering_factor`. If slope is not 4, `scattering_factor` is defined such that the absorption coefficient is that many times as strong as Rayleigh scattering at the reference wavelength of 1 um. scatt_slope : float, optional Wavelength dependence of scattering, with 4 being Rayleigh. error_multiple : float, optional All error bars are multiplied by this factor. T_star : float, optional Effective temperature of the star. This is used to make wavelength binning of transit depths more accurate. T_spot : float, optional Effective temperature of the star spots. This is used to make wavelength dependent correction to the observed transit depths. spot_cov_frac : float, optional The spot covering fraction of the star by area. This is used to make wavelength dependent correction to the transit depths. Returns ------- fit_info : :class:`.FitInfo` object This object is used to indicate which parameters to fit for, which to fix, and what values all parameters should take.''' all_variables = locals().copy() del all_variables["profile_kwargs"] all_variables.update(profile_kwargs) fit_info = FitInfo(all_variables) return fit_info