cat2cat.cat2cat_ml ================== .. py:module:: cat2cat.cat2cat_ml Functions --------- .. autoapisummary:: cat2cat.cat2cat_ml.cat2cat_ml_run Module Contents --------------- .. py:function:: cat2cat_ml_run(mappings: cat2cat.dataclass.cat2cat_mappings, ml: cat2cat.dataclass.cat2cat_ml, **kwargs: Any) -> cat2cat_ml_run_results Run model diagnostics before using ML-based cat2cat weights. This helper evaluates baseline and model-based classification quality within each mapping group and aggregates summary statistics across groups. :param mappings: Mapping configuration created with ``cat2cat_mappings``. :param ml: ML configuration created with ``cat2cat_ml``. :param \*\*kwargs: Optional diagnostics settings: - ``test_prop`` (float): test split proportion in ``(0, 1)``. Default is ``0.2``. - ``split_seed`` (int): random seed for train/test split. Default is ``42``. - ``min_match`` (float): minimum fraction of records in ``ml.data`` whose category appears in the mapping table. Must be in ``[0, 1)``. Default is ``0.8``. :returns: object with per-group raw diagnostics and aggregated metrics such as mean accuracy, mean Brier score, mean P(true class), failure rates, and model-vs-baseline comparisons. :rtype: cat2cat_ml_run_results :raises TypeError: if ``mappings`` or ``ml`` has invalid type. :raises ValueError: if kwargs names/ranges are invalid or mapping coverage is below ``min_match``. .. rubric:: Examples >>> from sklearn.ensemble import RandomForestClassifier >>> from cat2cat import cat2cat_ml_run >>> from cat2cat.dataclass import cat2cat_mappings, cat2cat_ml >>> from cat2cat.datasets import load_trans, load_occup >>> trans = load_trans() >>> occup = load_occup() >>> data_2010 = occup.loc[occup.year == 2010, :].copy() >>> mappings = cat2cat_mappings(trans, "backward") >>> ml = cat2cat_ml( ... data=data_2010, ... cat_var="code", ... features=["salary", "age", "edu", "sex"], ... models=[RandomForestClassifier(n_estimators=50, random_state=1234)], ... ) >>> out = cat2cat_ml_run(mappings=mappings, ml=ml, test_prop=0.2) >>> hasattr(out, "mean_acc") True