Saturday, November 6, 2010

User-defined Model and Fit Statistic

An advanced demo for ADASS 2010 illustrating a script that incorporates a user-defined model and a user-defined fit statistic with a prior into an X-ray spectral fit.

#!/usr/bin/env python
import sherpa.astro.ui as sherpa
from sherpa.models import Parameter, ArithmeticModel
from sherpa.stats import truncation_value
import numpy

class MyPowerLaw(ArithmeticModel):
"""
Class to characterize a power-law function

f(x) = ampl (x/ref) ^ (-gamma)

"""
def __init__(self, name='mypowerlaw'):

self.gamma = Parameter(name, "gamma", 1.0, min=-10, max=10)
self.ref = Parameter(name, "ref", 1.0, alwaysfrozen=True)
self.ampl = Parameter(name, "ampl", 1.0, min=0)

ArithmeticModel.__init__(self, name,
(self.gamma, self.ref, self.ampl))

def calc(self, p, x, xhi=None, *args, **kwargs):
"""
Params
`p` list of ordered parameter values.
`x` ndarray of domain values, bin midpoints or lower
bin edges.
`xhi` ndarray of upper bin edges.

returns ndarray of calculated function values.
"""

if xhi is None:
return self.point(p, x)
return self.integrated(p, x, xhi)


@staticmethod
def point(p, x):
"""
/ x_i \(-p[0])
point version, f_i(x_i) = p[2] |------|
\ p[1] /

Params
`p` list of ordered parameter values.
`x` ndarray of bin midpoints.

returns ndarray of calculated function values at
bin midpoints
"""

if xhi is None:
return p[2]*numpy.power(x/p[1], -p[0])


@staticmethod
def integrated(p, xlo, xhi):
"""

integrated form from lower bin edge to upper edge

⌠ xhi_i p[2] / /xhi \ (-p[0]) /xlo \ (-p[0]) \
| f_i(x_i) dx = ----- * |xhi |----| - xlo |----| |
⌡ xlo_i 1-p[0] \ \p[1]/ \p[1]/ /

Params
`p` list of ordered parameter values.
`xlo` ndarray of lower bin edges.
`xhi` ndarray of upper bin edges.

returns ndarray of integrated function values over
lower and upper bin edges.
"""
if p[0] == 1.0:
return p[2] * p[1] * (numpy.log(xhi) - numpy.log(xlo))

p1 = numpy.power(xlo, 1.0-p[0])/(1.0-p[0])
p2 = numpy.power(xhi, 1.0-p[0])/(1.0-p[0])
return p[2] / numpy.power(p[1], -p[0]) * (p2 - p1)

# <demo> stop

def logNormal(x, mu, sigma):
"""
Compute natural log of Normal distribution PDF analytically

Params
`x` x
`mu` mean
`sigma` standard deviation

returns ln(f_X(x; mu, sigma))
"""
sigma_sqr = sigma*sigma
return (-numpy.log(numpy.sqrt(2*numpy.pi*sigma_sqr)) -
numpy.square(x-mu)/2/sigma_sqr)

# <demo> stop

# stat interface instancemethod
def calc_stat(self, data, model, staterror=None,
syserror=None, weight=None):
"""
Calculate the fit statistic value and an array of statistic
value contributions per bin.

Params
`data` ndarray of observed data points
`model` ndarray of predicted data points
`staterror` ndarray of statistical error on observed data points
`syserror` ndarray of systematic error on observed data points
`weight` ndarray of statistic weights

returns a tuple of the fit statistic value and array of
stat contributions per bin.
"""
gamma = self.gamma.val # nH (abs1.nh)
beta = self.beta.val # powerlaw index (p1.gamma)
xi = numpy.log(self.alpha.val) # log powerlaw norm (p1.ampl)

phigamma = logNormal(gamma, self.mugamma.val, self.siggamma.val)
phibeta = logNormal(beta, self.mubeta.val, self.sigbeta.val)
phixi = logNormal(xi, self.muxi.val, self.sigxi.val)

# truncation_value used as an approximation for
# non-positive values in model
if truncation_value > 0:
model[model<=0.0] = truncation_value

# Poisson log-likelihood summation
likelihood = sum(-model + data*numpy.log(model))

# Add the prior to log-likelihood
prior = (phigamma+phibeta+phixi)
stat = likelihood + prior

# reverse sign to minimize log-likelihood!
return (-stat, numpy.ones_like(data))


# <demo> stop

sherpa.load_pha("3c273.pi")
sherpa.notice(0, 6.0)

sherpa.add_model(MyPowerLaw)
sherpa.set_source(sherpa.xsphabs.abs1 * mypowerlaw.p1)

sherpa.load_user_stat("stat", calc_stat,
priors=dict(mugamma=0.017, # hyperparameters
mubeta=1.26,
muxi=-8.5,
siggamma=1,
sigbeta=1,
sigxi=1,
gamma=abs1.nh, # fit parameters
beta=p1.gamma,
alpha=p1.ampl))

print stat

sherpa.set_stat(stat)
sherpa.set_method("neldermead")

sherpa.fit()

sherpa.plot_fit_delchi()

# <demo> stop

No comments:

Post a Comment