Levenshtein distance

Okay, so how do we compute edit (or Levenshtein) distance? The basic idea is to define it recursively, deciding at each point in the string whether we want to insert/delete an element (each at some cost \(c\)) or whether we want to try matching the string.

\[d_\text{lev}(\mathbf{a}, \mathbf{b}) = \begin{cases} c_\text{ins} \times |\mathbf{b}| & \text{if } |\mathbf{a}| = 0 \\ c_\text{del} \times |\mathbf{a}| & \text{if } |\mathbf{b}| = 0 \\ \min \begin{cases}d_\text{lev}(a_1\ldots a_{|\mathbf{a}|-1}, \mathbf{b}) + c_\text{del} \\ d_\text{lev}(\mathbf{a}, b_1\ldots b_{|\mathbf{b}|-1}) + c_\text{ins} \\ d_\text{lev}(a_1\ldots a_{|\mathbf{a}|-1}, b_1\ldots b_{|\mathbf{b}|-1}) + c_\text{sub} \times \mathbb{1}[a_{|\mathbf{a}|} \neq b_{|\mathbf{b}|}]\end{cases} & \text{otherwise}\end{cases}\]

where \(c_\text{sub}\) defaults to \(c_\text{del} + c_\text{ins}\).

from collections import defaultdict
from typing import Optional

class StringEdit1:
    '''Distance between strings

    Parameters
    ----------
    insertion_cost
    deletion_cost
    substitution_cost
    '''
    
    def __init__(self, insertion_cost: float = 1.,
                 deletion_cost: float = 1.,
                 substitution_cost: Optional[float] = None):
        self._insertion_cost = insertion_cost
        self._deletion_cost = deletion_cost

        if substitution_cost is None:
            self._substitution_cost = insertion_cost + deletion_cost
        else:
            self._substitution_cost = substitution_cost
         
    def __call__(self, source: str, target: str) -> float:
        self._call_counter = defaultdict(int)
        return self._naive_levenshtein(source, target)
        
    def _naive_levenshtein(self, source, target):
        self._call_counter[(source, target)] += 1
        
        cost = 0
        
        # base case
        if len(source) == 0:
            return self._insertion_cost*len(target)
        
        if len(target) == 0:
            return self._deletion_cost*len(source)

        # test if last characters of the strings match
        if (source[len(source)-1] == target[len(target)-1]):
            sub_cost = 0.
        else:
            sub_cost = self._substitution_cost

        # minimum of delete from source, deletefrom target, and delete from both
        return min(self._naive_levenshtein(source[:-1], target) + self._deletion_cost,
                   self._naive_levenshtein(source, target[:-1]) + self._insertion_cost,
                   self._naive_levenshtein(source[:-1], target[:-1]) + sub_cost)
    
    @property
    def call_counter(self):
        return self._call_counter
editdist = StringEdit1(1, 1)

editdist('æbstɹækt', 'æbstɹækt'), editdist('æbstɹækt', 'æbstɹækʃən'), editdist('æbstɹækʃən', 'æbstɹækt'), editdist('æbstɹækt', '')
(0.0, 4.0, 4.0, 8)

Okay. So here’s the thing. This looks nice, but it’s actually not that efficient because we’re actually redoing a whole ton of work.

editdist('æbstɹækʃən', 'æbstɹækt')

editdist.call_counter
defaultdict(int,
            {('æbstɹækʃən', 'æbstɹækt'): 1,
             ('æbstɹækʃə', 'æbstɹækt'): 1,
             ('æbstɹækʃ', 'æbstɹækt'): 1,
             ('æbstɹæk', 'æbstɹækt'): 1,
             ('æbstɹæ', 'æbstɹækt'): 1,
             ('æbstɹ', 'æbstɹækt'): 1,
             ('æbst', 'æbstɹækt'): 1,
             ('æbs', 'æbstɹækt'): 1,
             ('æb', 'æbstɹækt'): 1,
             ('æ', 'æbstɹækt'): 1,
             ('', 'æbstɹækt'): 1,
             ('æ', 'æbstɹæk'): 19,
             ('', 'æbstɹæk'): 20,
             ('æ', 'æbstɹæ'): 181,
             ('', 'æbstɹæ'): 200,
             ('æ', 'æbstɹ'): 1159,
             ('', 'æbstɹ'): 1340,
             ('æ', 'æbst'): 5641,
             ('', 'æbst'): 6800,
             ('æ', 'æbs'): 22363,
             ('', 'æbs'): 28004,
             ('æ', 'æb'): 75517,
             ('', 'æb'): 97880,
             ('æ', 'æ'): 224143,
             ('', 'æ'): 299660,
             ('æ', ''): 332688,
             ('', ''): 224143,
             ('æb', 'æbstɹæk'): 17,
             ('æb', 'æbstɹæ'): 145,
             ('æb', 'æbstɹ'): 833,
             ('æb', 'æbst'): 3649,
             ('æb', 'æbs'): 13073,
             ('æb', 'æb'): 40081,
             ('æb', 'æ'): 108545,
             ('æb', ''): 157184,
             ('æbs', 'æbstɹæk'): 15,
             ('æbs', 'æbstɹæ'): 113,
             ('æbs', 'æbstɹ'): 575,
             ('æbs', 'æbst'): 2241,
             ('æbs', 'æbs'): 7183,
             ('æbs', 'æb'): 19825,
             ('æbs', 'æ'): 48639,
             ('æbs', ''): 68464,
             ('æbst', 'æbstɹæk'): 13,
             ('æbst', 'æbstɹæ'): 85,
             ('æbst', 'æbstɹ'): 377,
             ('æbst', 'æbst'): 1289,
             ('æbst', 'æbs'): 3653,
             ('æbst', 'æb'): 8989,
             ('æbst', 'æ'): 19825,
             ('æbst', ''): 27008,
             ('æbstɹ', 'æbstɹæk'): 11,
             ('æbstɹ', 'æbstɹæ'): 61,
             ('æbstɹ', 'æbstɹ'): 231,
             ('æbstɹ', 'æbst'): 681,
             ('æbstɹ', 'æbs'): 1683,
             ('æbstɹ', 'æb'): 3653,
             ('æbstɹ', 'æ'): 7183,
             ('æbstɹ', ''): 9424,
             ('æbstɹæ', 'æbstɹæk'): 9,
             ('æbstɹæ', 'æbstɹæ'): 41,
             ('æbstɹæ', 'æbstɹ'): 129,
             ('æbstɹæ', 'æbst'): 321,
             ('æbstɹæ', 'æbs'): 681,
             ('æbstɹæ', 'æb'): 1289,
             ('æbstɹæ', 'æ'): 2241,
             ('æbstɹæ', ''): 2816,
             ('æbstɹæk', 'æbstɹæk'): 7,
             ('æbstɹæk', 'æbstɹæ'): 25,
             ('æbstɹæk', 'æbstɹ'): 63,
             ('æbstɹæk', 'æbst'): 129,
             ('æbstɹæk', 'æbs'): 231,
             ('æbstɹæk', 'æb'): 377,
             ('æbstɹæk', 'æ'): 575,
             ('æbstɹæk', ''): 688,
             ('æbstɹækʃ', 'æbstɹæk'): 5,
             ('æbstɹækʃ', 'æbstɹæ'): 13,
             ('æbstɹækʃ', 'æbstɹ'): 25,
             ('æbstɹækʃ', 'æbst'): 41,
             ('æbstɹækʃ', 'æbs'): 61,
             ('æbstɹækʃ', 'æb'): 85,
             ('æbstɹækʃ', 'æ'): 113,
             ('æbstɹækʃ', ''): 128,
             ('æbstɹækʃə', 'æbstɹæk'): 3,
             ('æbstɹækʃə', 'æbstɹæ'): 5,
             ('æbstɹækʃə', 'æbstɹ'): 7,
             ('æbstɹækʃə', 'æbst'): 9,
             ('æbstɹækʃə', 'æbs'): 11,
             ('æbstɹækʃə', 'æb'): 13,
             ('æbstɹækʃə', 'æ'): 15,
             ('æbstɹækʃə', ''): 16,
             ('æbstɹækʃən', 'æbstɹæk'): 1,
             ('æbstɹækʃən', 'æbstɹæ'): 1,
             ('æbstɹækʃən', 'æbstɹ'): 1,
             ('æbstɹækʃən', 'æbst'): 1,
             ('æbstɹækʃən', 'æbs'): 1,
             ('æbstɹækʃən', 'æb'): 1,
             ('æbstɹækʃən', 'æ'): 1,
             ('æbstɹækʃən', ''): 1})

We could try to get around this by memoizing using the lru_cache decorator.

from functools import lru_cache

class StringEdit2(StringEdit1):
    '''Distance between strings


    Parameters
    ----------
    insertion_cost
    deletion_cost
    substitution_cost
    '''
    
    @lru_cache(256)
    def _naive_levenshtein(self, source, target):
        self._call_counter[(source, target)] += 1
        
        cost = 0

        # base case
        if len(source) == 0:
            return self._insertion_cost*len(target)
        
        if len(target) == 0:
            return self._deletion_cost*len(source)

        # test if last characters of the strings match
        if (source[len(source)-1] == target[len(target)-1]):
            sub_cost = 0
        else:
            sub_cost = self._substitution_cost

        # minimum of delete from source, deletefrom target, and delete from both
        return min(self._naive_levenshtein(source[:-1], target) + self._deletion_cost,
                   self._naive_levenshtein(source, target[:-1]) + self._insertion_cost,
                   self._naive_levenshtein(source[:-1], target[:-1]) + sub_cost)
%%timeit

editdist = StringEdit1(1, 1)

editdist('æbstɹækt', 'æbstɹækt'), editdist('æbstɹækt', 'æbstɹækʃən'), editdist('æbstɹækʃən', 'æbstɹækt'), editdist('æbstɹækt', '')
2.19 s ± 8.49 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%%timeit

editdist = StringEdit2(1, 1)

editdist('æbstɹækt', 'æbstɹækt'), editdist('æbstɹækt', 'æbstɹækʃən'), editdist('æbstɹækʃən', 'æbstɹækt'), editdist('æbstɹækt', '')
178 µs ± 892 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
editdist = StringEdit2(1, 1)

editdist('æbstɹækʃən', 'æbstɹækt')

editdist.call_counter
defaultdict(int,
            {('æbstɹækʃən', 'æbstɹækt'): 1,
             ('æbstɹækʃə', 'æbstɹækt'): 1,
             ('æbstɹækʃ', 'æbstɹækt'): 1,
             ('æbstɹæk', 'æbstɹækt'): 1,
             ('æbstɹæ', 'æbstɹækt'): 1,
             ('æbstɹ', 'æbstɹækt'): 1,
             ('æbst', 'æbstɹækt'): 1,
             ('æbs', 'æbstɹækt'): 1,
             ('æb', 'æbstɹækt'): 1,
             ('æ', 'æbstɹækt'): 1,
             ('', 'æbstɹækt'): 1,
             ('æ', 'æbstɹæk'): 1,
             ('', 'æbstɹæk'): 1,
             ('æ', 'æbstɹæ'): 1,
             ('', 'æbstɹæ'): 1,
             ('æ', 'æbstɹ'): 1,
             ('', 'æbstɹ'): 1,
             ('æ', 'æbst'): 1,
             ('', 'æbst'): 1,
             ('æ', 'æbs'): 1,
             ('', 'æbs'): 1,
             ('æ', 'æb'): 1,
             ('', 'æb'): 1,
             ('æ', 'æ'): 1,
             ('', 'æ'): 1,
             ('æ', ''): 1,
             ('', ''): 1,
             ('æb', 'æbstɹæk'): 1,
             ('æb', 'æbstɹæ'): 1,
             ('æb', 'æbstɹ'): 1,
             ('æb', 'æbst'): 1,
             ('æb', 'æbs'): 1,
             ('æb', 'æb'): 1,
             ('æb', 'æ'): 1,
             ('æb', ''): 1,
             ('æbs', 'æbstɹæk'): 1,
             ('æbs', 'æbstɹæ'): 1,
             ('æbs', 'æbstɹ'): 1,
             ('æbs', 'æbst'): 1,
             ('æbs', 'æbs'): 1,
             ('æbs', 'æb'): 1,
             ('æbs', 'æ'): 1,
             ('æbs', ''): 1,
             ('æbst', 'æbstɹæk'): 1,
             ('æbst', 'æbstɹæ'): 1,
             ('æbst', 'æbstɹ'): 1,
             ('æbst', 'æbst'): 1,
             ('æbst', 'æbs'): 1,
             ('æbst', 'æb'): 1,
             ('æbst', 'æ'): 1,
             ('æbst', ''): 1,
             ('æbstɹ', 'æbstɹæk'): 1,
             ('æbstɹ', 'æbstɹæ'): 1,
             ('æbstɹ', 'æbstɹ'): 1,
             ('æbstɹ', 'æbst'): 1,
             ('æbstɹ', 'æbs'): 1,
             ('æbstɹ', 'æb'): 1,
             ('æbstɹ', 'æ'): 1,
             ('æbstɹ', ''): 1,
             ('æbstɹæ', 'æbstɹæk'): 1,
             ('æbstɹæ', 'æbstɹæ'): 1,
             ('æbstɹæ', 'æbstɹ'): 1,
             ('æbstɹæ', 'æbst'): 1,
             ('æbstɹæ', 'æbs'): 1,
             ('æbstɹæ', 'æb'): 1,
             ('æbstɹæ', 'æ'): 1,
             ('æbstɹæ', ''): 1,
             ('æbstɹæk', 'æbstɹæk'): 1,
             ('æbstɹæk', 'æbstɹæ'): 1,
             ('æbstɹæk', 'æbstɹ'): 1,
             ('æbstɹæk', 'æbst'): 1,
             ('æbstɹæk', 'æbs'): 1,
             ('æbstɹæk', 'æb'): 1,
             ('æbstɹæk', 'æ'): 1,
             ('æbstɹæk', ''): 1,
             ('æbstɹækʃ', 'æbstɹæk'): 1,
             ('æbstɹækʃ', 'æbstɹæ'): 1,
             ('æbstɹækʃ', 'æbstɹ'): 1,
             ('æbstɹækʃ', 'æbst'): 1,
             ('æbstɹækʃ', 'æbs'): 1,
             ('æbstɹækʃ', 'æb'): 1,
             ('æbstɹækʃ', 'æ'): 1,
             ('æbstɹækʃ', ''): 1,
             ('æbstɹækʃə', 'æbstɹæk'): 1,
             ('æbstɹækʃə', 'æbstɹæ'): 1,
             ('æbstɹækʃə', 'æbstɹ'): 1,
             ('æbstɹækʃə', 'æbst'): 1,
             ('æbstɹækʃə', 'æbs'): 1,
             ('æbstɹækʃə', 'æb'): 1,
             ('æbstɹækʃə', 'æ'): 1,
             ('æbstɹækʃə', ''): 1,
             ('æbstɹækʃən', 'æbstɹæk'): 1,
             ('æbstɹækʃən', 'æbstɹæ'): 1,
             ('æbstɹækʃən', 'æbstɹ'): 1,
             ('æbstɹækʃən', 'æbst'): 1,
             ('æbstɹækʃən', 'æbs'): 1,
             ('æbstɹækʃən', 'æb'): 1,
             ('æbstɹækʃən', 'æ'): 1,
             ('æbstɹækʃən', ''): 1})

That helps a lot. Why? Because it only every computes the distance for a substring once. This is effectively what the Wagner–Fischer algorithm that you read about is doing. This is our first instance of a dynamic programming algorithm. The basic idea for Wagner-Fisher (and other algorithms we’ll use later in the class) is to cache the memoized values for a function within a chart whose rows correspond to positions in the source string and whose columns correspond to positions in the target string.

import numpy as np

class StringEdit3(StringEdit2):
    '''Distance between strings


    Parameters
    ----------
    insertion_cost
    deletion_cost
    substitution_cost
    '''

    def __call__(self, source: str, target: str) -> float:
        return self._wagner_fisher(source, target)

    
    def _wagner_fisher(self, source: str, target: str):
        n, m = len(source), len(target)
        source, target = '#'+source, '#'+target

        distance = np.zeros([n+1, m+1], dtype=float)
        
        for i in range(1,n+1):
            distance[i,0] = distance[i-1,0]+self._deletion_cost

        for j in range(1,m+1):
            distance[0,j] = distance[0,j-1]+self._insertion_cost
            
        for i in range(1,n+1):
            for j in range(1,m+1):
                if source[i] == target[j]:
                    substitution_cost = 0.
                else:
                    substitution_cost = self._substitution_cost
                    
                costs = np.array([distance[i-1,j]+self._deletion_cost,
                                  distance[i-1,j-1]+substitution_cost,
                                  distance[i,j-1]+self._insertion_cost])
                    
                distance[i,j] = costs.min()
                
        return distance[n,m]
editdist = StringEdit3(1, 1)

editdist('æbstɹækt', 'æbstɹækʃən')
4.0

So why use Wagner-Fisher when we can just use memoization on the naive algorithm? The reason is that the chart used in Wagner-Fisher allows us to very easily store information about the implicit alignment between string elements. This notion of alignment is the same as the one we saw above when talking about how best to match up a square to the face of a cube when discussing boolean vectors.

So what do we need to do add to our previous implementation of Wagner-Fisher to store backtraces? Importantly, note that you will need to return a list of backtraces because there could be multiple equally good ones. (This point will come up for all of the dynamic programming algorithms we look at and, as we’ll see, is actually abstractly related to syntactic ambiguity.)

from typing import Tuple, List

class StringEdit4(StringEdit3):
    '''distance, alignment, and edit paths between strings


    Parameters
    ----------
    insertion_cost : float
    deletion_cost : float
    substitution_cost : float | NoneType (default: None)
    '''

    def __call__(self, source: str, 
                 target: str) ->  Tuple[float, List[List[Tuple[int, int]]]]:
        return self._wagner_fisher(source, target)
            
    def _wagner_fisher(self, source, target):
        '''compute minimum edit distance and alignment'''

        n, m = len(source), len(target)

        source, target = self._add_sentinel(source, target)

        distance = np.zeros([n+1, m+1], dtype=float)
        pointers = np.zeros([n+1, m+1], dtype=list)

        pointers[0,0] = []
        
        for i in range(1,n+1):
            distance[i,0] = distance[i-1,0]+self._deletion_cost
            pointers[i,0] = [(i-1,0)]

        for j in range(1,m+1):
            distance[0,j] = distance[0,j-1]+self._insertion_cost
            pointers[0,j] = [(0,j-1)]
            
        for i in range(1,n+1):
            for j in range(1,m+1):
                if source[i] == target[j]:
                    substitution_cost = 0.
                else:
                    substitution_cost = self._substitution_cost
                    
                costs = np.array([distance[i-1,j]+self._deletion_cost,
                                  distance[i-1,j-1]+substitution_cost,
                                  distance[i,j-1]+self._insertion_cost])
                    
                distance[i,j] = costs.min()

                best_edits = np.where(costs==distance[i,j])[0]

                indices = [(i-1,j), (i-1,j-1), (i,j-1)]
                pointers[i,j] = [indices[i] for i in  best_edits]

        pointer_backtrace = self._construct_backtrace(pointers,
                                                      idx=(n,m))
                
        return distance[n,m], [[(i-1,j-1) for i, j in bt] 
                               for bt in pointer_backtrace]


    def _construct_backtrace(self, pointers, idx):
        if idx == (0,0):
            return [[]]
        else:
            pointer_backtrace = [backtrace+[idx]
                                 for prev_idx in pointers[idx]
                                 for backtrace in self._construct_backtrace(pointers,
                                                                            prev_idx)]
            
            return pointer_backtrace

    def _add_sentinel(self, source, target):
        if isinstance(source, str):
            source = '#'+source
        elif isinstance(source, list):
            source = ['#'] + source
        elif isinstance(source, tuple):
            source = ('#',) + source
        else:
            raise ValueError('source must be str, list, or tuple')
            
        if isinstance(target, str):
            target = '#' + target
        elif isinstance(target, list):
            target = ['#'] + target
        elif isinstance(target, tuple):
            target = ('#',) + target
        else:
            raise ValueError('target must be str, list, or tuple')
            
        return source, target
editdist = StringEdit4(1, 1)

editdist('æbstɹækʃən', 'æbstɹækt')
(4.0,
 [[(0, 0),
   (1, 1),
   (2, 2),
   (3, 3),
   (4, 4),
   (5, 5),
   (6, 6),
   (6, 7),
   (7, 7),
   (8, 7),
   (9, 7)],
  [(0, 0),
   (1, 1),
   (2, 2),
   (3, 3),
   (4, 4),
   (5, 5),
   (6, 6),
   (7, 7),
   (8, 7),
   (9, 7)],
  [(0, 0),
   (1, 1),
   (2, 2),
   (3, 3),
   (4, 4),
   (5, 5),
   (6, 6),
   (7, 6),
   (7, 7),
   (8, 7),
   (9, 7)],
  [(0, 0),
   (1, 1),
   (2, 2),
   (3, 3),
   (4, 4),
   (5, 5),
   (6, 6),
   (7, 6),
   (8, 7),
   (9, 7)],
  [(0, 0),
   (1, 1),
   (2, 2),
   (3, 3),
   (4, 4),
   (5, 5),
   (6, 6),
   (7, 6),
   (8, 6),
   (8, 7),
   (9, 7)],
  [(0, 0),
   (1, 1),
   (2, 2),
   (3, 3),
   (4, 4),
   (5, 5),
   (6, 6),
   (7, 6),
   (8, 6),
   (9, 7)],
  [(0, 0),
   (1, 1),
   (2, 2),
   (3, 3),
   (4, 4),
   (5, 5),
   (6, 6),
   (7, 6),
   (8, 6),
   (9, 6),
   (9, 7)]])

This isn’t particularly interpretable, so we can postprocess the output slightly to better see what’s going on.

class StringEdit5(StringEdit4):
    '''distance, alignment, and edit paths between strings


    Parameters
    ----------
    insertion_cost : float
    deletion_cost : float
    substitution_cost : float | NoneType (default: None)
    '''

    def __call__(self, source: str, 
                 target: str) ->  Tuple[float, List[List[Tuple[str, str]]]]:
        distance, alignment = self._wagner_fisher(source, target)
        
        return distance, [[(source[i[0]], 
                            target[i[1]]) 
                           for i in a] 
                          for a in alignment]
editdist = StringEdit5(1, 1)

editdist('æbstɹækʃən', 'æbstɹækt')
(4.0,
 [[('æ', 'æ'),
   ('b', 'b'),
   ('s', 's'),
   ('t', 't'),
   ('ɹ', 'ɹ'),
   ('æ', 'æ'),
   ('k', 'k'),
   ('k', 't'),
   ('ʃ', 't'),
   ('ə', 't'),
   ('n', 't')],
  [('æ', 'æ'),
   ('b', 'b'),
   ('s', 's'),
   ('t', 't'),
   ('ɹ', 'ɹ'),
   ('æ', 'æ'),
   ('k', 'k'),
   ('ʃ', 't'),
   ('ə', 't'),
   ('n', 't')],
  [('æ', 'æ'),
   ('b', 'b'),
   ('s', 's'),
   ('t', 't'),
   ('ɹ', 'ɹ'),
   ('æ', 'æ'),
   ('k', 'k'),
   ('ʃ', 'k'),
   ('ʃ', 't'),
   ('ə', 't'),
   ('n', 't')],
  [('æ', 'æ'),
   ('b', 'b'),
   ('s', 's'),
   ('t', 't'),
   ('ɹ', 'ɹ'),
   ('æ', 'æ'),
   ('k', 'k'),
   ('ʃ', 'k'),
   ('ə', 't'),
   ('n', 't')],
  [('æ', 'æ'),
   ('b', 'b'),
   ('s', 's'),
   ('t', 't'),
   ('ɹ', 'ɹ'),
   ('æ', 'æ'),
   ('k', 'k'),
   ('ʃ', 'k'),
   ('ə', 'k'),
   ('ə', 't'),
   ('n', 't')],
  [('æ', 'æ'),
   ('b', 'b'),
   ('s', 's'),
   ('t', 't'),
   ('ɹ', 'ɹ'),
   ('æ', 'æ'),
   ('k', 'k'),
   ('ʃ', 'k'),
   ('ə', 'k'),
   ('n', 't')],
  [('æ', 'æ'),
   ('b', 'b'),
   ('s', 's'),
   ('t', 't'),
   ('ɹ', 'ɹ'),
   ('æ', 'æ'),
   ('k', 'k'),
   ('ʃ', 'k'),
   ('ə', 'k'),
   ('n', 'k'),
   ('n', 't')]])