diff --git a/CMakeLists.txt b/CMakeLists.txt index e57d9f65..875dc4ae 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -122,9 +122,11 @@ if(CMAKE_HIP_COMPILER) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-hip) if (WIN32) - target_compile_definitions(ggml-hip PRIVATE GGML_CUDA_NO_PEER_COPY=1) + target_compile_definitions(ggml-hip PRIVATE GGML_CUDA_NO_PEER_COPY) endif() + target_compile_definitions(ggml-hip PRIVATE GGML_HIP_NO_VMM) + set(OLLAMA_HIP_INSTALL_DIR ${OLLAMA_INSTALL_DIR}/rocm) install(TARGETS ggml-hip RUNTIME_DEPENDENCIES diff --git a/kvcache/cache.go b/kvcache/cache.go index 5d8b2f9b..2541f7c1 100644 --- a/kvcache/cache.go +++ b/kvcache/cache.go @@ -29,6 +29,17 @@ type Cache interface { // cache implementation used. Put(ctx ml.Context, key, value ml.Tensor) + // SetConfig controls optimizations (mostly backend-specific) that may transform + // the output of the cache to work better with specific kernels. If not called, + // the backend settings will be used. This works well when calling Attention. + // + // The config can be overridden by models, especially if they require vanilla + // output when implementing their own version of attention. To do this, pass + // an empty ml.CacheConfig. + // + // Most models will not need to use this. + SetConfig(ml.CacheConfig) + // ** cache management ** // Init sets up runtime parameters diff --git a/kvcache/causal.go b/kvcache/causal.go index 69068439..b2e7b3ab 100644 --- a/kvcache/causal.go +++ b/kvcache/causal.go @@ -22,6 +22,9 @@ type Causal struct { Capacity int32 windowSize int32 + // config controls mostly backend-specific optimizations + config *ml.CacheConfig + // ** current forward pass ** // the active layer for Get and Put @@ -75,14 +78,42 @@ func NewSWACache(windowSize int32, shift shiftFn) *Causal { } func (c *Causal) Init(backend ml.Backend, dtype ml.DType, capacity int32) { + if c.config == nil { + var config ml.CacheConfig + if cc, ok := backend.(ml.BackendCacheConfig); ok { + config = cc.CacheConfig() + } + c.config = &config + } + + if c.config.CachePadding == 0 { + c.config.CachePadding = 1 + } + + if c.config.MaskBatchPadding == 0 { + c.config.MaskBatchPadding = 1 + } + + if c.config.MaskDType == ml.DTypeOther { + c.config.MaskDType = ml.DTypeF32 + } + c.DType = dtype - c.Capacity = capacity - c.cells = make([]cacheCell, capacity) + c.Capacity = int32(roundUp(int(capacity), c.config.CachePadding)) + c.cells = make([]cacheCell, c.Capacity) c.cellRanges = make(map[int]cellRange) c.backend = backend c.cacheCtx = backend.NewContext() } +func (c *Causal) SetConfig(config ml.CacheConfig) { + if c.config != nil { + panic("config cannot be changed after being previously set, either by the model or backend") + } + + c.config = &config +} + func (c *Causal) Close() { c.cacheCtx.Close() } @@ -157,36 +188,91 @@ func (c *Causal) findStartLoc() (int, error) { return 0, fmt.Errorf("%w (length: %v)", ErrKvCacheFull, c.Capacity) } +func roundDown(length, pad int) int { + return (length / pad) * pad +} + +func roundUp(length, pad int) int { + return ((length + pad - 1) / pad) * pad +} + // Builds a mask of history x batch indicating whether for each token in the batch the // token in the history should apply. This is based on both the sequence and causality (the // position of the history is not ahead of the token in the batch). func (c *Causal) buildMask(ctx ml.Context, positions []int32, seqs []int) (ml.Tensor, error) { - // TODO(jessegross): This does not do padding, which is required for flash attention - len := c.curCellRange.max - c.curCellRange.min + 1 - mask := make([]float32, c.curBatchSize*len) + // Align and pad the two dimensions as required by the backend + batchSize := roundUp(c.curBatchSize, c.config.MaskBatchPadding) + + c.curCellRange.min = roundDown(c.curCellRange.min, c.config.CachePadding) + c.curCellRange.max = roundUp(c.curCellRange.max+1, c.config.CachePadding) - 1 + + length := c.curCellRange.max - c.curCellRange.min + 1 + mask := make([]float32, batchSize*length) for i := range c.curBatchSize { for j := c.curCellRange.min; j <= c.curCellRange.max; j++ { if !slices.Contains(c.cells[j].sequences, seqs[i]) || c.cells[j].pos > positions[i] || c.cells[j].pos < positions[i]-c.windowSize { - mask[i*len+(j-c.curCellRange.min)] = float32(math.Inf(-1)) + mask[i*length+(j-c.curCellRange.min)] = float32(math.Inf(-1)) } } } - return ctx.FromFloatSlice(mask, len, c.curBatchSize) + // Mask out any padding tokens we added. For padding that we added to the cache history, this + // has already been masked out because the sequence doesn't match. + for i := c.curBatchSize * length; i < len(mask); i++ { + mask[i] = float32(math.Inf(-1)) + } + + maskTensor, err := ctx.FromFloatSlice(mask, length, batchSize) + if err != nil { + return nil, err + } + + if c.config.MaskDType != ml.DTypeF32 { + out := ctx.Empty(c.config.MaskDType, maskTensor.Shape()...) + ctx.Forward(maskTensor.Copy(ctx, out)) + maskTensor = out + } + + return maskTensor, nil } -func moveCell(ctx ml.Context, objs []ml.Tensor, src, dst, len int) { - for _, obj := range objs { - if obj == nil { +func (c *Causal) moveCells(ctx ml.Context, src, dst, len int) { + for i := range c.keys { + if c.keys[i] == nil { continue } - srcView := obj.View(ctx, obj.Stride(2)*src, obj.Dim(0)*obj.Dim(1)*len) - dstView := obj.View(ctx, obj.Stride(2)*dst, obj.Dim(0)*obj.Dim(1)*len) + key := c.keys[i] - ctx.Forward(srcView.Copy(ctx, dstView)) + kHeadDim := key.Dim(0) + numKVHeads := key.Dim(1) + rowSize := key.Stride(2) + + kSrcView := key.View(ctx, rowSize*src, kHeadDim*numKVHeads*len) + kDstView := key.View(ctx, rowSize*dst, kHeadDim*numKVHeads*len) + + value := c.values[i] + var vSrcView, vDstView ml.Tensor + if c.config.PermutedV { + vHeadDim := value.Dim(1) + elemSize := value.Stride(0) + + vSrcView = value.View(ctx, elemSize*src, len, int(c.Capacity)*elemSize, vHeadDim*numKVHeads) + vDstView = value.View(ctx, elemSize*dst, len, int(c.Capacity)*elemSize, vHeadDim*numKVHeads) + } else { + vHeadDim := value.Dim(0) + rowSize := value.Stride(2) + + vSrcView = value.View(ctx, rowSize*src, vHeadDim*numKVHeads*len) + vDstView = value.View(ctx, rowSize*dst, vHeadDim*numKVHeads*len) + } + + ctx.Forward( + kSrcView.Copy(ctx, kDstView), + vSrcView.Copy(ctx, vDstView), + ) } } @@ -238,8 +324,7 @@ func (c *Causal) defrag() { pendingLen++ break } else { - moveCell(ctx, c.keys, pendingSrc, pendingDst, pendingLen) - moveCell(ctx, c.values, pendingSrc, pendingDst, pendingLen) + c.moveCells(ctx, pendingSrc, pendingDst, pendingLen) moves++ } } @@ -263,8 +348,7 @@ func (c *Causal) defrag() { } if pendingLen > 0 { - moveCell(ctx, c.keys, pendingSrc, pendingDst, pendingLen) - moveCell(ctx, c.values, pendingSrc, pendingDst, pendingLen) + c.moveCells(ctx, pendingSrc, pendingDst, pendingLen) moves++ } @@ -305,35 +389,73 @@ func (c *Causal) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) { key := c.keys[c.curLayer] value := c.values[c.curLayer] - key = key.View(ctx, key.Stride(2)*c.curCellRange.min, - key.Dim(0), key.Stride(1), - key.Dim(1), key.Stride(2), - c.curMask.Dim(0), + kHeadDim := key.Dim(0) + numKVHeads := key.Dim(1) + rowSize := key.Stride(2) + cachedSize := c.curMask.Dim(0) + + key = key.View(ctx, rowSize*c.curCellRange.min, + kHeadDim, key.Stride(1), + numKVHeads, key.Stride(2), + cachedSize, ) - value = value.View(ctx, key.Stride(2)*c.curCellRange.min, - value.Dim(0), value.Stride(1), - value.Dim(1), value.Stride(2), - c.curMask.Dim(0), - ) + if c.config.PermutedV { + vHeadDim := value.Dim(1) + elemSize := value.Stride(0) + + value = value.View(ctx, elemSize*c.curCellRange.min, + cachedSize, value.Stride(1), + vHeadDim, value.Stride(2), + numKVHeads, + ) + } else { + vHeadDim := value.Dim(0) + rowSize := value.Stride(2) + + value = value.View(ctx, rowSize*c.curCellRange.min, + vHeadDim, value.Stride(1), + numKVHeads, value.Stride(2), + cachedSize, + ) + } return key, value, c.curMask } func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) { - if c.curBatchSize != key.Dim(2) { - panic(fmt.Errorf("inconsistent batch sizes (layer: %v, batch size: %v layer batch size: %v)", c.curLayer, c.curBatchSize, key.Dim(2))) + kHeadDim := key.Dim(0) + vHeadDim := value.Dim(0) + numKVHeads := key.Dim(1) + batchSize := key.Dim(2) + + if c.curBatchSize != batchSize { + panic(fmt.Errorf("inconsistent batch sizes (layer: %v, batch size: %v layer batch size: %v)", c.curLayer, c.curBatchSize, batchSize)) } if c.keys[c.curLayer] == nil || c.values[c.curLayer] == nil { - c.keys[c.curLayer] = c.cacheCtx.Zeros(c.DType, key.Dim(0), key.Dim(1), int(c.Capacity)) - c.values[c.curLayer] = c.cacheCtx.Zeros(c.DType, value.Dim(0), value.Dim(1), int(c.Capacity)) + c.keys[c.curLayer] = c.cacheCtx.Zeros(c.DType, kHeadDim, numKVHeads, int(c.Capacity)) + + if c.config.PermutedV { + c.values[c.curLayer] = c.cacheCtx.Zeros(c.DType, int(c.Capacity), vHeadDim, numKVHeads) + } else { + c.values[c.curLayer] = c.cacheCtx.Zeros(c.DType, vHeadDim, numKVHeads, int(c.Capacity)) + } } - ctx.Forward( - key.Copy(ctx, c.keys[c.curLayer].View(ctx, c.keys[c.curLayer].Stride(2)*c.curLoc, key.Dim(0)*key.Dim(1)*key.Dim(2))), - value.Copy(ctx, c.values[c.curLayer].View(ctx, c.values[c.curLayer].Stride(2)*c.curLoc, value.Dim(0)*value.Dim(1)*value.Dim(2))), - ) + rowSize := c.keys[c.curLayer].Stride(2) + ctx.Forward(key.Copy(ctx, c.keys[c.curLayer].View(ctx, rowSize*c.curLoc, kHeadDim*numKVHeads*batchSize))) + + if c.config.PermutedV { + elemSize := c.values[c.curLayer].Stride(0) + + value = value.Permute(ctx, 1, 2, 0, 3) + ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, elemSize*c.curLoc, batchSize, int(c.Capacity)*elemSize, vHeadDim*numKVHeads))) + } else { + rowSize := c.values[c.curLayer].Stride(2) + + ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, rowSize*c.curLoc, vHeadDim*numKVHeads*batchSize))) + } } func (c *Causal) CopyPrefix(srcSeq, dstSeq int, len int32) { @@ -389,9 +511,13 @@ func (c *Causal) shift(seq int, beginIndex, offset int32) error { continue } - key = key.View(ctx, key.Stride(2)*seqRange.min, - key.Dim(0), key.Stride(1), - key.Dim(1), key.Stride(2), + kHeadDim := key.Dim(0) + numKVHeads := key.Dim(1) + rowSize := key.Stride(2) + + key = key.View(ctx, rowSize*seqRange.min, + kHeadDim, key.Stride(1), + numKVHeads, key.Stride(2), size, ) diff --git a/kvcache/causal_test.go b/kvcache/causal_test.go index bd7d0ae8..84d8de54 100644 --- a/kvcache/causal_test.go +++ b/kvcache/causal_test.go @@ -309,7 +309,7 @@ func (b *testBackend) SystemInfo() string { type testContext struct{} -func (c *testContext) Zeros(dtype ml.DType, shape ...int) ml.Tensor { +func (c *testContext) Empty(dtype ml.DType, shape ...int) ml.Tensor { total := 0 if len(shape) > 0 { @@ -322,8 +322,12 @@ func (c *testContext) Zeros(dtype ml.DType, shape ...int) ml.Tensor { return &testTensor{dtype: dtype, elementSize: 4, data: make([]float32, total), shape: shape} } +func (c *testContext) Zeros(dtype ml.DType, shape ...int) ml.Tensor { + return c.Empty(dtype, shape...) +} + func (c *testContext) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) { - t := c.Zeros(ml.DTypeF32, shape...).(*testTensor) + t := c.Empty(ml.DTypeF32, shape...).(*testTensor) copy(t.data, s) @@ -391,7 +395,7 @@ func (t *testTensor) Floats() []float32 { } func (t *testTensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor { - out := ctx.Zeros(t.DType(), t.Shape()...).(*testTensor) + out := ctx.Empty(t.DType(), t.Shape()...).(*testTensor) for i := range out.data { out.data[i] = t.data[i] + t2.(*testTensor).data[i] @@ -468,7 +472,7 @@ func (t *testTensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor { context := &testContext{} - view := context.Zeros(t.dtype, s...).(*testTensor) + view := context.Empty(t.dtype, s...).(*testTensor) view.data = t.data[offset : offset+len(view.data)] return view diff --git a/kvcache/encoder.go b/kvcache/encoder.go index b85b1046..39b4cdfb 100644 --- a/kvcache/encoder.go +++ b/kvcache/encoder.go @@ -1,6 +1,8 @@ package kvcache import ( + "fmt" + "github.com/ollama/ollama/ml" ) @@ -11,6 +13,9 @@ import ( // // Not currently safe for multiple sequences type EncoderCache struct { + // config controls mostly backend-specific optimizations + config *ml.CacheConfig + // ** current forward pass ** // the active layer for Get and Put @@ -40,9 +45,29 @@ func NewEncoderCache() *EncoderCache { } func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, capacity int32) { + if c.config == nil { + var config ml.CacheConfig + if cc, ok := backend.(ml.BackendCacheConfig); ok { + config = cc.CacheConfig() + } + c.config = &config + } + + if c.config.CachePadding != 0 && c.config.CachePadding != 1 { + panic(fmt.Errorf("encoder cache is unable to enforce requested CachePadding (%v)", c.config.CachePadding)) + } + c.cacheCtx = backend.NewContext() } +func (c *EncoderCache) SetConfig(config ml.CacheConfig) { + if c.config != nil { + panic("config cannot be changed after being previously set, either by the model or backend") + } + + c.config = &config +} + func (c *EncoderCache) Close() { c.cacheCtx.Close() } @@ -75,9 +100,13 @@ func (c *EncoderCache) Put(ctx ml.Context, key, value ml.Tensor) { c.encoderPos = c.curPos c.encoderCached = true + if c.config.PermutedV { + value = value.Permute(ctx, 1, 2, 0, 3) + } + if c.keys[c.curLayer] == nil || c.values[c.curLayer] == nil { - c.keys[c.curLayer] = c.cacheCtx.Zeros(key.DType(), key.Shape()...) - c.values[c.curLayer] = c.cacheCtx.Zeros(value.DType(), value.Shape()...) + c.keys[c.curLayer] = c.cacheCtx.Empty(key.DType(), key.Shape()...) + c.values[c.curLayer] = c.cacheCtx.Empty(value.DType(), value.Shape()...) } ctx.Forward( diff --git a/kvcache/wrapper.go b/kvcache/wrapper.go index 2d4c1089..76956a88 100644 --- a/kvcache/wrapper.go +++ b/kvcache/wrapper.go @@ -28,6 +28,12 @@ func (c *WrapperCache) Init(backend ml.Backend, dtype ml.DType, capacity int32) } } +func (c *WrapperCache) SetConfig(config ml.CacheConfig) { + for _, cache := range c.caches { + cache.SetConfig(config) + } +} + func (c *WrapperCache) Close() { for _, cache := range c.caches { cache.Close() diff --git a/ml/backend.go b/ml/backend.go index 07bc75b6..83b7a8c9 100644 --- a/ml/backend.go +++ b/ml/backend.go @@ -27,6 +27,35 @@ type Backend interface { SystemInfo() string } +// BackendCacheConfig should be implemented by backends that need special output +// from the cache to meet specific requirements. It is frequently implemented in +// conjunction with ScaledDotProductAttention. +type BackendCacheConfig interface { + CacheConfig() CacheConfig +} + +// CacheConfig controls optimizations (mostly backend-specific) that may transform +// the output the cache to work better with specific kernels. +type CacheConfig struct { + // CachePadding specifies the multiple for the number of tokens of cache history + // that will be returned from cache Get for k, v and mask. The capacity of the + // cache itself will also be increased to a multiple of this size if needed. + CachePadding int + + // PermutedV performs Permute(ctx, 1, 2, 0, 3) on v tensors stored via Put + // and return the permuted version via Get. This uses the cache copy operation + // to avoid a Contiguous call on the permuted tensor. + PermutedV bool + + // MaskDType specifies the data type for generating the mask. If unset it will + // default to DTypeF32. + MaskDType DType + + // MaskBatchPadding specifies the multiple for the batch size dimension in the mask. + // Any position that does not correspond to an actual token will be filled with -Inf. + MaskBatchPadding int +} + // BackendParams controls how the backend loads and executes models type BackendParams struct { // NumThreads sets the number of threads to use if running on the CPU @@ -40,6 +69,9 @@ type BackendParams struct { // TensorSplit is the fraction of the model to offload to each GPU TensorSplit []float32 + + // FlashAttention indicates that we should use a fused flash attention kernel + FlashAttention bool } var backends = make(map[string]func(*os.File, BackendParams) (Backend, error)) @@ -61,6 +93,7 @@ func NewBackend(f *os.File, params BackendParams) (Backend, error) { } type Context interface { + Empty(dtype DType, shape ...int) Tensor Zeros(dtype DType, shape ...int) Tensor FromFloatSlice(s []float32, shape ...int) (Tensor, error) FromIntSlice(s []int32, shape ...int) (Tensor, error) @@ -116,6 +149,10 @@ type Tensor interface { // operation equivalent to following code on a tensor named // query: // +// query = query.Permute(ctx, 0, 2, 1, 3) +// key = key.Permute(ctx, 0, 2, 1, 3) +// value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx) +// // kq := key.MulmatFullPrec(ctx, query) // // kq = kq.Scale(ctx, scale) @@ -170,7 +207,7 @@ func Dump(ctx Context, t Tensor, opts ...DumpOptions) string { return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32) }) case DTypeF16: - f32 := ctx.Zeros(DTypeF32, t.Shape()...) + f32 := ctx.Empty(DTypeF32, t.Shape()...) f32 = t.Copy(ctx, f32) return dump[[]float32](ctx, f32, opts[0].Items, func(f float32) string { return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32) diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index 7f91990c..f4948fca 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -79,6 +79,8 @@ var devices = sync.OnceValue(func() []device { }) type Backend struct { + flashAttention bool + meta *fs.GGML cpus, gpus []Context tensors map[string]*Context @@ -192,9 +194,10 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) { } return &Backend{ - meta: meta, - cpus: cpus, - gpus: gpus, + flashAttention: params.FlashAttention, + meta: meta, + cpus: cpus, + gpus: gpus, sched: C.ggml_backend_sched_new( (*C.ggml_backend_t)(unsafe.Pointer(&backends[0])), (*C.ggml_backend_buffer_type_t)(unsafe.Pointer(&bufts[0])), @@ -219,7 +222,7 @@ func (b *Backend) Get(name string) ml.Tensor { for _, c := range append(b.gpus, b.cpus...) { if t := C.ggml_get_tensor(c.ctx, cname); t != nil { - return &Tensor{t: t} + return &Tensor{b: b, t: t} } } @@ -247,6 +250,14 @@ func (b *Backend) NewContext() ml.Context { } } +func (b *Backend) CacheConfig() ml.CacheConfig { + if b.flashAttention { + return ml.CacheConfig{CachePadding: 256, MaskDType: ml.DTypeF16, MaskBatchPadding: C.GGML_KQ_MASK_PAD} + } else { + return ml.CacheConfig{CachePadding: 32, PermutedV: true} + } +} + type Context struct { b *Backend ctx *C.struct_ggml_context @@ -300,7 +311,7 @@ func shapeToGGML(shape []int) *C.int64_t { return &sh[0] } -func (c Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor { +func newTensor(ctx Context, dtype ml.DType, zero bool, shape []int) ml.Tensor { if len(shape) < 1 || len(shape) > 4 { panic("unsupported number of dimensions") } @@ -314,19 +325,29 @@ func (c Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor { var t *C.struct_ggml_tensor switch dtype { case ml.DTypeF32: - t = C.ggml_new_tensor(c.ctx, C.GGML_TYPE_F32, C.int(len(shape)), shapeToGGML(shape)) + t = C.ggml_new_tensor(ctx.ctx, C.GGML_TYPE_F32, C.int(len(shape)), shapeToGGML(shape)) case ml.DTypeF16: - t = C.ggml_new_tensor(c.ctx, C.GGML_TYPE_F16, C.int(len(shape)), shapeToGGML(shape)) + t = C.ggml_new_tensor(ctx.ctx, C.GGML_TYPE_F16, C.int(len(shape)), shapeToGGML(shape)) case ml.DTypeI32: - t = C.ggml_new_tensor(c.ctx, C.GGML_TYPE_I32, C.int(len(shape)), shapeToGGML(shape)) + t = C.ggml_new_tensor(ctx.ctx, C.GGML_TYPE_I32, C.int(len(shape)), shapeToGGML(shape)) default: panic("unsupported dtype") } - b := C.ggml_backend_alloc_buffer(c.backend, C.ggml_nbytes(t)) + b := C.ggml_backend_alloc_buffer(ctx.backend, C.ggml_nbytes(t)) C.ggml_backend_tensor_alloc(b, t, C.ggml_backend_buffer_get_base(b)) - C.ggml_set_zero(t) - return &Tensor{t: t} + if zero { + C.ggml_set_zero(t) + } + return &Tensor{b: ctx.b, t: t} +} + +func (c Context) Empty(dtype ml.DType, shape ...int) ml.Tensor { + return newTensor(c, dtype, false, shape) +} + +func (c Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor { + return newTensor(c, dtype, true, shape) } func fromSlice[S ~[]E, E float32 | int32](ctx Context, s S, shape []int, dtype uint32) (ml.Tensor, error) { @@ -335,7 +356,7 @@ func fromSlice[S ~[]E, E float32 | int32](ctx Context, s S, shape []int, dtype u if n == 0 { var shape C.int64_t = 0 t := C.ggml_new_tensor(ctx.ctx, dtype, 1, &shape) - return &Tensor{t: t}, nil + return &Tensor{b: ctx.b, t: t}, nil } for _, v := range shape { @@ -350,7 +371,7 @@ func fromSlice[S ~[]E, E float32 | int32](ctx Context, s S, shape []int, dtype u b := C.ggml_backend_alloc_buffer(ctx.backend, C.ggml_nbytes(t)) C.ggml_backend_tensor_alloc(b, t, C.ggml_backend_buffer_get_base(b)) C.ggml_backend_tensor_set(t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t)) - return &Tensor{t: t}, nil + return &Tensor{b: ctx.b, t: t}, nil } func (c Context) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) { @@ -368,6 +389,7 @@ func (c *Context) Close() { } type Tensor struct { + b *Backend t *C.struct_ggml_tensor sync func() } @@ -434,6 +456,7 @@ func (t *Tensor) DType() ml.DType { func (t *Tensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor { return &Tensor{ + b: t.b, t: C.ggml_add(ctx.(*Context).ctx, t.t, t2.(*Tensor).t), } } @@ -448,24 +471,28 @@ func (t *Tensor) Stack(ctx ml.Context, dim int, s ...ml.Tensor) ml.Tensor { func (t *Tensor) Concat(ctx ml.Context, t2 ml.Tensor, dim int) ml.Tensor { return &Tensor{ + b: t.b, t: C.ggml_concat(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.int(dim)), } } func (t *Tensor) Contiguous(ctx ml.Context) ml.Tensor { return &Tensor{ + b: t.b, t: C.ggml_cont(ctx.(*Context).ctx, t.t), } } func (t *Tensor) Mul(ctx ml.Context, t2 ml.Tensor) ml.Tensor { return &Tensor{ + b: t.b, t: C.ggml_mul(ctx.(*Context).ctx, t.t, t2.(*Tensor).t), } } func (t *Tensor) Mulmat(ctx ml.Context, t2 ml.Tensor) ml.Tensor { return &Tensor{ + b: t.b, t: C.ggml_mul_mat(ctx.(*Context).ctx, t.t, t2.(*Tensor).t), } } @@ -475,12 +502,13 @@ func (t *Tensor) MulmatFullPrec(ctx ml.Context, t2 ml.Tensor) ml.Tensor { C.ggml_mul_mat_set_prec(mul, C.GGML_PREC_F32) return &Tensor{ + b: t.b, t: mul, } } func (t *Tensor) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tensor { - tt := (&Tensor{t: C.ggml_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w) + tt := (&Tensor{b: t.b, t: C.ggml_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w) if b != nil { tt = tt.Add(ctx, b) } @@ -489,7 +517,7 @@ func (t *Tensor) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tenso } func (t *Tensor) RMSNorm(ctx ml.Context, w ml.Tensor, eps float32) ml.Tensor { - return (&Tensor{t: C.ggml_rms_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w) + return (&Tensor{b: t.b, t: C.ggml_rms_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w) } func (t *Tensor) Pad(ctx ml.Context, shape ...int) ml.Tensor { @@ -498,6 +526,7 @@ func (t *Tensor) Pad(ctx ml.Context, shape ...int) ml.Tensor { } return &Tensor{ + b: t.b, t: C.ggml_pad(ctx.(*Context).ctx, t.t, C.int(shape[0]), C.int(shape[1]), C.int(shape[2]), C.int(shape[3])), } } @@ -508,18 +537,21 @@ func (t *Tensor) Permute(ctx ml.Context, shape ...int) ml.Tensor { } return &Tensor{ + b: t.b, t: C.ggml_permute(ctx.(*Context).ctx, t.t, C.int(shape[0]), C.int(shape[1]), C.int(shape[2]), C.int(shape[3])), } } func (t *Tensor) Rows(ctx ml.Context, t2 ml.Tensor) ml.Tensor { return &Tensor{ + b: t.b, t: C.ggml_get_rows(ctx.(*Context).ctx, t.t, t2.(*Tensor).t), } } func (t *Tensor) Copy(ctx ml.Context, t2 ml.Tensor) ml.Tensor { return &Tensor{ + b: t.b, t: C.ggml_cpy(ctx.(*Context).ctx, t.t, t2.(*Tensor).t), } } @@ -528,18 +560,22 @@ func (t *Tensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor { switch len(shape) { case 1: return &Tensor{ + b: t.b, t: C.ggml_reshape_1d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0])), } case 2: return &Tensor{ + b: t.b, t: C.ggml_reshape_2d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1])), } case 3: return &Tensor{ + b: t.b, t: C.ggml_reshape_3d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1]), C.int64_t(shape[2])), } case 4: return &Tensor{ + b: t.b, t: C.ggml_reshape_4d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1]), C.int64_t(shape[2]), C.int64_t(shape[3])), } default: @@ -549,18 +585,21 @@ func (t *Tensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor { func (t *Tensor) Scale(ctx ml.Context, s float64) ml.Tensor { return &Tensor{ + b: t.b, t: C.ggml_scale(ctx.(*Context).ctx, t.t, (C.float)(s)), } } func (t *Tensor) Softmax(ctx ml.Context) ml.Tensor { return &Tensor{ + b: t.b, t: C.ggml_soft_max(ctx.(*Context).ctx, t.t), } } func (t *Tensor) Tanh(ctx ml.Context) ml.Tensor { return &Tensor{ + b: t.b, t: C.ggml_tanh_inplace(ctx.(*Context).ctx, t.t), } } @@ -571,6 +610,7 @@ func (t *Tensor) Unpad(ctx ml.Context, shape ...int) ml.Tensor { } return &Tensor{ + b: t.b, t: C.ggml_unpad(ctx.(*Context).ctx, t.t, C.int(shape[0]), C.int(shape[1]), C.int(shape[2]), C.int(shape[3])), } } @@ -579,10 +619,12 @@ func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor { switch len(shape) { case 1: return &Tensor{ + b: t.b, t: C.ggml_view_1d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.size_t(offset)), } case 3: return &Tensor{ + b: t.b, t: C.ggml_view_2d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[2]), C.size_t(shape[1]), @@ -590,6 +632,7 @@ func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor { } case 5: return &Tensor{ + b: t.b, t: C.ggml_view_3d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[2]), C.int64_t(shape[4]), C.size_t(shape[1]), C.size_t(shape[3]), @@ -597,6 +640,7 @@ func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor { } case 7: return &Tensor{ + b: t.b, t: C.ggml_view_4d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[2]), C.int64_t(shape[4]), C.int64_t(shape[6]), C.size_t(shape[1]), C.size_t(shape[3]), C.size_t(shape[5]), @@ -613,7 +657,7 @@ const ( func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim uint32, ropeBase, ropeScale float32) ml.Tensor { if ropeFactors == nil { - ropeFactors = &Tensor{} + ropeFactors = &Tensor{b: t.b} } dequant := t.t @@ -622,6 +666,7 @@ func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDi } return &Tensor{ + b: t.b, t: C.ggml_rope_ext( ctx.(*Context).ctx, dequant, positionIDs.(*Tensor).t, ropeFactors.(*Tensor).t, C.int(ropeDim), @@ -639,18 +684,21 @@ func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDi func (t *Tensor) GELU(ctx ml.Context) ml.Tensor { return &Tensor{ + b: t.b, t: C.ggml_gelu_inplace(ctx.(*Context).ctx, t.t), } } func (t *Tensor) SILU(ctx ml.Context) ml.Tensor { return &Tensor{ + b: t.b, t: C.ggml_silu_inplace(ctx.(*Context).ctx, t.t), } } func (t *Tensor) Conv2D(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor { return &Tensor{ + b: t.b, t: C.ggml_conv_2d(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.int(s0), C.int(s1), C.int(p0), C.int(p1), C.int(d0), C.int(d1)), } } @@ -661,13 +709,25 @@ func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask ml.T kqMask = mask.(*Tensor).t } - kq := key.MulmatFullPrec(ctx, t) - kq = &Tensor{ - t: C.ggml_soft_max_ext(ctx.(*Context).ctx, kq.(*Tensor).t, kqMask, C.float(scale), 0), - } + query := t.Permute(ctx, 0, 2, 1, 3) + key = key.Permute(ctx, 0, 2, 1, 3) - kqv := value.Mulmat(ctx, kq) - return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) + if t.b.flashAttention { + value = value.Permute(ctx, 0, 2, 1, 3) + + kqv := C.ggml_flash_attn_ext(ctx.(*Context).ctx, query.(*Tensor).t, key.(*Tensor).t, value.(*Tensor).t, kqMask, C.float(scale), 0, 0) + C.ggml_flash_attn_ext_set_prec(kqv, C.GGML_PREC_F32) + return &Tensor{b: t.b, t: kqv} + } else { + kq := key.MulmatFullPrec(ctx, query) + kq = &Tensor{ + b: t.b, + t: C.ggml_soft_max_ext(ctx.(*Context).ctx, kq.(*Tensor).t, kqMask, C.float(scale), 0), + } + + kqv := value.Mulmat(ctx, kq) + return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) + } } func (b *Backend) SystemInfo() string { diff --git a/ml/nn/attention.go b/ml/nn/attention.go index 4f0c9fa1..a3f43a1e 100644 --- a/ml/nn/attention.go +++ b/ml/nn/attention.go @@ -3,6 +3,7 @@ package nn import ( "fmt" + "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" ) @@ -11,40 +12,50 @@ import ( // // Parameters: // - ctx: Context for tensor operations -// - query: Query tensor (Q) with shape [d_k, seq_len_q, heads] -// - key: Key tensor (K) with shape [d_k, seq_len_k, kv_heads] -// - value: Value tensor (V) with shape [seq_len_k, d_v, kv_heads] -// - mask: Optional attention mask that is added to the attention score. If -// provided, should broadcast to [seq_len_k, seq_len_q, heads] +// - query: Query tensor (Q) with shape [d_k, heads, seq_len_q] +// - key: Key tensor (K) with shape [d_k, kv_heads, seq_len_k], can be nil to read from cache only +// - value: Value tensor (V) with shape [d_v, kv_heads, seq_len_k], can be nil to read from cache only // - scale: Scaling factor, typically 1/√d_k where d_k is the key dimension +// - cache: KV cache to store key/value and get past history, can be nil to only use provided key/value // // Returns: // // Attention output with shape [d_v, heads, seq_len_q] -func Attention(ctx ml.Context, query, key, value, mask ml.Tensor, scale float64) ml.Tensor { - if query.Dim(0) != key.Dim(0) { - panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0))) +func Attention(ctx ml.Context, query, key, value ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor { + if key != nil && value != nil { + if query.Dim(0) != key.Dim(0) { + panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0))) + } + + if key.Dim(1) != value.Dim(1) { + panic(fmt.Errorf("kv_heads in attention operation does not match between key(%v) and value(%v)", key.Dim(1), value.Dim(1))) + } + + if key.Dim(2) != value.Dim(2) { + panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and value(%v)", key.Dim(2), value.Dim(2))) + } + + if cache != nil { + cache.Put(ctx, key, value) + } + } else if cache == nil { + panic("key & value tensors must be provided if cache is nil") } - if mask != nil && query.Dim(1) != mask.Dim(1) { - panic(fmt.Errorf("seq_len_q in attention operation does not match between query(%v) and mask(%v)", query.Dim(1), mask.Dim(1))) + var mask ml.Tensor + if cache != nil { + key, value, mask = cache.Get(ctx) } - if key.Dim(1) != value.Dim(0) { - panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and value(%v)", key.Dim(1), value.Dim(0))) - } - - if mask != nil && key.Dim(1) != mask.Dim(0) { - panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and mask(%v)", key.Dim(1), mask.Dim(0))) - } - - if key.Dim(2) != value.Dim(2) { - panic(fmt.Errorf("kv_heads in attention operation does not match between key(%v) and value(%v)", key.Dim(2), value.Dim(2))) - } - - if sdpa, ok := query.(ml.ScaledDotProductAttention); ok { + // Only use the fast SDPA implementation if we have a cache, since that's what + // will do any expected backend-specific transformations for us + if sdpa, ok := query.(ml.ScaledDotProductAttention); ok && cache != nil { return sdpa.ScaledDotProductAttention(ctx, key, value, mask, scale) } else { + query = query.Permute(ctx, 0, 2, 1, 3) + key = key.Permute(ctx, 0, 2, 1, 3) + value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx) + kq := key.MulmatFullPrec(ctx, query) kq = kq.Scale(ctx, scale) diff --git a/model/models/llama/model.go b/model/models/llama/model.go index 6106af86..9bf6f497 100644 --- a/model/models/llama/model.go +++ b/model/models/llama/model.go @@ -81,15 +81,8 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten v := sa.Value.Forward(ctx, hiddenState) v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize) - cache.Put(ctx, k, v) - k, v, mask := cache.Get(ctx) - - q = q.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) - k = k.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) - v = v.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx) - scaleFactor := 1.0 / math.Sqrt(float64(headDim)) - kqv := nn.Attention(ctx, q, k, v, mask, scaleFactor) + kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache) kqv = kqv.Reshape(ctx, opts.hiddenSize, batchSize) return sa.Output.Forward(ctx, kqv) diff --git a/model/models/mllama/model.go b/model/models/mllama/model.go index 9b35a262..743f4c32 100644 --- a/model/models/mllama/model.go +++ b/model/models/mllama/model.go @@ -43,7 +43,9 @@ func New(c ml.Config) (model.Model, error) { TextModel: newTextModel(c), } - m.Cache = kvcache.NewWrapperCache(kvcache.NewEncoderCache(), kvcache.NewCausalCache(m.TextModel.Shift)) + encoderCache := kvcache.NewEncoderCache() + encoderCache.SetConfig(ml.CacheConfig{}) + m.Cache = kvcache.NewWrapperCache(encoderCache, kvcache.NewCausalCache(m.TextModel.Shift)) return &m, nil } diff --git a/model/models/mllama/model_text.go b/model/models/mllama/model_text.go index 003bf9cb..e294b4c7 100644 --- a/model/models/mllama/model_text.go +++ b/model/models/mllama/model_text.go @@ -31,22 +31,15 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ m value := sa.Value.Forward(ctx, hiddenState) value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize) - cache.Put(ctx, key, value) - key, value, mask := cache.Get(ctx) - - query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) - key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) - value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx) - scaleFactor := 1.0 / math.Sqrt(float64(headDim)) - attention := nn.Attention(ctx, query, key, value, mask, scaleFactor) + attention := nn.Attention(ctx, query, key, value, scaleFactor, cache) attention = attention.Reshape(ctx, opts.hiddenSize, batchSize) return sa.Output.Forward(ctx, attention) } func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { - // This will only get called for layers in the cache, which are just the self attention layers + // This will only get called for layers in the causal cache, which are just the self attention layers return key.RoPE(ctx, shift, m.RopeFactors, m.ropeDim, m.ropeBase, m.ropeScale), nil } @@ -107,7 +100,7 @@ func (ca *TextCrossAttention) Forward(ctx ml.Context, hiddenState, crossAttentio query = query.Reshape(ctx, headDim, opts.numHeads, batchSize) query = ca.QueryNorm.Forward(ctx, query, opts.eps) - var key, value, mask ml.Tensor + var key, value ml.Tensor if crossAttentionStates != nil { numVisionTokens, numTiles := crossAttentionStates.Dim(1), crossAttentionStates.Dim(2) @@ -119,16 +112,23 @@ func (ca *TextCrossAttention) Forward(ctx ml.Context, hiddenState, crossAttentio value = value.Reshape(ctx, headDim, opts.numKVHeads, numVisionTokens*numTiles) cache.Put(ctx, key, value) - } else { - key, value, mask = cache.Get(ctx) } - query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) - key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) - value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx) + key, value, _ = cache.Get(ctx) scaleFactor := 1.0 / math.Sqrt(float64(headDim)) - attention := nn.Attention(ctx, query, key, value, mask, scaleFactor) + + query = query.Permute(ctx, 0, 2, 1, 3) + key = key.Permute(ctx, 0, 2, 1, 3) + value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx) + + kq := key.MulmatFullPrec(ctx, query) + + kq = kq.Scale(ctx, scaleFactor) + kq = kq.Softmax(ctx) + + kqv := value.Mulmat(ctx, kq) + attention := kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) attention = attention.Reshape(ctx, opts.hiddenSize, batchSize) return ca.Output.Forward(ctx, attention) diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index db9b271e..5705931a 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -818,7 +818,7 @@ func Execute(args []string) error { batchSize := fs.Int("batch-size", 512, "Batch size") numGPULayers := fs.Int("n-gpu-layers", 0, "Number of layers to offload to GPU") mainGPU := fs.Int("main-gpu", 0, "Main GPU") - _ = fs.Bool("flash-attn", false, "Enable flash attention") + flashAttention := fs.Bool("flash-attn", false, "Enable flash attention") kvSize := fs.Int("ctx-size", 2048, "Context (or KV cache) size") kvCacheType := fs.String("kv-cache-type", "", "quantization type for KV cache (default: f16)") port := fs.Int("port", 8080, "Port to expose the server on") @@ -863,7 +863,6 @@ func Execute(args []string) error { } // TODO(jessegross): Parameters that need to be implemented: - // flash-attn // no-mmap // mlock @@ -878,10 +877,11 @@ func Execute(args []string) error { } params := ml.BackendParams{ - NumThreads: *threads, - NumGPULayers: *numGPULayers, - MainGPU: *mainGPU, - TensorSplit: tensorSplitFloats, + NumThreads: *threads, + NumGPULayers: *numGPULayers, + MainGPU: *mainGPU, + TensorSplit: tensorSplitFloats, + FlashAttention: *flashAttention, } server.ready.Add(1) diff --git a/server/internal/client/ollama/registry.go b/server/internal/client/ollama/registry.go index e4c36d7d..82a8bbca 100644 --- a/server/internal/client/ollama/registry.go +++ b/server/internal/client/ollama/registry.go @@ -24,6 +24,7 @@ import ( "os" "path/filepath" "runtime" + "slices" "strconv" "strings" "sync/atomic" @@ -53,7 +54,7 @@ var ( // ErrMissingModel is returned when the model part of a name is missing // or invalid. - ErrNameInvalid = errors.New("invalid name; must be in the form {scheme://}{host/}{namespace/}[model]{:tag}{@digest}") + ErrNameInvalid = errors.New("invalid or missing name") // ErrCached is passed to [Trace.PushUpdate] when a layer already // exists. It is a non-fatal error and is never returned by [Registry.Push]. @@ -205,10 +206,18 @@ type Registry struct { // It is only used when a layer is larger than [MaxChunkingThreshold]. MaxChunkSize int64 - // NameMask, if set, is the name used to convert non-fully qualified + // Mask, if set, is the name used to convert non-fully qualified // names to fully qualified names. If empty, the default mask // ("registry.ollama.ai/library/_:latest") is used. - NameMask string + Mask string +} + +func (r *Registry) completeName(name string) names.Name { + mask := defaultMask + if r.Mask != "" { + mask = names.Parse(r.Mask) + } + return names.Merge(names.Parse(name), mask) } // DefaultRegistry returns a new Registry configured from the environment. The @@ -243,52 +252,6 @@ func DefaultRegistry() (*Registry, error) { return &rc, nil } -type PushParams struct { - // From is an optional destination name for the model. If empty, the - // destination name is the same as the source name. - From string -} - -// parseName parses name using [names.ParseExtended] and then merges the name with the -// default name, and checks that the name is fully qualified. If a digest is -// present, it parse and returns it with the other fields as their zero values. -// -// It returns an error if the name is not fully qualified, or if the digest, if -// any, is invalid. -// -// The scheme is returned as provided by [names.ParseExtended]. -func parseName(s, mask string) (scheme string, n names.Name, d blob.Digest, err error) { - maskName := defaultMask - if mask != "" { - maskName = names.Parse(mask) - if !maskName.IsFullyQualified() { - return "", names.Name{}, blob.Digest{}, fmt.Errorf("invalid name mask: %s", mask) - } - } - scheme, n, ds := names.ParseExtended(s) - if !n.IsValid() { - return "", names.Name{}, blob.Digest{}, fmt.Errorf("%w: %q", ErrNameInvalid, s) - } - n = names.Merge(n, maskName) - if ds != "" { - // Digest is present. Validate it. - d, err = blob.ParseDigest(ds) - if err != nil { - return "", names.Name{}, blob.Digest{}, err - } - } - - // The name check is deferred until after the digest check because we - // say that digests take precedence over names, and so should there - // errors when being parsed. - if !n.IsFullyQualified() { - return "", names.Name{}, blob.Digest{}, fmt.Errorf("%w: %q", ErrNameInvalid, s) - } - - scheme = cmp.Or(scheme, "https") - return scheme, n, d, nil -} - func (r *Registry) maxStreams() int { n := cmp.Or(r.MaxStreams, runtime.GOMAXPROCS(0)) @@ -308,6 +271,12 @@ func (r *Registry) maxChunkSize() int64 { return cmp.Or(r.MaxChunkSize, DefaultMaxChunkSize) } +type PushParams struct { + // From is an optional destination name for the model. If empty, the + // destination name is the same as the source name. + From string +} + // Push pushes the model with the name in the cache to the remote registry. func (r *Registry) Push(ctx context.Context, c *blob.DiskCache, name string, p *PushParams) error { if p == nil { @@ -337,7 +306,7 @@ func (r *Registry) Push(ctx context.Context, c *blob.DiskCache, name string, p * t := traceFromContext(ctx) - scheme, n, _, err := parseName(name, r.NameMask) + scheme, n, _, err := parseName(name, r.Mask) if err != nil { // This should never happen since ResolveLocal should have // already validated the name. @@ -431,7 +400,7 @@ func canRetry(err error) bool { // typically slower than splitting the model up across layers, and is mostly // utilized for layers of type equal to "application/vnd.ollama.image". func (r *Registry) Pull(ctx context.Context, c *blob.DiskCache, name string) error { - scheme, n, _, err := parseName(name, r.NameMask) + scheme, n, _, err := parseName(name, r.Mask) if err != nil { return err } @@ -582,9 +551,9 @@ func (r *Registry) Pull(ctx context.Context, c *blob.DiskCache, name string) err // Unlink is like [blob.DiskCache.Unlink], but makes name fully qualified // before attempting to unlink the model. func (r *Registry) Unlink(c *blob.DiskCache, name string) (ok bool, _ error) { - _, n, _, err := parseName(name, r.NameMask) - if err != nil { - return false, err + n := r.completeName(name) + if !n.IsFullyQualified() { + return false, fmt.Errorf("%w: %q", ErrNameInvalid, name) } return c.Unlink(n.String()) } @@ -658,9 +627,9 @@ type Layer struct { } // ResolveLocal resolves a name to a Manifest in the local cache. The name is -// parsed using [names.ParseExtended] but the scheme is ignored. +// parsed using [names.Split] but the scheme is ignored. func (r *Registry) ResolveLocal(c *blob.DiskCache, name string) (*Manifest, error) { - _, n, d, err := parseName(name, r.NameMask) + _, n, d, err := parseName(name, r.Mask) if err != nil { return nil, err } @@ -686,7 +655,7 @@ func (r *Registry) ResolveLocal(c *blob.DiskCache, name string) (*Manifest, erro // Resolve resolves a name to a Manifest in the remote registry. func (r *Registry) Resolve(ctx context.Context, name string) (*Manifest, error) { - scheme, n, d, err := parseName(name, r.NameMask) + scheme, n, d, err := parseName(name, r.Mask) if err != nil { return nil, err } @@ -869,3 +838,69 @@ func maybeUnexpectedEOF(err error) error { } return err } + +type publicError struct { + wrapped error + message string +} + +func withPublicMessagef(err error, message string, args ...any) error { + return publicError{wrapped: err, message: fmt.Sprintf(message, args...)} +} + +func (e publicError) Error() string { return e.message } +func (e publicError) Unwrap() error { return e.wrapped } + +var supportedSchemes = []string{ + "http", + "https", + "https+insecure", +} + +var supportedSchemesMessage = fmt.Sprintf("supported schemes are %v", strings.Join(supportedSchemes, ", ")) + +// parseName parses and validates an extended name, returning the scheme, name, +// and digest. +// +// If the scheme is empty, scheme will be "https". If an unsupported scheme is +// given, [ErrNameInvalid] wrapped with a display friendly message is returned. +// +// If the digest is invalid, [ErrNameInvalid] wrapped with a display friendly +// message is returned. +// +// If the name is not, once merged with the mask, fully qualified, +// [ErrNameInvalid] wrapped with a display friendly message is returned. +func parseName(s string, mask string) (scheme string, _ names.Name, _ blob.Digest, _ error) { + scheme, name, digest := names.Split(s) + scheme = cmp.Or(scheme, "https") + if !slices.Contains(supportedSchemes, scheme) { + err := withPublicMessagef(ErrNameInvalid, "unsupported scheme: %q: %s", scheme, supportedSchemesMessage) + return "", names.Name{}, blob.Digest{}, err + } + + var d blob.Digest + if digest != "" { + var err error + d, err = blob.ParseDigest(digest) + if err != nil { + err = withPublicMessagef(ErrNameInvalid, "invalid digest: %q", digest) + return "", names.Name{}, blob.Digest{}, err + } + if name == "" { + // We have can resolve a manifest from a digest only, + // so skip name validation and return the scheme and + // digest. + return scheme, names.Name{}, d, nil + } + } + + maskName := defaultMask + if mask != "" { + maskName = names.Parse(mask) + } + n := names.Merge(names.Parse(name), maskName) + if !n.IsFullyQualified() { + return "", names.Name{}, blob.Digest{}, fmt.Errorf("%w: %q", ErrNameInvalid, s) + } + return scheme, n, d, nil +} diff --git a/server/internal/client/ollama/registry_test.go b/server/internal/client/ollama/registry_test.go index af898c26..20a1f159 100644 --- a/server/internal/client/ollama/registry_test.go +++ b/server/internal/client/ollama/registry_test.go @@ -84,14 +84,14 @@ func newClient(t *testing.T, h http.HandlerFunc) (*Registry, *blob.DiskCache) { } } - rc := &Registry{ + r := &Registry{ HTTPClient: &http.Client{ Transport: recordRoundTripper(h), }, } link := func(name string, manifest string) { - _, n, _, err := parseName(name, rc.NameMask) + _, n, _, err := parseName(name, r.Mask) if err != nil { panic(err) } @@ -122,7 +122,7 @@ func newClient(t *testing.T, h http.HandlerFunc) (*Registry, *blob.DiskCache) { commit("sizemismatch", mklayer("exists"), &Layer{Digest: blob.DigestFromBytes("present"), Size: 499}) link("invalid", "!!!!!") - return rc, c + return r, c } func okHandler(w http.ResponseWriter, r *http.Request) { @@ -145,29 +145,6 @@ func importBytes(t *testing.T, c *blob.DiskCache, data string) blob.Digest { return d } -func TestRegistryPushInvalidNames(t *testing.T) { - rc, c := newClient(t, nil) - - cases := []struct { - name string - err error - }{ - {"", ErrNameInvalid}, - {"@", ErrNameInvalid}, - {"@x", blob.ErrInvalidDigest}, - } - - for _, tt := range cases { - t.Run(tt.name, func(t *testing.T) { - // Create a new registry and push a new image. - err := rc.Push(t.Context(), c, tt.name, nil) - if !errors.Is(err, tt.err) { - t.Errorf("err = %v; want %v", err, tt.err) - } - }) - } -} - func withTraceUnexpected(ctx context.Context) (context.Context, *Trace) { t := &Trace{Update: func(*Layer, int64, error) { panic("unexpected") }} return WithTrace(ctx, t), t @@ -622,7 +599,7 @@ func TestInsecureSkipVerify(t *testing.T) { })) defer s.Close() - const name = "ollama.com/library/insecure" + const name = "library/insecure" var rc Registry url := fmt.Sprintf("https://%s/%s", s.Listener.Addr(), name) @@ -724,3 +701,38 @@ func TestErrorUnmarshal(t *testing.T) { }) } } + +// TestParseNameErrors tests that parseName returns errors messages with enough +// detail for users to debug naming issues they may encounter. Previous to this +// test, the error messages were not very helpful and each problem was reported +// as the same message. +// +// It is only for testing error messages, not that all invalids and valids are +// covered. Those are in other tests for names.Name and blob.Digest. +func TestParseNameErrors(t *testing.T) { + cases := []struct { + name string + err error + want string + }{ + {"x", nil, ""}, + {"x@", nil, ""}, + + {"", ErrNameInvalid, `invalid or missing name: ""`}, + {"://", ErrNameInvalid, `invalid or missing name: "://"`}, + {"x://", ErrNameInvalid, `unsupported scheme: "x": supported schemes are http, https, https+insecure`}, + + {"@sha123-1234", ErrNameInvalid, `invalid digest: "sha123-1234"`}, + {"x@sha123-1234", ErrNameInvalid, `invalid digest: "sha123-1234"`}, + } + + for _, tt := range cases { + _, _, _, err := parseName(tt.name, DefaultMask) + if !errors.Is(err, tt.err) { + t.Errorf("[%s]: err = %v; want %v", tt.name, err, tt.err) + } + if err != nil && !strings.Contains(err.Error(), tt.want) { + t.Errorf("[%s]: err =\n\t%v\nwant\n\t%v", tt.name, err, tt.want) + } + } +} diff --git a/server/internal/internal/names/name.go b/server/internal/internal/names/name.go index 361cce76..f0a1185d 100644 --- a/server/internal/internal/names/name.go +++ b/server/internal/internal/names/name.go @@ -8,7 +8,7 @@ import ( "github.com/ollama/ollama/server/internal/internal/stringsx" ) -const MaxNameLength = 50 + 1 + 50 + 1 + 50 // /: +const MaxNameLength = 350 + 1 + 80 + 1 + 80 + 1 + 80 // //: type Name struct { // Make incomparable to enfoce use of Compare / Equal for @@ -25,19 +25,12 @@ type Name struct { // format of a valid name string is: // // s: -// { host } "/" { namespace } "/" { model } ":" { tag } "@" { digest } // { host } "/" { namespace } "/" { model } ":" { tag } -// { host } "/" { namespace } "/" { model } "@" { digest } // { host } "/" { namespace } "/" { model } -// { namespace } "/" { model } ":" { tag } "@" { digest } // { namespace } "/" { model } ":" { tag } -// { namespace } "/" { model } "@" { digest } // { namespace } "/" { model } -// { model } ":" { tag } "@" { digest } // { model } ":" { tag } -// { model } "@" { digest } // { model } -// "@" { digest } // host: // pattern: { alphanum | "_" } { alphanum | "_" | "-" | "." | ":" }* // length: [1, 350] @@ -50,9 +43,6 @@ type Name struct { // tag: // pattern: { alphanum | "_" } { alphanum | "-" | "_" | "." }* // length: [1, 80] -// digest: -// pattern: { alphanum | "_" } { alphanum | "-" | ":" }* -// length: [1, 80] // // The name returned is not guaranteed to be valid. If it is not valid, the // field values are left in an undefined state. Use [Name.IsValid] to check @@ -82,23 +72,17 @@ func Parse(s string) Name { } } -// ParseExtended parses and returns any scheme, Name, and digest from from s in -// the the form [scheme://][name][@digest]. All parts are optional. -// -// If the scheme is present, it must be followed by "://". The digest is -// prefixed by "@" and comes after the name. The name is parsed using [Parse]. -// -// The scheme and digest are stripped before the name is parsed by [Parse]. -// -// For convience, the scheme is never empty. If the scheme is not present, the -// returned scheme is "https". +// Split splits an extended name string into its scheme, name, and digest +// parts. // // Examples: // // http://ollama.com/bmizerany/smol:latest@digest // https://ollama.com/bmizerany/smol:latest // ollama.com/bmizerany/smol:latest@digest // returns "https" scheme. -func ParseExtended(s string) (scheme string, _ Name, digest string) { +// model@digest +// @digest +func Split(s string) (scheme, name, digest string) { i := strings.Index(s, "://") if i >= 0 { scheme = s[:i] @@ -109,21 +93,7 @@ func ParseExtended(s string) (scheme string, _ Name, digest string) { digest = s[i+1:] s = s[:i] } - return scheme, Parse(s), digest -} - -func FormatExtended(scheme string, n Name, digest string) string { - var b strings.Builder - if scheme != "" { - b.WriteString(scheme) - b.WriteString("://") - } - b.WriteString(n.String()) - if digest != "" { - b.WriteByte('@') - b.WriteString(digest) - } - return b.String() + return scheme, s, digest } // Merge merges two names into a single name. Non-empty host, namespace, and @@ -141,39 +111,68 @@ func Merge(a, b Name) Name { // IsValid returns true if the name is valid. func (n Name) IsValid() bool { - if n.h != "" && !isValidHost(n.h) { + if n.h != "" && !isValidPart(partHost, n.h) { return false } - if n.n != "" && !isValidNamespace(n.n) { + if n.n != "" && !isValidPart(partNamespace, n.n) { return false } - if n.m != "" && !isValidModel(n.m) { + if n.t != "" && !isValidPart(partTag, n.t) { return false } - if n.t != "" && !isValidTag(n.t) { - return false - } - return true + + // at bare minimum, model must be present and valid + return n.m != "" && isValidPart(partModel, n.m) } func (n Name) IsFullyQualified() bool { return n.IsValid() && n.h != "" && n.n != "" && n.m != "" && n.t != "" } -func isValidHost(_ string) bool { - return true // TODO: implement +const ( + partHost = iota + partNamespace + partModel + partTag +) + +func isValidPart(kind int, s string) bool { + maxlen := 80 + if kind == partHost { + maxlen = 350 + } + if len(s) > maxlen { + return false + } + + for i := range s { + if i == 0 { + if !isAlphanumericOrUnderscore(s[i]) { + return false + } + continue + } + switch s[i] { + case '_', '-': + case '.': + if kind == partNamespace { + return false + } + case ':': + if kind != partHost { + return false + } + default: + if !isAlphanumericOrUnderscore(s[i]) { + return false + } + } + } + return true } -func isValidNamespace(_ string) bool { - return true // TODO: implement -} - -func isValidModel(_ string) bool { - return true // TODO: implement -} - -func isValidTag(_ string) bool { - return true // TODO: implement +func isAlphanumericOrUnderscore(c byte) bool { + return c >= 'A' && c <= 'Z' || c >= 'a' && c <= 'z' || c >= '0' && c <= '9' || c == '_' } func (n Name) Host() string { return n.h } diff --git a/server/internal/internal/names/name_test.go b/server/internal/internal/names/name_test.go index 760fec5f..e3dc5fe3 100644 --- a/server/internal/internal/names/name_test.go +++ b/server/internal/internal/names/name_test.go @@ -81,15 +81,11 @@ func TestParseExtended(t *testing.T) { } for _, tt := range cases { t.Run(tt.in, func(t *testing.T) { - scheme, name, digest := ParseExtended(tt.in) - if scheme != tt.wantScheme || name.Compare(tt.wantName) != 0 || digest != tt.wantDigest { + scheme, name, digest := Split(tt.in) + n := Parse(name) + if scheme != tt.wantScheme || n.Compare(tt.wantName) != 0 || digest != tt.wantDigest { t.Errorf("ParseExtended(%q) = %q, %#v, %q, want %q, %#v, %q", tt.in, scheme, name, digest, tt.wantScheme, tt.wantName, tt.wantDigest) } - - // Round trip - if got := FormatExtended(scheme, name, digest); got != tt.in { - t.Errorf("FormatExtended(%q, %q, %q) = %q", scheme, name, digest, got) - } }) } } @@ -150,3 +146,75 @@ func BenchmarkParseName(b *testing.B) { junkName = Parse("h/n/m:t") } } + +const ( + part80 = "88888888888888888888888888888888888888888888888888888888888888888888888888888888" + part350 = "33333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333" +) + +var testCases = map[string]bool{ // name -> valid + "": false, + + "_why/_the/_lucky:_stiff": true, + + // minimal + "h/n/m:t": true, + + "host/namespace/model:tag": true, + "host/namespace/model": true, + "namespace/model": true, + "model": true, + + // long (but valid) + part80 + "/" + part80 + "/" + part80 + ":" + part80: true, + part350 + "/" + part80 + "/" + part80 + ":" + part80: true, + + // too long + part80 + "/" + part80 + "/" + part80 + ":" + part350: false, + "x" + part350 + "/" + part80 + "/" + part80 + ":" + part80: false, + + "h/nn/mm:t": true, // bare minimum part sizes + + // unqualified + "m": true, + "n/m:": true, + "h/n/m": true, + "@t": false, + "m@d": false, + + // invalids + "^": false, + "mm:": true, + "/nn/mm": true, + "//": false, // empty model + "//mm": true, + "hh//": false, // empty model + "//mm:@": false, + "00@": false, + "@": false, + + // not starting with alphanum + "-hh/nn/mm:tt": false, + "hh/-nn/mm:tt": false, + "hh/nn/-mm:tt": false, + "hh/nn/mm:-tt": false, + + // smells like a flag + "-h": false, + + // hosts + "host:https/namespace/model:tag": true, + + // colon in non-host part before tag + "host/name:space/model:tag": false, +} + +func TestParseNameValidation(t *testing.T) { + for s, valid := range testCases { + got := Parse(s) + if got.IsValid() != valid { + t.Logf("got: %v", got) + t.Errorf("Parse(%q).IsValid() = %v; want !%[2]v", s, got.IsValid()) + } + } +} diff --git a/server/internal/registry/server.go b/server/internal/registry/server.go index 8eb6daf8..6ea590a7 100644 --- a/server/internal/registry/server.go +++ b/server/internal/registry/server.go @@ -204,7 +204,7 @@ func (s *Local) handleDelete(_ http.ResponseWriter, r *http.Request) error { return err } if !ok { - return &serverError{404, "manifest_not_found", "manifest not found"} + return &serverError{404, "not_found", "model not found"} } return nil } diff --git a/server/internal/registry/server_test.go b/server/internal/registry/server_test.go index 22267ba7..7ba13d50 100644 --- a/server/internal/registry/server_test.go +++ b/server/internal/registry/server_test.go @@ -109,11 +109,8 @@ func TestServerDelete(t *testing.T) { got = s.send(t, "DELETE", "/api/delete", ``) checkErrorResponse(t, got, 400, "bad_request", "empty request body") - got = s.send(t, "DELETE", "/api/delete", `{"model": "!"}`) - checkErrorResponse(t, got, 404, "manifest_not_found", "not found") - got = s.send(t, "DELETE", "/api/delete", `{"model": "://"}`) - checkErrorResponse(t, got, 400, "bad_request", "invalid name") + checkErrorResponse(t, got, 400, "bad_request", "invalid or missing name") got = s.send(t, "DELETE", "/unknown_path", `{}`) // valid body checkErrorResponse(t, got, 404, "not_found", "not found")