Finding data with tree pattern matching

Definition of Tree up to this point
from typing import TypeVar

DataType = TypeVar("DataType")

class Tree:
    """A tree
    
    Parameters
    ----------
    data
        The data contained in this tree
    children
        The subtrees of this tree
    """
    def __init__(self, data: DataType, children: list['Tree']=[]):
        self._data = data
        self._children = children
        
        self._validate()
        
    def _validate(self) -> None:
        try:
            assert all(isinstance(c, Tree)
                       for c in self._children)
        except AssertionError:
            msg = 'all children must be trees'
            raise TypeError(msg)
        
    @property
    def data(self) -> DataType:
        return self._data 
    
    @property
    def children(self) -> list['Tree']:
        return self._children

    def __str__(self):
        if self._children:
            return ' '.join(c.__str__() for c in self._children)
        else:
            return str(self._data)
        
    def __repr__(self):
        return self.to_string(0)
     
    def to_string(self, depth: int) -> str:
        s = (depth - 1) * '  ' +\
            int(depth > 0) * '--' +\
            self._data + '\n'
        s += ''.join(c.to_string(depth+1)
                     for c in self._children)
        
        return s

    def __contains__(self, data: DataType) -> bool:
        # pre-order depth-first search
        if self._data == data:
            return True
        else:
            for child in self._children:
                if data in child:
                    return True
                
            return False
        
    def __getitem__(self, idx: tuple[int]) -> 'Tree':
        idx = (idx,) if isinstance(idx, int) else idx
        
        try:
            assert all(isinstance(i, int) for i in idx)
            assert all(i >= 0 for i in idx)
        except AssertionError:
            errmsg = 'index must be a positive int or tuple of positive ints'
            raise IndexError(errmsg)
        
        if not idx:
            return self
        elif len(idx) == 1:
            return self._children[idx[0]]
        else:
            return self._children[idx[0]][idx[1:]]

We can get from indices to trees, but how would we go from data to indices? Similar to a list, we can implement an index() method.

class Tree(Tree):
     
    def index(self, data, index_path=tuple()):
        indices = [index_path] if self._data==data else []
        root_path = [] if index_path == -1 else index_path
        
        indices += [j 
                    for i, c in enumerate(self._children) 
                    for j in c.index(data, root_path+(i,))]

        return indices
tree1 = Tree('S', 
             [Tree('NP', 
                   [Tree('D', 
                         [Tree('a')]),
                    Tree('N', 
                         [Tree('greyhound')])]),
             Tree('VP', 
                   [Tree('V', 
                         [Tree('loves')]),
                    Tree('NP',
                         [Tree('D',
                               [Tree('the')]),
                          Tree('N',
                               [Tree('greyhound')])])])])
determiner_indices = tree1.index('D')

determiner_indices
[(0, 0), (1, 1, 0)]
tree1[determiner_indices[0]]
D
--a
tree1[determiner_indices[1]]
D
--the

Searching on tree patterns

What if instead we wanted to find where a piece of data was based on an entire tree pattern?

tree_pattern = Tree('S', 
                    [Tree('NP',
                          [Tree('D', 
                                [Tree('the')])]),
                     Tree('VP')])

tree_pattern
S
--NP
  --D
    --the
--VP

We could implement a find() method.

class Tree(Tree):
    
    def find(self, pattern: 'Tree', 
             subtree_idx: tuple=tuple()) -> list[tuple]:
        '''The subtrees matching the pattern
        
        Parameters
        ----------
        pattern
            the tree pattern to match against
        subtree_idx
            the index of the subtree within the tree pattern to return
            defaults to the entire match
        '''
        
        #raise NotImplementedError
        
        match_indices = [i + subtree_idx
                         for i in self.index(pattern.data) 
                         if self[i].match(pattern)]
            
        return match_indices
   
    def match(self, pattern: 'Tree') -> bool:
        if self._data != pattern.data:
            return False
        
        for child1, child2 in zip(self._children, pattern.children):
            if not child1.match(child2):
                return False
                
        return True
tree1 = Tree('S', 
             [Tree('NP', 
                   [Tree('D', 
                         [Tree('a')]),
                    Tree('N', 
                         [Tree('greyhound')])]),
             Tree('VP', 
                   [Tree('V', 
                         [Tree('loves')]),
                    Tree('NP',
                         [Tree('D',
                               [Tree('a')]),
                          Tree('N',
                               [Tree('greyhound')])])])])

tree2 = Tree('S', 
             [Tree('NP', 
                   [Tree('D', 
                         [Tree('the')]),
                    Tree('N', 
                         [Tree('greyhound')])]),
             Tree('VP', 
                   [Tree('V', 
                         [Tree('loves')]),
                    Tree('NP',
                         [Tree('D',
                               [Tree('a')]),
                          Tree('N',
                               [Tree('greyhound')])])])])

tree3 = Tree('S', 
             [Tree('NP', 
                   [Tree('D', 
                         [Tree('a')]),
                    Tree('N', 
                         [Tree('greyhound')])]),
             Tree('VP', 
                   [Tree('V', 
                         [Tree('loves')]),
                    Tree('NP',
                         [Tree('D',
                               [Tree('the')]),
                          Tree('N',
                               [Tree('greyhound')])])])])

tree4 = Tree('S', 
             [Tree('NP', 
                   [Tree('D', 
                         [Tree('the')]),
                    Tree('N', 
                         [Tree('greyhound')])]),
             Tree('VP', 
                   [Tree('V', 
                         [Tree('loves')]),
                    Tree('NP',
                         [Tree('D',
                               [Tree('the')]),
                          Tree('N',
                               [Tree('greyhound')])])])])
tree2.find(tree_pattern, (0,0))
[(0, 0)]
tree_pattern = Tree('VP', 
                    [Tree('V'),
                     Tree('NP', 
                          [Tree('D', 
                                [Tree('the')])])])

tree_pattern
VP
--V
--NP
  --D
    --the
tree1.find(tree_pattern, subtree_idx=(1,))
[]
tree2.find(tree_pattern, subtree_idx=(1,))
[]
tree3.find(tree_pattern, subtree_idx=(1,))
[(1, 1)]
tree4.find(tree_pattern, subtree_idx=(1,))
[(1, 1)]

This sort of treelet-based matching is somewhat weak as it stands. What if we wanted:

  1. …nodes to be allowed to have some value from a set?
  2. …arbitrary distance between the nodes we are matching on?
  3. …arbitrary boolean conditions on node matches?

Expanding pattern-based search with SPARQL

To handle this, we need both a domain-specific language (DSL) for specifying such queries and an interpeter for that language. We can use SPARQL for our DSL. To intepret SPARQL, we will use the existing interpreter in rdflib.

To use rdflib’s interpreter, we need to map our Tree objects into an in-memory format for which a SPARQL interpreter is already implemented. We will use Resource Description Format as implemented in rdflib.

from rdflib import Graph, URIRef

class Tree(Tree):
    
    RDF_TYPES = {}
    RDF_EDGES = {'is': URIRef('is-a'),
                 'parent': URIRef('is-the-parent-of'),
                 'child': URIRef('is-a-child-of'),
                 'sister': URIRef('is-a-sister-of')}
            
    def to_rdf(self, graph=None, nodes={}, idx=tuple()) -> Graph: 
        graph = Graph() if graph is None else graph
        
        idxstr = '_'.join(str(i) for i in idx)
        nodes[idx] = URIRef(idxstr)
            
        if self._data not in Tree.RDF_TYPES:
            Tree.RDF_TYPES[self._data] = URIRef(self._data)

        typetriple = (nodes[idx], 
                      Tree.RDF_EDGES['is'],
                      Tree.RDF_TYPES[self.data])

        graph.add(typetriple)

        for i, child in enumerate(self._children):
            childidx = idx+(i,)
            child.to_rdf(graph, nodes, childidx)
                
            partriple = (nodes[idx], 
                         Tree.RDF_EDGES['parent'],
                         nodes[childidx])
            chitriple = (nodes[childidx], 
                         Tree.RDF_EDGES['child'],
                         nodes[idx])
            
            graph.add(partriple)
            graph.add(chitriple)
            
        for i, child1 in enumerate(self._children):
            for j, child2 in enumerate(self._children):
                child1idx = idx+(i,)
                child2idx = idx+(j,)
                sistriple = (nodes[child1idx], 
                             Tree.RDF_EDGES['sister'],
                             nodes[child2idx])
                
                graph.add(sistriple)
        
        self._rdf_nodes = nodes
        
        return graph
    
    @property
    def rdf(self) -> Graph:
        if not hasattr(self, "_rdf"):
            self._rdf = self.to_rdf()

        return self._rdf
    
    def find(self, query: str) -> list[tuple[int]]:
        return [tuple([int(i) 
                       for i in str(res[0]).split('_')]) 
                for res in self.rdf.query(query)]
tree1 = Tree('S', 
             [Tree('NP', 
                   [Tree('D', 
                         [Tree('a')]),
                    Tree('N', 
                         [Tree('greyhound')])]),
             Tree('VP', 
                   [Tree('V', 
                         [Tree('loves')]),
                    Tree('NP',
                         [Tree('D',
                               [Tree('a')]),
                          Tree('N',
                               [Tree('greyhound')])])])])

tree2 = Tree('S', 
             [Tree('NP', 
                   [Tree('D', 
                         [Tree('the')]),
                    Tree('N', 
                         [Tree('greyhound')])]),
             Tree('VP', 
                   [Tree('V', 
                         [Tree('loves')]),
                    Tree('NP',
                         [Tree('D',
                               [Tree('a')]),
                          Tree('N',
                               [Tree('greyhound')])])])])

tree3 = Tree('S', 
             [Tree('NP', 
                   [Tree('D', 
                         [Tree('a')]),
                    Tree('N', 
                         [Tree('greyhound')])]),
             Tree('VP', 
                   [Tree('V', 
                         [Tree('loves')]),
                    Tree('NP',
                         [Tree('D',
                               [Tree('the')]),
                          Tree('N',
                               [Tree('greyhound')])])])])

tree4 = Tree('S', 
             [Tree('NP', 
                   [Tree('D', 
                         [Tree('the')]),
                    Tree('N', 
                         [Tree('greyhound')])]),
             Tree('VP', 
                   [Tree('V', 
                         [Tree('loves')]),
                    Tree('NP',
                         [Tree('D',
                               [Tree('the')]),
                          Tree('N',
                               [Tree('greyhound')])])])])
tree1.find('''SELECT ?node
              WHERE { ?node <is-a> <NP>.
                      ?node <is-the-parent-of>* ?child.
                      ?node <is-a-child-of>* ?parent.
                      ?parent <is-a> <S>.
                      ?child <is-a> <the>.
                      ?node <is-a-sister-of> ?sister.
                      ?sister <is-a> <VP>.
                    }''')
[]
tree2.find('''SELECT ?node
              WHERE { ?node <is-a> <NP>.
                      ?node <is-the-parent-of>* ?child.
                      ?child <is-a> <the>.
                      ?node <is-a-sister-of> ?sister.
                      ?sister <is-a> <VP>.
                    }''')
[(0,)]
tree2.find('''SELECT ?node
              WHERE { ?node <is-a> <NP>;
                            <is-the-parent-of>* ?child;
                            <is-a-sister-of> ?sister.
                      ?child <is-a> <the>.
                      ?sister <is-a> <VP>.
                    }''')
[(0,)]
tree3.find('''SELECT ?node
              WHERE { ?node <is-a> <NP>;
                            <is-the-parent-of>* ?child;
                            <is-a-sister-of> ?sister.
                      ?child <is-a> <the>.
                      ?sister <is-a> <VP>.
                    }''')
[]
tree4.find('''SELECT ?node
              WHERE { ?node <is-a> <NP>;
                            <is-the-parent-of>* ?child;
                            <is-a-sister-of> ?sister.
                      ?child <is-a> <the>.
                      ?sister <is-a> <V>.
                    }''')
[(1, 1)]