mirror of
https://github.com/likelovewant/ollama-for-amd.git
synced 2025-12-21 22:33:56 +00:00
runner.go: Retry decoding after defragmentation if needed
Fragmentation of the KV cache can occur due to cache shifting or different sequences getting processed. Decode uses a heuristic to decide if it should defrag. However, this heuristic isn't 100% accurate, so decoding can sometimes fail by surprise. For these cases, if decode indicates that there is no KV cache space, we should defrag and then try again.
This commit is contained in:
@@ -157,9 +157,7 @@ type Context struct {
|
||||
numThreads int
|
||||
}
|
||||
|
||||
func (c *Context) KvCacheClear() {
|
||||
C.llama_kv_cache_clear(c.c)
|
||||
}
|
||||
var ErrKvCacheFull = errors.New("could not find a kv cache slot")
|
||||
|
||||
func (c *Context) Decode(batch *Batch) error {
|
||||
// Positive return values does not mean a fatal error, but rather a warning.
|
||||
@@ -173,7 +171,7 @@ func (c *Context) Decode(batch *Batch) error {
|
||||
}
|
||||
|
||||
if code > 0 {
|
||||
return fmt.Errorf("could not find a KV slot for the batch - try reducing the size of the batch or increase the context. code: %d", code)
|
||||
return ErrKvCacheFull
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -195,6 +193,14 @@ func (c *Context) KvCacheSeqCp(srcSeqId int, dstSeqId int, p0 int, p1 int) {
|
||||
C.llama_kv_cache_seq_cp(c.c, C.int(srcSeqId), C.int(dstSeqId), C.int(p0), C.int(p1))
|
||||
}
|
||||
|
||||
func (c *Context) KvCacheClear() {
|
||||
C.llama_kv_cache_clear(c.c)
|
||||
}
|
||||
|
||||
func (c *Context) KvCacheDefrag() {
|
||||
C.llama_kv_cache_defrag(c.c)
|
||||
}
|
||||
|
||||
// Get the embeddings for a sequence id
|
||||
func (c *Context) GetEmbeddingsSeq(seqId int) []float32 {
|
||||
embeddings := unsafe.Pointer(C.llama_get_embeddings_seq(c.c, C.int(seqId)))
|
||||
|
||||
@@ -426,8 +426,15 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
||||
|
||||
err := s.lc.Decode(batch)
|
||||
if err != nil {
|
||||
slog.Error("failed to decode batch", "error", err)
|
||||
return
|
||||
if errors.Is(err, llama.ErrKvCacheFull) {
|
||||
slog.Debug("defragmenting kv cache")
|
||||
s.cache.lc.KvCacheDefrag()
|
||||
err = s.lc.Decode(batch)
|
||||
}
|
||||
if err != nil {
|
||||
slog.Error("failed to decode batch", "error", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if crossAttention {
|
||||
|
||||
Reference in New Issue
Block a user