Merge branch 'ollama:main' into main

This commit is contained in:
likelovewant
2024-11-17 22:54:33 +08:00
committed by GitHub
24 changed files with 351 additions and 155 deletions

View File

@@ -21,6 +21,8 @@ package llama
#cgo cuda CFLAGS: -fPIE -DGGML_USE_CUDA -DGGML_CUDA_DMMV_X=32 -DGGML_CUDA_PEER_MAX_BATCH_SIZE=128 -DGGML_CUDA_MMV_Y=1 -DGGML_BUILD=1
#cgo cuda CXXFLAGS: -DGGML_USE_CUDA -DGGML_CUDA_DMMV_X=32 -DGGML_CUDA_PEER_MAX_BATCH_SIZE=128 -DGGML_CUDA_MMV_Y=1 -DGGML_BUILD=1
#cgo cuda CXXFLAGS: -DGGML_USE_CUDA -DGGML_CUDA_DMMV_X=32 -DGGML_CUDA_PEER_MAX_BATCH_SIZE=128 -DGGML_CUDA_MMV_Y=1 -DGGML_BUILD=1
#cgo cuda_jetpack5 LDFLAGS: -lggml_cuda_jetpack5 -L/usr/local/cuda-11/lib64
#cgo cuda_jetpack6 LDFLAGS: -lggml_cuda_jetpack6 -L/usr/local/cuda-12/lib64
#cgo cuda_v11 LDFLAGS: -lggml_cuda_v11 -L/usr/local/cuda-11/lib64
#cgo cuda_v12 LDFLAGS: -lggml_cuda_v12 -L/usr/local/cuda-12/lib64
#cgo darwin,amd64 CFLAGS: -Wno-incompatible-pointer-types-discards-qualifiers
@@ -36,8 +38,8 @@ package llama
#cgo linux CXXFLAGS: -D_GNU_SOURCE
#cgo linux,amd64 LDFLAGS: -L${SRCDIR}/build/Linux/amd64
#cgo linux,amd64 LDFLAGS: -L${SRCDIR}/build/Linux/amd64
#cgo linux,arm64 CFLAGS: -D__aarch64__ -D__ARM_NEON -D__ARM_FEATURE_FMA -D__ARM_FEATURE_MATMUL_INT8
#cgo linux,arm64 CXXFLAGS: -D__aarch64__ -D__ARM_NEON -D__ARM_FEATURE_FMA -D__ARM_FEATURE_MATMUL_INT8
#cgo linux,arm64 CFLAGS: -D__aarch64__ -D__ARM_NEON -D__ARM_FEATURE_FMA
#cgo linux,arm64 CXXFLAGS: -D__aarch64__ -D__ARM_NEON -D__ARM_FEATURE_FMA
#cgo linux,arm64 LDFLAGS: -L${SRCDIR}/build/Linux/arm64
#cgo linux,arm64,sve CFLAGS: -march=armv8.6-a+sve
#cgo linux,arm64,sve CXXFLAGS: -march=armv8.6-a+sve
@@ -598,6 +600,10 @@ func (c *Context) SetCrossAttention(state bool) {
C.llama_set_cross_attention(c.c, C.bool(state))
}
func (c *Context) Synchronize() {
C.llama_synchronize(c.c)
}
// sampling
// TODO: this is a temporary wrapper to allow calling C++ code from CGo
type SamplingContext struct {

View File

@@ -20,7 +20,7 @@ GPU_COMPILER_CFLAGS_LINUX = $(CFLAGS) -Xcompiler -fPIC -D_GNU_SOURCE
GPU_COMPILER_CXXFLAGS_WIN = $(CXXFLAGS) -D_WIN32_WINNT=0x602
GPU_COMPILER_CXXFLAGS_LINUX = $(CXXFLAGS) -Xcompiler -fPIC -D_GNU_SOURCE
GPU_LIBS = $(sort $(wildcard $(addsuffix *.$(SHARED_EXT)*,$(addprefix $(GPU_LIB_DIR)/$(SHARED_PREFIX),$(GPU_RUNNER_LIBS_SHORT)))))
GPU_DIST_DEPS_LIBS= $(sort $(addprefix $(DIST_LIB_DIR)/,$(notdir $(GPU_LIBS))))
GPU_DIST_DEPS_LIBS= $(sort $(addprefix $(DIST_GPU_RUNNER_DEPS_DIR)/,$(notdir $(GPU_LIBS))))
ifeq ($(OS),linux)
CUDA_PATH?=/usr/local/cuda

View File

@@ -2,6 +2,7 @@ package main
import (
"errors"
"fmt"
"log/slog"
"reflect"
"time"
@@ -22,7 +23,11 @@ type InputCache struct {
lc *llama.Context
}
func NewInputCache(lc *llama.Context, kvSize int, numSlots int, multiUserCache bool) *InputCache {
func NewInputCache(lc *llama.Context, kvSize int, numSlots int, multiUserCache bool) (*InputCache, error) {
if kvSize/numSlots < 1 {
return nil, fmt.Errorf("must have at least one kv cache entry per parallel sequence (kv: %v parallel: %v)", kvSize, numSlots)
}
slots := make([]InputCacheSlot, numSlots)
for i := range slots {
@@ -37,7 +42,7 @@ func NewInputCache(lc *llama.Context, kvSize int, numSlots int, multiUserCache b
slots: slots,
multiUserCache: multiUserCache,
lc: lc,
}
}, nil
}
// Locking: Operations on InputCacheSlot (including finding one
@@ -58,7 +63,7 @@ type InputCacheSlot struct {
lastUsed time.Time
}
func (c *InputCache) LoadCacheSlot(prompt []input, cachePrompt bool) (*InputCacheSlot, []input, int, error) {
func (c *InputCache) LoadCacheSlot(prompt []input, cachePrompt bool) (*InputCacheSlot, []input, error) {
var slot *InputCacheSlot
var numPast int
var err error
@@ -75,7 +80,7 @@ func (c *InputCache) LoadCacheSlot(prompt []input, cachePrompt bool) (*InputCach
slot, numPast, err = c.findBestCacheSlot(prompt)
}
if err != nil {
return nil, nil, 0, err
return nil, nil, err
}
if !cachePrompt {
@@ -102,7 +107,7 @@ func (c *InputCache) LoadCacheSlot(prompt []input, cachePrompt bool) (*InputCach
prompt = prompt[numPast:]
slot.Inputs = slot.Inputs[:numPast]
return slot, prompt, numPast, nil
return slot, prompt, nil
}
func (c *InputCache) findLongestCacheSlot(prompt []input) (*InputCacheSlot, int, error) {
@@ -194,14 +199,30 @@ func countCommonPrefix(a []input, b []input) int {
return count
}
func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int, numDiscard int, numPast int) {
// TODO (jessegross): KV cache removal can fail for certain types of models
// server.cpp doesn't handle this, though we can be more graceful
c.lc.KvCacheSeqRm(slot.Id, numKeep, numKeep+numDiscard)
c.lc.KvCacheSeqAdd(slot.Id, numKeep+numDiscard, numPast, -numDiscard)
// Frees up space in the KV cache by deleting the oldest half of history and shifting
// the newest half into that space (saving numKeep inputs at the beginning).
//
// Assumes that at least 1 entry can be freed up by shifting (i.e. numKeep < numCtx)
func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int) {
targetFree := (c.numCtx - numKeep) / 2
targetFree = max(targetFree, 1)
for i := numKeep + numDiscard; i < len(slot.Inputs); i++ {
slot.Inputs[i-numDiscard] = slot.Inputs[i]
currentFree := c.numCtx - len(slot.Inputs)
discard := targetFree - currentFree
if discard <= 0 {
return
}
slot.Inputs = slot.Inputs[:len(slot.Inputs)-numDiscard]
slog.Debug("context limit hit - shifting", "limit", c.numCtx, "input", len(slot.Inputs),
"keep", numKeep, "discard", discard)
// TODO (jessegross): KV cache removal can fail for certain types of models
c.lc.KvCacheSeqRm(slot.Id, numKeep, numKeep+discard)
c.lc.KvCacheSeqAdd(slot.Id, numKeep+discard, len(slot.Inputs), -discard)
for i := numKeep + discard; i < len(slot.Inputs); i++ {
slot.Inputs[i-discard] = slot.Inputs[i]
}
slot.Inputs = slot.Inputs[:len(slot.Inputs)-discard]
}

View File

@@ -20,6 +20,8 @@ import (
"time"
"unicode/utf8"
"golang.org/x/sync/semaphore"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/llama"
)
@@ -34,9 +36,6 @@ type input struct {
}
type Sequence struct {
// number of inputs evaluated
numPast int
// batch index
iBatch int
@@ -112,21 +111,15 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
params.numKeep = len(inputs)
}
if !params.embedding {
// Subtracting 4 ensures that at least 1 input can be discarded during shift
params.numKeep = min(params.numKeep, s.cache.numCtx-4)
params.numKeep += s.bosToken
} else {
// Embeddings are 1 shot - just truncate to the context window, without ever shifting
params.numKeep = min(params.numKeep, s.cache.numCtx)
if s.model.AddBOSToken() {
params.numKeep += 1
}
// truncate to fit in context window
// Ensure that at least 1 input can be discarded during shift
params.numKeep = min(params.numKeep, s.cache.numCtx-1)
if len(inputs) > s.cache.numCtx {
slog.Warn("truncating input prompt", "limit", s.cache.numCtx, "prompt", len(inputs), "numKeep", params.numKeep)
newInputs := inputs[:params.numKeep]
newInputs = append(newInputs, inputs[len(inputs)-s.cache.numCtx+params.numKeep:]...)
inputs = newInputs
slog.Warn("input exceeds context length", "prompt", len(inputs), "limit", s.cache.numCtx)
}
var sc *llama.SamplingContext
@@ -170,15 +163,13 @@ func (s *Server) inputs(prompt string, images []ImageData) ([]input, error) {
for i, part := range parts {
// text - tokenize
if strings.TrimSpace(part) != "" {
tokens, err := s.lc.Model().Tokenize(part, i == 0, true)
if err != nil {
return nil, err
}
tokens, err := s.lc.Model().Tokenize(part, i == 0, true)
if err != nil {
return nil, err
}
for _, t := range tokens {
inputs = append(inputs, input{token: t})
}
for _, t := range tokens {
inputs = append(inputs, input{token: t})
}
// image - generate image embedding
@@ -212,41 +203,51 @@ func (s *Server) inputs(prompt string, images []ImageData) ([]input, error) {
}
type Server struct {
model *llama.Model
lc *llama.Context
// is the server ready to process requests?
// protects access to model and image
ready sync.WaitGroup
// required for image embeddings
// loaded model
model *llama.Model
// image model context for multi-modal models
image *ImageContext
// status for external health reporting - loading, ready to serve, etc.
status ServerStatus
// current progress on loading the model
progress float32
// number of simultaneous requests to handle
parallel int
// maximum number of elements in a batch (per sequence)
// TODO (jmorganca): make this n_batch
batchSize int
// parallel is the number of parallel requests to handle
parallel int
// protects access to everything below this line
// this is context state needed for decoding
mu sync.Mutex
// seqs is the list of parallel sequences being evaluated
// TODO (jmorganca): this can probably be moved into run()
// indicates that data is ready for processing
cond *sync.Cond
// decoding state
lc *llama.Context
// the list of simultaneous sequences being evaluated
seqs []*Sequence
// seqs can have a maximum of parallel entries, which
// is enfoced by seqSem
seqsSem *semaphore.Weighted
// KV cache
cache *InputCache
// does this model require a beginning of sequence token?
bosToken int
// next sequence for prompt processing to avoid starvation
nextSeq int
// is the server ready to process requests?
ready sync.WaitGroup
mu sync.Mutex
cond *sync.Cond
progress float32
status ServerStatus
}
func (s *Server) allNil() bool {
@@ -258,18 +259,6 @@ func (s *Server) allNil() bool {
return true
}
func (s *Server) shiftContext(seq *Sequence) {
numLeft := seq.numPast - seq.numKeep
numDiscard := numLeft / 2
slog.Debug("context limit hit - shifting", "limit", s.cache.numCtx, "numPast", seq.numPast,
"numKeep", seq.numKeep, "numLeft", numLeft, "numDiscard", numDiscard)
s.cache.ShiftCacheSlot(seq.cache, seq.numKeep, numDiscard, seq.numPast)
seq.numPast -= numDiscard
}
func flushPending(seq *Sequence) bool {
joined := strings.Join(seq.pendingResponses, "")
seq.pendingResponses = []string{}
@@ -368,18 +357,33 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
continue
}
// If an error occurred during the processing of a previous batch then we may have emptied the inputs
// without adding a new one. In this case, end the sequence rather than infinite looping.
if len(seq.inputs) == 0 {
slog.Error("removing sequence due to no input tokens", "index", seqIdx, "cache id", seq.cache.Id)
s.removeSequence(seqIdx, "error")
continue
}
// if past the num predict limit
if seq.numPredict > 0 && seq.numPredicted > seq.numPredict {
if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict {
s.removeSequence(seqIdx, "limit")
continue
}
if seq.numPast+len(seq.inputs) > s.cache.numCtx {
s.shiftContext(seq)
}
var numInputsProcessed int
shifted := false
for i, input := range seq.inputs {
if len(seq.cache.Inputs)+1 > s.cache.numCtx {
if !shifted {
s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
shifted = true
} else {
break
}
}
embedding := input.embed != nil
// If we don't currently have a batch, use one of the correct type and
@@ -403,13 +407,12 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
}
crossAttention = seq.crossAttention
batch.Add(input.token, input.embed, seq.numPast, numInputsProcessed+1 == len(seq.inputs), seq.cache.Id)
seq.numPast++
batch.Add(input.token, input.embed, len(seq.cache.Inputs), i+1 == len(seq.inputs), seq.cache.Id)
seq.cache.Inputs = append(seq.cache.Inputs, input)
numInputsProcessed++
}
if numInputsProcessed > 0 {
seq.cache.Inputs = append(seq.cache.Inputs, seq.inputs[:numInputsProcessed]...)
seq.inputs = seq.inputs[numInputsProcessed:]
seq.iBatch = batch.NumTokens() - 1
}
@@ -427,6 +430,13 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
return
}
if crossAttention {
// synchronize state to ensure the cross attention batch is complete.
// needed specifically for multi-GPU systems otherwise an inflight
// task may be incorrectly invalidated causing a crash
s.lc.Synchronize()
}
for i, seq := range s.seqs {
if seq == nil {
continue
@@ -627,12 +637,17 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
return
}
// TODO (jmorganca): add to sequence queue instead of
// failing if a slot isn't available
// Ensure that a place to put the sequence is available
if err := s.seqsSem.Acquire(r.Context(), 1); err != nil {
slog.Error("Failed to acquire semaphore", "error", err)
return
}
defer s.seqsSem.Release(1)
s.mu.Lock()
for i, sq := range s.seqs {
if sq == nil {
seq.cache, seq.inputs, seq.numPast, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt)
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt)
if err != nil {
s.mu.Unlock()
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
@@ -711,11 +726,17 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
return
}
// TODO (jessegross): Wait for a free slot instead of failing and blocking forever
// Ensure that a place to put the sequence is available
if err := s.seqsSem.Acquire(r.Context(), 1); err != nil {
slog.Error("Failed to acquire semaphore", "error", err)
return
}
defer s.seqsSem.Release(1)
s.mu.Lock()
for i, sq := range s.seqs {
if sq == nil {
seq.cache, seq.inputs, seq.numPast, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt)
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt)
if err != nil {
s.mu.Unlock()
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
@@ -802,10 +823,6 @@ func (s *Server) loadModel(
}
}
if s.model.AddBOSToken() {
s.bosToken = 1
}
if ppath != "" {
var err error
s.image, err = NewImageContext(s.lc, ppath)
@@ -814,7 +831,10 @@ func (s *Server) loadModel(
}
}
s.cache = NewInputCache(s.lc, kvSize, s.parallel, multiUserCache)
s.cache, err = NewInputCache(s.lc, kvSize, s.parallel, multiUserCache)
if err != nil {
panic(err)
}
s.status = ServerStatusReady
s.ready.Done()
@@ -867,6 +887,7 @@ func main() {
batchSize: *batchSize,
parallel: *parallel,
seqs: make([]*Sequence, *parallel),
seqsSem: semaphore.NewWeighted(int64(*parallel)),
status: ServerStatusLoadingModel,
}