At a very high level, instead of having embeddings at the input layers, this method keeps the embeddings at the layer level. That is every transformer layer would have its own set of learnable embedding vectors that are used to modify the processed hidden states flowing through the network. Mostly, the embeddings are precomputed and stored separately. They are queried at inference time and has very low latency, so you can get comparable performance with half the RAM. (i am not exactly sure how 3n is doing it, but talking it in a general sense).
I simplified what i wrote. There is an off accelerator memory where the embeddings are stored and queried at inference time, i did not want to get into details. That is how you reduce the in memory RAM. There are definitely more things going on in the paper as it builds upon the concept I described. The central idea remains the same: you have input embedding layers which map text to continuous vectors. Instead of loading all these layers at runtime, you can break it per layer at training time, and then fetch the required ones from a separate store during inference. Would not be in RAM. Per layer is not mentioned in the paper. But surely it's not a great leap from the paper itself?
The name "per-layer embeddings" is all we have to go on, and there are currently no published papers (that I'm aware of) using any similar mechanism, so, yes, it's a huge leap from a paper that doesn't mention per-layer anything.
It's fine to speculate based on the name, but don't pretend that it's a known technique when it clearly isn't.
Someone [1] inspected dimensions of the embedding component of model and it seems GP was on the right track. Assuming I understood correctly in [2], it does seem to be the embedding of the input tokens which is passed directly into each layer.
I have not looked at the model but since the embedding dimension of 256 seems quite small (for reference according to [3] the old Gemma 1B had 1152 dimension input embedding), I'm guessing that this is not done _in lieu_ of the main input embedding to first layer, but in addition to it.
At a very high level, instead of having embeddings at the input layers, this method keeps the embeddings at the layer level. That is every transformer layer would have its own set of learnable embedding vectors that are used to modify the processed hidden states flowing through the network. Mostly, the embeddings are precomputed and stored separately. They are queried at inference time and has very low latency, so you can get comparable performance with half the RAM. (i am not exactly sure how 3n is doing it, but talking it in a general sense).