Created
January 4, 2022 03:55
-
-
Save twiecki/86b02349c60385eb6d77793d37bd96a9 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class ModelBuilder: | |
@classmethod | |
def load_default_config(cls): | |
# return dict of parameters | |
... | |
def __init__(self, data: pd.DataFrame, trait: str, config: Dict): | |
""" | |
Initialize the model builder. | |
Parameters | |
---------- | |
data: pd.DataFrame | |
Dataframe containing the raw, uncleaned field data. | |
config: Dict | |
Dictionary of configuration for the models' priors and variables. | |
See :func:`ModelBuilder.load_default_config` for an example. | |
""" | |
self.config = copy.deepcopy(config) | |
self.data, self.coords = self._clean_data(data) | |
self.model_type = None # Attribute for the type of bayesian model | |
self.model = None # Attribute for the pymc3 model | |
self.idata = None # Attribute for the az.InferenceData result | |
self.run_id = -1 # Attribute for the bayesian run id | |
def _clean_data( | |
self, | |
data: pd.DataFrame | |
) -> Tuple[pd.DataFrame, Dict[str, Union[pd.CategoricalIndex, pd.Series]]]: | |
""" | |
Clean the data passed to the model. | |
Parameters | |
---------- | |
data: pd.DataFrame | |
Dataframe containing the raw, uncleaned field data. | |
Returns | |
---------- | |
Tuple[pd.DataFrame, Dict[str, Union[pd.CategoricalIndex, pd.Series]]] | |
The clean field data, as well as the dictionary of coordinates for the model. | |
Notes | |
----- | |
Filter out zero counts. | |
Define the model coords. | |
""" | |
data = data.copy() | |
coords = {} | |
return data, coords | |
def build(self) -> pm.Model: | |
""" | |
Build the single field model for the given data, trait and config. | |
""" | |
return self.model | |
def sample(self, *, model: pm.Model = None, **kwargs) -> az.InferenceData: | |
"""Sample the model and return the trace. | |
Parameters | |
---------- | |
model : optional | |
A model previously created using `self.build()`. Build | |
a new model if None. | |
**kwargs : dict | |
Additional arguments to `pm.sample` | |
""" | |
if model is None and self.model is None: | |
model = self.build() | |
elif model is None: | |
model = self.model | |
with model: | |
trace = pm.sample(return_inferencedata=False, **self.config["sampler"], **kwargs) | |
ppc = pm.sample_posterior_predictive(trace) | |
prior = pm.sample_prior_predictive() | |
idata = az.from_pymc3( | |
trace=trace, | |
prior=prior, | |
posterior_predictive=ppc, | |
model=model, | |
) | |
self.idata = idata | |
return idata |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment