import copy
from typing import List, Tuple, Union
from linchemin import settings
from linchemin.cgu.convert import converter
from linchemin.cgu.syngraph import (
BipartiteSynGraph,
MonopartiteMolSynGraph,
MonopartiteReacSynGraph,
)
from linchemin.cgu.syngraph_operations import merge_syngraph
from linchemin.cgu.translate import nx, translator
from linchemin.cheminfo.constructors import MoleculeConstructor
from linchemin.cheminfo.models import Molecule
from linchemin.utilities import console_logger
""" Module containing functions and classes to identify, extract and mine routes """
logger = console_logger(__name__)
class RouteMiner:
"""A class for extracting routes from a list of SynGraph objects."""
def __init__(
self,
route_list: List[
Union[MonopartiteReacSynGraph, BipartiteSynGraph, MonopartiteMolSynGraph]
],
root: Union[str, None] = settings.ROUTE_MINING.root,
):
"""To initialize a RouteMiner object"""
if all(isinstance(r, MonopartiteReacSynGraph) for r in route_list):
self.route_list = route_list
elif any(
isinstance(r, (BipartiteSynGraph, MonopartiteMolSynGraph))
for r in route_list
):
self.route_list = [
converter(route, "monopartite_reactions") for route in route_list
]
else:
logger.error("Only syngraph objects can be used.")
raise TypeError
if isinstance(root, str) or root is None:
self.root = root
else:
logger.error("The input target molecule should be in smiles (string) form.")
raise TypeError
def mine_routes(self) -> List[MonopartiteReacSynGraph]:
"""To extract a list of MonopartiteSynGraph routes from a tree."""
tree = merge_syngraph(self.route_list)
return TreeMiner(tree, self.root).mine_tree()
class TreeMiner:
def __init__(
self,
tree: Union[MonopartiteReacSynGraph, BipartiteSynGraph],
root: Union[str, None] = settings.ROUTE_MINING.root,
):
"""To initialize a TreeMiner object"""
self.tree = self.set_tree(tree)
self.root = self.set_root(root)
def set_root(self, root: Union[str, None]) -> Union[Molecule, None]:
"""To set the root attribute"""
if root is None:
if isinstance(self.tree, MonopartiteReacSynGraph):
extracted_roots = list(set(self.tree.get_molecule_roots()))
else:
extracted_roots = list(set(self.tree.get_roots()))
return extracted_roots[0]
else:
mol_root = MoleculeConstructor().build_from_molecule_string(root, "smiles")
tree_molecules = set(converter(self.tree, "monopartite_molecules").graph)
if mol_root in tree_molecules:
return mol_root
logger.error("The selected root does not appear in the tree")
raise KeyError
@staticmethod
def set_tree(
tree: Union[MonopartiteReacSynGraph, BipartiteSynGraph, MonopartiteMolSynGraph]
):
"""To set the tree attribute."""
if isinstance(
tree, (MonopartiteReacSynGraph, BipartiteSynGraph, MonopartiteMolSynGraph)
):
return tree
logger.error("Only syngraph objects can be used.")
raise TypeError
def mine_tree(self) -> List[MonopartiteReacSynGraph]:
"""To mine routes from a tree."""
tree_nx = translator("syngraph", self.tree, "networkx", "bipartite")
routes_nx = RouteFinder(tree_nx, self.root.smiles).find_routes()
return [
translator("networkx", route_nx, "syngraph", "monopartite_reactions")
for route_nx in routes_nx
]
class RouteFinder:
def __init__(
self,
nx_tree: nx.DiGraph,
root: str,
product_edge_label: str = settings.ROUTE_MINING.product_edge_label,
reactant_edge_label: str = settings.ROUTE_MINING.reactant_edge_label,
):
"""To initialize a new RouteFinder object."""
self.nx_tree = nx_tree
self.root = root
self.routes: List = []
self.product_edge = product_edge_label
self.reactant_edge = reactant_edge_label
def find_routes(self) -> List[nx.DiGraph]:
"""To find routes from a tree."""
initial_route = nx.DiGraph()
initial_seen = set()
stack = {}
self.traverse_route(initial_route, self.root, initial_seen, stack)
return self.routes
def traverse_route(
self,
route_graph: nx.DiGraph,
current_node,
seen: set,
stack: dict,
):
"""To extract routes from a Tree."""
# Add the current node with its attributes to the route_graph
route_graph.add_node(current_node, **self.nx_tree.nodes[current_node])
seen.add(current_node)
current_label = route_graph.nodes[current_node]["label"]
# handle node differently depending on its type
if current_label == settings.ROUTE_MINING.molecule_node_label:
self.handle_molecule_node(route_graph, current_node, seen, stack)
elif current_label == settings.ROUTE_MINING.chemicalequation_node_label:
self.handle_chemeq_node(route_graph, current_node, seen, stack)
def handle_molecule_node(
self, route_graph: nx.DiGraph, current_node: str, seen: set, stack: dict
):
"""To handle Molecule nodes during the tree traversal"""
# Find P edges connected to the current M node
p_neighbors = self.get_p_neighbors(current_node)
# If the current M node has more than one P edge, it is an OR node
if len(p_neighbors) > 1:
self.handle_or_node(route_graph, p_neighbors, seen, stack)
elif p_neighbors:
self.handle_and_node(route_graph, p_neighbors, seen, stack)
def handle_chemeq_node(
self, route_graph: nx.DiGraph, current_node: str, seen: set, stack: dict
) -> None:
"""To handle the ChemicalEquation nodes during the tree traversal."""
# Find R edges connected to the current CE node
r_neighbors = self.get_r_neighbors(current_node)
stack[current_node] = [u for u, v, d in r_neighbors if u not in seen]
while stack[current_node]:
n = stack[current_node].pop()
(u, v, edge_data) = next(tup for tup in r_neighbors if tup[0] == n)
route_graph.add_edge(u, v, **edge_data)
self.traverse_route(route_graph, u, seen, stack)
stack.pop(current_node)
if stack:
self.handle_leaf_nodes(route_graph, seen, stack)
elif route_graph not in self.routes:
self.routes.append(route_graph)
def handle_or_node(
self, route_graph: nx.DiGraph, p_neighbors: List[tuple], seen: set, stack: dict
) -> None:
"""To handle 'OR' Molecule nodes --> Molecule nodes that are product of more than one ChemicalEquation node."""
for i, (u, v, edge_data) in enumerate(p_neighbors):
if self.is_loop(u, v, edge_data, seen, route_graph):
continue
if i < len(p_neighbors) - 1: # Only create new routes for non-last P edges
new_route, new_seen, new_stack = self.create_new_route(
route_graph, seen, stack
)
new_route.add_edge(u, v, **edge_data)
self.traverse_route(
new_route,
u,
new_seen,
new_stack,
)
else:
route_graph.add_edge(u, v, **edge_data)
self.traverse_route(
route_graph,
u,
seen,
stack,
)
@staticmethod
def create_new_route(
route_graph: nx.DiGraph, seen: set, stack: dict
) -> Tuple[nx.DiGraph, set, dict]:
"""To create a new route and the corresponding set of visited nodes."""
return copy.deepcopy(route_graph), copy.deepcopy(seen), copy.deepcopy(stack)
def handle_and_node(
self, route_graph: nx.DiGraph, p_neighbors: List[tuple], seen: set, stack: dict
) -> None:
u, v, edge_data = p_neighbors[0]
if self.is_loop(u, v, edge_data, seen, route_graph):
return
route_graph.add_edge(u, v, **edge_data)
self.traverse_route(
route_graph,
u,
seen,
stack,
)
def is_loop(
self,
node: str,
previous_node: str,
edge_data: dict,
seen: set,
route_graph: nx.DiGraph,
) -> bool:
"""To check if the next node is involved in a loop."""
r_edge_list = self.get_r_neighbors(node)
r_neighbors = {m for (m, ce, d) in r_edge_list}
if r_neighbors.issubset(seen):
return True
if n := next((neighbor for neighbor in r_neighbors if neighbor in seen), None):
graph_test = copy.deepcopy(route_graph)
(u, v, edge_data_new) = next(tup for tup in r_edge_list if tup[0] == n)
graph_test.add_edge(u, v, **edge_data_new)
graph_test.add_edge(node, previous_node, **edge_data)
try:
nx.find_cycle(graph_test, orientation="original")
return True
except nx.NetworkXNoCycle:
return False
return False
def get_p_neighbors(self, node: str) -> List[tuple]:
"""To get the ChemicalEquation nodes, parents of a Molecule node, following the 'PRODUCT' edges."""
return [
(u, v, d)
for u, v, d in self.nx_tree.in_edges(node, data=True)
if d["label"] == self.product_edge
]
def get_r_neighbors(self, node: str) -> List[tuple]:
"""To get the Molecule nodes, parents of a ChemicalEquation node, following the 'REACTANT' edges."""
return [
(u, v, d)
for u, v, d in self.nx_tree.in_edges(node, data=True)
if d["label"] == self.reactant_edge
]
def handle_leaf_nodes(
self, route_graph: nx.DiGraph, seen: set, stack: dict
) -> None:
if ce_nodes := [ce for ce, val in stack.items() if len(val) > 0]:
for ce in ce_nodes:
while stack[ce]:
n = stack[ce].pop()
r_neighbors = self.get_r_neighbors(ce)
(u, v, edge_data) = next(tup for tup in r_neighbors if tup[0] == n)
route_graph.add_edge(u, v, **edge_data)
self.traverse_route(
route_graph,
u,
seen,
stack,
)
if route_graph not in self.routes:
self.routes.append(route_graph)
[docs]
def mine_routes(
input_list: Union[
List[Union[MonopartiteReacSynGraph, BipartiteSynGraph, MonopartiteMolSynGraph]]
],
root: Union[str, None] = settings.ROUTE_MINING.root,
new_reaction_list: Union[List[str], None] = settings.ROUTE_MINING.new_reaction_list,
) -> List[MonopartiteReacSynGraph]:
"""
To mine all the routes that can be found in tree obtained by merging the input list of routes.
Parameters:
----------
input_list: Union[List[Union[MonopartiteReacSynGraph, BipartiteSynGraph, MonopartiteMolSynGraph]]]
A list of SynGraph routes.
root : Optional[Union[str, None]]
The smiles of the target molecule for which routes should be searched.
If not provided, the root node will be determined automatically (default None)
new_reaction_list : Optional[Union[List[str], None]]
The list of smiles of the chemical reactions to be added.
If not provided, only the input graph objects are considered (default None)
Returns:
--------
extracted routes : List[MonopartiteReacSynGraph]
A list of MonopartiteReacSynGraph objects corresponding to ALL the mined routes (including the input ones)
Raises:
-------
TypeError: If the input data is not a list of SynGraph objects.
Example:
--------
>>> input_list = [route1, route2]
>>> root = 'CCC(=O)Nc1ccc(cc1)C(=O)N[C@@H](CO)C(=O)O'
>>> new_reaction_list = ['CC(=O)Nc1ccccc1C(=O)O.[O-]S(=O)(=O)C(F)(F)F>>CC(=O)Nc1ccccc1C(=O)OS(=O)(=O)C(F)(F)F']
>>> mine_routes(input_list,root,new_reaction_list)
"""
if isinstance(input_list, list) and all(
isinstance(
r, (MonopartiteReacSynGraph, BipartiteSynGraph, MonopartiteMolSynGraph)
)
for r in input_list
):
graphs_list = copy.deepcopy(input_list)
if new_reaction_list is not None:
new_graph = build_graph_from_node_sequence(new_reaction_list)
graphs_list.append(new_graph)
return RouteMiner(graphs_list, root).mine_routes()
logger.error("Only a list of syngraph objects can be used.")
raise TypeError
def build_graph_from_node_sequence(new_nodes: List[str]) -> MonopartiteReacSynGraph:
"""To build a MonopartiteReacSynGraph from a list of reaction smiles"""
new_nodes_d = [{"query_id": n, "output_string": s} for n, s in enumerate(new_nodes)]
return MonopartiteReacSynGraph(new_nodes_d)