LMQL's constraint language comes with a set of standard operators that can be combined. However, it is also possible to implement custom operations that enable the validation of more complex properties, while maintaining composability with built-in operations.
This chapter will explain how to implement a custom LMQL operator, using a two custom operators.
Implementing a Custom Operator
To implement a custom operator, you need to define a class that inherits from
lmql.ops.Node. The class must be decorated with the
@LMQLOp decorator, which takes the name of the operator as an argument. The operator name can then be used in LMQL queries to refer to the custom operation.
The basic interface of a custom operator is given by the following methods:
from lmql.ops import LMQLOp, Node, InOpStrInSet class CustomOperator(Node): # provides the value semantics of the operator def forward(self, *args, **kwargs): ... # provides a lookahead on the result of the operator (for next-token masking) def follow(self, *args, **kwargs): ... # provides the definitiveness of a result (temporary or final) def final(self, *args, **kwargs): ...
forward(self, *args, **kwargs)This method implements the forward semantics of you operator, i.e. it determines the result of your operator
custom_operator(*args), i.e. whether something you are validating holds or does not hold for the specified arguments. This method is called for every token in the model output, and thus allows you to implement early stopping during generation.
follow(self, x, **kwargs): This method implements the follow semantics of your operator, i.e. it provides a lookahead on the result of your operator, given the current value of the specified arguments. The return value of this function is a so-called follow map, that differentiates among the different classes of continuation tokens. This method is called at least once on each token in the model output, to determine the set of allowed next tokens.
final(self, x, **kwargs): Lastly``,
finalimplements the final semantics of your operator, i.e. it determines whether the result of your operator is final with respect to the current model output. This allows you to differentiate between temporary and final validation results, which is important to prevent early termination in cases of only temporary violations.
Possible return values are
"fin"for temporary and final (definitive) results, respectively. Going beyond boolean logic, LMQL also supports
"dec"to indicate monotonically growing and shrinking results, respectively. This is useful character counting, strings and lists. For instance, the input variable
xwill have either
"inc"final-ness, to indicate that it is a monotonically growing string (currently being generated), or
"fin"final-ness, to indicate that it has reached its final value (query execution has completed decoding its value).
Depending on your use case, it may suffice to only implement
forward (to enable early stopping during generation). However, for full masking support, you also have to implement
Example: Implementing a
To demonstrate, we implement an operator that forces the model to always generate
bar after generating
foo. To clarify, the following sequences should be valid and invalid, respectively:
"Hello foo!" # invalid "Hello foo bar!" # valid "Hello foo bar foo bar!" # valid "Hello foo bar foo!" # invalid
To enforce this constraint, we implement a custom operator
foo_bar as follows:
import re from lmql.ops import LMQLOp, Node, fmap, tset class FooBar(Node): def forward(self, x, **kwargs): # no restrictions, if x is not yet generated and there is no "foo" in x if x is None or "foo" not in x: return True # make sure foo is always followed by " bar" bar_segment = x.rsplit("foo",1) return bar_segment.startswith(" bar") or len(bar_segment) == 0 def follow(self, x, **kwargs): # (1) no restrictions, if x is not yet generated if x is None: return True # (2) no restrictions, if there is no "foo" in x if "foo" not in x: return True # (3) get current segment after last "foo" bar_segment = x.rsplit("foo",1) # 3(a) already satisfied if bar_segment.startswith(" bar"): return True # 3(b) force continuations to align with " bar" return fmap( (tset(" bar", regex=True), True), # " bar" -> True ("*", False) # anything else -> False ) def final(self, x, result=None, **kwargs): # violations are definitive if not result: return "fin" # otherwise, depends on definitiveness of x return "fin" if x == "fin" else "var"
forward() implementation: First, we implement the
forward method: To validate the
bar property, we check that any string segment following a potential
foo substring aligns with
" bar". For this, we have to consider the case of x being none (not yet generated), a partial match (including empty strings), and a match that extends beyond
" bar". Depending on model tokenization
forward() may be called on any such variation, and thus has to be able to handle all of them.
follow() implementation: Next, we implement the
follow method. For this, we again consider multiple cases:
xis not yet generated, we do not have to restrict the next token.
If there is no
x, we do not have to restrict the next token.
If there is a
fooright at the end of
x, we restrict as follows:
(a) If the segment already starts with
" bar", no restrictions are necessary.
(b) Otherwise, we restrict the next token to
" bar". For this, we construct a so-called follow map, a mapping of token ranges to the future evaluation result of our operator, if the next token is in the specified range.
barcase it suffices to indicate that a continuation of
True, and any other continuation to
False. This is achieved by the
fmapfunction, which constructs a follow map from a list of token ranges and their respective future evaluation results. To construct token ranges the
tsetconstructor can be used, which allows to select tokens by length, set, prefix or regex, independent from the concrete tokenizer in use. In this case we use the
regex=Trueoption, to automatically select all tokens that fully or partially match
final() implementation: Lastly, we implement the
final method. This indicates to the LMQL runtime, whether a result of our custom operator is final with respect to the current model output or temporary. In this case, a return value of
False can always be considered final (a definitive violation, warranting early termination). Otherwise, we have to consider the definitiveness of the current value of
x is final, then the result of our operator is also final. Otherwise, it is temporary, as a further continuation of
x may still result in satisfying the constraint.
Note: For illustrative purposes, 3b of our
follow()implementation simplifies an important detail about token alignment. It only consider the case, where
follow()is called right at the end of
"foo", i.e. depending on model behavior and tokenization,
follow()may also run on a partial result like
"foo b", where the correct follow map should indicate
"ar"as valid continuation not the full
" bar". To handle this, one can simply rely on the implementation of built-in InOpStrInSet (implementation of constraint
VAR in [...]), and replace the
InOpStrInSet().follow(bar_segment, [" bar"]), which will automatically handle all such cases.
Using the Custom Constraint Operator
To use the custom constraint operator, you can simply import it in your query context and use it as follows:
"Say 'foo':[A]" where foo_bar(A)
In general, you have to ensure that the
@LMQLOp decorator is executed in your current process before the query is parsed, e.g. by importing the module containing the operator implementation.
What Happens Under the Hood?
Given an operator implementation as above, the LMQL runtime will be able to both validate model output during generation and derive token-level prediction masks. For this,
final are called repeatedly during generation, with the current model output
x as input. This allows the runtime to derive both validity of the current model output, as well as next-token ranges for which the operator definitively (
final) evaluates to
False. Based on the correctness of the underlying implementation, this soundly ensures that the model will never select a
follow-masked token that would definitively violate your constraint.
Expressiveness of LMQL Constraints
LMQL constraints are applied eagerly during generation by relying on token masking. This means, that the model will not be able to generate any tokens that are masked by the constraints. However, naturally, this approach is limited with respects to expressiveness, since not all properties on text can be decided on a token-by-token basis. More specifically, expressiveness is limited to the validation of context-free languages. To enable safe use of token masking, LMQL's implementation of final/follow semantics provide a soundness guarantee with respect to token masking (see LMQL paper).
Nonetheless, due to eager evaluation of constraints during generation, LMQL constraints will trigger as soon as the model output violates the constraint definitively (i.e. the validation result is final), preventing the model from the costly generation of invalid output. This is an advantage over validation in post-processing, where violations may only be detected after the model has already generated a large amount of invalid output.