ericjwang.com

Constrained sampling is planning

(Work in progress.)

Introducing the problem

OpenAI recently released a blog post detailing some of its correspondence with Elon Musk. Although large sections of this correspondence were redacted, observant readers realized that the lengths of the redacted words were leaked in the HTML of the censor bars.

varepsilon

(@varepsilon)

This raises an obvious research question. Is it possible to recover a redacted text sequence given:

Many smart people think the answer is obvious. But they are split on whether it is obviously yes or obviously no.

What would it mean for the text to be recoverable? Certainly not that the text is uniquely determined by the lengths. We can go word by word and construct syntactically correct sentences satisfying the constraint, but the semantics will be highly unlikely — to say nothing of the pitfalls in numerical figures and typographical errors.

We might say that the text is recoverable if there exists a completion that seems significantly more plausible than any other, or if all “most plausible” completions are semantically similar up to details like capitalization or specific numerical figures. Then:

The latter seems like a higher bar. Regardless, both require us to produce plausible completions, which will involve a painstaking, backtracking-like approach that is well-suited to a computer search.


Constrained sampling is not sampling from a constrained distribution

Any algorithmic resolution to the problem is going to involve the use of a reasonably powerful open-source language model to construct a measure of “plausibility” that incorporates the semantic and grammatical information. Specifically:

What do we want to do with this distribution? Intuitively we want to find token sequences \(x_{\leq N}\) that maximize it. If we can produce two such sequences comparable in machine-assessed or human-assessed quality, the problem is underdetermined. However, this distribution is nasty:

This intractibility comes from an autoregressive LLM’s inability to effectively take future constraints into consideration. But if this is the case, how do constrained sampling frameworks like Guidance or Outlines work?

One good overview can be found here; but generally, these frameworks compute only the likelihood conditional on the next token noting violate the constraint; that is, if we let \(C_i\) be the event that \(x_{\leq i}\) does not violate the constraint, they sample from the distribution

\[q_\theta(x_{\leq N}) \triangleq \prod_{i = 1}^N q_\theta(x_i \mid x_{< i}) \triangleq \prod_{i = 1}^N p_\theta(x_i \mid x_{< i}, C_{i}).\]

The advantage of \(q_\theta\) is that it’s a distribution over \(X_{C_N}\) that we can actually sample from. The disadvantage is that it’s also a terrible estimate of \(p_\theta(x_i \mid x_{<i}, C_N)\) in many cases. If we are looking for an eight-letter word to follow the quick brown fox jumped over the lazy, the token _dog will satisfy the constraint — but the five-letter continuations that follow will all have higher net perplexity than _kangaroo.


Let’s reframe the problem slightly. What we want is to find \(x_{\leq N}\) that minimizes the following “surprisal” function:

\[f_\theta(x_{\leq N}) = \begin{cases} - \log p_\theta(x_{\leq N}) & \text{if $x_{\leq N} \in X_{C_N}$} \\ +\infty & \text{otherwise.} \end{cases}\]

But \(p_\theta(x_{\leq N})\) still factors:

\[\log p_\theta(x_{\leq N}) = \sum_{i=1}^N \log p_\theta(x_i|x_{<i})\]

So getting relatively low-surprisal samples is a tree search problem:

How does this relate to the existing literature for language model decoding strategies? Lilian Weng has a good overview of that literature on her blog post on Controllable Neural Text Generation. One takeaway is that a good amount of work on decoding strategies pre-GPT-3.5 was concerned with the unreliability of \(p_\theta\), which seems less relevant today now that LLMs are less prone to mode collapse or strange outliers. For instance, Meister et al. (2020) proposes smoothing “overfit” logits with a regularization term, but we’ve already been trusting our \(p_\theta\) enough to use it as a \(p\) this whole time.

Speaking of which — we have made no assumptions about \(p_\theta\) so far, but the choice of search algorithm is going to come down to the properties of this distribution.


Constructing the word tree

At this point we have the option of searching over a tree of tokens or of words. I prefer searching over words because we can decompose the task into two separate search problems for which we can define different heuristics.

What this means is that we implement a next-word subroutine \(NextWord(prompt, \ell)\) which outputs a non-exhaustive list of words \(w_i\) of length \(\ell\), each often consisting of multiple tokens, alongside the surprisals \(s_i = -\log p_\theta(w_i \mid prompt)\). These words and surprisals define the children and edge distances of \(prompt\) in the search tree.

We should take care when constructing this subroutine to carefully define the probability of a word. The LLaMA and Mistral tokenizers include spaces at the start of a token, which means that the probability the first word of our completion is “deepmind's” is actually something like

\[-\log p_\theta(\texttt{_deepmind} \mid \cdots) -\log p_\theta(\texttt{'s} \mid \cdots) \color{red}- \log \sum_{x \text{ starts with } \texttt{_}} p_\theta(x \mid \cdots).\]

Thus, the subroutine can be implemented as a breadth-first search over token sequences with the following adjustments:

We can see the results of one run of this subroutine below:

results

The final entry here is instructive; the model, thought an ArXiv link was a plausible continuation, and indeed https://ar is ten characters long. But the likelihood of a space after these ten characters is so vanishingly small that the word has over twice the surprisal of any of the others returned.


Parametrizing the tree search (WIP)

Our top-level search problem is as follows.

What function do we use to assign priorities to word sequences? Using the running surprisal is the same as breadth-first search, which is good enough for discovering words but certainly not up to the task of exploring a tree with depth 48 and a moderate branching factor.

Incorporating the heuristics above (A* search) seems like a good way to balance exploitation and exploration here. However, tuning these heuristics is difficult in practice — my experiments show that wordcount-based, character-count-based, and position-based heuristics often become too greedy later in the sequence even when they seem almost breadth-first near the start.

We run our experiments on my RTX 4090 with Mistral-7B on HuggingFace Transformers (mostly for convenience). Limiting the search space via old-fashioned beam search gets us closer to a coherent response but also becomes quite greedy due to the branching factor: with a beam size of 100, we quickly end up discarding all possibilities aside from something like —

working at the cutting edge of ai is unfortunately expensive. for example, developing something state-of-the-art in deep nets is mostly about the very large data and compute demands, which means [sic] running large neural net models and collecting very large datasets. for these purposes, large data centers and large google compute engine bills are required, with the hardware and data costs

(This was sampled from an earlier version of the code that didn’t constrain the last character to be a period.)

A completion like this may seem superficially plausible, although it still has some glaring errors: repetitive wording, uninformative content that doesn’t need to be redacted, no mention of DeepMind, and an anomalous reference to GCE.

Regardless, my thinking is that to make further progress we need to manage the exploitation-exploration tradeoff directly. We may need to enforce a certain amount of diversity in our exploration, in terms of explored sequence length (multiple queues for each length, with a scheduler that determines which queue to explore from) and/or set of words chosen (contrastive sampling).

Another approach is to fine-tune a model to accept conditional generation tokens:

Finally, a good number of performance optimizations are available to make the model run faster. We can squeeze out some small performance gains by increasing the model’s batch size. Moreover, because the tree structure of our token sequences is well understood, reusing the “common ancestor” KV cache between evaluations and across batches is another obvious optimization.

Work continues in fits and starts. As of this writing, I’m running another beam search with larger beam size and smaller branching factor.