But it may very well be slower than just recompute it. At least for ordinary MHA and even GQA.
So, either a model arch woodoo significantly reducing kv cache size (while keeping roughly the same compute cost), or some really careful implementation moving kv cache of upcoming requests to devices in background [0].
[0] My back of envelop calc shows that even then it still does not make sense for, say, Llama 3 70B on H100s. Time to stare at TPU spec harder trying to make sense of it I guess.
It depends on how large the input prompt (previous context) is. Also, if you can keep cache on GPU with a LRU mechanism, for certain workloads it's very efficient.
You can also design an API optimized for batch workloads (say the same core prompt with different data for instruct-style reasoning) - that can result in large savings in those scenarios.
If you can pipeline upcoming requests and tie state to a specific request, doesn't that allow you to change how you design physical memory? (at least for inference)
Stupid question, but why wouldn't {extremely large slow-write, fast-read memory} + {smaller, very fast-write memory} be a feasible hardware architecture?
If you know many, many cycles ahead what you'll need to have loaded at a specific time.
Or hell, maybe it's time to go back to memory bank switching.
The throughput of the PCIe link between the CPU and GPU, is far less than the aggregate throughput of the internal interconnects between neighbouring tensor cores.
Matrix operations might flow a lot of data around — but that data flow is akin to a bunch of individual people travelling along the individual residential streets they live on. There's a lot of movement there, but also a lot of capacity for movement, because there's no bottleneck of everyone needing to go to the same place or come from the same place.
Persisting the data out of the GPU and then loading it back in, is more like all those people commuting to work and then going back home. Big fan-in onto the PCIe "highway" over to the CPU and into RAM; then big fan-out back. Traffic jams for miles.
In the time it takes to restore a 1GB state snapshot from RAM into VRAM, you can probably chew through the equivalent of 1TB or more of intermediate matrix states.
I don’t know of any public details on how they implement Context Caching, but that is presumably exactly what they are doing. Just caching the text would be a minimal savings.
The second doesn't require any matrix operations, it's just setting some values.