CrabNetModel
- class mastml.models.CrabNetModel(composition_column, epochs, val_frac, drop_unary=False, savepath=None)[source]
Bases:
BaseEstimator,TransformerMixin- Implementation of PyTorch-based CrabNet regressor model based on the following work:
Wang, A., Kauwe, S., Murdock, R., Sparks, T. “Compositionally restricted attention-based network for “ materials property predictions”, npj Computational Materials (2021) (https://www.nature.com/articles/s41524-021-00545-1)
The code to run CrabNet was integrated into MAST-ML based on source files available in this repository: https://github.com/anthony-wang/CrabNet
- Args:
composition_column (str): string denoting the name of an input column containing materials compositions
epochs (int): number of training epochs
- val_frac (float): fraction of input data to use as validation (0 to 1). Note this can be set to 0 for most MAST-ML
runs as the splitting is done outside of this model.
- drop_unary (bool): whether or not to drop compositions containing one element. Should probably be False always. If true,
can cause some array length mismatches.
savepath (str): path to save the data to. This is set automatically in the data_splitters routine.
- Methods:
get_model: method that prepares the CrabNet model. Called when the CrabNetModel class is instantiated.
- fit: method that fits the model parameters to the provided training data
- Args:
- X: (pd.DataFrame), dataframe of X data, needs to contain at least a column of material compositions. Note
that featurization is done internally by the model.
y: (pd.Series), series of y target data
- Returns:
fitted model
- predict: method that evaluates model on new data to give predictions
- Args:
- X: (pd.DataFrame), dataframe of X data, needs to contain at least a column of material compositions. Note
that featurization is done internally by the model.
as_frame: (bool), whether to return data as pandas dataframe (else numpy array)
- Returns:
series or array of predicted values
Methods Summary
fit(X, y)predict(X[, as_frame])set_predict_request(*[, as_frame])Configure whether metadata should be requested to be passed to the
predictmethod.Methods Documentation
- set_predict_request(*, as_frame: bool | None | str = '$UNCHANGED$') CrabNetModel
Configure whether metadata should be requested to be passed to the
predictmethod.Note that this method is only relevant when this estimator is used as a sub-estimator within a meta-estimator and metadata routing is enabled with
enable_metadata_routing=True(seesklearn.set_config()). Please check the User Guide on how the routing mechanism works.The options for each parameter are:
True: metadata is requested, and passed topredictif provided. The request is ignored if metadata is not provided.False: metadata is not requested and the meta-estimator will not pass it topredict.None: metadata is not requested, and the meta-estimator will raise an error if the user provides it.str: metadata should be passed to the meta-estimator with this given alias instead of the original name.
The default (
sklearn.utils.metadata_routing.UNCHANGED) retains the existing request. This allows you to change the request for some parameters and not others.Added in version 1.3.
Parameters
- as_framestr, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED
Metadata routing for
as_frameparameter inpredict.
Returns
- selfobject
The updated object.