mirror of
https://github.com/likelovewant/ollama-for-amd.git
synced 2025-12-21 14:26:30 +00:00
embeddings: modified batch size (#13429)
This PR detects embedding models and sets batch_size = context_size so the full input fits in a single batch. Previously, if batch size was smaller than the input, tokens could be split across batches and cause a SIGTRAP crash. This change ensures all tokens stay in one batch and prevents crashes. Fixes: #12938 #13054 Co-authored-by: Jesse Gross <jesse@ollama.com>
This commit is contained in:
@@ -474,6 +474,13 @@ func (s *llamaServer) Load(ctx context.Context, systemInfo ml.SystemInfo, system
|
||||
s.mem.GPUs[i].Cache = make([]uint64, s.totalLayers)
|
||||
}
|
||||
|
||||
// Check if embedding model and adjust batch size accordingly
|
||||
_, isEmbedding := s.ggml.KV()[fmt.Sprintf("%s.pooling_type", s.ggml.KV().Architecture())]
|
||||
if isEmbedding && s.loadRequest.BatchSize < s.options.NumCtx {
|
||||
s.loadRequest.BatchSize = s.options.NumCtx
|
||||
slog.Info("embedding model detected, setting batch size to context length", "batch_size", s.loadRequest.BatchSize)
|
||||
}
|
||||
|
||||
kv, graphPartialOffload, graphFullOffload := s.ggml.GraphSize(uint64(s.options.NumCtx), uint64(s.loadRequest.BatchSize),
|
||||
s.loadRequest.Parallel, s.loadRequest.KvCacheType, s.loadRequest.FlashAttention)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user