Merge branch 'ollama:main' into main

This commit is contained in:
likelovewant
2025-03-02 13:45:52 +08:00
committed by GitHub
19 changed files with 665 additions and 273 deletions

View File

@@ -122,9 +122,11 @@ if(CMAKE_HIP_COMPILER)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-hip) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-hip)
if (WIN32) if (WIN32)
target_compile_definitions(ggml-hip PRIVATE GGML_CUDA_NO_PEER_COPY=1) target_compile_definitions(ggml-hip PRIVATE GGML_CUDA_NO_PEER_COPY)
endif() endif()
target_compile_definitions(ggml-hip PRIVATE GGML_HIP_NO_VMM)
set(OLLAMA_HIP_INSTALL_DIR ${OLLAMA_INSTALL_DIR}/rocm) set(OLLAMA_HIP_INSTALL_DIR ${OLLAMA_INSTALL_DIR}/rocm)
install(TARGETS ggml-hip install(TARGETS ggml-hip
RUNTIME_DEPENDENCIES RUNTIME_DEPENDENCIES

View File

@@ -29,6 +29,17 @@ type Cache interface {
// cache implementation used. // cache implementation used.
Put(ctx ml.Context, key, value ml.Tensor) Put(ctx ml.Context, key, value ml.Tensor)
// SetConfig controls optimizations (mostly backend-specific) that may transform
// the output of the cache to work better with specific kernels. If not called,
// the backend settings will be used. This works well when calling Attention.
//
// The config can be overridden by models, especially if they require vanilla
// output when implementing their own version of attention. To do this, pass
// an empty ml.CacheConfig.
//
// Most models will not need to use this.
SetConfig(ml.CacheConfig)
// ** cache management ** // ** cache management **
// Init sets up runtime parameters // Init sets up runtime parameters

View File

@@ -22,6 +22,9 @@ type Causal struct {
Capacity int32 Capacity int32
windowSize int32 windowSize int32
// config controls mostly backend-specific optimizations
config *ml.CacheConfig
// ** current forward pass ** // ** current forward pass **
// the active layer for Get and Put // the active layer for Get and Put
@@ -75,14 +78,42 @@ func NewSWACache(windowSize int32, shift shiftFn) *Causal {
} }
func (c *Causal) Init(backend ml.Backend, dtype ml.DType, capacity int32) { func (c *Causal) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
if c.config == nil {
var config ml.CacheConfig
if cc, ok := backend.(ml.BackendCacheConfig); ok {
config = cc.CacheConfig()
}
c.config = &config
}
if c.config.CachePadding == 0 {
c.config.CachePadding = 1
}
if c.config.MaskBatchPadding == 0 {
c.config.MaskBatchPadding = 1
}
if c.config.MaskDType == ml.DTypeOther {
c.config.MaskDType = ml.DTypeF32
}
c.DType = dtype c.DType = dtype
c.Capacity = capacity c.Capacity = int32(roundUp(int(capacity), c.config.CachePadding))
c.cells = make([]cacheCell, capacity) c.cells = make([]cacheCell, c.Capacity)
c.cellRanges = make(map[int]cellRange) c.cellRanges = make(map[int]cellRange)
c.backend = backend c.backend = backend
c.cacheCtx = backend.NewContext() c.cacheCtx = backend.NewContext()
} }
func (c *Causal) SetConfig(config ml.CacheConfig) {
if c.config != nil {
panic("config cannot be changed after being previously set, either by the model or backend")
}
c.config = &config
}
func (c *Causal) Close() { func (c *Causal) Close() {
c.cacheCtx.Close() c.cacheCtx.Close()
} }
@@ -157,36 +188,91 @@ func (c *Causal) findStartLoc() (int, error) {
return 0, fmt.Errorf("%w (length: %v)", ErrKvCacheFull, c.Capacity) return 0, fmt.Errorf("%w (length: %v)", ErrKvCacheFull, c.Capacity)
} }
func roundDown(length, pad int) int {
return (length / pad) * pad
}
func roundUp(length, pad int) int {
return ((length + pad - 1) / pad) * pad
}
// Builds a mask of history x batch indicating whether for each token in the batch the // Builds a mask of history x batch indicating whether for each token in the batch the
// token in the history should apply. This is based on both the sequence and causality (the // token in the history should apply. This is based on both the sequence and causality (the
// position of the history is not ahead of the token in the batch). // position of the history is not ahead of the token in the batch).
func (c *Causal) buildMask(ctx ml.Context, positions []int32, seqs []int) (ml.Tensor, error) { func (c *Causal) buildMask(ctx ml.Context, positions []int32, seqs []int) (ml.Tensor, error) {
// TODO(jessegross): This does not do padding, which is required for flash attention // Align and pad the two dimensions as required by the backend
len := c.curCellRange.max - c.curCellRange.min + 1 batchSize := roundUp(c.curBatchSize, c.config.MaskBatchPadding)
mask := make([]float32, c.curBatchSize*len)
c.curCellRange.min = roundDown(c.curCellRange.min, c.config.CachePadding)
c.curCellRange.max = roundUp(c.curCellRange.max+1, c.config.CachePadding) - 1
length := c.curCellRange.max - c.curCellRange.min + 1
mask := make([]float32, batchSize*length)
for i := range c.curBatchSize { for i := range c.curBatchSize {
for j := c.curCellRange.min; j <= c.curCellRange.max; j++ { for j := c.curCellRange.min; j <= c.curCellRange.max; j++ {
if !slices.Contains(c.cells[j].sequences, seqs[i]) || c.cells[j].pos > positions[i] || if !slices.Contains(c.cells[j].sequences, seqs[i]) || c.cells[j].pos > positions[i] ||
c.cells[j].pos < positions[i]-c.windowSize { c.cells[j].pos < positions[i]-c.windowSize {
mask[i*len+(j-c.curCellRange.min)] = float32(math.Inf(-1)) mask[i*length+(j-c.curCellRange.min)] = float32(math.Inf(-1))
} }
} }
} }
return ctx.FromFloatSlice(mask, len, c.curBatchSize) // Mask out any padding tokens we added. For padding that we added to the cache history, this
// has already been masked out because the sequence doesn't match.
for i := c.curBatchSize * length; i < len(mask); i++ {
mask[i] = float32(math.Inf(-1))
}
maskTensor, err := ctx.FromFloatSlice(mask, length, batchSize)
if err != nil {
return nil, err
}
if c.config.MaskDType != ml.DTypeF32 {
out := ctx.Empty(c.config.MaskDType, maskTensor.Shape()...)
ctx.Forward(maskTensor.Copy(ctx, out))
maskTensor = out
}
return maskTensor, nil
} }
func moveCell(ctx ml.Context, objs []ml.Tensor, src, dst, len int) { func (c *Causal) moveCells(ctx ml.Context, src, dst, len int) {
for _, obj := range objs { for i := range c.keys {
if obj == nil { if c.keys[i] == nil {
continue continue
} }
srcView := obj.View(ctx, obj.Stride(2)*src, obj.Dim(0)*obj.Dim(1)*len) key := c.keys[i]
dstView := obj.View(ctx, obj.Stride(2)*dst, obj.Dim(0)*obj.Dim(1)*len)
ctx.Forward(srcView.Copy(ctx, dstView)) kHeadDim := key.Dim(0)
numKVHeads := key.Dim(1)
rowSize := key.Stride(2)
kSrcView := key.View(ctx, rowSize*src, kHeadDim*numKVHeads*len)
kDstView := key.View(ctx, rowSize*dst, kHeadDim*numKVHeads*len)
value := c.values[i]
var vSrcView, vDstView ml.Tensor
if c.config.PermutedV {
vHeadDim := value.Dim(1)
elemSize := value.Stride(0)
vSrcView = value.View(ctx, elemSize*src, len, int(c.Capacity)*elemSize, vHeadDim*numKVHeads)
vDstView = value.View(ctx, elemSize*dst, len, int(c.Capacity)*elemSize, vHeadDim*numKVHeads)
} else {
vHeadDim := value.Dim(0)
rowSize := value.Stride(2)
vSrcView = value.View(ctx, rowSize*src, vHeadDim*numKVHeads*len)
vDstView = value.View(ctx, rowSize*dst, vHeadDim*numKVHeads*len)
}
ctx.Forward(
kSrcView.Copy(ctx, kDstView),
vSrcView.Copy(ctx, vDstView),
)
} }
} }
@@ -238,8 +324,7 @@ func (c *Causal) defrag() {
pendingLen++ pendingLen++
break break
} else { } else {
moveCell(ctx, c.keys, pendingSrc, pendingDst, pendingLen) c.moveCells(ctx, pendingSrc, pendingDst, pendingLen)
moveCell(ctx, c.values, pendingSrc, pendingDst, pendingLen)
moves++ moves++
} }
} }
@@ -263,8 +348,7 @@ func (c *Causal) defrag() {
} }
if pendingLen > 0 { if pendingLen > 0 {
moveCell(ctx, c.keys, pendingSrc, pendingDst, pendingLen) c.moveCells(ctx, pendingSrc, pendingDst, pendingLen)
moveCell(ctx, c.values, pendingSrc, pendingDst, pendingLen)
moves++ moves++
} }
@@ -305,35 +389,73 @@ func (c *Causal) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
key := c.keys[c.curLayer] key := c.keys[c.curLayer]
value := c.values[c.curLayer] value := c.values[c.curLayer]
key = key.View(ctx, key.Stride(2)*c.curCellRange.min, kHeadDim := key.Dim(0)
key.Dim(0), key.Stride(1), numKVHeads := key.Dim(1)
key.Dim(1), key.Stride(2), rowSize := key.Stride(2)
c.curMask.Dim(0), cachedSize := c.curMask.Dim(0)
key = key.View(ctx, rowSize*c.curCellRange.min,
kHeadDim, key.Stride(1),
numKVHeads, key.Stride(2),
cachedSize,
) )
value = value.View(ctx, key.Stride(2)*c.curCellRange.min, if c.config.PermutedV {
value.Dim(0), value.Stride(1), vHeadDim := value.Dim(1)
value.Dim(1), value.Stride(2), elemSize := value.Stride(0)
c.curMask.Dim(0),
) value = value.View(ctx, elemSize*c.curCellRange.min,
cachedSize, value.Stride(1),
vHeadDim, value.Stride(2),
numKVHeads,
)
} else {
vHeadDim := value.Dim(0)
rowSize := value.Stride(2)
value = value.View(ctx, rowSize*c.curCellRange.min,
vHeadDim, value.Stride(1),
numKVHeads, value.Stride(2),
cachedSize,
)
}
return key, value, c.curMask return key, value, c.curMask
} }
func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) { func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
if c.curBatchSize != key.Dim(2) { kHeadDim := key.Dim(0)
panic(fmt.Errorf("inconsistent batch sizes (layer: %v, batch size: %v layer batch size: %v)", c.curLayer, c.curBatchSize, key.Dim(2))) vHeadDim := value.Dim(0)
numKVHeads := key.Dim(1)
batchSize := key.Dim(2)
if c.curBatchSize != batchSize {
panic(fmt.Errorf("inconsistent batch sizes (layer: %v, batch size: %v layer batch size: %v)", c.curLayer, c.curBatchSize, batchSize))
} }
if c.keys[c.curLayer] == nil || c.values[c.curLayer] == nil { if c.keys[c.curLayer] == nil || c.values[c.curLayer] == nil {
c.keys[c.curLayer] = c.cacheCtx.Zeros(c.DType, key.Dim(0), key.Dim(1), int(c.Capacity)) c.keys[c.curLayer] = c.cacheCtx.Zeros(c.DType, kHeadDim, numKVHeads, int(c.Capacity))
c.values[c.curLayer] = c.cacheCtx.Zeros(c.DType, value.Dim(0), value.Dim(1), int(c.Capacity))
if c.config.PermutedV {
c.values[c.curLayer] = c.cacheCtx.Zeros(c.DType, int(c.Capacity), vHeadDim, numKVHeads)
} else {
c.values[c.curLayer] = c.cacheCtx.Zeros(c.DType, vHeadDim, numKVHeads, int(c.Capacity))
}
} }
ctx.Forward( rowSize := c.keys[c.curLayer].Stride(2)
key.Copy(ctx, c.keys[c.curLayer].View(ctx, c.keys[c.curLayer].Stride(2)*c.curLoc, key.Dim(0)*key.Dim(1)*key.Dim(2))), ctx.Forward(key.Copy(ctx, c.keys[c.curLayer].View(ctx, rowSize*c.curLoc, kHeadDim*numKVHeads*batchSize)))
value.Copy(ctx, c.values[c.curLayer].View(ctx, c.values[c.curLayer].Stride(2)*c.curLoc, value.Dim(0)*value.Dim(1)*value.Dim(2))),
) if c.config.PermutedV {
elemSize := c.values[c.curLayer].Stride(0)
value = value.Permute(ctx, 1, 2, 0, 3)
ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, elemSize*c.curLoc, batchSize, int(c.Capacity)*elemSize, vHeadDim*numKVHeads)))
} else {
rowSize := c.values[c.curLayer].Stride(2)
ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, rowSize*c.curLoc, vHeadDim*numKVHeads*batchSize)))
}
} }
func (c *Causal) CopyPrefix(srcSeq, dstSeq int, len int32) { func (c *Causal) CopyPrefix(srcSeq, dstSeq int, len int32) {
@@ -389,9 +511,13 @@ func (c *Causal) shift(seq int, beginIndex, offset int32) error {
continue continue
} }
key = key.View(ctx, key.Stride(2)*seqRange.min, kHeadDim := key.Dim(0)
key.Dim(0), key.Stride(1), numKVHeads := key.Dim(1)
key.Dim(1), key.Stride(2), rowSize := key.Stride(2)
key = key.View(ctx, rowSize*seqRange.min,
kHeadDim, key.Stride(1),
numKVHeads, key.Stride(2),
size, size,
) )

View File

@@ -309,7 +309,7 @@ func (b *testBackend) SystemInfo() string {
type testContext struct{} type testContext struct{}
func (c *testContext) Zeros(dtype ml.DType, shape ...int) ml.Tensor { func (c *testContext) Empty(dtype ml.DType, shape ...int) ml.Tensor {
total := 0 total := 0
if len(shape) > 0 { if len(shape) > 0 {
@@ -322,8 +322,12 @@ func (c *testContext) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
return &testTensor{dtype: dtype, elementSize: 4, data: make([]float32, total), shape: shape} return &testTensor{dtype: dtype, elementSize: 4, data: make([]float32, total), shape: shape}
} }
func (c *testContext) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
return c.Empty(dtype, shape...)
}
func (c *testContext) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) { func (c *testContext) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) {
t := c.Zeros(ml.DTypeF32, shape...).(*testTensor) t := c.Empty(ml.DTypeF32, shape...).(*testTensor)
copy(t.data, s) copy(t.data, s)
@@ -391,7 +395,7 @@ func (t *testTensor) Floats() []float32 {
} }
func (t *testTensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor { func (t *testTensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
out := ctx.Zeros(t.DType(), t.Shape()...).(*testTensor) out := ctx.Empty(t.DType(), t.Shape()...).(*testTensor)
for i := range out.data { for i := range out.data {
out.data[i] = t.data[i] + t2.(*testTensor).data[i] out.data[i] = t.data[i] + t2.(*testTensor).data[i]
@@ -468,7 +472,7 @@ func (t *testTensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
context := &testContext{} context := &testContext{}
view := context.Zeros(t.dtype, s...).(*testTensor) view := context.Empty(t.dtype, s...).(*testTensor)
view.data = t.data[offset : offset+len(view.data)] view.data = t.data[offset : offset+len(view.data)]
return view return view

View File

@@ -1,6 +1,8 @@
package kvcache package kvcache
import ( import (
"fmt"
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml"
) )
@@ -11,6 +13,9 @@ import (
// //
// Not currently safe for multiple sequences // Not currently safe for multiple sequences
type EncoderCache struct { type EncoderCache struct {
// config controls mostly backend-specific optimizations
config *ml.CacheConfig
// ** current forward pass ** // ** current forward pass **
// the active layer for Get and Put // the active layer for Get and Put
@@ -40,9 +45,29 @@ func NewEncoderCache() *EncoderCache {
} }
func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, capacity int32) { func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
if c.config == nil {
var config ml.CacheConfig
if cc, ok := backend.(ml.BackendCacheConfig); ok {
config = cc.CacheConfig()
}
c.config = &config
}
if c.config.CachePadding != 0 && c.config.CachePadding != 1 {
panic(fmt.Errorf("encoder cache is unable to enforce requested CachePadding (%v)", c.config.CachePadding))
}
c.cacheCtx = backend.NewContext() c.cacheCtx = backend.NewContext()
} }
func (c *EncoderCache) SetConfig(config ml.CacheConfig) {
if c.config != nil {
panic("config cannot be changed after being previously set, either by the model or backend")
}
c.config = &config
}
func (c *EncoderCache) Close() { func (c *EncoderCache) Close() {
c.cacheCtx.Close() c.cacheCtx.Close()
} }
@@ -75,9 +100,13 @@ func (c *EncoderCache) Put(ctx ml.Context, key, value ml.Tensor) {
c.encoderPos = c.curPos c.encoderPos = c.curPos
c.encoderCached = true c.encoderCached = true
if c.config.PermutedV {
value = value.Permute(ctx, 1, 2, 0, 3)
}
if c.keys[c.curLayer] == nil || c.values[c.curLayer] == nil { if c.keys[c.curLayer] == nil || c.values[c.curLayer] == nil {
c.keys[c.curLayer] = c.cacheCtx.Zeros(key.DType(), key.Shape()...) c.keys[c.curLayer] = c.cacheCtx.Empty(key.DType(), key.Shape()...)
c.values[c.curLayer] = c.cacheCtx.Zeros(value.DType(), value.Shape()...) c.values[c.curLayer] = c.cacheCtx.Empty(value.DType(), value.Shape()...)
} }
ctx.Forward( ctx.Forward(

View File

@@ -28,6 +28,12 @@ func (c *WrapperCache) Init(backend ml.Backend, dtype ml.DType, capacity int32)
} }
} }
func (c *WrapperCache) SetConfig(config ml.CacheConfig) {
for _, cache := range c.caches {
cache.SetConfig(config)
}
}
func (c *WrapperCache) Close() { func (c *WrapperCache) Close() {
for _, cache := range c.caches { for _, cache := range c.caches {
cache.Close() cache.Close()

View File

@@ -27,6 +27,35 @@ type Backend interface {
SystemInfo() string SystemInfo() string
} }
// BackendCacheConfig should be implemented by backends that need special output
// from the cache to meet specific requirements. It is frequently implemented in
// conjunction with ScaledDotProductAttention.
type BackendCacheConfig interface {
CacheConfig() CacheConfig
}
// CacheConfig controls optimizations (mostly backend-specific) that may transform
// the output the cache to work better with specific kernels.
type CacheConfig struct {
// CachePadding specifies the multiple for the number of tokens of cache history
// that will be returned from cache Get for k, v and mask. The capacity of the
// cache itself will also be increased to a multiple of this size if needed.
CachePadding int
// PermutedV performs Permute(ctx, 1, 2, 0, 3) on v tensors stored via Put
// and return the permuted version via Get. This uses the cache copy operation
// to avoid a Contiguous call on the permuted tensor.
PermutedV bool
// MaskDType specifies the data type for generating the mask. If unset it will
// default to DTypeF32.
MaskDType DType
// MaskBatchPadding specifies the multiple for the batch size dimension in the mask.
// Any position that does not correspond to an actual token will be filled with -Inf.
MaskBatchPadding int
}
// BackendParams controls how the backend loads and executes models // BackendParams controls how the backend loads and executes models
type BackendParams struct { type BackendParams struct {
// NumThreads sets the number of threads to use if running on the CPU // NumThreads sets the number of threads to use if running on the CPU
@@ -40,6 +69,9 @@ type BackendParams struct {
// TensorSplit is the fraction of the model to offload to each GPU // TensorSplit is the fraction of the model to offload to each GPU
TensorSplit []float32 TensorSplit []float32
// FlashAttention indicates that we should use a fused flash attention kernel
FlashAttention bool
} }
var backends = make(map[string]func(*os.File, BackendParams) (Backend, error)) var backends = make(map[string]func(*os.File, BackendParams) (Backend, error))
@@ -61,6 +93,7 @@ func NewBackend(f *os.File, params BackendParams) (Backend, error) {
} }
type Context interface { type Context interface {
Empty(dtype DType, shape ...int) Tensor
Zeros(dtype DType, shape ...int) Tensor Zeros(dtype DType, shape ...int) Tensor
FromFloatSlice(s []float32, shape ...int) (Tensor, error) FromFloatSlice(s []float32, shape ...int) (Tensor, error)
FromIntSlice(s []int32, shape ...int) (Tensor, error) FromIntSlice(s []int32, shape ...int) (Tensor, error)
@@ -116,6 +149,10 @@ type Tensor interface {
// operation equivalent to following code on a tensor named // operation equivalent to following code on a tensor named
// query: // query:
// //
// query = query.Permute(ctx, 0, 2, 1, 3)
// key = key.Permute(ctx, 0, 2, 1, 3)
// value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
//
// kq := key.MulmatFullPrec(ctx, query) // kq := key.MulmatFullPrec(ctx, query)
// //
// kq = kq.Scale(ctx, scale) // kq = kq.Scale(ctx, scale)
@@ -170,7 +207,7 @@ func Dump(ctx Context, t Tensor, opts ...DumpOptions) string {
return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32) return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32)
}) })
case DTypeF16: case DTypeF16:
f32 := ctx.Zeros(DTypeF32, t.Shape()...) f32 := ctx.Empty(DTypeF32, t.Shape()...)
f32 = t.Copy(ctx, f32) f32 = t.Copy(ctx, f32)
return dump[[]float32](ctx, f32, opts[0].Items, func(f float32) string { return dump[[]float32](ctx, f32, opts[0].Items, func(f float32) string {
return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32) return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32)

View File

@@ -79,6 +79,8 @@ var devices = sync.OnceValue(func() []device {
}) })
type Backend struct { type Backend struct {
flashAttention bool
meta *fs.GGML meta *fs.GGML
cpus, gpus []Context cpus, gpus []Context
tensors map[string]*Context tensors map[string]*Context
@@ -192,9 +194,10 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
} }
return &Backend{ return &Backend{
meta: meta, flashAttention: params.FlashAttention,
cpus: cpus, meta: meta,
gpus: gpus, cpus: cpus,
gpus: gpus,
sched: C.ggml_backend_sched_new( sched: C.ggml_backend_sched_new(
(*C.ggml_backend_t)(unsafe.Pointer(&backends[0])), (*C.ggml_backend_t)(unsafe.Pointer(&backends[0])),
(*C.ggml_backend_buffer_type_t)(unsafe.Pointer(&bufts[0])), (*C.ggml_backend_buffer_type_t)(unsafe.Pointer(&bufts[0])),
@@ -219,7 +222,7 @@ func (b *Backend) Get(name string) ml.Tensor {
for _, c := range append(b.gpus, b.cpus...) { for _, c := range append(b.gpus, b.cpus...) {
if t := C.ggml_get_tensor(c.ctx, cname); t != nil { if t := C.ggml_get_tensor(c.ctx, cname); t != nil {
return &Tensor{t: t} return &Tensor{b: b, t: t}
} }
} }
@@ -247,6 +250,14 @@ func (b *Backend) NewContext() ml.Context {
} }
} }
func (b *Backend) CacheConfig() ml.CacheConfig {
if b.flashAttention {
return ml.CacheConfig{CachePadding: 256, MaskDType: ml.DTypeF16, MaskBatchPadding: C.GGML_KQ_MASK_PAD}
} else {
return ml.CacheConfig{CachePadding: 32, PermutedV: true}
}
}
type Context struct { type Context struct {
b *Backend b *Backend
ctx *C.struct_ggml_context ctx *C.struct_ggml_context
@@ -300,7 +311,7 @@ func shapeToGGML(shape []int) *C.int64_t {
return &sh[0] return &sh[0]
} }
func (c Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor { func newTensor(ctx Context, dtype ml.DType, zero bool, shape []int) ml.Tensor {
if len(shape) < 1 || len(shape) > 4 { if len(shape) < 1 || len(shape) > 4 {
panic("unsupported number of dimensions") panic("unsupported number of dimensions")
} }
@@ -314,19 +325,29 @@ func (c Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
var t *C.struct_ggml_tensor var t *C.struct_ggml_tensor
switch dtype { switch dtype {
case ml.DTypeF32: case ml.DTypeF32:
t = C.ggml_new_tensor(c.ctx, C.GGML_TYPE_F32, C.int(len(shape)), shapeToGGML(shape)) t = C.ggml_new_tensor(ctx.ctx, C.GGML_TYPE_F32, C.int(len(shape)), shapeToGGML(shape))
case ml.DTypeF16: case ml.DTypeF16:
t = C.ggml_new_tensor(c.ctx, C.GGML_TYPE_F16, C.int(len(shape)), shapeToGGML(shape)) t = C.ggml_new_tensor(ctx.ctx, C.GGML_TYPE_F16, C.int(len(shape)), shapeToGGML(shape))
case ml.DTypeI32: case ml.DTypeI32:
t = C.ggml_new_tensor(c.ctx, C.GGML_TYPE_I32, C.int(len(shape)), shapeToGGML(shape)) t = C.ggml_new_tensor(ctx.ctx, C.GGML_TYPE_I32, C.int(len(shape)), shapeToGGML(shape))
default: default:
panic("unsupported dtype") panic("unsupported dtype")
} }
b := C.ggml_backend_alloc_buffer(c.backend, C.ggml_nbytes(t)) b := C.ggml_backend_alloc_buffer(ctx.backend, C.ggml_nbytes(t))
C.ggml_backend_tensor_alloc(b, t, C.ggml_backend_buffer_get_base(b)) C.ggml_backend_tensor_alloc(b, t, C.ggml_backend_buffer_get_base(b))
C.ggml_set_zero(t) if zero {
return &Tensor{t: t} C.ggml_set_zero(t)
}
return &Tensor{b: ctx.b, t: t}
}
func (c Context) Empty(dtype ml.DType, shape ...int) ml.Tensor {
return newTensor(c, dtype, false, shape)
}
func (c Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
return newTensor(c, dtype, true, shape)
} }
func fromSlice[S ~[]E, E float32 | int32](ctx Context, s S, shape []int, dtype uint32) (ml.Tensor, error) { func fromSlice[S ~[]E, E float32 | int32](ctx Context, s S, shape []int, dtype uint32) (ml.Tensor, error) {
@@ -335,7 +356,7 @@ func fromSlice[S ~[]E, E float32 | int32](ctx Context, s S, shape []int, dtype u
if n == 0 { if n == 0 {
var shape C.int64_t = 0 var shape C.int64_t = 0
t := C.ggml_new_tensor(ctx.ctx, dtype, 1, &shape) t := C.ggml_new_tensor(ctx.ctx, dtype, 1, &shape)
return &Tensor{t: t}, nil return &Tensor{b: ctx.b, t: t}, nil
} }
for _, v := range shape { for _, v := range shape {
@@ -350,7 +371,7 @@ func fromSlice[S ~[]E, E float32 | int32](ctx Context, s S, shape []int, dtype u
b := C.ggml_backend_alloc_buffer(ctx.backend, C.ggml_nbytes(t)) b := C.ggml_backend_alloc_buffer(ctx.backend, C.ggml_nbytes(t))
C.ggml_backend_tensor_alloc(b, t, C.ggml_backend_buffer_get_base(b)) C.ggml_backend_tensor_alloc(b, t, C.ggml_backend_buffer_get_base(b))
C.ggml_backend_tensor_set(t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t)) C.ggml_backend_tensor_set(t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t))
return &Tensor{t: t}, nil return &Tensor{b: ctx.b, t: t}, nil
} }
func (c Context) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) { func (c Context) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) {
@@ -368,6 +389,7 @@ func (c *Context) Close() {
} }
type Tensor struct { type Tensor struct {
b *Backend
t *C.struct_ggml_tensor t *C.struct_ggml_tensor
sync func() sync func()
} }
@@ -434,6 +456,7 @@ func (t *Tensor) DType() ml.DType {
func (t *Tensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor { func (t *Tensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
return &Tensor{ return &Tensor{
b: t.b,
t: C.ggml_add(ctx.(*Context).ctx, t.t, t2.(*Tensor).t), t: C.ggml_add(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
} }
} }
@@ -448,24 +471,28 @@ func (t *Tensor) Stack(ctx ml.Context, dim int, s ...ml.Tensor) ml.Tensor {
func (t *Tensor) Concat(ctx ml.Context, t2 ml.Tensor, dim int) ml.Tensor { func (t *Tensor) Concat(ctx ml.Context, t2 ml.Tensor, dim int) ml.Tensor {
return &Tensor{ return &Tensor{
b: t.b,
t: C.ggml_concat(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.int(dim)), t: C.ggml_concat(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.int(dim)),
} }
} }
func (t *Tensor) Contiguous(ctx ml.Context) ml.Tensor { func (t *Tensor) Contiguous(ctx ml.Context) ml.Tensor {
return &Tensor{ return &Tensor{
b: t.b,
t: C.ggml_cont(ctx.(*Context).ctx, t.t), t: C.ggml_cont(ctx.(*Context).ctx, t.t),
} }
} }
func (t *Tensor) Mul(ctx ml.Context, t2 ml.Tensor) ml.Tensor { func (t *Tensor) Mul(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
return &Tensor{ return &Tensor{
b: t.b,
t: C.ggml_mul(ctx.(*Context).ctx, t.t, t2.(*Tensor).t), t: C.ggml_mul(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
} }
} }
func (t *Tensor) Mulmat(ctx ml.Context, t2 ml.Tensor) ml.Tensor { func (t *Tensor) Mulmat(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
return &Tensor{ return &Tensor{
b: t.b,
t: C.ggml_mul_mat(ctx.(*Context).ctx, t.t, t2.(*Tensor).t), t: C.ggml_mul_mat(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
} }
} }
@@ -475,12 +502,13 @@ func (t *Tensor) MulmatFullPrec(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
C.ggml_mul_mat_set_prec(mul, C.GGML_PREC_F32) C.ggml_mul_mat_set_prec(mul, C.GGML_PREC_F32)
return &Tensor{ return &Tensor{
b: t.b,
t: mul, t: mul,
} }
} }
func (t *Tensor) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tensor { func (t *Tensor) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tensor {
tt := (&Tensor{t: C.ggml_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w) tt := (&Tensor{b: t.b, t: C.ggml_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w)
if b != nil { if b != nil {
tt = tt.Add(ctx, b) tt = tt.Add(ctx, b)
} }
@@ -489,7 +517,7 @@ func (t *Tensor) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tenso
} }
func (t *Tensor) RMSNorm(ctx ml.Context, w ml.Tensor, eps float32) ml.Tensor { func (t *Tensor) RMSNorm(ctx ml.Context, w ml.Tensor, eps float32) ml.Tensor {
return (&Tensor{t: C.ggml_rms_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w) return (&Tensor{b: t.b, t: C.ggml_rms_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w)
} }
func (t *Tensor) Pad(ctx ml.Context, shape ...int) ml.Tensor { func (t *Tensor) Pad(ctx ml.Context, shape ...int) ml.Tensor {
@@ -498,6 +526,7 @@ func (t *Tensor) Pad(ctx ml.Context, shape ...int) ml.Tensor {
} }
return &Tensor{ return &Tensor{
b: t.b,
t: C.ggml_pad(ctx.(*Context).ctx, t.t, C.int(shape[0]), C.int(shape[1]), C.int(shape[2]), C.int(shape[3])), t: C.ggml_pad(ctx.(*Context).ctx, t.t, C.int(shape[0]), C.int(shape[1]), C.int(shape[2]), C.int(shape[3])),
} }
} }
@@ -508,18 +537,21 @@ func (t *Tensor) Permute(ctx ml.Context, shape ...int) ml.Tensor {
} }
return &Tensor{ return &Tensor{
b: t.b,
t: C.ggml_permute(ctx.(*Context).ctx, t.t, C.int(shape[0]), C.int(shape[1]), C.int(shape[2]), C.int(shape[3])), t: C.ggml_permute(ctx.(*Context).ctx, t.t, C.int(shape[0]), C.int(shape[1]), C.int(shape[2]), C.int(shape[3])),
} }
} }
func (t *Tensor) Rows(ctx ml.Context, t2 ml.Tensor) ml.Tensor { func (t *Tensor) Rows(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
return &Tensor{ return &Tensor{
b: t.b,
t: C.ggml_get_rows(ctx.(*Context).ctx, t.t, t2.(*Tensor).t), t: C.ggml_get_rows(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
} }
} }
func (t *Tensor) Copy(ctx ml.Context, t2 ml.Tensor) ml.Tensor { func (t *Tensor) Copy(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
return &Tensor{ return &Tensor{
b: t.b,
t: C.ggml_cpy(ctx.(*Context).ctx, t.t, t2.(*Tensor).t), t: C.ggml_cpy(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
} }
} }
@@ -528,18 +560,22 @@ func (t *Tensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor {
switch len(shape) { switch len(shape) {
case 1: case 1:
return &Tensor{ return &Tensor{
b: t.b,
t: C.ggml_reshape_1d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0])), t: C.ggml_reshape_1d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0])),
} }
case 2: case 2:
return &Tensor{ return &Tensor{
b: t.b,
t: C.ggml_reshape_2d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1])), t: C.ggml_reshape_2d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1])),
} }
case 3: case 3:
return &Tensor{ return &Tensor{
b: t.b,
t: C.ggml_reshape_3d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1]), C.int64_t(shape[2])), t: C.ggml_reshape_3d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1]), C.int64_t(shape[2])),
} }
case 4: case 4:
return &Tensor{ return &Tensor{
b: t.b,
t: C.ggml_reshape_4d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1]), C.int64_t(shape[2]), C.int64_t(shape[3])), t: C.ggml_reshape_4d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1]), C.int64_t(shape[2]), C.int64_t(shape[3])),
} }
default: default:
@@ -549,18 +585,21 @@ func (t *Tensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor {
func (t *Tensor) Scale(ctx ml.Context, s float64) ml.Tensor { func (t *Tensor) Scale(ctx ml.Context, s float64) ml.Tensor {
return &Tensor{ return &Tensor{
b: t.b,
t: C.ggml_scale(ctx.(*Context).ctx, t.t, (C.float)(s)), t: C.ggml_scale(ctx.(*Context).ctx, t.t, (C.float)(s)),
} }
} }
func (t *Tensor) Softmax(ctx ml.Context) ml.Tensor { func (t *Tensor) Softmax(ctx ml.Context) ml.Tensor {
return &Tensor{ return &Tensor{
b: t.b,
t: C.ggml_soft_max(ctx.(*Context).ctx, t.t), t: C.ggml_soft_max(ctx.(*Context).ctx, t.t),
} }
} }
func (t *Tensor) Tanh(ctx ml.Context) ml.Tensor { func (t *Tensor) Tanh(ctx ml.Context) ml.Tensor {
return &Tensor{ return &Tensor{
b: t.b,
t: C.ggml_tanh_inplace(ctx.(*Context).ctx, t.t), t: C.ggml_tanh_inplace(ctx.(*Context).ctx, t.t),
} }
} }
@@ -571,6 +610,7 @@ func (t *Tensor) Unpad(ctx ml.Context, shape ...int) ml.Tensor {
} }
return &Tensor{ return &Tensor{
b: t.b,
t: C.ggml_unpad(ctx.(*Context).ctx, t.t, C.int(shape[0]), C.int(shape[1]), C.int(shape[2]), C.int(shape[3])), t: C.ggml_unpad(ctx.(*Context).ctx, t.t, C.int(shape[0]), C.int(shape[1]), C.int(shape[2]), C.int(shape[3])),
} }
} }
@@ -579,10 +619,12 @@ func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
switch len(shape) { switch len(shape) {
case 1: case 1:
return &Tensor{ return &Tensor{
b: t.b,
t: C.ggml_view_1d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.size_t(offset)), t: C.ggml_view_1d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.size_t(offset)),
} }
case 3: case 3:
return &Tensor{ return &Tensor{
b: t.b,
t: C.ggml_view_2d(ctx.(*Context).ctx, t.t, t: C.ggml_view_2d(ctx.(*Context).ctx, t.t,
C.int64_t(shape[0]), C.int64_t(shape[2]), C.int64_t(shape[0]), C.int64_t(shape[2]),
C.size_t(shape[1]), C.size_t(shape[1]),
@@ -590,6 +632,7 @@ func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
} }
case 5: case 5:
return &Tensor{ return &Tensor{
b: t.b,
t: C.ggml_view_3d(ctx.(*Context).ctx, t.t, t: C.ggml_view_3d(ctx.(*Context).ctx, t.t,
C.int64_t(shape[0]), C.int64_t(shape[2]), C.int64_t(shape[4]), C.int64_t(shape[0]), C.int64_t(shape[2]), C.int64_t(shape[4]),
C.size_t(shape[1]), C.size_t(shape[3]), C.size_t(shape[1]), C.size_t(shape[3]),
@@ -597,6 +640,7 @@ func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
} }
case 7: case 7:
return &Tensor{ return &Tensor{
b: t.b,
t: C.ggml_view_4d(ctx.(*Context).ctx, t.t, t: C.ggml_view_4d(ctx.(*Context).ctx, t.t,
C.int64_t(shape[0]), C.int64_t(shape[2]), C.int64_t(shape[4]), C.int64_t(shape[6]), C.int64_t(shape[0]), C.int64_t(shape[2]), C.int64_t(shape[4]), C.int64_t(shape[6]),
C.size_t(shape[1]), C.size_t(shape[3]), C.size_t(shape[5]), C.size_t(shape[1]), C.size_t(shape[3]), C.size_t(shape[5]),
@@ -613,7 +657,7 @@ const (
func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim uint32, ropeBase, ropeScale float32) ml.Tensor { func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim uint32, ropeBase, ropeScale float32) ml.Tensor {
if ropeFactors == nil { if ropeFactors == nil {
ropeFactors = &Tensor{} ropeFactors = &Tensor{b: t.b}
} }
dequant := t.t dequant := t.t
@@ -622,6 +666,7 @@ func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDi
} }
return &Tensor{ return &Tensor{
b: t.b,
t: C.ggml_rope_ext( t: C.ggml_rope_ext(
ctx.(*Context).ctx, dequant, positionIDs.(*Tensor).t, ropeFactors.(*Tensor).t, ctx.(*Context).ctx, dequant, positionIDs.(*Tensor).t, ropeFactors.(*Tensor).t,
C.int(ropeDim), C.int(ropeDim),
@@ -639,18 +684,21 @@ func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDi
func (t *Tensor) GELU(ctx ml.Context) ml.Tensor { func (t *Tensor) GELU(ctx ml.Context) ml.Tensor {
return &Tensor{ return &Tensor{
b: t.b,
t: C.ggml_gelu_inplace(ctx.(*Context).ctx, t.t), t: C.ggml_gelu_inplace(ctx.(*Context).ctx, t.t),
} }
} }
func (t *Tensor) SILU(ctx ml.Context) ml.Tensor { func (t *Tensor) SILU(ctx ml.Context) ml.Tensor {
return &Tensor{ return &Tensor{
b: t.b,
t: C.ggml_silu_inplace(ctx.(*Context).ctx, t.t), t: C.ggml_silu_inplace(ctx.(*Context).ctx, t.t),
} }
} }
func (t *Tensor) Conv2D(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor { func (t *Tensor) Conv2D(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
return &Tensor{ return &Tensor{
b: t.b,
t: C.ggml_conv_2d(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.int(s0), C.int(s1), C.int(p0), C.int(p1), C.int(d0), C.int(d1)), t: C.ggml_conv_2d(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.int(s0), C.int(s1), C.int(p0), C.int(p1), C.int(d0), C.int(d1)),
} }
} }
@@ -661,13 +709,25 @@ func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask ml.T
kqMask = mask.(*Tensor).t kqMask = mask.(*Tensor).t
} }
kq := key.MulmatFullPrec(ctx, t) query := t.Permute(ctx, 0, 2, 1, 3)
kq = &Tensor{ key = key.Permute(ctx, 0, 2, 1, 3)
t: C.ggml_soft_max_ext(ctx.(*Context).ctx, kq.(*Tensor).t, kqMask, C.float(scale), 0),
}
kqv := value.Mulmat(ctx, kq) if t.b.flashAttention {
return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) value = value.Permute(ctx, 0, 2, 1, 3)
kqv := C.ggml_flash_attn_ext(ctx.(*Context).ctx, query.(*Tensor).t, key.(*Tensor).t, value.(*Tensor).t, kqMask, C.float(scale), 0, 0)
C.ggml_flash_attn_ext_set_prec(kqv, C.GGML_PREC_F32)
return &Tensor{b: t.b, t: kqv}
} else {
kq := key.MulmatFullPrec(ctx, query)
kq = &Tensor{
b: t.b,
t: C.ggml_soft_max_ext(ctx.(*Context).ctx, kq.(*Tensor).t, kqMask, C.float(scale), 0),
}
kqv := value.Mulmat(ctx, kq)
return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
}
} }
func (b *Backend) SystemInfo() string { func (b *Backend) SystemInfo() string {

View File

@@ -3,6 +3,7 @@ package nn
import ( import (
"fmt" "fmt"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml"
) )
@@ -11,40 +12,50 @@ import (
// //
// Parameters: // Parameters:
// - ctx: Context for tensor operations // - ctx: Context for tensor operations
// - query: Query tensor (Q) with shape [d_k, seq_len_q, heads] // - query: Query tensor (Q) with shape [d_k, heads, seq_len_q]
// - key: Key tensor (K) with shape [d_k, seq_len_k, kv_heads] // - key: Key tensor (K) with shape [d_k, kv_heads, seq_len_k], can be nil to read from cache only
// - value: Value tensor (V) with shape [seq_len_k, d_v, kv_heads] // - value: Value tensor (V) with shape [d_v, kv_heads, seq_len_k], can be nil to read from cache only
// - mask: Optional attention mask that is added to the attention score. If
// provided, should broadcast to [seq_len_k, seq_len_q, heads]
// - scale: Scaling factor, typically 1/√d_k where d_k is the key dimension // - scale: Scaling factor, typically 1/√d_k where d_k is the key dimension
// - cache: KV cache to store key/value and get past history, can be nil to only use provided key/value
// //
// Returns: // Returns:
// //
// Attention output with shape [d_v, heads, seq_len_q] // Attention output with shape [d_v, heads, seq_len_q]
func Attention(ctx ml.Context, query, key, value, mask ml.Tensor, scale float64) ml.Tensor { func Attention(ctx ml.Context, query, key, value ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
if query.Dim(0) != key.Dim(0) { if key != nil && value != nil {
panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0))) if query.Dim(0) != key.Dim(0) {
panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0)))
}
if key.Dim(1) != value.Dim(1) {
panic(fmt.Errorf("kv_heads in attention operation does not match between key(%v) and value(%v)", key.Dim(1), value.Dim(1)))
}
if key.Dim(2) != value.Dim(2) {
panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and value(%v)", key.Dim(2), value.Dim(2)))
}
if cache != nil {
cache.Put(ctx, key, value)
}
} else if cache == nil {
panic("key & value tensors must be provided if cache is nil")
} }
if mask != nil && query.Dim(1) != mask.Dim(1) { var mask ml.Tensor
panic(fmt.Errorf("seq_len_q in attention operation does not match between query(%v) and mask(%v)", query.Dim(1), mask.Dim(1))) if cache != nil {
key, value, mask = cache.Get(ctx)
} }
if key.Dim(1) != value.Dim(0) { // Only use the fast SDPA implementation if we have a cache, since that's what
panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and value(%v)", key.Dim(1), value.Dim(0))) // will do any expected backend-specific transformations for us
} if sdpa, ok := query.(ml.ScaledDotProductAttention); ok && cache != nil {
if mask != nil && key.Dim(1) != mask.Dim(0) {
panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and mask(%v)", key.Dim(1), mask.Dim(0)))
}
if key.Dim(2) != value.Dim(2) {
panic(fmt.Errorf("kv_heads in attention operation does not match between key(%v) and value(%v)", key.Dim(2), value.Dim(2)))
}
if sdpa, ok := query.(ml.ScaledDotProductAttention); ok {
return sdpa.ScaledDotProductAttention(ctx, key, value, mask, scale) return sdpa.ScaledDotProductAttention(ctx, key, value, mask, scale)
} else { } else {
query = query.Permute(ctx, 0, 2, 1, 3)
key = key.Permute(ctx, 0, 2, 1, 3)
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
kq := key.MulmatFullPrec(ctx, query) kq := key.MulmatFullPrec(ctx, query)
kq = kq.Scale(ctx, scale) kq = kq.Scale(ctx, scale)

View File

@@ -81,15 +81,8 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
v := sa.Value.Forward(ctx, hiddenState) v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize) v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
cache.Put(ctx, k, v)
k, v, mask := cache.Get(ctx)
q = q.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
k = k.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
v = v.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
scaleFactor := 1.0 / math.Sqrt(float64(headDim)) scaleFactor := 1.0 / math.Sqrt(float64(headDim))
kqv := nn.Attention(ctx, q, k, v, mask, scaleFactor) kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache)
kqv = kqv.Reshape(ctx, opts.hiddenSize, batchSize) kqv = kqv.Reshape(ctx, opts.hiddenSize, batchSize)
return sa.Output.Forward(ctx, kqv) return sa.Output.Forward(ctx, kqv)

View File

@@ -43,7 +43,9 @@ func New(c ml.Config) (model.Model, error) {
TextModel: newTextModel(c), TextModel: newTextModel(c),
} }
m.Cache = kvcache.NewWrapperCache(kvcache.NewEncoderCache(), kvcache.NewCausalCache(m.TextModel.Shift)) encoderCache := kvcache.NewEncoderCache()
encoderCache.SetConfig(ml.CacheConfig{})
m.Cache = kvcache.NewWrapperCache(encoderCache, kvcache.NewCausalCache(m.TextModel.Shift))
return &m, nil return &m, nil
} }

View File

@@ -31,22 +31,15 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ m
value := sa.Value.Forward(ctx, hiddenState) value := sa.Value.Forward(ctx, hiddenState)
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize) value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
cache.Put(ctx, key, value)
key, value, mask := cache.Get(ctx)
query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
scaleFactor := 1.0 / math.Sqrt(float64(headDim)) scaleFactor := 1.0 / math.Sqrt(float64(headDim))
attention := nn.Attention(ctx, query, key, value, mask, scaleFactor) attention := nn.Attention(ctx, query, key, value, scaleFactor, cache)
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize) attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
return sa.Output.Forward(ctx, attention) return sa.Output.Forward(ctx, attention)
} }
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
// This will only get called for layers in the cache, which are just the self attention layers // This will only get called for layers in the causal cache, which are just the self attention layers
return key.RoPE(ctx, shift, m.RopeFactors, m.ropeDim, m.ropeBase, m.ropeScale), nil return key.RoPE(ctx, shift, m.RopeFactors, m.ropeDim, m.ropeBase, m.ropeScale), nil
} }
@@ -107,7 +100,7 @@ func (ca *TextCrossAttention) Forward(ctx ml.Context, hiddenState, crossAttentio
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize) query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
query = ca.QueryNorm.Forward(ctx, query, opts.eps) query = ca.QueryNorm.Forward(ctx, query, opts.eps)
var key, value, mask ml.Tensor var key, value ml.Tensor
if crossAttentionStates != nil { if crossAttentionStates != nil {
numVisionTokens, numTiles := crossAttentionStates.Dim(1), crossAttentionStates.Dim(2) numVisionTokens, numTiles := crossAttentionStates.Dim(1), crossAttentionStates.Dim(2)
@@ -119,16 +112,23 @@ func (ca *TextCrossAttention) Forward(ctx ml.Context, hiddenState, crossAttentio
value = value.Reshape(ctx, headDim, opts.numKVHeads, numVisionTokens*numTiles) value = value.Reshape(ctx, headDim, opts.numKVHeads, numVisionTokens*numTiles)
cache.Put(ctx, key, value) cache.Put(ctx, key, value)
} else {
key, value, mask = cache.Get(ctx)
} }
query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) key, value, _ = cache.Get(ctx)
key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
scaleFactor := 1.0 / math.Sqrt(float64(headDim)) scaleFactor := 1.0 / math.Sqrt(float64(headDim))
attention := nn.Attention(ctx, query, key, value, mask, scaleFactor)
query = query.Permute(ctx, 0, 2, 1, 3)
key = key.Permute(ctx, 0, 2, 1, 3)
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
kq := key.MulmatFullPrec(ctx, query)
kq = kq.Scale(ctx, scaleFactor)
kq = kq.Softmax(ctx)
kqv := value.Mulmat(ctx, kq)
attention := kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize) attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
return ca.Output.Forward(ctx, attention) return ca.Output.Forward(ctx, attention)

View File

@@ -818,7 +818,7 @@ func Execute(args []string) error {
batchSize := fs.Int("batch-size", 512, "Batch size") batchSize := fs.Int("batch-size", 512, "Batch size")
numGPULayers := fs.Int("n-gpu-layers", 0, "Number of layers to offload to GPU") numGPULayers := fs.Int("n-gpu-layers", 0, "Number of layers to offload to GPU")
mainGPU := fs.Int("main-gpu", 0, "Main GPU") mainGPU := fs.Int("main-gpu", 0, "Main GPU")
_ = fs.Bool("flash-attn", false, "Enable flash attention") flashAttention := fs.Bool("flash-attn", false, "Enable flash attention")
kvSize := fs.Int("ctx-size", 2048, "Context (or KV cache) size") kvSize := fs.Int("ctx-size", 2048, "Context (or KV cache) size")
kvCacheType := fs.String("kv-cache-type", "", "quantization type for KV cache (default: f16)") kvCacheType := fs.String("kv-cache-type", "", "quantization type for KV cache (default: f16)")
port := fs.Int("port", 8080, "Port to expose the server on") port := fs.Int("port", 8080, "Port to expose the server on")
@@ -863,7 +863,6 @@ func Execute(args []string) error {
} }
// TODO(jessegross): Parameters that need to be implemented: // TODO(jessegross): Parameters that need to be implemented:
// flash-attn
// no-mmap // no-mmap
// mlock // mlock
@@ -878,10 +877,11 @@ func Execute(args []string) error {
} }
params := ml.BackendParams{ params := ml.BackendParams{
NumThreads: *threads, NumThreads: *threads,
NumGPULayers: *numGPULayers, NumGPULayers: *numGPULayers,
MainGPU: *mainGPU, MainGPU: *mainGPU,
TensorSplit: tensorSplitFloats, TensorSplit: tensorSplitFloats,
FlashAttention: *flashAttention,
} }
server.ready.Add(1) server.ready.Add(1)

View File

@@ -24,6 +24,7 @@ import (
"os" "os"
"path/filepath" "path/filepath"
"runtime" "runtime"
"slices"
"strconv" "strconv"
"strings" "strings"
"sync/atomic" "sync/atomic"
@@ -53,7 +54,7 @@ var (
// ErrMissingModel is returned when the model part of a name is missing // ErrMissingModel is returned when the model part of a name is missing
// or invalid. // or invalid.
ErrNameInvalid = errors.New("invalid name; must be in the form {scheme://}{host/}{namespace/}[model]{:tag}{@digest}") ErrNameInvalid = errors.New("invalid or missing name")
// ErrCached is passed to [Trace.PushUpdate] when a layer already // ErrCached is passed to [Trace.PushUpdate] when a layer already
// exists. It is a non-fatal error and is never returned by [Registry.Push]. // exists. It is a non-fatal error and is never returned by [Registry.Push].
@@ -205,10 +206,18 @@ type Registry struct {
// It is only used when a layer is larger than [MaxChunkingThreshold]. // It is only used when a layer is larger than [MaxChunkingThreshold].
MaxChunkSize int64 MaxChunkSize int64
// NameMask, if set, is the name used to convert non-fully qualified // Mask, if set, is the name used to convert non-fully qualified
// names to fully qualified names. If empty, the default mask // names to fully qualified names. If empty, the default mask
// ("registry.ollama.ai/library/_:latest") is used. // ("registry.ollama.ai/library/_:latest") is used.
NameMask string Mask string
}
func (r *Registry) completeName(name string) names.Name {
mask := defaultMask
if r.Mask != "" {
mask = names.Parse(r.Mask)
}
return names.Merge(names.Parse(name), mask)
} }
// DefaultRegistry returns a new Registry configured from the environment. The // DefaultRegistry returns a new Registry configured from the environment. The
@@ -243,52 +252,6 @@ func DefaultRegistry() (*Registry, error) {
return &rc, nil return &rc, nil
} }
type PushParams struct {
// From is an optional destination name for the model. If empty, the
// destination name is the same as the source name.
From string
}
// parseName parses name using [names.ParseExtended] and then merges the name with the
// default name, and checks that the name is fully qualified. If a digest is
// present, it parse and returns it with the other fields as their zero values.
//
// It returns an error if the name is not fully qualified, or if the digest, if
// any, is invalid.
//
// The scheme is returned as provided by [names.ParseExtended].
func parseName(s, mask string) (scheme string, n names.Name, d blob.Digest, err error) {
maskName := defaultMask
if mask != "" {
maskName = names.Parse(mask)
if !maskName.IsFullyQualified() {
return "", names.Name{}, blob.Digest{}, fmt.Errorf("invalid name mask: %s", mask)
}
}
scheme, n, ds := names.ParseExtended(s)
if !n.IsValid() {
return "", names.Name{}, blob.Digest{}, fmt.Errorf("%w: %q", ErrNameInvalid, s)
}
n = names.Merge(n, maskName)
if ds != "" {
// Digest is present. Validate it.
d, err = blob.ParseDigest(ds)
if err != nil {
return "", names.Name{}, blob.Digest{}, err
}
}
// The name check is deferred until after the digest check because we
// say that digests take precedence over names, and so should there
// errors when being parsed.
if !n.IsFullyQualified() {
return "", names.Name{}, blob.Digest{}, fmt.Errorf("%w: %q", ErrNameInvalid, s)
}
scheme = cmp.Or(scheme, "https")
return scheme, n, d, nil
}
func (r *Registry) maxStreams() int { func (r *Registry) maxStreams() int {
n := cmp.Or(r.MaxStreams, runtime.GOMAXPROCS(0)) n := cmp.Or(r.MaxStreams, runtime.GOMAXPROCS(0))
@@ -308,6 +271,12 @@ func (r *Registry) maxChunkSize() int64 {
return cmp.Or(r.MaxChunkSize, DefaultMaxChunkSize) return cmp.Or(r.MaxChunkSize, DefaultMaxChunkSize)
} }
type PushParams struct {
// From is an optional destination name for the model. If empty, the
// destination name is the same as the source name.
From string
}
// Push pushes the model with the name in the cache to the remote registry. // Push pushes the model with the name in the cache to the remote registry.
func (r *Registry) Push(ctx context.Context, c *blob.DiskCache, name string, p *PushParams) error { func (r *Registry) Push(ctx context.Context, c *blob.DiskCache, name string, p *PushParams) error {
if p == nil { if p == nil {
@@ -337,7 +306,7 @@ func (r *Registry) Push(ctx context.Context, c *blob.DiskCache, name string, p *
t := traceFromContext(ctx) t := traceFromContext(ctx)
scheme, n, _, err := parseName(name, r.NameMask) scheme, n, _, err := parseName(name, r.Mask)
if err != nil { if err != nil {
// This should never happen since ResolveLocal should have // This should never happen since ResolveLocal should have
// already validated the name. // already validated the name.
@@ -431,7 +400,7 @@ func canRetry(err error) bool {
// typically slower than splitting the model up across layers, and is mostly // typically slower than splitting the model up across layers, and is mostly
// utilized for layers of type equal to "application/vnd.ollama.image". // utilized for layers of type equal to "application/vnd.ollama.image".
func (r *Registry) Pull(ctx context.Context, c *blob.DiskCache, name string) error { func (r *Registry) Pull(ctx context.Context, c *blob.DiskCache, name string) error {
scheme, n, _, err := parseName(name, r.NameMask) scheme, n, _, err := parseName(name, r.Mask)
if err != nil { if err != nil {
return err return err
} }
@@ -582,9 +551,9 @@ func (r *Registry) Pull(ctx context.Context, c *blob.DiskCache, name string) err
// Unlink is like [blob.DiskCache.Unlink], but makes name fully qualified // Unlink is like [blob.DiskCache.Unlink], but makes name fully qualified
// before attempting to unlink the model. // before attempting to unlink the model.
func (r *Registry) Unlink(c *blob.DiskCache, name string) (ok bool, _ error) { func (r *Registry) Unlink(c *blob.DiskCache, name string) (ok bool, _ error) {
_, n, _, err := parseName(name, r.NameMask) n := r.completeName(name)
if err != nil { if !n.IsFullyQualified() {
return false, err return false, fmt.Errorf("%w: %q", ErrNameInvalid, name)
} }
return c.Unlink(n.String()) return c.Unlink(n.String())
} }
@@ -658,9 +627,9 @@ type Layer struct {
} }
// ResolveLocal resolves a name to a Manifest in the local cache. The name is // ResolveLocal resolves a name to a Manifest in the local cache. The name is
// parsed using [names.ParseExtended] but the scheme is ignored. // parsed using [names.Split] but the scheme is ignored.
func (r *Registry) ResolveLocal(c *blob.DiskCache, name string) (*Manifest, error) { func (r *Registry) ResolveLocal(c *blob.DiskCache, name string) (*Manifest, error) {
_, n, d, err := parseName(name, r.NameMask) _, n, d, err := parseName(name, r.Mask)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -686,7 +655,7 @@ func (r *Registry) ResolveLocal(c *blob.DiskCache, name string) (*Manifest, erro
// Resolve resolves a name to a Manifest in the remote registry. // Resolve resolves a name to a Manifest in the remote registry.
func (r *Registry) Resolve(ctx context.Context, name string) (*Manifest, error) { func (r *Registry) Resolve(ctx context.Context, name string) (*Manifest, error) {
scheme, n, d, err := parseName(name, r.NameMask) scheme, n, d, err := parseName(name, r.Mask)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -869,3 +838,69 @@ func maybeUnexpectedEOF(err error) error {
} }
return err return err
} }
type publicError struct {
wrapped error
message string
}
func withPublicMessagef(err error, message string, args ...any) error {
return publicError{wrapped: err, message: fmt.Sprintf(message, args...)}
}
func (e publicError) Error() string { return e.message }
func (e publicError) Unwrap() error { return e.wrapped }
var supportedSchemes = []string{
"http",
"https",
"https+insecure",
}
var supportedSchemesMessage = fmt.Sprintf("supported schemes are %v", strings.Join(supportedSchemes, ", "))
// parseName parses and validates an extended name, returning the scheme, name,
// and digest.
//
// If the scheme is empty, scheme will be "https". If an unsupported scheme is
// given, [ErrNameInvalid] wrapped with a display friendly message is returned.
//
// If the digest is invalid, [ErrNameInvalid] wrapped with a display friendly
// message is returned.
//
// If the name is not, once merged with the mask, fully qualified,
// [ErrNameInvalid] wrapped with a display friendly message is returned.
func parseName(s string, mask string) (scheme string, _ names.Name, _ blob.Digest, _ error) {
scheme, name, digest := names.Split(s)
scheme = cmp.Or(scheme, "https")
if !slices.Contains(supportedSchemes, scheme) {
err := withPublicMessagef(ErrNameInvalid, "unsupported scheme: %q: %s", scheme, supportedSchemesMessage)
return "", names.Name{}, blob.Digest{}, err
}
var d blob.Digest
if digest != "" {
var err error
d, err = blob.ParseDigest(digest)
if err != nil {
err = withPublicMessagef(ErrNameInvalid, "invalid digest: %q", digest)
return "", names.Name{}, blob.Digest{}, err
}
if name == "" {
// We have can resolve a manifest from a digest only,
// so skip name validation and return the scheme and
// digest.
return scheme, names.Name{}, d, nil
}
}
maskName := defaultMask
if mask != "" {
maskName = names.Parse(mask)
}
n := names.Merge(names.Parse(name), maskName)
if !n.IsFullyQualified() {
return "", names.Name{}, blob.Digest{}, fmt.Errorf("%w: %q", ErrNameInvalid, s)
}
return scheme, n, d, nil
}

View File

@@ -84,14 +84,14 @@ func newClient(t *testing.T, h http.HandlerFunc) (*Registry, *blob.DiskCache) {
} }
} }
rc := &Registry{ r := &Registry{
HTTPClient: &http.Client{ HTTPClient: &http.Client{
Transport: recordRoundTripper(h), Transport: recordRoundTripper(h),
}, },
} }
link := func(name string, manifest string) { link := func(name string, manifest string) {
_, n, _, err := parseName(name, rc.NameMask) _, n, _, err := parseName(name, r.Mask)
if err != nil { if err != nil {
panic(err) panic(err)
} }
@@ -122,7 +122,7 @@ func newClient(t *testing.T, h http.HandlerFunc) (*Registry, *blob.DiskCache) {
commit("sizemismatch", mklayer("exists"), &Layer{Digest: blob.DigestFromBytes("present"), Size: 499}) commit("sizemismatch", mklayer("exists"), &Layer{Digest: blob.DigestFromBytes("present"), Size: 499})
link("invalid", "!!!!!") link("invalid", "!!!!!")
return rc, c return r, c
} }
func okHandler(w http.ResponseWriter, r *http.Request) { func okHandler(w http.ResponseWriter, r *http.Request) {
@@ -145,29 +145,6 @@ func importBytes(t *testing.T, c *blob.DiskCache, data string) blob.Digest {
return d return d
} }
func TestRegistryPushInvalidNames(t *testing.T) {
rc, c := newClient(t, nil)
cases := []struct {
name string
err error
}{
{"", ErrNameInvalid},
{"@", ErrNameInvalid},
{"@x", blob.ErrInvalidDigest},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
// Create a new registry and push a new image.
err := rc.Push(t.Context(), c, tt.name, nil)
if !errors.Is(err, tt.err) {
t.Errorf("err = %v; want %v", err, tt.err)
}
})
}
}
func withTraceUnexpected(ctx context.Context) (context.Context, *Trace) { func withTraceUnexpected(ctx context.Context) (context.Context, *Trace) {
t := &Trace{Update: func(*Layer, int64, error) { panic("unexpected") }} t := &Trace{Update: func(*Layer, int64, error) { panic("unexpected") }}
return WithTrace(ctx, t), t return WithTrace(ctx, t), t
@@ -622,7 +599,7 @@ func TestInsecureSkipVerify(t *testing.T) {
})) }))
defer s.Close() defer s.Close()
const name = "ollama.com/library/insecure" const name = "library/insecure"
var rc Registry var rc Registry
url := fmt.Sprintf("https://%s/%s", s.Listener.Addr(), name) url := fmt.Sprintf("https://%s/%s", s.Listener.Addr(), name)
@@ -724,3 +701,38 @@ func TestErrorUnmarshal(t *testing.T) {
}) })
} }
} }
// TestParseNameErrors tests that parseName returns errors messages with enough
// detail for users to debug naming issues they may encounter. Previous to this
// test, the error messages were not very helpful and each problem was reported
// as the same message.
//
// It is only for testing error messages, not that all invalids and valids are
// covered. Those are in other tests for names.Name and blob.Digest.
func TestParseNameErrors(t *testing.T) {
cases := []struct {
name string
err error
want string
}{
{"x", nil, ""},
{"x@", nil, ""},
{"", ErrNameInvalid, `invalid or missing name: ""`},
{"://", ErrNameInvalid, `invalid or missing name: "://"`},
{"x://", ErrNameInvalid, `unsupported scheme: "x": supported schemes are http, https, https+insecure`},
{"@sha123-1234", ErrNameInvalid, `invalid digest: "sha123-1234"`},
{"x@sha123-1234", ErrNameInvalid, `invalid digest: "sha123-1234"`},
}
for _, tt := range cases {
_, _, _, err := parseName(tt.name, DefaultMask)
if !errors.Is(err, tt.err) {
t.Errorf("[%s]: err = %v; want %v", tt.name, err, tt.err)
}
if err != nil && !strings.Contains(err.Error(), tt.want) {
t.Errorf("[%s]: err =\n\t%v\nwant\n\t%v", tt.name, err, tt.want)
}
}
}

View File

@@ -8,7 +8,7 @@ import (
"github.com/ollama/ollama/server/internal/internal/stringsx" "github.com/ollama/ollama/server/internal/internal/stringsx"
) )
const MaxNameLength = 50 + 1 + 50 + 1 + 50 // <namespace>/<model>:<tag> const MaxNameLength = 350 + 1 + 80 + 1 + 80 + 1 + 80 // <host>/<namespace>/<model>:<tag>
type Name struct { type Name struct {
// Make incomparable to enfoce use of Compare / Equal for // Make incomparable to enfoce use of Compare / Equal for
@@ -25,19 +25,12 @@ type Name struct {
// format of a valid name string is: // format of a valid name string is:
// //
// s: // s:
// { host } "/" { namespace } "/" { model } ":" { tag } "@" { digest }
// { host } "/" { namespace } "/" { model } ":" { tag } // { host } "/" { namespace } "/" { model } ":" { tag }
// { host } "/" { namespace } "/" { model } "@" { digest }
// { host } "/" { namespace } "/" { model } // { host } "/" { namespace } "/" { model }
// { namespace } "/" { model } ":" { tag } "@" { digest }
// { namespace } "/" { model } ":" { tag } // { namespace } "/" { model } ":" { tag }
// { namespace } "/" { model } "@" { digest }
// { namespace } "/" { model } // { namespace } "/" { model }
// { model } ":" { tag } "@" { digest }
// { model } ":" { tag } // { model } ":" { tag }
// { model } "@" { digest }
// { model } // { model }
// "@" { digest }
// host: // host:
// pattern: { alphanum | "_" } { alphanum | "_" | "-" | "." | ":" }* // pattern: { alphanum | "_" } { alphanum | "_" | "-" | "." | ":" }*
// length: [1, 350] // length: [1, 350]
@@ -50,9 +43,6 @@ type Name struct {
// tag: // tag:
// pattern: { alphanum | "_" } { alphanum | "-" | "_" | "." }* // pattern: { alphanum | "_" } { alphanum | "-" | "_" | "." }*
// length: [1, 80] // length: [1, 80]
// digest:
// pattern: { alphanum | "_" } { alphanum | "-" | ":" }*
// length: [1, 80]
// //
// The name returned is not guaranteed to be valid. If it is not valid, the // The name returned is not guaranteed to be valid. If it is not valid, the
// field values are left in an undefined state. Use [Name.IsValid] to check // field values are left in an undefined state. Use [Name.IsValid] to check
@@ -82,23 +72,17 @@ func Parse(s string) Name {
} }
} }
// ParseExtended parses and returns any scheme, Name, and digest from from s in // Split splits an extended name string into its scheme, name, and digest
// the the form [scheme://][name][@digest]. All parts are optional. // parts.
//
// If the scheme is present, it must be followed by "://". The digest is
// prefixed by "@" and comes after the name. The name is parsed using [Parse].
//
// The scheme and digest are stripped before the name is parsed by [Parse].
//
// For convience, the scheme is never empty. If the scheme is not present, the
// returned scheme is "https".
// //
// Examples: // Examples:
// //
// http://ollama.com/bmizerany/smol:latest@digest // http://ollama.com/bmizerany/smol:latest@digest
// https://ollama.com/bmizerany/smol:latest // https://ollama.com/bmizerany/smol:latest
// ollama.com/bmizerany/smol:latest@digest // returns "https" scheme. // ollama.com/bmizerany/smol:latest@digest // returns "https" scheme.
func ParseExtended(s string) (scheme string, _ Name, digest string) { // model@digest
// @digest
func Split(s string) (scheme, name, digest string) {
i := strings.Index(s, "://") i := strings.Index(s, "://")
if i >= 0 { if i >= 0 {
scheme = s[:i] scheme = s[:i]
@@ -109,21 +93,7 @@ func ParseExtended(s string) (scheme string, _ Name, digest string) {
digest = s[i+1:] digest = s[i+1:]
s = s[:i] s = s[:i]
} }
return scheme, Parse(s), digest return scheme, s, digest
}
func FormatExtended(scheme string, n Name, digest string) string {
var b strings.Builder
if scheme != "" {
b.WriteString(scheme)
b.WriteString("://")
}
b.WriteString(n.String())
if digest != "" {
b.WriteByte('@')
b.WriteString(digest)
}
return b.String()
} }
// Merge merges two names into a single name. Non-empty host, namespace, and // Merge merges two names into a single name. Non-empty host, namespace, and
@@ -141,39 +111,68 @@ func Merge(a, b Name) Name {
// IsValid returns true if the name is valid. // IsValid returns true if the name is valid.
func (n Name) IsValid() bool { func (n Name) IsValid() bool {
if n.h != "" && !isValidHost(n.h) { if n.h != "" && !isValidPart(partHost, n.h) {
return false return false
} }
if n.n != "" && !isValidNamespace(n.n) { if n.n != "" && !isValidPart(partNamespace, n.n) {
return false return false
} }
if n.m != "" && !isValidModel(n.m) { if n.t != "" && !isValidPart(partTag, n.t) {
return false return false
} }
if n.t != "" && !isValidTag(n.t) {
return false // at bare minimum, model must be present and valid
} return n.m != "" && isValidPart(partModel, n.m)
return true
} }
func (n Name) IsFullyQualified() bool { func (n Name) IsFullyQualified() bool {
return n.IsValid() && n.h != "" && n.n != "" && n.m != "" && n.t != "" return n.IsValid() && n.h != "" && n.n != "" && n.m != "" && n.t != ""
} }
func isValidHost(_ string) bool { const (
return true // TODO: implement partHost = iota
partNamespace
partModel
partTag
)
func isValidPart(kind int, s string) bool {
maxlen := 80
if kind == partHost {
maxlen = 350
}
if len(s) > maxlen {
return false
}
for i := range s {
if i == 0 {
if !isAlphanumericOrUnderscore(s[i]) {
return false
}
continue
}
switch s[i] {
case '_', '-':
case '.':
if kind == partNamespace {
return false
}
case ':':
if kind != partHost {
return false
}
default:
if !isAlphanumericOrUnderscore(s[i]) {
return false
}
}
}
return true
} }
func isValidNamespace(_ string) bool { func isAlphanumericOrUnderscore(c byte) bool {
return true // TODO: implement return c >= 'A' && c <= 'Z' || c >= 'a' && c <= 'z' || c >= '0' && c <= '9' || c == '_'
}
func isValidModel(_ string) bool {
return true // TODO: implement
}
func isValidTag(_ string) bool {
return true // TODO: implement
} }
func (n Name) Host() string { return n.h } func (n Name) Host() string { return n.h }

View File

@@ -81,15 +81,11 @@ func TestParseExtended(t *testing.T) {
} }
for _, tt := range cases { for _, tt := range cases {
t.Run(tt.in, func(t *testing.T) { t.Run(tt.in, func(t *testing.T) {
scheme, name, digest := ParseExtended(tt.in) scheme, name, digest := Split(tt.in)
if scheme != tt.wantScheme || name.Compare(tt.wantName) != 0 || digest != tt.wantDigest { n := Parse(name)
if scheme != tt.wantScheme || n.Compare(tt.wantName) != 0 || digest != tt.wantDigest {
t.Errorf("ParseExtended(%q) = %q, %#v, %q, want %q, %#v, %q", tt.in, scheme, name, digest, tt.wantScheme, tt.wantName, tt.wantDigest) t.Errorf("ParseExtended(%q) = %q, %#v, %q, want %q, %#v, %q", tt.in, scheme, name, digest, tt.wantScheme, tt.wantName, tt.wantDigest)
} }
// Round trip
if got := FormatExtended(scheme, name, digest); got != tt.in {
t.Errorf("FormatExtended(%q, %q, %q) = %q", scheme, name, digest, got)
}
}) })
} }
} }
@@ -150,3 +146,75 @@ func BenchmarkParseName(b *testing.B) {
junkName = Parse("h/n/m:t") junkName = Parse("h/n/m:t")
} }
} }
const (
part80 = "88888888888888888888888888888888888888888888888888888888888888888888888888888888"
part350 = "33333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333"
)
var testCases = map[string]bool{ // name -> valid
"": false,
"_why/_the/_lucky:_stiff": true,
// minimal
"h/n/m:t": true,
"host/namespace/model:tag": true,
"host/namespace/model": true,
"namespace/model": true,
"model": true,
// long (but valid)
part80 + "/" + part80 + "/" + part80 + ":" + part80: true,
part350 + "/" + part80 + "/" + part80 + ":" + part80: true,
// too long
part80 + "/" + part80 + "/" + part80 + ":" + part350: false,
"x" + part350 + "/" + part80 + "/" + part80 + ":" + part80: false,
"h/nn/mm:t": true, // bare minimum part sizes
// unqualified
"m": true,
"n/m:": true,
"h/n/m": true,
"@t": false,
"m@d": false,
// invalids
"^": false,
"mm:": true,
"/nn/mm": true,
"//": false, // empty model
"//mm": true,
"hh//": false, // empty model
"//mm:@": false,
"00@": false,
"@": false,
// not starting with alphanum
"-hh/nn/mm:tt": false,
"hh/-nn/mm:tt": false,
"hh/nn/-mm:tt": false,
"hh/nn/mm:-tt": false,
// smells like a flag
"-h": false,
// hosts
"host:https/namespace/model:tag": true,
// colon in non-host part before tag
"host/name:space/model:tag": false,
}
func TestParseNameValidation(t *testing.T) {
for s, valid := range testCases {
got := Parse(s)
if got.IsValid() != valid {
t.Logf("got: %v", got)
t.Errorf("Parse(%q).IsValid() = %v; want !%[2]v", s, got.IsValid())
}
}
}

View File

@@ -204,7 +204,7 @@ func (s *Local) handleDelete(_ http.ResponseWriter, r *http.Request) error {
return err return err
} }
if !ok { if !ok {
return &serverError{404, "manifest_not_found", "manifest not found"} return &serverError{404, "not_found", "model not found"}
} }
return nil return nil
} }

View File

@@ -109,11 +109,8 @@ func TestServerDelete(t *testing.T) {
got = s.send(t, "DELETE", "/api/delete", ``) got = s.send(t, "DELETE", "/api/delete", ``)
checkErrorResponse(t, got, 400, "bad_request", "empty request body") checkErrorResponse(t, got, 400, "bad_request", "empty request body")
got = s.send(t, "DELETE", "/api/delete", `{"model": "!"}`)
checkErrorResponse(t, got, 404, "manifest_not_found", "not found")
got = s.send(t, "DELETE", "/api/delete", `{"model": "://"}`) got = s.send(t, "DELETE", "/api/delete", `{"model": "://"}`)
checkErrorResponse(t, got, 400, "bad_request", "invalid name") checkErrorResponse(t, got, 400, "bad_request", "invalid or missing name")
got = s.send(t, "DELETE", "/unknown_path", `{}`) // valid body got = s.send(t, "DELETE", "/unknown_path", `{}`) // valid body
checkErrorResponse(t, got, 404, "not_found", "not found") checkErrorResponse(t, got, 404, "not_found", "not found")