TL;DR instead of just getting a token and seeing if it would be accepted by the parser, you can actually zero-out probabilities for all invalid tokens, and do the computation for this in parallel at effectively zero cost:
> Here, compute_mask() can run on the CPU during the time it would be normally just waiting for the GPU to finish. The line prob[~mask] = 0.0 would normally be fused into the softmax kernel in the last stage of the LLM, with negligible overhead. Therefore, as long as the compute_mask() function completes faster than the LLM forward pass and parser.consume() is negligible (typically follows from compute_mask() speed), the constrained generation will be as fast as the unconstrained one.
I'm curious - have there been any research/conversations about pushing masking even earlier in the pipeline? In theory, there's a fair amount of compute that goes into computing the probability of tokens that will end up being masked away anyways.
TL;DR instead of just getting a token and seeing if it would be accepted by the parser, you can actually zero-out probabilities for all invalid tokens, and do the computation for this in parallel at effectively zero cost:
> Here, compute_mask() can run on the CPU during the time it would be normally just waiting for the GPU to finish. The line prob[~mask] = 0.0 would normally be fused into the softmax kernel in the last stage of the LLM, with negligible overhead. Therefore, as long as the compute_mask() function completes faster than the LLM forward pass and parser.consume() is negligible (typically follows from compute_mask() speed), the constrained generation will be as fast as the unconstrained one.
I'm curious - have there been any research/conversations about pushing masking even earlier in the pipeline? In theory, there's a fair amount of compute that goes into computing the probability of tokens that will end up being masked away anyways.