Code Documentation: Models

mastml.models Module

Module for constructing models for use in MAST-ML.

SklearnModel:

Class that wraps scikit-learn models to have MAST-ML type functionality. Providing the model name as a string and the keyword arguments for the model parameters will construct the model. Note that this class also supports construction of XGBoost models and Keras neural network models via Keras’ keras.wrappers.scikit_learn.KerasRegressor model.

EnsembleModel:

Class that constructs a model which is an ensemble of many base models (sometimes called weak learners). This class supports construction of ensembles of most scikit-learn regression models as well as ensembles of neural networks that are made via Keras’ keras.wrappers.scikit_learn.KerasRegressor class.

CrabNetModel:

Class that provides an implementation of PyTorch-based CrabNet regressor model based on the following work: Wang, A., Kauwe, S., Murdock, R., Sparks, T. “Compositionally restricted attention-based network for “ materials property predictions”, npj Computational Materials (2021) (https://www.nature.com/articles/s41524-021-00545-1)

Classes

BaggingRegressor([estimator, n_estimators, ...])

A Bagging regressor.

BaseEstimator()

Base class for all estimators in scikit-learn.

CrabNetModel(composition_column, epochs, ...)

Implementation of PyTorch-based CrabNet regressor model based on the following work:

EnsembleModel(model, n_estimators, **kwargs)

Class used to construct ensemble models with a particular number and type of weak learner (base model).

GaussianProcessRegressor([kernel, alpha, ...])

Gaussian process regression (GPR).

HostedModel(container_name)

KAN([width, grid, k, noise_scale, ...])

KAN class

KANLayer([in_dim, out_dim, num, k, ...])

KANLayer class

KANModel(width[, grid, k, steps, seed, ...])

Implementation of Kolmogorov-Arnold Networks (KANs) from the following work:

LBFGS(params[, lr, max_iter, max_eval, ...])

Implements L-BFGS algorithm.

LinearRegression(*[, fit_intercept, copy_X, ...])

Ordinary least squares Linear Regression.

SklearnModel(model, **kwargs)

Class to wrap any sklearn estimator, and provide some new dataframe functionality

SourceNN(source_arch, nn_params[, val_size, ...])

Symbolic_KANLayer([in_dim, out_dim, device])

KANLayer class

Transfer(prefit_path[, nn_params, val_size, ...])

TransformerMixin()

Mixin class for all transformers in scikit-learn.

model_wrapper(model)

Wrapper for pytorch model to include predict method

tqdm(*_, **__)

Decorate an iterable object, returning an iterator which acts exactly like the original iterable, but prints a dynamically updating progressbar every time a value is requested.

Class Inheritance Diagram

digraph inheritance6937fb74ba { bgcolor=transparent; rankdir=LR; size="8.0, 12.0"; "BaseEstimator" [fillcolor=white,fontname="Vera Sans, DejaVu Sans, Liberation Sans, Arial, Helvetica, sans",fontsize=10,height=0.25,shape=box,style="setlinewidth(0.5),filled",tooltip="Base class for all estimators in scikit-learn."]; "ReprHTMLMixin" -> "BaseEstimator" [arrowsize=0.5,style="setlinewidth(0.5)"]; "_HTMLDocumentationLinkMixin" -> "BaseEstimator" [arrowsize=0.5,style="setlinewidth(0.5)"]; "_MetadataRequester" -> "BaseEstimator" [arrowsize=0.5,style="setlinewidth(0.5)"]; "CrabNetModel" [URL="api/mastml.models.CrabNetModel.html#mastml.models.CrabNetModel",fillcolor=white,fontname="Vera Sans, DejaVu Sans, Liberation Sans, Arial, Helvetica, sans",fontsize=10,height=0.25,shape=box,style="setlinewidth(0.5),filled",target="_top",tooltip="Implementation of PyTorch-based CrabNet regressor model based on the following work:"]; "BaseEstimator" -> "CrabNetModel" [arrowsize=0.5,style="setlinewidth(0.5)"]; "TransformerMixin" -> "CrabNetModel" [arrowsize=0.5,style="setlinewidth(0.5)"]; "EnsembleModel" [URL="api/mastml.models.EnsembleModel.html#mastml.models.EnsembleModel",fillcolor=white,fontname="Vera Sans, DejaVu Sans, Liberation Sans, Arial, Helvetica, sans",fontsize=10,height=0.25,shape=box,style="setlinewidth(0.5),filled",target="_top",tooltip="Class used to construct ensemble models with a particular number and type of weak learner (base model). The"]; "BaseEstimator" -> "EnsembleModel" [arrowsize=0.5,style="setlinewidth(0.5)"]; "TransformerMixin" -> "EnsembleModel" [arrowsize=0.5,style="setlinewidth(0.5)"]; "HostedModel" [URL="api/mastml.models.HostedModel.html#mastml.models.HostedModel",fillcolor=white,fontname="Vera Sans, DejaVu Sans, Liberation Sans, Arial, Helvetica, sans",fontsize=10,height=0.25,shape=box,style="setlinewidth(0.5),filled",target="_top"]; "KAN" [fillcolor=white,fontname="Vera Sans, DejaVu Sans, Liberation Sans, Arial, Helvetica, sans",fontsize=10,height=0.25,shape=box,style="setlinewidth(0.5),filled",tooltip="KAN class"]; "Module" -> "KAN" [arrowsize=0.5,style="setlinewidth(0.5)"]; "KANModel" [URL="api/mastml.models.KANModel.html#mastml.models.KANModel",fillcolor=white,fontname="Vera Sans, DejaVu Sans, Liberation Sans, Arial, Helvetica, sans",fontsize=10,height=0.25,shape=box,style="setlinewidth(0.5),filled",target="_top",tooltip="Implementation of Kolmogorov-Arnold Networks (KANs) from the following work:"]; "KAN" -> "KANModel" [arrowsize=0.5,style="setlinewidth(0.5)"]; "Module" [fillcolor=white,fontname="Vera Sans, DejaVu Sans, Liberation Sans, Arial, Helvetica, sans",fontsize=10,height=0.25,shape=box,style="setlinewidth(0.5),filled",tooltip="Base class for all neural network modules."]; "ReprHTMLMixin" [fillcolor=white,fontname="Vera Sans, DejaVu Sans, Liberation Sans, Arial, Helvetica, sans",fontsize=10,height=0.25,shape=box,style="setlinewidth(0.5),filled",tooltip="Mixin to handle consistently the HTML representation."]; "SklearnModel" [URL="api/mastml.models.SklearnModel.html#mastml.models.SklearnModel",fillcolor=white,fontname="Vera Sans, DejaVu Sans, Liberation Sans, Arial, Helvetica, sans",fontsize=10,height=0.25,shape=box,style="setlinewidth(0.5),filled",target="_top",tooltip="Class to wrap any sklearn estimator, and provide some new dataframe functionality"]; "BaseEstimator" -> "SklearnModel" [arrowsize=0.5,style="setlinewidth(0.5)"]; "TransformerMixin" -> "SklearnModel" [arrowsize=0.5,style="setlinewidth(0.5)"]; "SourceNN" [URL="api/mastml.models.SourceNN.html#mastml.models.SourceNN",fillcolor=white,fontname="Vera Sans, DejaVu Sans, Liberation Sans, Arial, Helvetica, sans",fontsize=10,height=0.25,shape=box,style="setlinewidth(0.5),filled",target="_top"]; "Transfer" [URL="api/mastml.models.Transfer.html#mastml.models.Transfer",fillcolor=white,fontname="Vera Sans, DejaVu Sans, Liberation Sans, Arial, Helvetica, sans",fontsize=10,height=0.25,shape=box,style="setlinewidth(0.5),filled",target="_top"]; "TransformerMixin" [fillcolor=white,fontname="Vera Sans, DejaVu Sans, Liberation Sans, Arial, Helvetica, sans",fontsize=10,height=0.25,shape=box,style="setlinewidth(0.5),filled",tooltip="Mixin class for all transformers in scikit-learn."]; "_SetOutputMixin" -> "TransformerMixin" [arrowsize=0.5,style="setlinewidth(0.5)"]; "_HTMLDocumentationLinkMixin" [fillcolor=white,fontname="Vera Sans, DejaVu Sans, Liberation Sans, Arial, Helvetica, sans",fontsize=10,height=0.25,shape=box,style="setlinewidth(0.5),filled",tooltip="Mixin class allowing to generate a link to the API documentation."]; "_MetadataRequester" [fillcolor=white,fontname="Vera Sans, DejaVu Sans, Liberation Sans, Arial, Helvetica, sans",fontsize=10,height=0.25,shape=box,style="setlinewidth(0.5),filled",tooltip="Mixin class for adding metadata request functionality."]; "_SetOutputMixin" [fillcolor=white,fontname="Vera Sans, DejaVu Sans, Liberation Sans, Arial, Helvetica, sans",fontsize=10,height=0.25,shape=box,style="setlinewidth(0.5),filled",tooltip="Mixin that dynamically wraps methods to return container based on config."]; "model_wrapper" [URL="api/mastml.models.model_wrapper.html#mastml.models.model_wrapper",fillcolor=white,fontname="Vera Sans, DejaVu Sans, Liberation Sans, Arial, Helvetica, sans",fontsize=10,height=0.25,shape=box,style="setlinewidth(0.5),filled",target="_top",tooltip="Wrapper for pytorch model to include predict method"]; }