by thepasch 3 days ago

If I’m reading this right, this is pretty wild. They turned a Qwen autoregressor into a diffuser by using a bunch of really clever techniques, and they vastly outperform any “native diffuser,” actually being competitive with the base model they were trained from. The obvious upside here is the massive speedup in generation.

And then through a LoRA adapter, you can ground the diffuser on the base model’s distribution (essentially have it “compare” its proposals against what the base model would’ve generated), which effectively means: exact same byte-for-byte output for the same seed, just roughly twice as fast (which should improve even more for batched tasks).

I’m not an expert, more of a “practicing enthusiast,” so I might be missing something, but at first glance, this reads super exciting to me.

oliver236 3 days ago | [-0 more]

I think your excitement is justified. The paper is claiming a serious bridge between AR quality and parallel decoding, and the lossless LoRA-assisted mode is the wildest part.

awestroke 3 days ago | [-9 more]

I don't understand how you can compare against the base model output without generating with the base model, in which case what's the point?

radarsat1 3 days ago | [-0 more]

Because the nature of transformers is that running a bunch of pregenerated tokens through them is a parallel operation, not autoregressive. That's how it works at training time, but speculative decoding uses it at inference time. So if you just want to check whether a set of known tokens is "likely" given the base model, you can run them all through and get probability distributions, no need to sample.

It's the same reason there's a difference in speed between "prompt processing" and "generation". The former is just taking the pre-generated prompt and building the KV cache, which is parallel, not autoregressive and therefore way faster.

qeternity 3 days ago | [-0 more]

I haven't read TFA yet but a common technique is speculative decoding where a fast draft model will generate X tokens, which are then verified by the larger target model. The target model may accept some Y <= X tokens but the speedup comes from the fact that this can be done in parallel as a prefill operation due to the nature of transformers.

So let's say a draft model generates 5 tokens, all 5 of these can be verified in parallel with a single forward pass of the target model. The target model may only accept the first 4 tokens (or whatever) but as long as the 5 forward passes of the draft model + 1 prefill of the target model is faster than 4 forward passes of the target, you will have a speedup while maintaining the exact output distribution as the target.

nodja 3 days ago | [-3 more]

Same reason why prompt processing is faster than text generation.

When you already know the tokens ahead of time you can calculate the probabilities of all tokens batched together, incurring significant bandwidth savings. This won't work if you're already compute bound so people with macs/etc. won't get as much benefits from this.

Majromax 2 days ago | [-2 more]

Are Macs/etc compute bound with their 'it fits in unified memory' language models? Certainly by the time you're streaming weights from SSD you must be back in a bandwidth-bound regime.

dd8601fn 2 days ago | [-0 more]

From what I understood, if we’re talking a single user on a mac (not batching) you’re rarely compute bound in the first place. More rows per pass is nearly free that way when cores were sitting idle anyway.

If that’s wrong I would certainly appreciate being corrected, though. But if it’s right, a 2.9x speed-up after rejected tokens, nearly for free, sounds amazing.

nodja 2 days ago | [-0 more]

That will depend on the model, but they'll hit compute limits before a typical GPU in almost all cases. Macs will still benefit a speedup from this, just not one as big as the one reported.

Balinares 3 days ago | [-0 more]

Isn't that exactly how draft models speed up inference, though? Validating a batch of tokens is significantly faster than generating them.

anentropic 3 days ago | [-0 more]

presumably that happens at training time?

then once successfully trained you get faster inference from just the diffusion model

a1j9o94 3 days ago | [-0 more]

You would only use the base model during training. This is a distillation technique

porridgeraisin 3 days ago | [-3 more]

Eh. There is nothing diffusion about this. Nothing to do with denoising. This setup is still purely causal, making it quite a dishonest framing IMO. There is no more introspection here than what happens in MTP + SD setups.

Let me explain what is going on here. This is basically a form of multi-token prediction. And speculative decoding in inference. See my earlier post[1] to understand what that is. TL;DR, in multi-token prediction you train separate LM heads to predict the next as well as next to next token as well as... Upto chosen next kth token. Training multiple LM heads is expensive and can be unnecessary, so what people typically do is have a common base for all the k heads, explained further in [1]. These guys do another variant.

Here is what they do mechanically, given a sequence p consisting of five tokens PE([p1, p2, p3, p4, p5]). Where PE(.) adds relative position info to each token.

1. Create an augmented sequence PE([p1 MASK MASK MASK MASK]). Do a training pass on that, with the ground truth sequence p1..5. Here it is trained to, for example, to predict p3 given p1+pos=-2 MASK+pos=-1 MASK+pos=0, loosely notating.

2. Then separately[2], train it as usual on PE([p1 p2 p3 p4 p5]).

Step (1) teaches it to do multi-token prediction, essentially the single LM head will (very very loosely speaking) condition on the position `k` of the special MASK token and "route" it to the "implicit" k'th LM head.

Step (2) teaches it to be a usual LLM and predict the next token. No MASK tokens involved.

So far, you have trained a multi-token predictor.

Now during inference

You use this for speculative decoding. You generate 5 tokens ahead at once with MASK tokens. And then you run that sequence through the LLM again. This has the same benefits as usual speculative decoding, namely that you can do matrix-matrix multiplication as opposed to matrix-vector. The former is more memory-bandwidth efficient due to higher arithmetic intensity.

here is an example,

query = ["what", "is", "2+2"]) prompt = PE([...query, MASK*5]) you run output = LLM(prompt). Say output is ["what", "is", "2+2", "it", "is", "4"]. Note that the NN is trained to predict the kth next token when faced with positionally encoded MASK tokens. So you get all 5 in one go. To be precise, it learns to predict "4" given ["what", "is", "2+2", MASK, MASK]. Since it does not need the "it" and "is" explicitly, you can do it in parallel with generating the "it" and the "is". "is" is predicted given ["what", "is", "2+2", MASK], for example, and that also doesn't depend on the explicit "it" being there, and thus can also be done in parallel with generating "it", which is just normal generating the next token given the query. And then you use this as a draft in your speculative decoding setup.

Their claim is that using a multi-token predictor this way as a draft model works really well. To be clear, this is still causal, the reason diffusion models have hype is because they are capable of global refinement. This is not. In the same thread as [1], I explain how increasing the number of MASK tokens, i.e increasing `k`, i.e the number of tokens you predict at once in your multi-token prediction setup quickly leads to poor quality. This paper agrees with that. They try out k=2,3,4,8. They see a drop in quality at 8 itself. So finally, this is 4-token-prediction with self-speculative decoding(sans LayerSkip or such), removing seemingly no existing limitation of such setups. It is definitely an interesting way to train MTP though.

[1] https://news.ycombinator.com/item?id=45221692

[2] Note that it is computationally a single forward pass. Attention masks help you fuse steps 1 and 2 into a single operation. However, you still have 2 separate loss values.

Reubend a day ago | [-1 more]

After trying to understand their method, I think you're right. Doesn't seem like anything that I would personally call "diffusion". Much closer to MTP + speculative decoding.

Then again, their results with it are great. It would be interesting to benchmark it against standard SD on a model that already uses MTP.

porridgeraisin a day ago | [-0 more]

Yeah, I think it's a super neat way to do MTP. Conceptually much more pleasing and simple than existing methods. Especially since this way scaling `k` as models get better will be easier. Wish it had been presented as such.

radarsat1 2 days ago | [-0 more]

This reminds me a lot of the tricks to turn BERT into a generative model. I guess the causal masking that keeps it to essentially be autoregressive is an important difference though. Kind of best of both worlds.