Finding data with tree pattern matching

Definition of Tree up to this point
class Tree:
    """A tree.

    Parameters
    ----------
    data : str
        The data contained in this tree.
    children : list[Tree]
        The subtrees of this tree.
    """
    def __init__(self, data: str, children: list['Tree']=[]):
        self._data = data
        self._children = children
        
        self._validate()
        
    def _validate(self) -> None:
        try:
            assert all(isinstance(c, Tree)
                       for c in self._children)
        except AssertionError:
            msg = 'all children must be trees'
            raise TypeError(msg)
        
    @property
    def data(self) -> str:
        """The data at this node."""
        return self._data

    @property
    def children(self) -> list['Tree']:
        """The subtrees of this node."""
        return self._children

    def __str__(self):
        if self._children:
            return ' '.join(c.__str__() for c in self._children)
        else:
            return str(self._data)
        
    def __repr__(self):
        return self.to_string(0)
     
    def to_string(self, depth: int) -> str:
        """Render the tree as an indented string.

        Parameters
        ----------
        depth : int
            The current depth for indentation.

        Returns
        -------
        str
            An indented text representation of the tree.
        """
        s = (depth - 1) * '  ' +\
            int(depth > 0) * '--' +\
            self._data + '\n'
        s += ''.join(c.to_string(depth+1)
                     for c in self._children)

        return s

    def __contains__(self, data: str) -> bool:
        # pre-order depth-first search
        if self._data == data:
            return True
        else:
            for child in self._children:
                if data in child:
                    return True
                
            return False
        
    def __getitem__(self, idx: tuple[int]) -> 'Tree':
        idx = (idx,) if isinstance(idx, int) else idx
        
        try:
            assert all(isinstance(i, int) for i in idx)
            assert all(i >= 0 for i in idx)
        except AssertionError:
            errmsg = 'index must be a positive int or tuple of positive ints'
            raise IndexError(errmsg)
        
        if not idx:
            return self
        elif len(idx) == 1:
            return self._children[idx[0]]
        else:
            return self._children[idx[0]][idx[1:]]

We can get from indices to trees, but how would we go from data to indices? Similar to a list, we can implement an index() method.

class Tree(Tree):
     
    def index(self, data: str, index_path: tuple = tuple()) -> list[tuple]:
        """Find all index paths where the node data matches.

        Parameters
        ----------
        data : str
            The data value to search for.
        index_path : tuple
            The current path from the root (used in recursion).

        Returns
        -------
        list[tuple]
            All index paths whose node data equals ``data``.
        """
        indices = [index_path] if self._data==data else []
        root_path = [] if index_path == -1 else index_path
        
        indices += [j 
                    for i, c in enumerate(self._children) 
                    for j in c.index(data, root_path+(i,))]

        return indices
tree1 = Tree('S', 
             [Tree('NP', 
                   [Tree('D', 
                         [Tree('a')]),
                    Tree('N', 
                         [Tree('greyhound')])]),
             Tree('VP', 
                   [Tree('V', 
                         [Tree('loves')]),
                    Tree('NP',
                         [Tree('D',
                               [Tree('the')]),
                          Tree('N',
                               [Tree('greyhound')])])])])
determiner_indices = tree1.index('D')

determiner_indices
[(0, 0), (1, 1, 0)]
tree1[determiner_indices[0]]
D
--a
tree1[determiner_indices[1]]
D
--the

Searching on tree patterns

What if instead we wanted to find where a piece of data was based on an entire tree pattern?

tree_pattern = Tree('S', 
                    [Tree('NP',
                          [Tree('D', 
                                [Tree('the')])]),
                     Tree('VP')])

tree_pattern
S
--NP
  --D
    --the
--VP

We could implement a find() method.

class Tree(Tree):
    
    def find(self, pattern: 'Tree',
             subtree_idx: tuple = tuple()) -> list[tuple]:
        """Find subtrees matching a tree pattern.

        Parameters
        ----------
        pattern : Tree
            The tree pattern to match against.
        subtree_idx : tuple
            The index within the match to return (default: entire match).

        Returns
        -------
        list[tuple]
            Index paths to matching subtrees.
        """
        
        #raise NotImplementedError
        
        match_indices = [i + subtree_idx
                         for i in self.index(pattern.data) 
                         if self[i].match(pattern)]
            
        return match_indices
   
    def match(self, pattern: 'Tree') -> bool:
        """Check whether this tree matches a pattern.

        Parameters
        ----------
        pattern : Tree
            The tree pattern to match against.

        Returns
        -------
        bool
            True if this tree matches the pattern.
        """
        if self._data != pattern.data:
            return False
        
        for child1, child2 in zip(self._children, pattern.children):
            if not child1.match(child2):
                return False
                
        return True
tree1 = Tree('S', 
             [Tree('NP', 
                   [Tree('D', 
                         [Tree('a')]),
                    Tree('N', 
                         [Tree('greyhound')])]),
             Tree('VP', 
                   [Tree('V', 
                         [Tree('loves')]),
                    Tree('NP',
                         [Tree('D',
                               [Tree('a')]),
                          Tree('N',
                               [Tree('greyhound')])])])])

tree2 = Tree('S', 
             [Tree('NP', 
                   [Tree('D', 
                         [Tree('the')]),
                    Tree('N', 
                         [Tree('greyhound')])]),
             Tree('VP', 
                   [Tree('V', 
                         [Tree('loves')]),
                    Tree('NP',
                         [Tree('D',
                               [Tree('a')]),
                          Tree('N',
                               [Tree('greyhound')])])])])

tree3 = Tree('S', 
             [Tree('NP', 
                   [Tree('D', 
                         [Tree('a')]),
                    Tree('N', 
                         [Tree('greyhound')])]),
             Tree('VP', 
                   [Tree('V', 
                         [Tree('loves')]),
                    Tree('NP',
                         [Tree('D',
                               [Tree('the')]),
                          Tree('N',
                               [Tree('greyhound')])])])])

tree4 = Tree('S', 
             [Tree('NP', 
                   [Tree('D', 
                         [Tree('the')]),
                    Tree('N', 
                         [Tree('greyhound')])]),
             Tree('VP', 
                   [Tree('V', 
                         [Tree('loves')]),
                    Tree('NP',
                         [Tree('D',
                               [Tree('the')]),
                          Tree('N',
                               [Tree('greyhound')])])])])
tree2.find(tree_pattern, (0,0))
[(0, 0)]
tree_pattern = Tree('VP', 
                    [Tree('V'),
                     Tree('NP', 
                          [Tree('D', 
                                [Tree('the')])])])

tree_pattern
VP
--V
--NP
  --D
    --the
tree1.find(tree_pattern, subtree_idx=(1,))
[]
tree2.find(tree_pattern, subtree_idx=(1,))
[]
tree3.find(tree_pattern, subtree_idx=(1,))
[(1, 1)]
tree4.find(tree_pattern, subtree_idx=(1,))
[(1, 1)]

This sort of treelet-based matching is somewhat weak as it stands. What if we wanted:

  1. …nodes to be allowed to have some value from a set?
  2. …arbitrary distance between the nodes we are matching on?
  3. …arbitrary boolean conditions on node matches?

Expanding pattern-based search with SPARQL

To handle this, we need both a domain-specific language (DSL) for specifying such queries and an interpeter for that language. We can use SPARQL for our DSL. To intepret SPARQL, we will use the existing interpreter in rdflib.

To use rdflib’s interpreter, we need to map our Tree objects into an in-memory format for which a SPARQL interpreter is already implemented. We will use Resource Description Format as implemented in rdflib.

from rdflib import Graph, URIRef

class Tree(Tree):
    
    RDF_TYPES = {}
    RDF_EDGES = {'is': URIRef('is-a'),
                 'parent': URIRef('is-the-parent-of'),
                 'child': URIRef('is-a-child-of'),
                 'sister': URIRef('is-a-sister-of')}
            
    def to_rdf(self, graph=None, nodes={}, idx=tuple()) -> Graph:
        """Convert the tree to an RDF graph for SPARQL querying.

        Parameters
        ----------
        graph : Graph, optional
            An existing graph to add triples to.
        nodes : dict, optional
            A mapping from index tuples to URI nodes.
        idx : tuple, optional
            The index of this node in the parent tree.

        Returns
        -------
        Graph
            The RDF graph representing the tree.
        """
        graph = Graph() if graph is None else graph
        
        idxstr = '_'.join(str(i) for i in idx)
        nodes[idx] = URIRef(idxstr)
            
        if self._data not in Tree.RDF_TYPES:
            Tree.RDF_TYPES[self._data] = URIRef(self._data)

        typetriple = (nodes[idx], 
                      Tree.RDF_EDGES['is'],
                      Tree.RDF_TYPES[self.data])

        graph.add(typetriple)

        for i, child in enumerate(self._children):
            childidx = idx+(i,)
            child.to_rdf(graph, nodes, childidx)
                
            partriple = (nodes[idx], 
                         Tree.RDF_EDGES['parent'],
                         nodes[childidx])
            chitriple = (nodes[childidx], 
                         Tree.RDF_EDGES['child'],
                         nodes[idx])
            
            graph.add(partriple)
            graph.add(chitriple)
            
        for i, child1 in enumerate(self._children):
            for j, child2 in enumerate(self._children):
                child1idx = idx+(i,)
                child2idx = idx+(j,)
                sistriple = (nodes[child1idx], 
                             Tree.RDF_EDGES['sister'],
                             nodes[child2idx])
                
                graph.add(sistriple)
        
        self._rdf_nodes = nodes
        
        return graph
    
    @property
    def rdf(self) -> Graph:
        """The lazily-constructed RDF graph for this tree."""
        if not hasattr(self, "_rdf"):
            self._rdf = self.to_rdf()

        return self._rdf

    def find(self, query: str) -> list[tuple[int]]:
        """Find subtrees matching a SPARQL query.

        Parameters
        ----------
        query : str
            A SPARQL SELECT query.

        Returns
        -------
        list[tuple[int]]
            Index paths to matching nodes.
        """
        return [tuple([int(i) 
                       for i in str(res[0]).split('_')]) 
                for res in self.rdf.query(query)]
tree1 = Tree('S', 
             [Tree('NP', 
                   [Tree('D', 
                         [Tree('a')]),
                    Tree('N', 
                         [Tree('greyhound')])]),
             Tree('VP', 
                   [Tree('V', 
                         [Tree('loves')]),
                    Tree('NP',
                         [Tree('D',
                               [Tree('a')]),
                          Tree('N',
                               [Tree('greyhound')])])])])

tree2 = Tree('S', 
             [Tree('NP', 
                   [Tree('D', 
                         [Tree('the')]),
                    Tree('N', 
                         [Tree('greyhound')])]),
             Tree('VP', 
                   [Tree('V', 
                         [Tree('loves')]),
                    Tree('NP',
                         [Tree('D',
                               [Tree('a')]),
                          Tree('N',
                               [Tree('greyhound')])])])])

tree3 = Tree('S', 
             [Tree('NP', 
                   [Tree('D', 
                         [Tree('a')]),
                    Tree('N', 
                         [Tree('greyhound')])]),
             Tree('VP', 
                   [Tree('V', 
                         [Tree('loves')]),
                    Tree('NP',
                         [Tree('D',
                               [Tree('the')]),
                          Tree('N',
                               [Tree('greyhound')])])])])

tree4 = Tree('S', 
             [Tree('NP', 
                   [Tree('D', 
                         [Tree('the')]),
                    Tree('N', 
                         [Tree('greyhound')])]),
             Tree('VP', 
                   [Tree('V', 
                         [Tree('loves')]),
                    Tree('NP',
                         [Tree('D',
                               [Tree('the')]),
                          Tree('N',
                               [Tree('greyhound')])])])])
tree1.find('''SELECT ?node
              WHERE { ?node <is-a> <NP>.
                      ?node <is-the-parent-of>* ?child.
                      ?node <is-a-child-of>* ?parent.
                      ?parent <is-a> <S>.
                      ?child <is-a> <the>.
                      ?node <is-a-sister-of> ?sister.
                      ?sister <is-a> <VP>.
                    }''')
[]
tree2.find('''SELECT ?node
              WHERE { ?node <is-a> <NP>.
                      ?node <is-the-parent-of>* ?child.
                      ?child <is-a> <the>.
                      ?node <is-a-sister-of> ?sister.
                      ?sister <is-a> <VP>.
                    }''')
[(0,)]
tree2.find('''SELECT ?node
              WHERE { ?node <is-a> <NP>;
                            <is-the-parent-of>* ?child;
                            <is-a-sister-of> ?sister.
                      ?child <is-a> <the>.
                      ?sister <is-a> <VP>.
                    }''')
[(0,)]
tree3.find('''SELECT ?node
              WHERE { ?node <is-a> <NP>;
                            <is-the-parent-of>* ?child;
                            <is-a-sister-of> ?sister.
                      ?child <is-a> <the>.
                      ?sister <is-a> <VP>.
                    }''')
[]
tree4.find('''SELECT ?node
              WHERE { ?node <is-a> <NP>;
                            <is-the-parent-of>* ?child;
                            <is-a-sister-of> ?sister.
                      ?child <is-a> <the>.
                      ?sister <is-a> <V>.
                    }''')
[(1, 1)]