Probabilistic Context-Free Grammars for Morphological Analysis

Introduction to Probabilistic Context-Free Grammars

So far, we’ve explored context-free grammars (CFGs) as powerful tools for modeling hierarchical structures in language. In this section, we’ll extend our framework to probabilistic context-free grammars (PCFGs), which augment CFGs with probabilities associated with each rule. This probabilistic extension allows us to:

  1. Model the likelihood of different grammatical structures
  2. Disambiguate between multiple possible analyses
  3. Learn grammars from data in an unsupervised or semi-supervised manner
  4. Quantify the well-formedness of novel word formations

PCFGs are particularly useful for morphological analysis, where we often want to determine not just whether a word can be derived according to a grammar, but also how likely that derivation is compared to alternatives.

Formal Definition of PCFGs

A probabilistic context-free grammar is a 5-tuple \(G = (V, \Sigma, R, S, P)\) where:

  1. \(V\) is a finite set of variables (non-terminals)
  2. \(\Sigma\) is a finite set of terminals (the alphabet)
  3. \(R\) is a finite set of rules of the form \(A \rightarrow \alpha\) where \(A \in V\) and \(\alpha \in (V \cup \Sigma)^*\)
  4. \(S \in V\) is the start variable
  5. \(P\) is a probability function that assigns a probability \(P(A \rightarrow \alpha)\) to each rule, such that for each \(A \in V\): \[\sum_{\alpha : (A \rightarrow \alpha) \in R} P(A \rightarrow \alpha) = 1\]

In other words, the probabilities of all rules with the same left-hand side must sum to 1, forming a conditional probability distribution \(P(\alpha | A)\).

The Probabilistic Rule Class

Let’s extend our existing Rule class to incorporate probabilities. For computational convenience, we’ll work with log probabilities rather than raw probabilities:

class ProbabilisticRule(Rule):
    """A probabilistic context-free grammar rule

    Parameters
    ----------
    left_side : str
        The left-hand side variable
    right_side : *str
        The right-hand side sequence of variables and terminals
    logprob : float, optional
        The log probability of the rule, by default None

    Attributes
    ----------
    left_side : str
    right_side : tuple(str)
    logprob : float
    """

    def __init__(self, left_side, *right_side, logprob=None):
        super().__init__(left_side, *right_side)
        self._logprob = logprob

    def __repr__(self) -> str:
        rule = self._left_side + ' -> ' + ' '.join(self._right_side)
        
        if self._logprob is None:
            return rule
        else:
            return rule + ' [' + f"{self._logprob:.4f}" + ']'
    
    @property
    def logprob(self) -> float:
        return self._logprob
    
    @logprob.setter
    def logprob(self, value: float) -> None:
        self._logprob = value

The Probabilistic Context-Free Grammar Class

Now, let’s define our ProbabilisticContextFreeGrammar class, extending the basic ContextFreeGrammar:

import numpy as np
from typing import Dict, List, Set, Tuple, Optional, Union
from scipy.special import logsumexp

class ProbabilisticContextFreeGrammar:
    """
    A probabilistic context-free grammar in Chomsky Normal Form

    Parameters
    ----------
    alphabet : set(str)
        The set of terminal symbols
    variables : set(str)
        The set of non-terminal symbols
    rules : set(ProbabilisticRule)
        The set of production rules with associated probabilities
    start_variable : str
        The start symbol of the grammar

    Attributes
    ----------
    alphabet : set(str)
    variables : set(str)
    rules : set(ProbabilisticRule)
    start_variable : str
    """
    
    def __init__(self, alphabet, variables, rules, start_variable):
        self._alphabet = alphabet
        self._variables = variables
        self._rules = rules
        self._start_variable = start_variable
        
        self._validate_variables()
        self._validate_rules()

        self._parser = CKYParser(self)

    def __call__(self, string, mode="recognize"):
        return self._parser(string, mode)
        
    def _validate_variables(self):
        try:
            assert not self._alphabet & self._variables
        except AssertionError:
            raise ValueError('alphabet and variables must not share elements')
        
        try:
            assert self._start_variable in self._variables
        except AssertionError:
            raise ValueError('start variable must be in set of variables')

    def _validate_rules(self):
        # Validate that rules are well-formed
        for r in self._rules:
            r.validate(self._alphabet, self._variables)
        
        # Check that the probabilities form valid distributions
        log_probs = {}
        for r in self._rules:
            if r.left_side not in log_probs:
                log_probs[r.left_side] = []
            log_probs[r.left_side].append(r.logprob)
        
        # For each non-terminal, check that its rules sum to 1 (or close to 1 for numerical stability)
        for nt, lps in log_probs.items():
            sum_prob = np.exp(logsumexp(lps))
            if not np.isclose(sum_prob, 1.0, rtol=1e-5):
                raise ValueError(f'Rules for {nt} sum to {sum_prob}, not 1.0')

    @property            
    def alphabet(self) -> Set[str]:
        return self._alphabet

    @property    
    def variables(self) -> Set[str]:
        return self._variables

    @property    
    def rules(self) -> Set[ProbabilisticRule]:
        return self._rules

    @property
    def start_variable(self) -> str:
        return self._start_variable
        
    def reduce(self, *right_side) -> Set[Tuple[str, float]]:
        """
        Find all non-terminals that can be rewritten as right_side, with their log probabilities

        Parameters
        ----------
        right_side : tuple(str)
            The right-hand side of a rule

        Returns
        -------
        set(tuple(str, float))
            Set of (non-terminal, log probability) pairs
        """
        return {(r.left_side, r.logprob) for r in self._rules
                if r.right_side == tuple(right_side)}

The Inside Algorithm for Computing String Probabilities

In a PCFG, each parse tree has a probability, which is the product of the probabilities of all rules used in the derivation. The probability of a string in the language is the sum of the probabilities of all possible parse trees for that string.

To compute this efficiently, we use the Inside algorithm, which is a dynamic programming algorithm similar to the CKY algorithm. The Inside algorithm computes what are called “inside probabilities” \(\beta_{i,j}(A)\), which represent the probability that non-terminal \(A\) derives the substring \(w_i...w_{j-1}\).

For a grammar in Chomsky Normal Form, the inside probabilities are defined recursively:

\[ \beta_{i,j}(A) = \begin{cases} P(A \rightarrow w_i) & \text{if } j = i + 1 \\ \sum_{k=i+1}^{j-1} \sum_{B,C \in V} P(A \rightarrow B C) \cdot \beta_{i,k}(B) \cdot \beta_{k,j}(C) & \text{if } j > i + 1 \end{cases} \]

The probability of the whole string \(w\) according to the grammar is then \(\beta_{0,n}(S)\), where \(n\) is the length of \(w\) and \(S\) is the start symbol.

Implementation of the Inside Algorithm

Let’s implement the Inside algorithm as part of our CKY parser:

class CKYChart:
    """
    A chart for a CKY parser

    Parameters
    ----------
    input_size : int
        The length of the input string
    """

    def __init__(self, input_size):
        self._input_size = input_size
        
        # Initialize the chart
        self._chart = [[set()
                        for _ in range(input_size+1)]
                       for _ in range(input_size+1)]
        
    def __getitem__(self, index):
        i, j = index
        self._validate_index(i, j)
        return self._chart[i][j]

    def __setitem__(self, key, item):
        i, j = key
        self._chart[i][j] = item
        
    def _validate_index(self, i, j):
        try:
            assert i < j
        except AssertionError:
            raise ValueError("cannot index into the lower triangle of the chart")

        try:
            self._chart[i]
        except IndexError:
            raise ValueError("row index is too large")

        try:
            self._chart[i][j]
        except IndexError:
            raise ValueError("column index is too large")

    @property
    def input_size(self) -> int:
        return self._input_size

class CKYChartEntry:
    """
    A chart entry for a CKY chart

    Parameters
    ----------
    symbol : str
        The non-terminal symbol
    logprob : float
        The log probability
    backpointers : tuple(CKYChartEntry), optional
        Pointers to child entries (for parsing)
    """

    def __init__(self, symbol, logprob, *backpointers):
        self._symbol = symbol
        self._logprob = logprob
        self._backpointers = backpointers

    def __hash__(self) -> int:
        return hash((self._symbol, self._backpointers))

    def __eq__(self, other) -> bool:
        return (self._symbol == other._symbol and 
                self._backpointers == other._backpointers)

    def __repr__(self) -> str:
        if self._backpointers:
            return f"{self._symbol}[{self._logprob:.4f}] -> {' '.join(bp.symbol for bp in self._backpointers)}"
        else:
            return f"{self._symbol}[{self._logprob:.4f}]"

    @property
    def symbol(self) -> str:
        return self._symbol

    @property
    def logprob(self) -> float:
        return self._logprob

    @property
    def backpointers(self) -> Tuple["CKYChartEntry", ...]:
        return self._backpointers

class CKYParser:
    """
    A CKY parser for probabilistic context-free grammars

    Parameters
    ----------
    grammar : ProbabilisticContextFreeGrammar
        The grammar to parse with
    """
    
    def __init__(self, grammar):
        self._grammar = grammar
    
    def __call__(self, string, mode="recognize"):
        if mode == "recognize":
            return self._recognize(string)
        elif mode == "parse":
            return self._parse(string)            
        else:
            raise ValueError('mode must be "parse" or "recognize"')
    
    def _fill_chart(self, string, keep_backtraces=True):
        """
        Fill the CKY chart using the inside algorithm
        
        Parameters
        ----------
        string : list(str)
            The input string
        keep_backtraces : bool, optional
            Whether to keep backtraces for parsing, by default True
            
        Returns
        -------
        CKYChart
            The filled chart
        """
        input_size = len(string)
        chart = CKYChart(input_size)
        
        # Base case: fill in chart[i,i+1] for each terminal
        for j in range(1, input_size+1):
            i = j - 1
            # Get all non-terminals that can derive this terminal
            for symbol, logprob in self._grammar.reduce(string[i]):
                chart[i, j].add(CKYChartEntry(symbol, logprob, CKYChartEntry(string[i], 0.0)))

        # Recursive case: fill in chart[i,j] for j-i > 1
        for length in range(2, input_size+1):
            for i in range(input_size - length + 1):
                j = i + length
                
                # For each possible split point k
                for k in range(i+1, j):
                    # For each pair of non-terminals B, C where B derives w[i:k] and C derives w[k:j]
                    for left_entry in chart[i, k]:
                        for right_entry in chart[k, j]:
                            # Find all non-terminals A such that A -> B C
                            for parent, rule_logprob in self._grammar.reduce(left_entry.symbol, right_entry.symbol):
                                # Compute the probability of this derivation
                                entry_logprob = rule_logprob + left_entry.logprob + right_entry.logprob
                                
                                if keep_backtraces:
                                    # Create a new chart entry with backpointers
                                    new_entry = CKYChartEntry(parent, entry_logprob, left_entry, right_entry)
                                    chart[i, j].add(new_entry)
                                else:
                                    # Update the aggregate probability for this non-terminal
                                    existing_entries = [e for e in chart[i, j] if e.symbol == parent]
                                    if existing_entries:
                                        # Combine probabilities using logsumexp for numerical stability
                                        old_logprob = existing_entries[0].logprob
                                        new_logprob = logsumexp([old_logprob, entry_logprob])
                                        chart[i, j].remove(existing_entries[0])
                                        chart[i, j].add(CKYChartEntry(parent, new_logprob))
                                    else:
                                        chart[i, j].add(CKYChartEntry(parent, entry_logprob))

        return chart

    def _parse(self, string):
        """
        Parse a string to find all possible parse trees
        
        Parameters
        ----------
        string : list(str) or str
            The string to parse
            
        Returns
        -------
        list(tuple(Tree, float))
            List of parse trees with their log probabilities
        """
        # Ensure input is a list of tokens
        if isinstance(string, str):
            string = string.split()
        
        chart = self._fill_chart(string)
        
        # Construct parse trees from the chart
        parses = []
        for entry in chart[0, len(string)]:
            if entry.symbol == self._grammar.start_variable:
                parse_tree = self._build_tree(entry)
                parses.append((parse_tree, entry.logprob))
        
        return parses
    
    def _build_tree(self, entry):
        """
        Build a parse tree from a chart entry
        
        Parameters
        ----------
        entry : CKYChartEntry
            The chart entry
            
        Returns
        -------
        Tree
            The parse tree
        """
        if not entry.backpointers:
            # This is a terminal node
            return Tree(entry.symbol, [])
        else:
            # This is a non-terminal node
            children = [self._build_tree(bp) for bp in entry.backpointers]
            return Tree(entry.symbol, children)
        
    def _recognize(self, string):
        """
        Compute the log probability of a string under the grammar
        
        Parameters
        ----------
        string : list(str) or str
            The string to recognize
            
        Returns
        -------
        float
            The log probability of the string, or -inf if the string is not in the language
        """
        # Ensure input is a list of tokens
        if isinstance(string, str):
            string = string.split()
        
        chart = self._fill_chart(string, keep_backtraces=False)
        
        # Find the probability of the string by looking for the start symbol in the root of the chart
        logprobs = [entry.logprob for entry in chart[0, len(string)] 
                  if entry.symbol == self._grammar.start_variable]
        
        if logprobs:
            # If there are multiple parse trees, sum their probabilities
            return logsumexp(logprobs)
        else:
            # String is not in the language
            return float('-inf')

Application to Morphological Analysis with CELEX

Now let’s apply our PCFG framework to morphological analysis using the CELEX database. CELEX provides morphological parses for thousands of English words, which we can use to estimate a PCFG for morphological analysis.

Example: Analyzing “unhappiness”

Let’s work through a detailed example of how to use our PCFG framework to analyze the morphological structure of a word like “unhappiness”:

# First, let's load CELEX data
import pyparsing

LPAR = pyparsing.Suppress('(')
RPAR = pyparsing.Suppress(')')
LBRAC = pyparsing.Suppress('[')
RBRAC = pyparsing.Suppress(']')

tag = pyparsing.Regex(r'[^\(\)\[\]]+')
morpheme = pyparsing.Regex(r'[a-z]+')

exp = pyparsing.Forward()
tree = pyparsing.Group(LPAR + pyparsing.delimitedList(exp) + RPAR + LBRAC + tag + RBRAC)
exp <<= morpheme | tree

class ParsedWord:
    """A morphological parse
    
    Parameters
    ----------
    data : str
        The root of the tree (always a category)
    word : str, optional
        The word (only exists at the root of the parse)
    children : list(ParsedWord), optional
        The child parses
    """
    
    PARSER = exp

    def __init__(self, data, word=None, children=None):
        self.word = word
        self.data = data
        self.children = children or []
    
    @classmethod
    def from_string(cls, word, tree):
        treelist = cls.PARSER.parseString(tree)[0]
        return cls.from_list(treelist, word)
    
    @classmethod
    def from_list(cls, treelist, word=None):
        if all(isinstance(x, str) for x in treelist):
            return cls(treelist[1], word, [cls(treelist[0])])
        else:
            return cls(treelist[-1], word,
                       [cls.from_list(l) for l in treelist[:-1]])

    def flatten(self):
        """The flattened version of the parse"""
        if len(self.children) > 1:
            return [morph for child in self.children for morph in child.flatten()]
        else:
            return [(self.data, self.children[0].data)]
            
    @property
    def rules(self):
        """Extract CFG rules from this parsed word"""
        if self.children:
            rule = Rule(self.data, *[c.data for c in self.children], logprob=None)
            child_rules = [r for c in self.children for r in c.rules]
            
            return [rule] + child_rules
                   
        else:
            return []

# Let's parse a specific example from CELEX
unhappiness_tree = "(((un)[Prefix] (happy)[Adj])[Adj] (ness)[Suffix])[N]"
unhappiness = ParsedWord.from_string('unhappiness', unhappiness_tree)

# Extract the rules from this parse
rules = unhappiness.rules
for rule in rules:
    print(rule)

# Now let's estimate a PCFG from a treebank of CELEX parses
# Assuming we have a list of parsed words:
parsed_words = [unhappiness]  # In reality, this would be many more examples

# Estimate the grammar
pcfg_estimator = PCFGEstimator()
pcfg = pcfg_estimator.fit(parsed_words)

# Let's compute the probability of "unhappiness"
morphemes = ['un', 'happy', 'ness']
log_prob = pcfg(morphemes)
print(f"Log probability of 'unhappiness': {log_prob}")

# We can also get the most likely parse
parses = pcfg(morphemes, mode="parse")
best_parse, best_logprob = max(parses, key=lambda x: x[1])
print(f"Best parse: {best_parse}")
print(f"Best parse log probability: {best_logprob}")

Evaluating Novel Morphological Forms

One of the key applications of PCFGs in morphology is evaluating the well-formedness of novel word formations. For example, we can compare the probabilities of attested and unattested forms:

# Attested forms
attested = [
    ['un', 'happy', 'ness'],
    ['re', 'consider', 'ation'],
    ['dis', 'agree', 'ment']
]

# Unattested but well-formed forms
well_formed_unattested = [
    ['un', 'belief', 'able', 'ness'],
    ['re', 'present', 'ative', 'ness'],
    ['anti', 'establish', 'ment', 'arian', 'ism']
]

# Ill-formed examples
ill_formed = [
    ['ness', 'un', 'happy'],  # Wrong order
    ['un', 'ness', 'happy'],  # Wrong order
    ['un', 'ly', 'happy']     # Incompatible affixes
]

# Compute log probabilities for each group
attested_probs = [pcfg(word) for word in attested]
well_formed_probs = [pcfg(word) for word in well_formed_unattested]
ill_formed_probs = [pcfg(word) for word in ill_formed]

# Print average log probabilities
print(f"Average log probability of attested forms: {np.mean(attested_probs)}")
print(f"Average log probability of well-formed but unattested forms: {np.mean(well_formed_probs)}")
print(f"Average log probability of ill-formed examples: {np.mean(ill_formed_probs)}")