Source code for smartclass.chem.classification.search_classes

"""Substructure search for chemical classes.

This module provides the main classification functionality for smartclass,
matching chemical structures against SMARTS-based chemical class definitions.
"""

from __future__ import annotations

from pathlib import Path
from typing import TYPE_CHECKING

from rdkit.Chem import SubstructMatchParameters

from smartclass.chem.classification.bfs_search_classes_generator import (
    tqdm_bfs_search_classes_generator,
)
from smartclass.chem.conversion.convert_smiles_to_mol import convert_smiles_to_mol
from smartclass.config import get_config
from smartclass.io import (
    export_results,
    load_external_classes_file,
    load_pkg_chemical_hierarchy,
    load_pkg_classes,
    load_smiles,
)
from smartclass.logging import get_logger
from smartclass.resources.chembl import load_latest_chembl


if TYPE_CHECKING:
    from rdkit.Chem import Mol

__all__ = [
    "search_classes",
]

logger = get_logger(__name__)

# Result field names for export
RESULT_FIELDS = [
    "class_id",
    "class_structure",
    "structure_inchikey",
    "structure_smarts",
    "structure_ab",
    "matched_ab",
]


def _parse_smiles_to_mols(smiles_set: set[str]) -> list[Mol]:
    """Convert a set of SMILES strings to RDKit Mol objects.

    :param smiles_set: Set of SMILES strings to convert.
    :returns: List of valid Mol objects (invalid SMILES are filtered out).
    """
    structures = []
    for smi in smiles_set:
        # Skip header rows that might be in the data
        if smi.lower() == "smiles":
            continue
        mol = convert_smiles_to_mol(smi)
        if mol is not None:
            structures.append(mol)
        else:
            logger.debug(f"Failed to parse SMILES: {smi}")
    return structures


def _load_classes(
    classes_file: str | None,
    classes_name_id: str | None,
    classes_name_smarts: str | None,
) -> list[dict[str, list[str]]]:
    """Load chemical classes from file or package defaults.

    :param classes_file: Optional path to external classes file.
    :param classes_name_id: Name of ID column in external file.
    :param classes_name_smarts: Name of SMARTS column in external file.
    :returns: List containing a dictionary mapping class IDs to SMARTS patterns.
    """
    if classes_file:
        id_col = classes_name_id or "class"
        smarts_col = classes_name_smarts or "structure"
        c = load_external_classes_file(
            file=classes_file,
            id_name=id_col,
            smarts_name=smarts_col,
        )
    else:
        logger.info("No classes file provided, loading default package classes.")
        c = load_pkg_classes()

    # Build dictionary mapping class_id to list of SMARTS patterns
    classes_dict: dict[str, list[str]] = {}
    for row in c.iter_rows():
        key = row[0]
        value = row[1]
        if key in classes_dict:
            classes_dict[key].append(value)
        else:
            classes_dict[key] = [value]

    return [classes_dict]


def _filter_closest_matches(results: list[dict]) -> list[dict]:
    """Filter results to keep only the closest class match for each structure.

    :param results: List of classification results.
    :returns: Filtered list with only closest matches per InChIKey.
    """
    max_ab_per_inchikey: dict[str, int] = {}
    for result in results:
        inchikey = result["structure_inchikey"]
        matched_ab = result["matched_ab"]
        if (
            inchikey not in max_ab_per_inchikey
            or matched_ab > max_ab_per_inchikey[inchikey]
        ):
            max_ab_per_inchikey[inchikey] = matched_ab

    return [
        result
        for result in results
        if result["matched_ab"] == max_ab_per_inchikey[result["structure_inchikey"]]
    ]


def _export_classification_results(
    results: list[dict],
    output_dir: Path | None = None,
) -> None:
    """Export classification results to multiple formats.

    :param results: List of classification results to export.
    :param output_dir: Directory for output files. Uses config default if None.
    """
    if not results:
        logger.warning("No results to export.")
        return

    config = get_config()
    output_path = output_dir or config.output_dir

    # Export sorted by matched_ab (descending)
    results_by_match = sorted(results, key=lambda x: x["matched_ab"], reverse=True)
    export_results(
        output=str(output_path / "results_by_class.tsv"),
        results=results_by_match,
    )

    # Export sorted by structure
    results_by_structure = sorted(
        results_by_match,
        key=lambda x: x["structure_inchikey"],
    )
    export_results(
        output=str(output_path / "results_by_structure.tsv"),
        results=results_by_structure,
    )

    logger.info(f"Results exported to {output_path}")


[docs] def search_classes( classes_file: str | None = None, classes_name_id: str | None = None, classes_name_smarts: str | None = None, closest_only: bool = True, include_hierarchy: bool = False, input_smiles: str | None = None, smiles: str | list[str] | None = None, export: bool = True, output_dir: Path | str | None = None, ) -> list[dict]: """ Perform substructure search to classify chemical structures. This function matches input structures against a set of chemical class definitions using SMARTS patterns. Results include the class ID, matching SMARTS pattern, and structural similarity metrics. :param classes_file: Path to TSV file with chemical class definitions. If None, uses the default package classes. :param classes_name_id: Column name for class IDs in the classes file. Defaults to "class". :param classes_name_smarts: Column name for SMARTS in the classes file. Defaults to "structure". :param closest_only: If True, return only the closest matching class for each structure. Default is True. :param include_hierarchy: If True, use chemical hierarchy for faster searching. Default is False. :param input_smiles: Path to file containing SMILES strings to classify. :param smiles: Single SMILES string or list of SMILES to classify. :param export: If True, export results to files. Default is True. :param output_dir: Directory for output files. Uses config default if None. :returns: List of dictionaries with classification results. :raises ValueError: If no structures are provided and ChEMBL fallback fails. """ # Set column name defaults if not classes_name_id: classes_name_id = "class" if not classes_name_smarts: classes_name_smarts = "structure" # Collect SMILES from all sources smiles_set: set[str] = set() if input_smiles: smiles_set.update(load_smiles(input=input_smiles)) if smiles: if isinstance(smiles, str): smiles_set.add(smiles) else: smiles_set.update(smiles) # Convert SMILES to Mol objects structures = _parse_smiles_to_mols(smiles_set) # Fallback to ChEMBL if no structures provided if not structures: logger.info("No structures provided, loading ChEMBL library as fallback.") structures = load_latest_chembl() if not structures: raise ValueError("No structures available for classification.") # Load chemical classes classes = _load_classes(classes_file, classes_name_id, classes_name_smarts) # Load class hierarchy if requested class_hierarchy: dict[str, list[str]] = {} if include_hierarchy: class_hierarchy = load_pkg_chemical_hierarchy() # Configure substructure matching parameters params = SubstructMatchParameters() params.useGenericMatchers = True logger.info( f"Classifying {len(structures)} structures " f"against {len(classes[0])} chemical classes...", ) # Perform classification results = list( tqdm_bfs_search_classes_generator( classes=classes, class_hierarchy=class_hierarchy, structures=structures, params=params, ), ) # Filter to closest matches if requested if closest_only and results: results = _filter_closest_matches(results) # Export results if export and results: output_path = Path(output_dir) if output_dir else None _export_classification_results(results, output_path) logger.info(f"Classification complete. Found {len(results)} matches.") return results
if __name__ == "__main__": search_classes()