Distributional representations from language models

A lot of recent work uses the representations learned by a language model both to develop systems for performing natural language processing tasks but also as a means of taking an analysis-driven approach to scientific questions. Remember that in this sort of approach, one aims to learn highly expressive representations and then extract generalizations about those representations post hoc.

Language models can also be useful in a hypothesis-driven approach, but our interest will not be in the representations they themselves learn. Rather, we’ll be interested in what those representations allow us to avoid doing. Specifically, we’ll use language models as a component in the models we develop in this section as a means for providing a rich representation of the distributional properites of items in a sentence in a context where we want to view those properties largely as nuisance variables.

To understand how one derives such representations from a language model, we first need to discuss what a language model is in the classical sense and how the particular language model we’ll use is related to the classical notion of a language model.

What is a language model

In the classical sense, a language model is a probability distribution \(p(\mathbf{w})\) over strings \(\mathbf{w} \in \Sigma^*\) built from some vocabulary \(\Sigma\). Language models can be parameterized in a wide variety of ways. One way is to define \(p(\mathbf{w})\) in terms of the probability that some probabilistic grammar–e.g. a weighted finite state automaton (WFSAs) or probabilistic context free grammar–assigns to \(\mathbf{w}\), summing across analyses that the grammar assigns.

A specific case of this idea that has been important for the development of modern large language models is the family of \(n\)-gram models. This family of models is derived by starting from the fact that, if we view \(p(\mathbf{w})\) as a joint probability, we can rewrite it in the following way by the chain rule.1

\[p(w_1 \ldots w_L) = p(w_1)p(w_2 \mid w_1) \ldots p(w_L \mid w_1 \ldots w_{L-1}) = p(w_1)\prod_{i=1}^L p(w_i \mid w_1 \ldots w_{i-1})\]

The crucial modeling assumption that \(n\)-gram models make is that \(W_i\) is conditionally independent of \(\{W_j \mid j < i - (n - 1)\}\) given \(\{W_j \mid i > j > i - (n - 1)\}\).2

\[p(w_i \mid w_1 \ldots w_{i-1}) = p(w_i \mid w_{i - (n - 1)} \ldots w_{i-1})\]

This assumption doesn’t tell us how to compute \(p(w_i \mid w_{i - (n - 1)} \ldots w_{i-1})\). Generally, an \(n\)-gram model will assume that:

\[W_i \mid W_{i - (n - 1)} = w_{i - (n - 1)} \ldots W_{i-1} = w_{i-1} \sim \text{Cat}(\boldsymbol\theta_{w_{i - (n - 1)} \ldots w_{i-1}})\]

Under this assumption, every substring \(\mathbf{w} \in \Sigma^{n-1}\) has its own \(\boldsymbol\theta_\mathbf{w}\). This assumption in turn implies that we must estimate \(\boldsymbol\theta\)s for every one of the \(|\Sigma^{n-1}| = |\Sigma|^{n-1}\) possible substrings of length \(n-1\), where each \(\boldsymbol\theta\) itself contains \(|\Sigma|\) parameters.

The idea behind neural language models is to use an alternative parameterization of the distribution of \(W_i\).

Neural language models

The trick to understanding neural language models is to see that, even if we constrain ourselves to categorical distributions, the distributon of \(W_i\) given its string context \(w_1 \ldots w_{i-1} w_{i+1} \ldots w_L\) can be defined in terms of an arbitrary function of that context.3

For instance, suppose we start from the factorization of the joint probability \(p(w_1 \ldots w_L) = p(w_1)\prod_{i=1}^L p(w_i \mid w_1 \ldots w_{i-1})\) that we discussed above–no \(n\)-gram assumption. What we need to compute the probabilities in this product is a way of mapping from an arbitrary substring to the parameters of a categorical distribution over \(\Sigma\). Let’s call this mapping \(f\).

In the context of \(n\)-gram models, \(f\) is sort of trivial: \(f(w_1\ldots w_{i-1}) = \boldsymbol\theta_{w_{i - (n - 1)} \ldots w_{i-1}}\), and we assume we somehow know \(\boldsymbol\theta_{w_{i - (n - 1)} \ldots w_{i-1}}\) (e.g. because we estimated it using MLE or MAP estimation or whatever). But \(f\) need not be so trivial.

One way to make it nontrivial is to define \(f\) in such a way that it can handle strings of arbitrary length by costructing a compressed representation of those strings and then mapping that representation to the parameters of a categorical distribution on \(\Sigma\). This idea can be implemented in a variety of ways. One way to do it–popular in the early days of neural language models–is to use recurrent neural networks (RNNs). The more common approach now is to use transformers.

In both cases, \(f\) is a parameterized function, whose parameters can be trained using gradient-based optimization by taking their derivative relative to the log-likelihood \(\log p(\mathbf{w})\).

RNN language models

An RNN language model is generally defined in terms of three components: (i) an embedding method; (ii) an RNN cell; and (iii) a language modeling head.

Embedding method

An embedding module implements some method \(e\) of mapping elements \(w \in \Sigma\) to some representation (or embedding) of those elements \(\mathbf{x} \in \mathbb{R}^{D_\text{vocab}}\). A simple variant of such a module simply keeps these embeddings in a matrix \(\mathbf{X}\) and returns \(e(w) = \mathbf{x}_w\), but alternative variants exist. This simple variant is what Embedding implements in torch.nn.

from torch.nn import Embedding

RNN cell

The RNN cell is the workhorse of an RNN language model in the sense that it is what is used to compute the representation \(\mathbf{h}_\mathbf{w} \in \mathbb{R}^{D_\text{string}}\) of a string. RNN cells can have more or less complex structure. The simplest variant if often called an Elman RNN cell (Elman 1990). This variant defines a function \(g: \mathbb{R}^{D_\text{string}} \times \mathbb{R}^{D_\text{vocab}} \rightarrow \mathbb{R}^{D_\text{string}}\) that constructs the representation of a string \(\mathbf{w'}w\) from the representation of its prefix \(\mathbf{w'}\) and the representation of the input element \(w\).

\[\mathbf{h}_{\mathbf{w'}w} = g(\mathbf{h}_\mathbf{w'}, e(w))\]

This function could itself be arbitrarily complex. Generally, it is defined in terms of an affine map with parameters \(\mathbf{W}_1 \in \mathbb{R}^{D_\text{string} \times D_\text{string}}\), \(\mathbf{W}_2 \in \mathbb{R}^{D_\text{string} \times D_\text{vocab}}\), \(b \in \mathbb{R}\) composed with some pointwise nonlinearity \(\sigma\)–usually a logistic \(\text{logit}^{-1}\) or a hyperbolic tangent \(\text{tanh}\):

\[g\left(\mathbf{h}_\mathbf{w'}, e(w)\right) \equiv \sigma\left(\mathbf{W}_1\mathbf{h}_\mathbf{w'} + \mathbf{W}_2e(w) + b\right)\]

One thing that the nonlinearity functions to do is to keep the elements of \(\mathbf{h}_\mathbf{w}\) from getting very large as \(\mathbf{w}\) gets longer.

Elman RNN cells are implemented as the basic RNNCell in torch.nn.

from torch.nn import RNNCell

These cells are in turn bundled into a container module RNN.

from torch.nn import RNN

What this bundling allows one to do is to easily pass a sequence of embeddings \(e(w_1)e(w_2)\ldots e(w_L)\) for the string of interest \(\mathbf{w}\) and return a sequence of representations \(\mathbf{h}_{w_1}\mathbf{h}_{w_1w_2}\ldots \mathbf{h}_\mathbf{w}\). These representations can then

Language modeling head

A language modeling head \(m\) maps the representation of a string \(\mathbf{h}_\mathbf{w'}\) to the parameters of a probability distribution over the vocabulary items \(\Sigma\). In the set up we’ve been discussing, this distribution is used to model \(p(w_i \mid w_1\ldots w_{i-1})\) by defining:

\[\begin{align*} \boldsymbol\theta_{w_1\ldots w_{i-1}} &= m\left(\mathbf{h}_{w_1\ldots w_{i-1}}\right)\\ &= m\left(g\left(\mathbf{h}_{w_1\ldots w_{i-2}}, e(w_{i-1}\right)\right)\\ &= m\left(g\left(g\left(\mathbf{h}_{w_1\ldots w_{i-3}}, e(w_{i-2})\right), e(w_{i-1}\right)\right)\\ \end{align*}\]

Like the other components, the language modeling head can take a variety of forms. One of the simplest is to apply an affine map \(\mathbf{V} \in \mathbb{R}^{|\Sigma| \times D_\text{string}}, \mathbf{b} \in \mathbb{R}^{|\Sigma|}\) to \(\mathbf{h}_\mathbf{w'}\) and then send it through a softmax function.

\[m(\mathbf{h}) = \text{softmax}\left(\mathbf{V}\mathbf{h} + \mathbf{b}\right)\]

where \(\text{softmax}(\mathbf{x}) = \left[\frac{\exp(x_1)}{\sum_i \exp(x_i)}, \frac{\exp(x_2)}{\sum_i \exp(x_i)}, \ldots\right]\).

from torch import Tensor
from torch.nn import Module, Sequential, Linear, Softmax

class LanguageModelHead(Module):
    def __init__(self, string_dim: int, n_vocab: int):
        self.head = Sequential(
            Linear(string_dim, n_vocab),
            Softmax(n_vocab)
        )
        
    def forward(self, x: Tensor) -> Tensor:
        return self.head(x)

Composing the three

We can then build a torch.nn.Module that composes the three.

class RNNLanguageModel(Module):
    def __init__(self, vocab_dim, string_dim: int, n_vocab: int):
        self.embeddings = Embedding(
            n_vocab + 1, vocab_dim, 
            padding_idx = 0
        )
        self.rnn = RNN(vocab_dim, string_dim)
        self.lm_head = LanguageModelHead(string_dim, n_vocab)
        
    def forward(self, strings_hashed: Tensor) -> Tensor:
        words_embedded = self.embeddings(strings_hashed)
        strings_embedded = self.rnn(words_embedded)
        next_word_probs = self.lm_head(strings_embedded)
        
        return next_word_probs

Extensions

For the most part, differences among neural language models come down to how the component that constructs the string representations \(\mathbf{h}_\mathbf{w}\) is set up. A few ideas that were pursued heavily for a while in the NLP literature relied on modification to the core RNN architecture:

  1. Modifying the form of the RNN cell–e.g. using a long short term memory (LSTM) or gated recurrent unit (GRU), rather than an Elman cell.
  2. Stacking multiple RNN cells on top of each other so that we have multiple layers of representation \(\mathbf{h}^{(l)}_{\mathbf{w'}w} = g\left(\mathbf{h}^{(l)}_\mathbf{w}, \mathbf{h}^{(l-1)}_{\mathbf{w'}w}\right)\), where \(\mathbf{h}^{(l-1)}_{\mathbf{w'}w} \equiv e(w)\) and the parameters of \(g\) generally differ by layer.
  3. Having RNNs that provide representations for both the forward factorization of \(p(\mathbf{w}) = p(w_1)\prod_{i=1}^L p(w_i \mid w_1 \ldots w_{i-1})\) and the backward factorization \(p(\mathbf{w}) = p(w_L)\prod_{i=L}^1 p(w_i \mid w_{i+1} \ldots w_L)\) by defining a forward representation \(\mathbf{h}^{\rightarrow}_{\mathbf{w'}w} = g\left(\mathbf{h}^{\rightarrow}_\mathbf{w}, e(w)\right)\) and a backward representation \(\mathbf{h}^{\leftarrow}_{\mathbf{w'}w} = g\left(\mathbf{h}^{\leftarrow}_\mathbf{w}, e(w)\right)\), and the parameters of \(g\) generally differ by direction.

The stacking and bidirectionality ideas were one ingredient along the path toward modern language models–most of which use transformers.

Transformer language models

A popular alternative to RNNs are transformers, which are what most modern language models use. The transformers that researchers use in practice are somewhat complex, but they boil down to a fairly simple idea. I’m going to discuss a very stripped down version of what a transformer in terms of this simple idea. But just note that, in practice, their internals aren’t as simple as what I’m about to lay out.

In the context of building a neural language model, we want our transformer cells to perform effectively the same function that our RNN cells did: construct a representation \(\mathbf{h}_\mathbf{w}\) of a string \(\mathbf{w}\), which we can use to predict the next word in sequence. To do this, transformer cells use a self-attention module \(a\) that maps (i) a query vector \(\mathbf{q} \in \mathbb{R}^{D_\text{query}}\) (for query); (ii) a collection of key vectors \(\mathbf{K} \in \mathbb{R}^{L \times D_\text{query}}\) (for keys); and (iii) a collection of value vectors \(\mathbf{V}_\text{in} \in \mathbb{R}^{L \times D_\text{value}}\) to an output value \(\mathbf{v}_\text{out} \in \mathbb{R}^{D_\text{value}}\). It does this by looking at the similarity of each key to the query–as measured by the dot product–and using that similarity to compute a weighted sum over the values.

There are a few ways to implement this idea, but transformers use a particular form of dot-product attention. A very simple form of dot-product attention is:4

\[a\left(\mathbf{q}, \mathbf{K}, \mathbf{V}\right) = \text{softmax}\left(\mathbf{K}\mathbf{q}\right)\mathbf{V}\]

where the softmax can be thought of as producing the parameters of a categorical distribution over the \(L\) values, which we take the expectation of.

An extremely stripped down variant of a transformer would then define \(\mathbf{h}_\mathbf{w}\) in terms of dot-product attention. For simplicity, let’s assume that \(D_\text{vocab} = D_\text{string} = D_\text{query} = D_\text{value}\).5 We can then treat the representations of previous words as both the keys and the values.

\[\mathbf{h}_{\mathbf{w'}w} = a\left(e(w), \begin{bmatrix} e(w'_1) \\ e(w'_2) \\ \ldots \\ e(w'_L)\end{bmatrix}, \begin{bmatrix} e(w'_1) \\ e(w'_2) \\ \ldots \\ e(w'_L)\end{bmatrix}\right)\]

And just as in RNNs, we can create a stack of representations.

\[\mathbf{h}^{(l)}_{\mathbf{w'}w} = a\left(\mathbf{h}^{(l-1)}_{\mathbf{w'}w}, \begin{bmatrix} \mathbf{h}^{(l-1)}_{w'_1} \\ \mathbf{h}^{(l-1)}_{w'_1w'_2} \\ \ldots \\ \mathbf{h}^{(l-1)}_\mathbf{w'}\end{bmatrix}, \begin{bmatrix} \mathbf{h}^{(l-1)}_{w'_1} \\ \mathbf{h}^{(l-1)}_{w'_1w'_2} \\ \ldots \\ \mathbf{h}^{(l-1)}_\mathbf{w'}\end{bmatrix}\right)\]

where, again, \(\mathbf{h}^{(0)}_{\mathbf{w'}w} \equiv e(w)\).

This setup satisfies the requirement that we be able to build a representation of the substring prior to some word \(w\) and us it to predict \(w\). But there’s something a bit off: in contrast to the RNN, the representation cannot capture the ordering of the words. To see this, note that we could randomly permute the string \(\mathbf{w'}\) before \(w\) and still get the same \(\mathbf{h}_{\mathbf{w'}w}\).

To handle this lack of information about relative position, transformer language models generally concatenate the input embeddings with a positional encoding. What this positional encoding looks like isn’t so important for our purposes. Just know that it is a vector that provides information about the relative position of a word in a string.

Representation from language models

In learning a representation \(\mathbf{h}_\mathbf{w}\) for a string, language models are definitionally learning a representation of the distribution of the items in that string: the embedding \(e(w)\) provide a type-level (or static) representation for the word \(w\) and the representation \(\mathbf{h}_{\mathbf{w'}w}\) provides a token-level (or contextual) representation for the word w in the context of \(\mathbf{w'}\). We often loosely talk about this representation as “semantic” because a linguistic expression’s distributional properties are correlated with some aspects of its semantics; but strictly speaking, these representations are fully distributional in nature.

The token-level distributional representations \(\mathbf{h}_{\mathbf{w'}w}\) have turned out to be very useful as a way to provide inputs to some system that benefits from representations of the distributional properties. This is particularly true of systems that need some representation of an expression’s meaning, since (as I just noted) distributional properties correlate with some aspects of meaning.

But they can also be useful for systems that simply need a good representation of the distributional properties of an expression. Indeed, if we only need the representations to be representations of distributional properties, we are on much firmer ground in using these representations in theory-building than if we furthermore need good representations of the meaning, because the relationship between meaning and distribution is one important part of what we’re developing theories about when we’re developing syntactic and semantic theories.

In this context, where we want maximally good distributional representations, one question that arises is whether only modeling the prior context of a particular element is the right thing to do. Wouldn’t it be a lot better if that representation were sensitive to the entirety of the element’s context? For an element \(w_i\), not only the forward context \(w_1\ldots w_{i-1}\), but also the backward context \(w_{i+1}\ldots w_{L}\).

We already briefly saw one way to try to capture both the forward and backward contexts: simultaneously train one language model against the forward factorization of the joint probability to obtain \(\mathbf{h}^{\rightarrow}_{\mathbf{w'}w}\) as well as another language model against the backward factorization to obtain \(\mathbf{h}^{\leftarrow}_{\mathbf{w'}w}\). Then, we can treat the concatenation of the two as the representation of \(w\).

An alternative approach that works–particularly well in the context of transformer-based models–is to derive a token-level representation \(\mathbf{h}^{(l)}_i\) for \(w_i\) from the entire context of \(w_i\).

\[\mathbf{h}^{(l)}_i = a\left(\mathbf{h}^{(l-1)}_i, \begin{bmatrix} \mathbf{h}^{(l-1)}_1 \\ \mathbf{h}^{(l-1)}_2 \\ \ldots \\ \mathbf{h}^{(l-1)}_N\end{bmatrix}, \begin{bmatrix} \mathbf{h}^{(l-1)}_1 \\ \mathbf{h}^{(l-1)}_2 \\ \ldots \\ \mathbf{h}^{(l-1)}_N\end{bmatrix}\right)\]

where \(\mathbf{h}^{(0)}_i \equiv e(w_i)\).

One question that arises here is how we could train such a system, since the representation \(\mathbf{h}^{(l)}_i\) is sensitive to the identity of \(w_i\) via its embedding. One answer is to move away from a language modeling objective of maximizing \(p(\mathbf{w})\) to an alternative objective.

Neural non-language models

In modern parlance, the term language model has come to encompass more than just models that compute \(p(\mathbf{w})\). One important example of this broadening of the term is the introduction of masked language models, which we will use for providing our distributional representations. Rather than being trained to compute \(p(\mathbf{w})\) via some factorization of that probability, a masked language model is trained to compute \(p\left(w_{i_1}, w_{i_2}, \ldots, w_{i_M} \mid \{w_j \mid j \not\in \{i_1, \ldots, i_M\}\}\right)\), where the words at positions \(\{i_1, \ldots, i_M\}\) are masked by replacing them with a special token [MASK] \(\not\in \Sigma\) that has its own embedding \(e(\) [MASK] \()\). The objective is then to predict the identity of the masked \(w_{i_j}\) given the representations \(\mathbf{h}^{(l)}_{i_j}\) by applying a language modeling head to each such representation.

One early example of this sort of model was [Bidirectional Encoder Representations from Transformers] (BERT), which incorprates not only a masked language modeling object, but also a next sentence prediction objective (Devlin et al. 2019). A popular related model that only uses the masked language modeling objective is RoBERTa, which we will use to provide us with our distributional representation in this module (Liu et al. 2019). Neither BERT nor RoBERTa are even close to the best models for representing distributional properties around nowadays; but RoBERTa is sufficient for our purposes, and relatively small in terms of parameters as modern neural language models go.

Summing up

We looked at how neural language models can be used to derive both type- and token-level representations of linguistic expressions with the aim of using these models as components in our own models. Remember that, in our hypothesis-driven approach, our interest will not be in the representations these models themselves learn. Rather, we’ll be interested in what those representations allow us to avoid doing. Specifically, we’ll use language models as a component in the models we develop in this section as a means for providing a rich representation of the distributional properites of items in a sentence in a context where we want to view those properties largely as nuisance variables.

References

Devlin, Jacob, Ming-Wei Chang, Kenton Lee, and Kristina Toutanova. 2019. BERT: Pre-Training of Deep Bidirectional Transformers for Language Understanding.” In Proceedings of the 2019 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies, Volume 1 (Long and Short Papers), 4171–86. Minneapolis, Minnesota: Association for Computational Linguistics. https://doi.org/10.18653/v1/N19-1423.
Elman, Jeffrey L. 1990. “Finding Structure in Time.” Cognitive Science 14 (2): 179–211.
Liu, Yinhan, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, and Veselin Stoyanov. 2019. “RoBERTa: A Robustly Optimized BERT Pretraining Approach.” CoRR abs/1907.11692. http://arxiv.org/abs/1907.11692.

Footnotes

  1. This way of decomposing the joint probability is not the only one implied by the chain rule. For instance, we could arbitrarily permute the indices with a function \(\pi\) and then compute \(p(w_1 \ldots w_L) = p(w_{\pi^{-1}(1)})p(w_{\pi^{-1}(2)} \mid w_{\pi^{-1}(1)}) \ldots p(w_{\pi^{-1}(L)} \mid w_{\pi^{-1}(1)} \ldots w_{\pi^{-1}(L-1)}) = p(w_{\pi^{-1}(1)})\prod_{i=1}^L p(w_{\pi^{-1}(i)} \mid w_{\pi^{-1}(1)} \ldots w_{\pi^{-1}(i-1)})\), and the expression would still conform to the chain rule. We just gravitate toward the expression in terms of observed string position because it’s, in a sense, the most straightforward.↩︎

  2. At base, \(n\)-gram models are WFSAs whose states are strings \(\mathbf{w} \in \Sigma^{n-1}\) representing the previous \(n-1\) and whose transitions never produce an empty string–i.e. they have no \(\epsilon\)-transitions. The probability \(p(\mathbf{w})\) is more straightforward to compute under an \(n\)-gram model than an arbitrary WFSA because, in an \(n\)-gram model, the states themselves are assume to be observed, so we don’t need to marginalize over them–as we do, for instance, in computing language model probabilities for hidden Markov models.↩︎

  3. We also need to ensure that we satisfy the assumption of unit measure: \(\mathbb{P}(\Sigma^*) = \sum_{\mathbf{w} \in \Sigma^*} p(\mathbf{w}) = 1\). I’m going to ignore this point for the purposes of this discussion.↩︎

  4. This form is not exactly the one used in transformers, which scales the dot-product by \(\sqrt{D_\text{query}}\): \(a\left(\mathbf{q}, \mathbf{K}, \mathbf{V}\right) = \text{softmax}\left(\frac{\mathbf{K}\mathbf{q}}{\sqrt{D_\text{query}}}\right)\mathbf{V}\). This scaling serves to tamp down the dot-products so that they don’t get huge as the dimensions grow.↩︎

  5. If any are not equal, we simply define a parameterized mapping from \(\mathbb{R}^{D_1}\) to \(\mathbb{R}^{D_2}\) in order to get them in the same vector space.↩︎