Source code for smartclass.chem.classification.search_classes

"""Substructure search for chemical classes."""

from __future__ import annotations

import csv
import json
import logging

from rdkit.Chem import SubstructMatchParameters

from smartclass.chem.classification.bfs_search_classes_generator import (
    tqdm_bfs_search_classes_generator,
)
from smartclass.chem.conversion.convert_smiles_to_mol import convert_smiles_to_mol
from smartclass.helpers import convert_list_of_dict
from smartclass.io import (
    load_external_classes_file,
    load_pkg_chemical_hierarchy,
    load_pkg_classes,
    load_smiles,
)
from smartclass.resources.chembl import load_latest_chembl

__all__ = [
    "search_classes",
]


[docs] def search_classes( classes_file: None | str = None, classes_name_id: None | str = None, classes_name_smarts: None | str = None, closest_only: bool = True, include_hierarchy: bool = False, input_smiles: None | str = None, smiles: None | (str | list[str]) = None, ) -> list[dict]: """ Substructure search for chemical classes. :param classes_file: File providing the chemical classes. :type classes_file: Union[None,str] :param classes_name_id: Name of the ID column in the classes file. :type classes_name_id: Union[None,str] :param classes_name_smarts: Name of the SMARTS column in the classes file. :type classes_name_smarts: Union[None,str] :param closest_only: Flag to return only the closest class. :type closest_only: bool :param include_hierarchy: Flag to include hierarchy search (default is False). :type include_hierarchy: bool :param input_smiles: File providing the (list of) structure(s) to classify. :type input_smiles: Union[None,str] :param smiles: (List of) structure(s) to classify. :type smiles: Union[None,str,list[str]] :returns: A list of matched classes. :rtype: list[dict] """ # Load structures s: set[str] = set() if input_smiles: s.update(load_smiles(input=input_smiles)) if smiles: s.update(smiles) # TODO redundant with get_latest_chembl structures: list = list() for smi in s: if smi != "smiles": mol = convert_smiles_to_mol(smi) # TODO looks important # Kekulize(mol) if mol is not None: structures.append(mol) # TODO change this if not structures: logging.basicConfig(level=logging.INFO) logging.info("No structures given, loading ChEMBL library instead.") structures = load_latest_chembl() # Load classes if classes_file: c = load_external_classes_file( file=classes_file, id_name=classes_name_id, smarts_name=classes_name_smarts ) else: logging.basicConfig(level=logging.INFO) logging.info("No classes given, loading default package classes instead.") c = load_pkg_classes() classes = [] classes_dict: dict[str, list[str]] = {} for row in c.iter_rows(): key = row[0] value = row[1] if key in classes_dict: classes_dict[key].append(value) else: classes_dict[key] = [value] classes.append(classes_dict) # Load class hierarchy class_hierarchy: dict[str, list[str]] = {} if include_hierarchy: class_hierarchy = load_pkg_chemical_hierarchy() # Use generic matches params = SubstructMatchParameters() params.useGenericMatchers = True logging.basicConfig(level=logging.INFO) logging.info(f"Classifying {len(structures)} structures...") logging.basicConfig(level=logging.INFO) logging.info(f"...against {len(classes[0])} chemical classes...") results = list( tqdm_bfs_search_classes_generator( classes=classes, class_hierarchy=class_hierarchy, structures=structures, params=params, ) ) # results = list( # dfs_search_classes_generator( # classes=classes, # class_hierarchy=class_hierarchy, # structures=structures, # params=params, # ) # ) # Filter the results to keep only the result with the closest class for each unique InChIKey if closest_only: max_ab_per_inchikey: dict = {} for result in results: inchikey = result["structure_inchikey"] matched_ab = result["matched_ab"] if ( inchikey not in max_ab_per_inchikey or matched_ab > max_ab_per_inchikey[inchikey] ): max_ab_per_inchikey[inchikey] = matched_ab results = [ result for result in results if result["matched_ab"] == max_ab_per_inchikey[result["structure_inchikey"]] ] # Export key = "class_id" value = "structure_inchikey" fields = [ "class_id", "class_structure", "structure_inchikey", "structure_smarts", "structure_ab", "matched_ab", ] results_kv = convert_list_of_dict(results, key, value) # Export results to JSON as key_value with open("scratch/results_kv.json", "w") as file: json.dump(results_kv, file, indent=4) # Export results to CSV results_sorted = sorted(results, key=lambda x: x["matched_ab"], reverse=True) with open("scratch/results_kv.tsv", "w", newline="") as file: writer = csv.DictWriter( file, fieldnames=fields, delimiter="\t", ) writer.writeheader() writer.writerows(results) results_sorted = sorted(results_sorted, key=lambda x: x["structure_inchikey"]) with open("scratch/results_vk.tsv", "w", newline="") as file: writer = csv.DictWriter( file, fieldnames=fields, delimiter="\t", ) writer.writeheader() writer.writerows(results_sorted) # Revert dict results_vk = convert_list_of_dict(results, key, value, invert=True) # Export results to JSON as value_key with open("scratch/results_vk.json", "w") as file: json.dump(results_vk, file, indent=4) logging.basicConfig(level=logging.INFO) logging.info("Done") return results
if __name__ == "__main__": search_classes()