mirror of
https://github.com/likelovewant/ollama-for-amd.git
synced 2025-12-21 22:33:56 +00:00
Merge branch 'ollama:main' into main
This commit is contained in:
@@ -122,9 +122,11 @@ if(CMAKE_HIP_COMPILER)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-hip)
|
||||
|
||||
if (WIN32)
|
||||
target_compile_definitions(ggml-hip PRIVATE GGML_CUDA_NO_PEER_COPY=1)
|
||||
target_compile_definitions(ggml-hip PRIVATE GGML_CUDA_NO_PEER_COPY)
|
||||
endif()
|
||||
|
||||
target_compile_definitions(ggml-hip PRIVATE GGML_HIP_NO_VMM)
|
||||
|
||||
set(OLLAMA_HIP_INSTALL_DIR ${OLLAMA_INSTALL_DIR}/rocm)
|
||||
install(TARGETS ggml-hip
|
||||
RUNTIME_DEPENDENCIES
|
||||
|
||||
@@ -29,6 +29,17 @@ type Cache interface {
|
||||
// cache implementation used.
|
||||
Put(ctx ml.Context, key, value ml.Tensor)
|
||||
|
||||
// SetConfig controls optimizations (mostly backend-specific) that may transform
|
||||
// the output of the cache to work better with specific kernels. If not called,
|
||||
// the backend settings will be used. This works well when calling Attention.
|
||||
//
|
||||
// The config can be overridden by models, especially if they require vanilla
|
||||
// output when implementing their own version of attention. To do this, pass
|
||||
// an empty ml.CacheConfig.
|
||||
//
|
||||
// Most models will not need to use this.
|
||||
SetConfig(ml.CacheConfig)
|
||||
|
||||
// ** cache management **
|
||||
|
||||
// Init sets up runtime parameters
|
||||
|
||||
@@ -22,6 +22,9 @@ type Causal struct {
|
||||
Capacity int32
|
||||
windowSize int32
|
||||
|
||||
// config controls mostly backend-specific optimizations
|
||||
config *ml.CacheConfig
|
||||
|
||||
// ** current forward pass **
|
||||
|
||||
// the active layer for Get and Put
|
||||
@@ -75,14 +78,42 @@ func NewSWACache(windowSize int32, shift shiftFn) *Causal {
|
||||
}
|
||||
|
||||
func (c *Causal) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
|
||||
if c.config == nil {
|
||||
var config ml.CacheConfig
|
||||
if cc, ok := backend.(ml.BackendCacheConfig); ok {
|
||||
config = cc.CacheConfig()
|
||||
}
|
||||
c.config = &config
|
||||
}
|
||||
|
||||
if c.config.CachePadding == 0 {
|
||||
c.config.CachePadding = 1
|
||||
}
|
||||
|
||||
if c.config.MaskBatchPadding == 0 {
|
||||
c.config.MaskBatchPadding = 1
|
||||
}
|
||||
|
||||
if c.config.MaskDType == ml.DTypeOther {
|
||||
c.config.MaskDType = ml.DTypeF32
|
||||
}
|
||||
|
||||
c.DType = dtype
|
||||
c.Capacity = capacity
|
||||
c.cells = make([]cacheCell, capacity)
|
||||
c.Capacity = int32(roundUp(int(capacity), c.config.CachePadding))
|
||||
c.cells = make([]cacheCell, c.Capacity)
|
||||
c.cellRanges = make(map[int]cellRange)
|
||||
c.backend = backend
|
||||
c.cacheCtx = backend.NewContext()
|
||||
}
|
||||
|
||||
func (c *Causal) SetConfig(config ml.CacheConfig) {
|
||||
if c.config != nil {
|
||||
panic("config cannot be changed after being previously set, either by the model or backend")
|
||||
}
|
||||
|
||||
c.config = &config
|
||||
}
|
||||
|
||||
func (c *Causal) Close() {
|
||||
c.cacheCtx.Close()
|
||||
}
|
||||
@@ -157,36 +188,91 @@ func (c *Causal) findStartLoc() (int, error) {
|
||||
return 0, fmt.Errorf("%w (length: %v)", ErrKvCacheFull, c.Capacity)
|
||||
}
|
||||
|
||||
func roundDown(length, pad int) int {
|
||||
return (length / pad) * pad
|
||||
}
|
||||
|
||||
func roundUp(length, pad int) int {
|
||||
return ((length + pad - 1) / pad) * pad
|
||||
}
|
||||
|
||||
// Builds a mask of history x batch indicating whether for each token in the batch the
|
||||
// token in the history should apply. This is based on both the sequence and causality (the
|
||||
// position of the history is not ahead of the token in the batch).
|
||||
func (c *Causal) buildMask(ctx ml.Context, positions []int32, seqs []int) (ml.Tensor, error) {
|
||||
// TODO(jessegross): This does not do padding, which is required for flash attention
|
||||
len := c.curCellRange.max - c.curCellRange.min + 1
|
||||
mask := make([]float32, c.curBatchSize*len)
|
||||
// Align and pad the two dimensions as required by the backend
|
||||
batchSize := roundUp(c.curBatchSize, c.config.MaskBatchPadding)
|
||||
|
||||
c.curCellRange.min = roundDown(c.curCellRange.min, c.config.CachePadding)
|
||||
c.curCellRange.max = roundUp(c.curCellRange.max+1, c.config.CachePadding) - 1
|
||||
|
||||
length := c.curCellRange.max - c.curCellRange.min + 1
|
||||
mask := make([]float32, batchSize*length)
|
||||
|
||||
for i := range c.curBatchSize {
|
||||
for j := c.curCellRange.min; j <= c.curCellRange.max; j++ {
|
||||
if !slices.Contains(c.cells[j].sequences, seqs[i]) || c.cells[j].pos > positions[i] ||
|
||||
c.cells[j].pos < positions[i]-c.windowSize {
|
||||
mask[i*len+(j-c.curCellRange.min)] = float32(math.Inf(-1))
|
||||
mask[i*length+(j-c.curCellRange.min)] = float32(math.Inf(-1))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ctx.FromFloatSlice(mask, len, c.curBatchSize)
|
||||
// Mask out any padding tokens we added. For padding that we added to the cache history, this
|
||||
// has already been masked out because the sequence doesn't match.
|
||||
for i := c.curBatchSize * length; i < len(mask); i++ {
|
||||
mask[i] = float32(math.Inf(-1))
|
||||
}
|
||||
|
||||
maskTensor, err := ctx.FromFloatSlice(mask, length, batchSize)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if c.config.MaskDType != ml.DTypeF32 {
|
||||
out := ctx.Empty(c.config.MaskDType, maskTensor.Shape()...)
|
||||
ctx.Forward(maskTensor.Copy(ctx, out))
|
||||
maskTensor = out
|
||||
}
|
||||
|
||||
return maskTensor, nil
|
||||
}
|
||||
|
||||
func moveCell(ctx ml.Context, objs []ml.Tensor, src, dst, len int) {
|
||||
for _, obj := range objs {
|
||||
if obj == nil {
|
||||
func (c *Causal) moveCells(ctx ml.Context, src, dst, len int) {
|
||||
for i := range c.keys {
|
||||
if c.keys[i] == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
srcView := obj.View(ctx, obj.Stride(2)*src, obj.Dim(0)*obj.Dim(1)*len)
|
||||
dstView := obj.View(ctx, obj.Stride(2)*dst, obj.Dim(0)*obj.Dim(1)*len)
|
||||
key := c.keys[i]
|
||||
|
||||
ctx.Forward(srcView.Copy(ctx, dstView))
|
||||
kHeadDim := key.Dim(0)
|
||||
numKVHeads := key.Dim(1)
|
||||
rowSize := key.Stride(2)
|
||||
|
||||
kSrcView := key.View(ctx, rowSize*src, kHeadDim*numKVHeads*len)
|
||||
kDstView := key.View(ctx, rowSize*dst, kHeadDim*numKVHeads*len)
|
||||
|
||||
value := c.values[i]
|
||||
var vSrcView, vDstView ml.Tensor
|
||||
if c.config.PermutedV {
|
||||
vHeadDim := value.Dim(1)
|
||||
elemSize := value.Stride(0)
|
||||
|
||||
vSrcView = value.View(ctx, elemSize*src, len, int(c.Capacity)*elemSize, vHeadDim*numKVHeads)
|
||||
vDstView = value.View(ctx, elemSize*dst, len, int(c.Capacity)*elemSize, vHeadDim*numKVHeads)
|
||||
} else {
|
||||
vHeadDim := value.Dim(0)
|
||||
rowSize := value.Stride(2)
|
||||
|
||||
vSrcView = value.View(ctx, rowSize*src, vHeadDim*numKVHeads*len)
|
||||
vDstView = value.View(ctx, rowSize*dst, vHeadDim*numKVHeads*len)
|
||||
}
|
||||
|
||||
ctx.Forward(
|
||||
kSrcView.Copy(ctx, kDstView),
|
||||
vSrcView.Copy(ctx, vDstView),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -238,8 +324,7 @@ func (c *Causal) defrag() {
|
||||
pendingLen++
|
||||
break
|
||||
} else {
|
||||
moveCell(ctx, c.keys, pendingSrc, pendingDst, pendingLen)
|
||||
moveCell(ctx, c.values, pendingSrc, pendingDst, pendingLen)
|
||||
c.moveCells(ctx, pendingSrc, pendingDst, pendingLen)
|
||||
moves++
|
||||
}
|
||||
}
|
||||
@@ -263,8 +348,7 @@ func (c *Causal) defrag() {
|
||||
}
|
||||
|
||||
if pendingLen > 0 {
|
||||
moveCell(ctx, c.keys, pendingSrc, pendingDst, pendingLen)
|
||||
moveCell(ctx, c.values, pendingSrc, pendingDst, pendingLen)
|
||||
c.moveCells(ctx, pendingSrc, pendingDst, pendingLen)
|
||||
moves++
|
||||
}
|
||||
|
||||
@@ -305,35 +389,73 @@ func (c *Causal) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
|
||||
key := c.keys[c.curLayer]
|
||||
value := c.values[c.curLayer]
|
||||
|
||||
key = key.View(ctx, key.Stride(2)*c.curCellRange.min,
|
||||
key.Dim(0), key.Stride(1),
|
||||
key.Dim(1), key.Stride(2),
|
||||
c.curMask.Dim(0),
|
||||
kHeadDim := key.Dim(0)
|
||||
numKVHeads := key.Dim(1)
|
||||
rowSize := key.Stride(2)
|
||||
cachedSize := c.curMask.Dim(0)
|
||||
|
||||
key = key.View(ctx, rowSize*c.curCellRange.min,
|
||||
kHeadDim, key.Stride(1),
|
||||
numKVHeads, key.Stride(2),
|
||||
cachedSize,
|
||||
)
|
||||
|
||||
value = value.View(ctx, key.Stride(2)*c.curCellRange.min,
|
||||
value.Dim(0), value.Stride(1),
|
||||
value.Dim(1), value.Stride(2),
|
||||
c.curMask.Dim(0),
|
||||
)
|
||||
if c.config.PermutedV {
|
||||
vHeadDim := value.Dim(1)
|
||||
elemSize := value.Stride(0)
|
||||
|
||||
value = value.View(ctx, elemSize*c.curCellRange.min,
|
||||
cachedSize, value.Stride(1),
|
||||
vHeadDim, value.Stride(2),
|
||||
numKVHeads,
|
||||
)
|
||||
} else {
|
||||
vHeadDim := value.Dim(0)
|
||||
rowSize := value.Stride(2)
|
||||
|
||||
value = value.View(ctx, rowSize*c.curCellRange.min,
|
||||
vHeadDim, value.Stride(1),
|
||||
numKVHeads, value.Stride(2),
|
||||
cachedSize,
|
||||
)
|
||||
}
|
||||
|
||||
return key, value, c.curMask
|
||||
}
|
||||
|
||||
func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
|
||||
if c.curBatchSize != key.Dim(2) {
|
||||
panic(fmt.Errorf("inconsistent batch sizes (layer: %v, batch size: %v layer batch size: %v)", c.curLayer, c.curBatchSize, key.Dim(2)))
|
||||
kHeadDim := key.Dim(0)
|
||||
vHeadDim := value.Dim(0)
|
||||
numKVHeads := key.Dim(1)
|
||||
batchSize := key.Dim(2)
|
||||
|
||||
if c.curBatchSize != batchSize {
|
||||
panic(fmt.Errorf("inconsistent batch sizes (layer: %v, batch size: %v layer batch size: %v)", c.curLayer, c.curBatchSize, batchSize))
|
||||
}
|
||||
|
||||
if c.keys[c.curLayer] == nil || c.values[c.curLayer] == nil {
|
||||
c.keys[c.curLayer] = c.cacheCtx.Zeros(c.DType, key.Dim(0), key.Dim(1), int(c.Capacity))
|
||||
c.values[c.curLayer] = c.cacheCtx.Zeros(c.DType, value.Dim(0), value.Dim(1), int(c.Capacity))
|
||||
c.keys[c.curLayer] = c.cacheCtx.Zeros(c.DType, kHeadDim, numKVHeads, int(c.Capacity))
|
||||
|
||||
if c.config.PermutedV {
|
||||
c.values[c.curLayer] = c.cacheCtx.Zeros(c.DType, int(c.Capacity), vHeadDim, numKVHeads)
|
||||
} else {
|
||||
c.values[c.curLayer] = c.cacheCtx.Zeros(c.DType, vHeadDim, numKVHeads, int(c.Capacity))
|
||||
}
|
||||
}
|
||||
|
||||
ctx.Forward(
|
||||
key.Copy(ctx, c.keys[c.curLayer].View(ctx, c.keys[c.curLayer].Stride(2)*c.curLoc, key.Dim(0)*key.Dim(1)*key.Dim(2))),
|
||||
value.Copy(ctx, c.values[c.curLayer].View(ctx, c.values[c.curLayer].Stride(2)*c.curLoc, value.Dim(0)*value.Dim(1)*value.Dim(2))),
|
||||
)
|
||||
rowSize := c.keys[c.curLayer].Stride(2)
|
||||
ctx.Forward(key.Copy(ctx, c.keys[c.curLayer].View(ctx, rowSize*c.curLoc, kHeadDim*numKVHeads*batchSize)))
|
||||
|
||||
if c.config.PermutedV {
|
||||
elemSize := c.values[c.curLayer].Stride(0)
|
||||
|
||||
value = value.Permute(ctx, 1, 2, 0, 3)
|
||||
ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, elemSize*c.curLoc, batchSize, int(c.Capacity)*elemSize, vHeadDim*numKVHeads)))
|
||||
} else {
|
||||
rowSize := c.values[c.curLayer].Stride(2)
|
||||
|
||||
ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, rowSize*c.curLoc, vHeadDim*numKVHeads*batchSize)))
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Causal) CopyPrefix(srcSeq, dstSeq int, len int32) {
|
||||
@@ -389,9 +511,13 @@ func (c *Causal) shift(seq int, beginIndex, offset int32) error {
|
||||
continue
|
||||
}
|
||||
|
||||
key = key.View(ctx, key.Stride(2)*seqRange.min,
|
||||
key.Dim(0), key.Stride(1),
|
||||
key.Dim(1), key.Stride(2),
|
||||
kHeadDim := key.Dim(0)
|
||||
numKVHeads := key.Dim(1)
|
||||
rowSize := key.Stride(2)
|
||||
|
||||
key = key.View(ctx, rowSize*seqRange.min,
|
||||
kHeadDim, key.Stride(1),
|
||||
numKVHeads, key.Stride(2),
|
||||
size,
|
||||
)
|
||||
|
||||
|
||||
@@ -309,7 +309,7 @@ func (b *testBackend) SystemInfo() string {
|
||||
|
||||
type testContext struct{}
|
||||
|
||||
func (c *testContext) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
|
||||
func (c *testContext) Empty(dtype ml.DType, shape ...int) ml.Tensor {
|
||||
total := 0
|
||||
|
||||
if len(shape) > 0 {
|
||||
@@ -322,8 +322,12 @@ func (c *testContext) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
|
||||
return &testTensor{dtype: dtype, elementSize: 4, data: make([]float32, total), shape: shape}
|
||||
}
|
||||
|
||||
func (c *testContext) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
|
||||
return c.Empty(dtype, shape...)
|
||||
}
|
||||
|
||||
func (c *testContext) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) {
|
||||
t := c.Zeros(ml.DTypeF32, shape...).(*testTensor)
|
||||
t := c.Empty(ml.DTypeF32, shape...).(*testTensor)
|
||||
|
||||
copy(t.data, s)
|
||||
|
||||
@@ -391,7 +395,7 @@ func (t *testTensor) Floats() []float32 {
|
||||
}
|
||||
|
||||
func (t *testTensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||
out := ctx.Zeros(t.DType(), t.Shape()...).(*testTensor)
|
||||
out := ctx.Empty(t.DType(), t.Shape()...).(*testTensor)
|
||||
|
||||
for i := range out.data {
|
||||
out.data[i] = t.data[i] + t2.(*testTensor).data[i]
|
||||
@@ -468,7 +472,7 @@ func (t *testTensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
|
||||
|
||||
context := &testContext{}
|
||||
|
||||
view := context.Zeros(t.dtype, s...).(*testTensor)
|
||||
view := context.Empty(t.dtype, s...).(*testTensor)
|
||||
view.data = t.data[offset : offset+len(view.data)]
|
||||
|
||||
return view
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package kvcache
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/ollama/ollama/ml"
|
||||
)
|
||||
|
||||
@@ -11,6 +13,9 @@ import (
|
||||
//
|
||||
// Not currently safe for multiple sequences
|
||||
type EncoderCache struct {
|
||||
// config controls mostly backend-specific optimizations
|
||||
config *ml.CacheConfig
|
||||
|
||||
// ** current forward pass **
|
||||
|
||||
// the active layer for Get and Put
|
||||
@@ -40,9 +45,29 @@ func NewEncoderCache() *EncoderCache {
|
||||
}
|
||||
|
||||
func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
|
||||
if c.config == nil {
|
||||
var config ml.CacheConfig
|
||||
if cc, ok := backend.(ml.BackendCacheConfig); ok {
|
||||
config = cc.CacheConfig()
|
||||
}
|
||||
c.config = &config
|
||||
}
|
||||
|
||||
if c.config.CachePadding != 0 && c.config.CachePadding != 1 {
|
||||
panic(fmt.Errorf("encoder cache is unable to enforce requested CachePadding (%v)", c.config.CachePadding))
|
||||
}
|
||||
|
||||
c.cacheCtx = backend.NewContext()
|
||||
}
|
||||
|
||||
func (c *EncoderCache) SetConfig(config ml.CacheConfig) {
|
||||
if c.config != nil {
|
||||
panic("config cannot be changed after being previously set, either by the model or backend")
|
||||
}
|
||||
|
||||
c.config = &config
|
||||
}
|
||||
|
||||
func (c *EncoderCache) Close() {
|
||||
c.cacheCtx.Close()
|
||||
}
|
||||
@@ -75,9 +100,13 @@ func (c *EncoderCache) Put(ctx ml.Context, key, value ml.Tensor) {
|
||||
c.encoderPos = c.curPos
|
||||
c.encoderCached = true
|
||||
|
||||
if c.config.PermutedV {
|
||||
value = value.Permute(ctx, 1, 2, 0, 3)
|
||||
}
|
||||
|
||||
if c.keys[c.curLayer] == nil || c.values[c.curLayer] == nil {
|
||||
c.keys[c.curLayer] = c.cacheCtx.Zeros(key.DType(), key.Shape()...)
|
||||
c.values[c.curLayer] = c.cacheCtx.Zeros(value.DType(), value.Shape()...)
|
||||
c.keys[c.curLayer] = c.cacheCtx.Empty(key.DType(), key.Shape()...)
|
||||
c.values[c.curLayer] = c.cacheCtx.Empty(value.DType(), value.Shape()...)
|
||||
}
|
||||
|
||||
ctx.Forward(
|
||||
|
||||
@@ -28,6 +28,12 @@ func (c *WrapperCache) Init(backend ml.Backend, dtype ml.DType, capacity int32)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *WrapperCache) SetConfig(config ml.CacheConfig) {
|
||||
for _, cache := range c.caches {
|
||||
cache.SetConfig(config)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *WrapperCache) Close() {
|
||||
for _, cache := range c.caches {
|
||||
cache.Close()
|
||||
|
||||
@@ -27,6 +27,35 @@ type Backend interface {
|
||||
SystemInfo() string
|
||||
}
|
||||
|
||||
// BackendCacheConfig should be implemented by backends that need special output
|
||||
// from the cache to meet specific requirements. It is frequently implemented in
|
||||
// conjunction with ScaledDotProductAttention.
|
||||
type BackendCacheConfig interface {
|
||||
CacheConfig() CacheConfig
|
||||
}
|
||||
|
||||
// CacheConfig controls optimizations (mostly backend-specific) that may transform
|
||||
// the output the cache to work better with specific kernels.
|
||||
type CacheConfig struct {
|
||||
// CachePadding specifies the multiple for the number of tokens of cache history
|
||||
// that will be returned from cache Get for k, v and mask. The capacity of the
|
||||
// cache itself will also be increased to a multiple of this size if needed.
|
||||
CachePadding int
|
||||
|
||||
// PermutedV performs Permute(ctx, 1, 2, 0, 3) on v tensors stored via Put
|
||||
// and return the permuted version via Get. This uses the cache copy operation
|
||||
// to avoid a Contiguous call on the permuted tensor.
|
||||
PermutedV bool
|
||||
|
||||
// MaskDType specifies the data type for generating the mask. If unset it will
|
||||
// default to DTypeF32.
|
||||
MaskDType DType
|
||||
|
||||
// MaskBatchPadding specifies the multiple for the batch size dimension in the mask.
|
||||
// Any position that does not correspond to an actual token will be filled with -Inf.
|
||||
MaskBatchPadding int
|
||||
}
|
||||
|
||||
// BackendParams controls how the backend loads and executes models
|
||||
type BackendParams struct {
|
||||
// NumThreads sets the number of threads to use if running on the CPU
|
||||
@@ -40,6 +69,9 @@ type BackendParams struct {
|
||||
|
||||
// TensorSplit is the fraction of the model to offload to each GPU
|
||||
TensorSplit []float32
|
||||
|
||||
// FlashAttention indicates that we should use a fused flash attention kernel
|
||||
FlashAttention bool
|
||||
}
|
||||
|
||||
var backends = make(map[string]func(*os.File, BackendParams) (Backend, error))
|
||||
@@ -61,6 +93,7 @@ func NewBackend(f *os.File, params BackendParams) (Backend, error) {
|
||||
}
|
||||
|
||||
type Context interface {
|
||||
Empty(dtype DType, shape ...int) Tensor
|
||||
Zeros(dtype DType, shape ...int) Tensor
|
||||
FromFloatSlice(s []float32, shape ...int) (Tensor, error)
|
||||
FromIntSlice(s []int32, shape ...int) (Tensor, error)
|
||||
@@ -116,6 +149,10 @@ type Tensor interface {
|
||||
// operation equivalent to following code on a tensor named
|
||||
// query:
|
||||
//
|
||||
// query = query.Permute(ctx, 0, 2, 1, 3)
|
||||
// key = key.Permute(ctx, 0, 2, 1, 3)
|
||||
// value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||
//
|
||||
// kq := key.MulmatFullPrec(ctx, query)
|
||||
//
|
||||
// kq = kq.Scale(ctx, scale)
|
||||
@@ -170,7 +207,7 @@ func Dump(ctx Context, t Tensor, opts ...DumpOptions) string {
|
||||
return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32)
|
||||
})
|
||||
case DTypeF16:
|
||||
f32 := ctx.Zeros(DTypeF32, t.Shape()...)
|
||||
f32 := ctx.Empty(DTypeF32, t.Shape()...)
|
||||
f32 = t.Copy(ctx, f32)
|
||||
return dump[[]float32](ctx, f32, opts[0].Items, func(f float32) string {
|
||||
return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32)
|
||||
|
||||
@@ -79,6 +79,8 @@ var devices = sync.OnceValue(func() []device {
|
||||
})
|
||||
|
||||
type Backend struct {
|
||||
flashAttention bool
|
||||
|
||||
meta *fs.GGML
|
||||
cpus, gpus []Context
|
||||
tensors map[string]*Context
|
||||
@@ -192,9 +194,10 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
|
||||
}
|
||||
|
||||
return &Backend{
|
||||
meta: meta,
|
||||
cpus: cpus,
|
||||
gpus: gpus,
|
||||
flashAttention: params.FlashAttention,
|
||||
meta: meta,
|
||||
cpus: cpus,
|
||||
gpus: gpus,
|
||||
sched: C.ggml_backend_sched_new(
|
||||
(*C.ggml_backend_t)(unsafe.Pointer(&backends[0])),
|
||||
(*C.ggml_backend_buffer_type_t)(unsafe.Pointer(&bufts[0])),
|
||||
@@ -219,7 +222,7 @@ func (b *Backend) Get(name string) ml.Tensor {
|
||||
|
||||
for _, c := range append(b.gpus, b.cpus...) {
|
||||
if t := C.ggml_get_tensor(c.ctx, cname); t != nil {
|
||||
return &Tensor{t: t}
|
||||
return &Tensor{b: b, t: t}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -247,6 +250,14 @@ func (b *Backend) NewContext() ml.Context {
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Backend) CacheConfig() ml.CacheConfig {
|
||||
if b.flashAttention {
|
||||
return ml.CacheConfig{CachePadding: 256, MaskDType: ml.DTypeF16, MaskBatchPadding: C.GGML_KQ_MASK_PAD}
|
||||
} else {
|
||||
return ml.CacheConfig{CachePadding: 32, PermutedV: true}
|
||||
}
|
||||
}
|
||||
|
||||
type Context struct {
|
||||
b *Backend
|
||||
ctx *C.struct_ggml_context
|
||||
@@ -300,7 +311,7 @@ func shapeToGGML(shape []int) *C.int64_t {
|
||||
return &sh[0]
|
||||
}
|
||||
|
||||
func (c Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
|
||||
func newTensor(ctx Context, dtype ml.DType, zero bool, shape []int) ml.Tensor {
|
||||
if len(shape) < 1 || len(shape) > 4 {
|
||||
panic("unsupported number of dimensions")
|
||||
}
|
||||
@@ -314,19 +325,29 @@ func (c Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
|
||||
var t *C.struct_ggml_tensor
|
||||
switch dtype {
|
||||
case ml.DTypeF32:
|
||||
t = C.ggml_new_tensor(c.ctx, C.GGML_TYPE_F32, C.int(len(shape)), shapeToGGML(shape))
|
||||
t = C.ggml_new_tensor(ctx.ctx, C.GGML_TYPE_F32, C.int(len(shape)), shapeToGGML(shape))
|
||||
case ml.DTypeF16:
|
||||
t = C.ggml_new_tensor(c.ctx, C.GGML_TYPE_F16, C.int(len(shape)), shapeToGGML(shape))
|
||||
t = C.ggml_new_tensor(ctx.ctx, C.GGML_TYPE_F16, C.int(len(shape)), shapeToGGML(shape))
|
||||
case ml.DTypeI32:
|
||||
t = C.ggml_new_tensor(c.ctx, C.GGML_TYPE_I32, C.int(len(shape)), shapeToGGML(shape))
|
||||
t = C.ggml_new_tensor(ctx.ctx, C.GGML_TYPE_I32, C.int(len(shape)), shapeToGGML(shape))
|
||||
default:
|
||||
panic("unsupported dtype")
|
||||
}
|
||||
|
||||
b := C.ggml_backend_alloc_buffer(c.backend, C.ggml_nbytes(t))
|
||||
b := C.ggml_backend_alloc_buffer(ctx.backend, C.ggml_nbytes(t))
|
||||
C.ggml_backend_tensor_alloc(b, t, C.ggml_backend_buffer_get_base(b))
|
||||
C.ggml_set_zero(t)
|
||||
return &Tensor{t: t}
|
||||
if zero {
|
||||
C.ggml_set_zero(t)
|
||||
}
|
||||
return &Tensor{b: ctx.b, t: t}
|
||||
}
|
||||
|
||||
func (c Context) Empty(dtype ml.DType, shape ...int) ml.Tensor {
|
||||
return newTensor(c, dtype, false, shape)
|
||||
}
|
||||
|
||||
func (c Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
|
||||
return newTensor(c, dtype, true, shape)
|
||||
}
|
||||
|
||||
func fromSlice[S ~[]E, E float32 | int32](ctx Context, s S, shape []int, dtype uint32) (ml.Tensor, error) {
|
||||
@@ -335,7 +356,7 @@ func fromSlice[S ~[]E, E float32 | int32](ctx Context, s S, shape []int, dtype u
|
||||
if n == 0 {
|
||||
var shape C.int64_t = 0
|
||||
t := C.ggml_new_tensor(ctx.ctx, dtype, 1, &shape)
|
||||
return &Tensor{t: t}, nil
|
||||
return &Tensor{b: ctx.b, t: t}, nil
|
||||
}
|
||||
|
||||
for _, v := range shape {
|
||||
@@ -350,7 +371,7 @@ func fromSlice[S ~[]E, E float32 | int32](ctx Context, s S, shape []int, dtype u
|
||||
b := C.ggml_backend_alloc_buffer(ctx.backend, C.ggml_nbytes(t))
|
||||
C.ggml_backend_tensor_alloc(b, t, C.ggml_backend_buffer_get_base(b))
|
||||
C.ggml_backend_tensor_set(t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t))
|
||||
return &Tensor{t: t}, nil
|
||||
return &Tensor{b: ctx.b, t: t}, nil
|
||||
}
|
||||
|
||||
func (c Context) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) {
|
||||
@@ -368,6 +389,7 @@ func (c *Context) Close() {
|
||||
}
|
||||
|
||||
type Tensor struct {
|
||||
b *Backend
|
||||
t *C.struct_ggml_tensor
|
||||
sync func()
|
||||
}
|
||||
@@ -434,6 +456,7 @@ func (t *Tensor) DType() ml.DType {
|
||||
|
||||
func (t *Tensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_add(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
|
||||
}
|
||||
}
|
||||
@@ -448,24 +471,28 @@ func (t *Tensor) Stack(ctx ml.Context, dim int, s ...ml.Tensor) ml.Tensor {
|
||||
|
||||
func (t *Tensor) Concat(ctx ml.Context, t2 ml.Tensor, dim int) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_concat(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.int(dim)),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Contiguous(ctx ml.Context) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_cont(ctx.(*Context).ctx, t.t),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Mul(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_mul(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Mulmat(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_mul_mat(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
|
||||
}
|
||||
}
|
||||
@@ -475,12 +502,13 @@ func (t *Tensor) MulmatFullPrec(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||
C.ggml_mul_mat_set_prec(mul, C.GGML_PREC_F32)
|
||||
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: mul,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tensor {
|
||||
tt := (&Tensor{t: C.ggml_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w)
|
||||
tt := (&Tensor{b: t.b, t: C.ggml_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w)
|
||||
if b != nil {
|
||||
tt = tt.Add(ctx, b)
|
||||
}
|
||||
@@ -489,7 +517,7 @@ func (t *Tensor) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tenso
|
||||
}
|
||||
|
||||
func (t *Tensor) RMSNorm(ctx ml.Context, w ml.Tensor, eps float32) ml.Tensor {
|
||||
return (&Tensor{t: C.ggml_rms_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w)
|
||||
return (&Tensor{b: t.b, t: C.ggml_rms_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w)
|
||||
}
|
||||
|
||||
func (t *Tensor) Pad(ctx ml.Context, shape ...int) ml.Tensor {
|
||||
@@ -498,6 +526,7 @@ func (t *Tensor) Pad(ctx ml.Context, shape ...int) ml.Tensor {
|
||||
}
|
||||
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_pad(ctx.(*Context).ctx, t.t, C.int(shape[0]), C.int(shape[1]), C.int(shape[2]), C.int(shape[3])),
|
||||
}
|
||||
}
|
||||
@@ -508,18 +537,21 @@ func (t *Tensor) Permute(ctx ml.Context, shape ...int) ml.Tensor {
|
||||
}
|
||||
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_permute(ctx.(*Context).ctx, t.t, C.int(shape[0]), C.int(shape[1]), C.int(shape[2]), C.int(shape[3])),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Rows(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_get_rows(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Copy(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_cpy(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
|
||||
}
|
||||
}
|
||||
@@ -528,18 +560,22 @@ func (t *Tensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor {
|
||||
switch len(shape) {
|
||||
case 1:
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_reshape_1d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0])),
|
||||
}
|
||||
case 2:
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_reshape_2d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1])),
|
||||
}
|
||||
case 3:
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_reshape_3d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1]), C.int64_t(shape[2])),
|
||||
}
|
||||
case 4:
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_reshape_4d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1]), C.int64_t(shape[2]), C.int64_t(shape[3])),
|
||||
}
|
||||
default:
|
||||
@@ -549,18 +585,21 @@ func (t *Tensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor {
|
||||
|
||||
func (t *Tensor) Scale(ctx ml.Context, s float64) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_scale(ctx.(*Context).ctx, t.t, (C.float)(s)),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Softmax(ctx ml.Context) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_soft_max(ctx.(*Context).ctx, t.t),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Tanh(ctx ml.Context) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_tanh_inplace(ctx.(*Context).ctx, t.t),
|
||||
}
|
||||
}
|
||||
@@ -571,6 +610,7 @@ func (t *Tensor) Unpad(ctx ml.Context, shape ...int) ml.Tensor {
|
||||
}
|
||||
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_unpad(ctx.(*Context).ctx, t.t, C.int(shape[0]), C.int(shape[1]), C.int(shape[2]), C.int(shape[3])),
|
||||
}
|
||||
}
|
||||
@@ -579,10 +619,12 @@ func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
|
||||
switch len(shape) {
|
||||
case 1:
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_view_1d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.size_t(offset)),
|
||||
}
|
||||
case 3:
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_view_2d(ctx.(*Context).ctx, t.t,
|
||||
C.int64_t(shape[0]), C.int64_t(shape[2]),
|
||||
C.size_t(shape[1]),
|
||||
@@ -590,6 +632,7 @@ func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
|
||||
}
|
||||
case 5:
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_view_3d(ctx.(*Context).ctx, t.t,
|
||||
C.int64_t(shape[0]), C.int64_t(shape[2]), C.int64_t(shape[4]),
|
||||
C.size_t(shape[1]), C.size_t(shape[3]),
|
||||
@@ -597,6 +640,7 @@ func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
|
||||
}
|
||||
case 7:
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_view_4d(ctx.(*Context).ctx, t.t,
|
||||
C.int64_t(shape[0]), C.int64_t(shape[2]), C.int64_t(shape[4]), C.int64_t(shape[6]),
|
||||
C.size_t(shape[1]), C.size_t(shape[3]), C.size_t(shape[5]),
|
||||
@@ -613,7 +657,7 @@ const (
|
||||
|
||||
func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim uint32, ropeBase, ropeScale float32) ml.Tensor {
|
||||
if ropeFactors == nil {
|
||||
ropeFactors = &Tensor{}
|
||||
ropeFactors = &Tensor{b: t.b}
|
||||
}
|
||||
|
||||
dequant := t.t
|
||||
@@ -622,6 +666,7 @@ func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDi
|
||||
}
|
||||
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_rope_ext(
|
||||
ctx.(*Context).ctx, dequant, positionIDs.(*Tensor).t, ropeFactors.(*Tensor).t,
|
||||
C.int(ropeDim),
|
||||
@@ -639,18 +684,21 @@ func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDi
|
||||
|
||||
func (t *Tensor) GELU(ctx ml.Context) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_gelu_inplace(ctx.(*Context).ctx, t.t),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) SILU(ctx ml.Context) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_silu_inplace(ctx.(*Context).ctx, t.t),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Conv2D(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_conv_2d(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.int(s0), C.int(s1), C.int(p0), C.int(p1), C.int(d0), C.int(d1)),
|
||||
}
|
||||
}
|
||||
@@ -661,13 +709,25 @@ func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask ml.T
|
||||
kqMask = mask.(*Tensor).t
|
||||
}
|
||||
|
||||
kq := key.MulmatFullPrec(ctx, t)
|
||||
kq = &Tensor{
|
||||
t: C.ggml_soft_max_ext(ctx.(*Context).ctx, kq.(*Tensor).t, kqMask, C.float(scale), 0),
|
||||
}
|
||||
query := t.Permute(ctx, 0, 2, 1, 3)
|
||||
key = key.Permute(ctx, 0, 2, 1, 3)
|
||||
|
||||
kqv := value.Mulmat(ctx, kq)
|
||||
return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
if t.b.flashAttention {
|
||||
value = value.Permute(ctx, 0, 2, 1, 3)
|
||||
|
||||
kqv := C.ggml_flash_attn_ext(ctx.(*Context).ctx, query.(*Tensor).t, key.(*Tensor).t, value.(*Tensor).t, kqMask, C.float(scale), 0, 0)
|
||||
C.ggml_flash_attn_ext_set_prec(kqv, C.GGML_PREC_F32)
|
||||
return &Tensor{b: t.b, t: kqv}
|
||||
} else {
|
||||
kq := key.MulmatFullPrec(ctx, query)
|
||||
kq = &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_soft_max_ext(ctx.(*Context).ctx, kq.(*Tensor).t, kqMask, C.float(scale), 0),
|
||||
}
|
||||
|
||||
kqv := value.Mulmat(ctx, kq)
|
||||
return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Backend) SystemInfo() string {
|
||||
|
||||
@@ -3,6 +3,7 @@ package nn
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
)
|
||||
|
||||
@@ -11,40 +12,50 @@ import (
|
||||
//
|
||||
// Parameters:
|
||||
// - ctx: Context for tensor operations
|
||||
// - query: Query tensor (Q) with shape [d_k, seq_len_q, heads]
|
||||
// - key: Key tensor (K) with shape [d_k, seq_len_k, kv_heads]
|
||||
// - value: Value tensor (V) with shape [seq_len_k, d_v, kv_heads]
|
||||
// - mask: Optional attention mask that is added to the attention score. If
|
||||
// provided, should broadcast to [seq_len_k, seq_len_q, heads]
|
||||
// - query: Query tensor (Q) with shape [d_k, heads, seq_len_q]
|
||||
// - key: Key tensor (K) with shape [d_k, kv_heads, seq_len_k], can be nil to read from cache only
|
||||
// - value: Value tensor (V) with shape [d_v, kv_heads, seq_len_k], can be nil to read from cache only
|
||||
// - scale: Scaling factor, typically 1/√d_k where d_k is the key dimension
|
||||
// - cache: KV cache to store key/value and get past history, can be nil to only use provided key/value
|
||||
//
|
||||
// Returns:
|
||||
//
|
||||
// Attention output with shape [d_v, heads, seq_len_q]
|
||||
func Attention(ctx ml.Context, query, key, value, mask ml.Tensor, scale float64) ml.Tensor {
|
||||
if query.Dim(0) != key.Dim(0) {
|
||||
panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0)))
|
||||
func Attention(ctx ml.Context, query, key, value ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
|
||||
if key != nil && value != nil {
|
||||
if query.Dim(0) != key.Dim(0) {
|
||||
panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0)))
|
||||
}
|
||||
|
||||
if key.Dim(1) != value.Dim(1) {
|
||||
panic(fmt.Errorf("kv_heads in attention operation does not match between key(%v) and value(%v)", key.Dim(1), value.Dim(1)))
|
||||
}
|
||||
|
||||
if key.Dim(2) != value.Dim(2) {
|
||||
panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and value(%v)", key.Dim(2), value.Dim(2)))
|
||||
}
|
||||
|
||||
if cache != nil {
|
||||
cache.Put(ctx, key, value)
|
||||
}
|
||||
} else if cache == nil {
|
||||
panic("key & value tensors must be provided if cache is nil")
|
||||
}
|
||||
|
||||
if mask != nil && query.Dim(1) != mask.Dim(1) {
|
||||
panic(fmt.Errorf("seq_len_q in attention operation does not match between query(%v) and mask(%v)", query.Dim(1), mask.Dim(1)))
|
||||
var mask ml.Tensor
|
||||
if cache != nil {
|
||||
key, value, mask = cache.Get(ctx)
|
||||
}
|
||||
|
||||
if key.Dim(1) != value.Dim(0) {
|
||||
panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and value(%v)", key.Dim(1), value.Dim(0)))
|
||||
}
|
||||
|
||||
if mask != nil && key.Dim(1) != mask.Dim(0) {
|
||||
panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and mask(%v)", key.Dim(1), mask.Dim(0)))
|
||||
}
|
||||
|
||||
if key.Dim(2) != value.Dim(2) {
|
||||
panic(fmt.Errorf("kv_heads in attention operation does not match between key(%v) and value(%v)", key.Dim(2), value.Dim(2)))
|
||||
}
|
||||
|
||||
if sdpa, ok := query.(ml.ScaledDotProductAttention); ok {
|
||||
// Only use the fast SDPA implementation if we have a cache, since that's what
|
||||
// will do any expected backend-specific transformations for us
|
||||
if sdpa, ok := query.(ml.ScaledDotProductAttention); ok && cache != nil {
|
||||
return sdpa.ScaledDotProductAttention(ctx, key, value, mask, scale)
|
||||
} else {
|
||||
query = query.Permute(ctx, 0, 2, 1, 3)
|
||||
key = key.Permute(ctx, 0, 2, 1, 3)
|
||||
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||
|
||||
kq := key.MulmatFullPrec(ctx, query)
|
||||
|
||||
kq = kq.Scale(ctx, scale)
|
||||
|
||||
@@ -81,15 +81,8 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
||||
v := sa.Value.Forward(ctx, hiddenState)
|
||||
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
|
||||
cache.Put(ctx, k, v)
|
||||
k, v, mask := cache.Get(ctx)
|
||||
|
||||
q = q.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
k = k.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
v = v.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||
|
||||
scaleFactor := 1.0 / math.Sqrt(float64(headDim))
|
||||
kqv := nn.Attention(ctx, q, k, v, mask, scaleFactor)
|
||||
kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache)
|
||||
kqv = kqv.Reshape(ctx, opts.hiddenSize, batchSize)
|
||||
|
||||
return sa.Output.Forward(ctx, kqv)
|
||||
|
||||
@@ -43,7 +43,9 @@ func New(c ml.Config) (model.Model, error) {
|
||||
TextModel: newTextModel(c),
|
||||
}
|
||||
|
||||
m.Cache = kvcache.NewWrapperCache(kvcache.NewEncoderCache(), kvcache.NewCausalCache(m.TextModel.Shift))
|
||||
encoderCache := kvcache.NewEncoderCache()
|
||||
encoderCache.SetConfig(ml.CacheConfig{})
|
||||
m.Cache = kvcache.NewWrapperCache(encoderCache, kvcache.NewCausalCache(m.TextModel.Shift))
|
||||
|
||||
return &m, nil
|
||||
}
|
||||
|
||||
@@ -31,22 +31,15 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ m
|
||||
value := sa.Value.Forward(ctx, hiddenState)
|
||||
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
|
||||
cache.Put(ctx, key, value)
|
||||
key, value, mask := cache.Get(ctx)
|
||||
|
||||
query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||
|
||||
scaleFactor := 1.0 / math.Sqrt(float64(headDim))
|
||||
attention := nn.Attention(ctx, query, key, value, mask, scaleFactor)
|
||||
attention := nn.Attention(ctx, query, key, value, scaleFactor, cache)
|
||||
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
|
||||
|
||||
return sa.Output.Forward(ctx, attention)
|
||||
}
|
||||
|
||||
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
// This will only get called for layers in the cache, which are just the self attention layers
|
||||
// This will only get called for layers in the causal cache, which are just the self attention layers
|
||||
return key.RoPE(ctx, shift, m.RopeFactors, m.ropeDim, m.ropeBase, m.ropeScale), nil
|
||||
}
|
||||
|
||||
@@ -107,7 +100,7 @@ func (ca *TextCrossAttention) Forward(ctx ml.Context, hiddenState, crossAttentio
|
||||
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
||||
query = ca.QueryNorm.Forward(ctx, query, opts.eps)
|
||||
|
||||
var key, value, mask ml.Tensor
|
||||
var key, value ml.Tensor
|
||||
if crossAttentionStates != nil {
|
||||
numVisionTokens, numTiles := crossAttentionStates.Dim(1), crossAttentionStates.Dim(2)
|
||||
|
||||
@@ -119,16 +112,23 @@ func (ca *TextCrossAttention) Forward(ctx ml.Context, hiddenState, crossAttentio
|
||||
value = value.Reshape(ctx, headDim, opts.numKVHeads, numVisionTokens*numTiles)
|
||||
|
||||
cache.Put(ctx, key, value)
|
||||
} else {
|
||||
key, value, mask = cache.Get(ctx)
|
||||
}
|
||||
|
||||
query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||
key, value, _ = cache.Get(ctx)
|
||||
|
||||
scaleFactor := 1.0 / math.Sqrt(float64(headDim))
|
||||
attention := nn.Attention(ctx, query, key, value, mask, scaleFactor)
|
||||
|
||||
query = query.Permute(ctx, 0, 2, 1, 3)
|
||||
key = key.Permute(ctx, 0, 2, 1, 3)
|
||||
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||
|
||||
kq := key.MulmatFullPrec(ctx, query)
|
||||
|
||||
kq = kq.Scale(ctx, scaleFactor)
|
||||
kq = kq.Softmax(ctx)
|
||||
|
||||
kqv := value.Mulmat(ctx, kq)
|
||||
attention := kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
|
||||
|
||||
return ca.Output.Forward(ctx, attention)
|
||||
|
||||
@@ -818,7 +818,7 @@ func Execute(args []string) error {
|
||||
batchSize := fs.Int("batch-size", 512, "Batch size")
|
||||
numGPULayers := fs.Int("n-gpu-layers", 0, "Number of layers to offload to GPU")
|
||||
mainGPU := fs.Int("main-gpu", 0, "Main GPU")
|
||||
_ = fs.Bool("flash-attn", false, "Enable flash attention")
|
||||
flashAttention := fs.Bool("flash-attn", false, "Enable flash attention")
|
||||
kvSize := fs.Int("ctx-size", 2048, "Context (or KV cache) size")
|
||||
kvCacheType := fs.String("kv-cache-type", "", "quantization type for KV cache (default: f16)")
|
||||
port := fs.Int("port", 8080, "Port to expose the server on")
|
||||
@@ -863,7 +863,6 @@ func Execute(args []string) error {
|
||||
}
|
||||
|
||||
// TODO(jessegross): Parameters that need to be implemented:
|
||||
// flash-attn
|
||||
// no-mmap
|
||||
// mlock
|
||||
|
||||
@@ -878,10 +877,11 @@ func Execute(args []string) error {
|
||||
}
|
||||
|
||||
params := ml.BackendParams{
|
||||
NumThreads: *threads,
|
||||
NumGPULayers: *numGPULayers,
|
||||
MainGPU: *mainGPU,
|
||||
TensorSplit: tensorSplitFloats,
|
||||
NumThreads: *threads,
|
||||
NumGPULayers: *numGPULayers,
|
||||
MainGPU: *mainGPU,
|
||||
TensorSplit: tensorSplitFloats,
|
||||
FlashAttention: *flashAttention,
|
||||
}
|
||||
|
||||
server.ready.Add(1)
|
||||
|
||||
@@ -24,6 +24,7 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
@@ -53,7 +54,7 @@ var (
|
||||
|
||||
// ErrMissingModel is returned when the model part of a name is missing
|
||||
// or invalid.
|
||||
ErrNameInvalid = errors.New("invalid name; must be in the form {scheme://}{host/}{namespace/}[model]{:tag}{@digest}")
|
||||
ErrNameInvalid = errors.New("invalid or missing name")
|
||||
|
||||
// ErrCached is passed to [Trace.PushUpdate] when a layer already
|
||||
// exists. It is a non-fatal error and is never returned by [Registry.Push].
|
||||
@@ -205,10 +206,18 @@ type Registry struct {
|
||||
// It is only used when a layer is larger than [MaxChunkingThreshold].
|
||||
MaxChunkSize int64
|
||||
|
||||
// NameMask, if set, is the name used to convert non-fully qualified
|
||||
// Mask, if set, is the name used to convert non-fully qualified
|
||||
// names to fully qualified names. If empty, the default mask
|
||||
// ("registry.ollama.ai/library/_:latest") is used.
|
||||
NameMask string
|
||||
Mask string
|
||||
}
|
||||
|
||||
func (r *Registry) completeName(name string) names.Name {
|
||||
mask := defaultMask
|
||||
if r.Mask != "" {
|
||||
mask = names.Parse(r.Mask)
|
||||
}
|
||||
return names.Merge(names.Parse(name), mask)
|
||||
}
|
||||
|
||||
// DefaultRegistry returns a new Registry configured from the environment. The
|
||||
@@ -243,52 +252,6 @@ func DefaultRegistry() (*Registry, error) {
|
||||
return &rc, nil
|
||||
}
|
||||
|
||||
type PushParams struct {
|
||||
// From is an optional destination name for the model. If empty, the
|
||||
// destination name is the same as the source name.
|
||||
From string
|
||||
}
|
||||
|
||||
// parseName parses name using [names.ParseExtended] and then merges the name with the
|
||||
// default name, and checks that the name is fully qualified. If a digest is
|
||||
// present, it parse and returns it with the other fields as their zero values.
|
||||
//
|
||||
// It returns an error if the name is not fully qualified, or if the digest, if
|
||||
// any, is invalid.
|
||||
//
|
||||
// The scheme is returned as provided by [names.ParseExtended].
|
||||
func parseName(s, mask string) (scheme string, n names.Name, d blob.Digest, err error) {
|
||||
maskName := defaultMask
|
||||
if mask != "" {
|
||||
maskName = names.Parse(mask)
|
||||
if !maskName.IsFullyQualified() {
|
||||
return "", names.Name{}, blob.Digest{}, fmt.Errorf("invalid name mask: %s", mask)
|
||||
}
|
||||
}
|
||||
scheme, n, ds := names.ParseExtended(s)
|
||||
if !n.IsValid() {
|
||||
return "", names.Name{}, blob.Digest{}, fmt.Errorf("%w: %q", ErrNameInvalid, s)
|
||||
}
|
||||
n = names.Merge(n, maskName)
|
||||
if ds != "" {
|
||||
// Digest is present. Validate it.
|
||||
d, err = blob.ParseDigest(ds)
|
||||
if err != nil {
|
||||
return "", names.Name{}, blob.Digest{}, err
|
||||
}
|
||||
}
|
||||
|
||||
// The name check is deferred until after the digest check because we
|
||||
// say that digests take precedence over names, and so should there
|
||||
// errors when being parsed.
|
||||
if !n.IsFullyQualified() {
|
||||
return "", names.Name{}, blob.Digest{}, fmt.Errorf("%w: %q", ErrNameInvalid, s)
|
||||
}
|
||||
|
||||
scheme = cmp.Or(scheme, "https")
|
||||
return scheme, n, d, nil
|
||||
}
|
||||
|
||||
func (r *Registry) maxStreams() int {
|
||||
n := cmp.Or(r.MaxStreams, runtime.GOMAXPROCS(0))
|
||||
|
||||
@@ -308,6 +271,12 @@ func (r *Registry) maxChunkSize() int64 {
|
||||
return cmp.Or(r.MaxChunkSize, DefaultMaxChunkSize)
|
||||
}
|
||||
|
||||
type PushParams struct {
|
||||
// From is an optional destination name for the model. If empty, the
|
||||
// destination name is the same as the source name.
|
||||
From string
|
||||
}
|
||||
|
||||
// Push pushes the model with the name in the cache to the remote registry.
|
||||
func (r *Registry) Push(ctx context.Context, c *blob.DiskCache, name string, p *PushParams) error {
|
||||
if p == nil {
|
||||
@@ -337,7 +306,7 @@ func (r *Registry) Push(ctx context.Context, c *blob.DiskCache, name string, p *
|
||||
|
||||
t := traceFromContext(ctx)
|
||||
|
||||
scheme, n, _, err := parseName(name, r.NameMask)
|
||||
scheme, n, _, err := parseName(name, r.Mask)
|
||||
if err != nil {
|
||||
// This should never happen since ResolveLocal should have
|
||||
// already validated the name.
|
||||
@@ -431,7 +400,7 @@ func canRetry(err error) bool {
|
||||
// typically slower than splitting the model up across layers, and is mostly
|
||||
// utilized for layers of type equal to "application/vnd.ollama.image".
|
||||
func (r *Registry) Pull(ctx context.Context, c *blob.DiskCache, name string) error {
|
||||
scheme, n, _, err := parseName(name, r.NameMask)
|
||||
scheme, n, _, err := parseName(name, r.Mask)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -582,9 +551,9 @@ func (r *Registry) Pull(ctx context.Context, c *blob.DiskCache, name string) err
|
||||
// Unlink is like [blob.DiskCache.Unlink], but makes name fully qualified
|
||||
// before attempting to unlink the model.
|
||||
func (r *Registry) Unlink(c *blob.DiskCache, name string) (ok bool, _ error) {
|
||||
_, n, _, err := parseName(name, r.NameMask)
|
||||
if err != nil {
|
||||
return false, err
|
||||
n := r.completeName(name)
|
||||
if !n.IsFullyQualified() {
|
||||
return false, fmt.Errorf("%w: %q", ErrNameInvalid, name)
|
||||
}
|
||||
return c.Unlink(n.String())
|
||||
}
|
||||
@@ -658,9 +627,9 @@ type Layer struct {
|
||||
}
|
||||
|
||||
// ResolveLocal resolves a name to a Manifest in the local cache. The name is
|
||||
// parsed using [names.ParseExtended] but the scheme is ignored.
|
||||
// parsed using [names.Split] but the scheme is ignored.
|
||||
func (r *Registry) ResolveLocal(c *blob.DiskCache, name string) (*Manifest, error) {
|
||||
_, n, d, err := parseName(name, r.NameMask)
|
||||
_, n, d, err := parseName(name, r.Mask)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -686,7 +655,7 @@ func (r *Registry) ResolveLocal(c *blob.DiskCache, name string) (*Manifest, erro
|
||||
|
||||
// Resolve resolves a name to a Manifest in the remote registry.
|
||||
func (r *Registry) Resolve(ctx context.Context, name string) (*Manifest, error) {
|
||||
scheme, n, d, err := parseName(name, r.NameMask)
|
||||
scheme, n, d, err := parseName(name, r.Mask)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -869,3 +838,69 @@ func maybeUnexpectedEOF(err error) error {
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
type publicError struct {
|
||||
wrapped error
|
||||
message string
|
||||
}
|
||||
|
||||
func withPublicMessagef(err error, message string, args ...any) error {
|
||||
return publicError{wrapped: err, message: fmt.Sprintf(message, args...)}
|
||||
}
|
||||
|
||||
func (e publicError) Error() string { return e.message }
|
||||
func (e publicError) Unwrap() error { return e.wrapped }
|
||||
|
||||
var supportedSchemes = []string{
|
||||
"http",
|
||||
"https",
|
||||
"https+insecure",
|
||||
}
|
||||
|
||||
var supportedSchemesMessage = fmt.Sprintf("supported schemes are %v", strings.Join(supportedSchemes, ", "))
|
||||
|
||||
// parseName parses and validates an extended name, returning the scheme, name,
|
||||
// and digest.
|
||||
//
|
||||
// If the scheme is empty, scheme will be "https". If an unsupported scheme is
|
||||
// given, [ErrNameInvalid] wrapped with a display friendly message is returned.
|
||||
//
|
||||
// If the digest is invalid, [ErrNameInvalid] wrapped with a display friendly
|
||||
// message is returned.
|
||||
//
|
||||
// If the name is not, once merged with the mask, fully qualified,
|
||||
// [ErrNameInvalid] wrapped with a display friendly message is returned.
|
||||
func parseName(s string, mask string) (scheme string, _ names.Name, _ blob.Digest, _ error) {
|
||||
scheme, name, digest := names.Split(s)
|
||||
scheme = cmp.Or(scheme, "https")
|
||||
if !slices.Contains(supportedSchemes, scheme) {
|
||||
err := withPublicMessagef(ErrNameInvalid, "unsupported scheme: %q: %s", scheme, supportedSchemesMessage)
|
||||
return "", names.Name{}, blob.Digest{}, err
|
||||
}
|
||||
|
||||
var d blob.Digest
|
||||
if digest != "" {
|
||||
var err error
|
||||
d, err = blob.ParseDigest(digest)
|
||||
if err != nil {
|
||||
err = withPublicMessagef(ErrNameInvalid, "invalid digest: %q", digest)
|
||||
return "", names.Name{}, blob.Digest{}, err
|
||||
}
|
||||
if name == "" {
|
||||
// We have can resolve a manifest from a digest only,
|
||||
// so skip name validation and return the scheme and
|
||||
// digest.
|
||||
return scheme, names.Name{}, d, nil
|
||||
}
|
||||
}
|
||||
|
||||
maskName := defaultMask
|
||||
if mask != "" {
|
||||
maskName = names.Parse(mask)
|
||||
}
|
||||
n := names.Merge(names.Parse(name), maskName)
|
||||
if !n.IsFullyQualified() {
|
||||
return "", names.Name{}, blob.Digest{}, fmt.Errorf("%w: %q", ErrNameInvalid, s)
|
||||
}
|
||||
return scheme, n, d, nil
|
||||
}
|
||||
|
||||
@@ -84,14 +84,14 @@ func newClient(t *testing.T, h http.HandlerFunc) (*Registry, *blob.DiskCache) {
|
||||
}
|
||||
}
|
||||
|
||||
rc := &Registry{
|
||||
r := &Registry{
|
||||
HTTPClient: &http.Client{
|
||||
Transport: recordRoundTripper(h),
|
||||
},
|
||||
}
|
||||
|
||||
link := func(name string, manifest string) {
|
||||
_, n, _, err := parseName(name, rc.NameMask)
|
||||
_, n, _, err := parseName(name, r.Mask)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
@@ -122,7 +122,7 @@ func newClient(t *testing.T, h http.HandlerFunc) (*Registry, *blob.DiskCache) {
|
||||
commit("sizemismatch", mklayer("exists"), &Layer{Digest: blob.DigestFromBytes("present"), Size: 499})
|
||||
link("invalid", "!!!!!")
|
||||
|
||||
return rc, c
|
||||
return r, c
|
||||
}
|
||||
|
||||
func okHandler(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -145,29 +145,6 @@ func importBytes(t *testing.T, c *blob.DiskCache, data string) blob.Digest {
|
||||
return d
|
||||
}
|
||||
|
||||
func TestRegistryPushInvalidNames(t *testing.T) {
|
||||
rc, c := newClient(t, nil)
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
err error
|
||||
}{
|
||||
{"", ErrNameInvalid},
|
||||
{"@", ErrNameInvalid},
|
||||
{"@x", blob.ErrInvalidDigest},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Create a new registry and push a new image.
|
||||
err := rc.Push(t.Context(), c, tt.name, nil)
|
||||
if !errors.Is(err, tt.err) {
|
||||
t.Errorf("err = %v; want %v", err, tt.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func withTraceUnexpected(ctx context.Context) (context.Context, *Trace) {
|
||||
t := &Trace{Update: func(*Layer, int64, error) { panic("unexpected") }}
|
||||
return WithTrace(ctx, t), t
|
||||
@@ -622,7 +599,7 @@ func TestInsecureSkipVerify(t *testing.T) {
|
||||
}))
|
||||
defer s.Close()
|
||||
|
||||
const name = "ollama.com/library/insecure"
|
||||
const name = "library/insecure"
|
||||
|
||||
var rc Registry
|
||||
url := fmt.Sprintf("https://%s/%s", s.Listener.Addr(), name)
|
||||
@@ -724,3 +701,38 @@ func TestErrorUnmarshal(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestParseNameErrors tests that parseName returns errors messages with enough
|
||||
// detail for users to debug naming issues they may encounter. Previous to this
|
||||
// test, the error messages were not very helpful and each problem was reported
|
||||
// as the same message.
|
||||
//
|
||||
// It is only for testing error messages, not that all invalids and valids are
|
||||
// covered. Those are in other tests for names.Name and blob.Digest.
|
||||
func TestParseNameErrors(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
err error
|
||||
want string
|
||||
}{
|
||||
{"x", nil, ""},
|
||||
{"x@", nil, ""},
|
||||
|
||||
{"", ErrNameInvalid, `invalid or missing name: ""`},
|
||||
{"://", ErrNameInvalid, `invalid or missing name: "://"`},
|
||||
{"x://", ErrNameInvalid, `unsupported scheme: "x": supported schemes are http, https, https+insecure`},
|
||||
|
||||
{"@sha123-1234", ErrNameInvalid, `invalid digest: "sha123-1234"`},
|
||||
{"x@sha123-1234", ErrNameInvalid, `invalid digest: "sha123-1234"`},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
_, _, _, err := parseName(tt.name, DefaultMask)
|
||||
if !errors.Is(err, tt.err) {
|
||||
t.Errorf("[%s]: err = %v; want %v", tt.name, err, tt.err)
|
||||
}
|
||||
if err != nil && !strings.Contains(err.Error(), tt.want) {
|
||||
t.Errorf("[%s]: err =\n\t%v\nwant\n\t%v", tt.name, err, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
"github.com/ollama/ollama/server/internal/internal/stringsx"
|
||||
)
|
||||
|
||||
const MaxNameLength = 50 + 1 + 50 + 1 + 50 // <namespace>/<model>:<tag>
|
||||
const MaxNameLength = 350 + 1 + 80 + 1 + 80 + 1 + 80 // <host>/<namespace>/<model>:<tag>
|
||||
|
||||
type Name struct {
|
||||
// Make incomparable to enfoce use of Compare / Equal for
|
||||
@@ -25,19 +25,12 @@ type Name struct {
|
||||
// format of a valid name string is:
|
||||
//
|
||||
// s:
|
||||
// { host } "/" { namespace } "/" { model } ":" { tag } "@" { digest }
|
||||
// { host } "/" { namespace } "/" { model } ":" { tag }
|
||||
// { host } "/" { namespace } "/" { model } "@" { digest }
|
||||
// { host } "/" { namespace } "/" { model }
|
||||
// { namespace } "/" { model } ":" { tag } "@" { digest }
|
||||
// { namespace } "/" { model } ":" { tag }
|
||||
// { namespace } "/" { model } "@" { digest }
|
||||
// { namespace } "/" { model }
|
||||
// { model } ":" { tag } "@" { digest }
|
||||
// { model } ":" { tag }
|
||||
// { model } "@" { digest }
|
||||
// { model }
|
||||
// "@" { digest }
|
||||
// host:
|
||||
// pattern: { alphanum | "_" } { alphanum | "_" | "-" | "." | ":" }*
|
||||
// length: [1, 350]
|
||||
@@ -50,9 +43,6 @@ type Name struct {
|
||||
// tag:
|
||||
// pattern: { alphanum | "_" } { alphanum | "-" | "_" | "." }*
|
||||
// length: [1, 80]
|
||||
// digest:
|
||||
// pattern: { alphanum | "_" } { alphanum | "-" | ":" }*
|
||||
// length: [1, 80]
|
||||
//
|
||||
// The name returned is not guaranteed to be valid. If it is not valid, the
|
||||
// field values are left in an undefined state. Use [Name.IsValid] to check
|
||||
@@ -82,23 +72,17 @@ func Parse(s string) Name {
|
||||
}
|
||||
}
|
||||
|
||||
// ParseExtended parses and returns any scheme, Name, and digest from from s in
|
||||
// the the form [scheme://][name][@digest]. All parts are optional.
|
||||
//
|
||||
// If the scheme is present, it must be followed by "://". The digest is
|
||||
// prefixed by "@" and comes after the name. The name is parsed using [Parse].
|
||||
//
|
||||
// The scheme and digest are stripped before the name is parsed by [Parse].
|
||||
//
|
||||
// For convience, the scheme is never empty. If the scheme is not present, the
|
||||
// returned scheme is "https".
|
||||
// Split splits an extended name string into its scheme, name, and digest
|
||||
// parts.
|
||||
//
|
||||
// Examples:
|
||||
//
|
||||
// http://ollama.com/bmizerany/smol:latest@digest
|
||||
// https://ollama.com/bmizerany/smol:latest
|
||||
// ollama.com/bmizerany/smol:latest@digest // returns "https" scheme.
|
||||
func ParseExtended(s string) (scheme string, _ Name, digest string) {
|
||||
// model@digest
|
||||
// @digest
|
||||
func Split(s string) (scheme, name, digest string) {
|
||||
i := strings.Index(s, "://")
|
||||
if i >= 0 {
|
||||
scheme = s[:i]
|
||||
@@ -109,21 +93,7 @@ func ParseExtended(s string) (scheme string, _ Name, digest string) {
|
||||
digest = s[i+1:]
|
||||
s = s[:i]
|
||||
}
|
||||
return scheme, Parse(s), digest
|
||||
}
|
||||
|
||||
func FormatExtended(scheme string, n Name, digest string) string {
|
||||
var b strings.Builder
|
||||
if scheme != "" {
|
||||
b.WriteString(scheme)
|
||||
b.WriteString("://")
|
||||
}
|
||||
b.WriteString(n.String())
|
||||
if digest != "" {
|
||||
b.WriteByte('@')
|
||||
b.WriteString(digest)
|
||||
}
|
||||
return b.String()
|
||||
return scheme, s, digest
|
||||
}
|
||||
|
||||
// Merge merges two names into a single name. Non-empty host, namespace, and
|
||||
@@ -141,39 +111,68 @@ func Merge(a, b Name) Name {
|
||||
|
||||
// IsValid returns true if the name is valid.
|
||||
func (n Name) IsValid() bool {
|
||||
if n.h != "" && !isValidHost(n.h) {
|
||||
if n.h != "" && !isValidPart(partHost, n.h) {
|
||||
return false
|
||||
}
|
||||
if n.n != "" && !isValidNamespace(n.n) {
|
||||
if n.n != "" && !isValidPart(partNamespace, n.n) {
|
||||
return false
|
||||
}
|
||||
if n.m != "" && !isValidModel(n.m) {
|
||||
if n.t != "" && !isValidPart(partTag, n.t) {
|
||||
return false
|
||||
}
|
||||
if n.t != "" && !isValidTag(n.t) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
|
||||
// at bare minimum, model must be present and valid
|
||||
return n.m != "" && isValidPart(partModel, n.m)
|
||||
}
|
||||
|
||||
func (n Name) IsFullyQualified() bool {
|
||||
return n.IsValid() && n.h != "" && n.n != "" && n.m != "" && n.t != ""
|
||||
}
|
||||
|
||||
func isValidHost(_ string) bool {
|
||||
return true // TODO: implement
|
||||
const (
|
||||
partHost = iota
|
||||
partNamespace
|
||||
partModel
|
||||
partTag
|
||||
)
|
||||
|
||||
func isValidPart(kind int, s string) bool {
|
||||
maxlen := 80
|
||||
if kind == partHost {
|
||||
maxlen = 350
|
||||
}
|
||||
if len(s) > maxlen {
|
||||
return false
|
||||
}
|
||||
|
||||
for i := range s {
|
||||
if i == 0 {
|
||||
if !isAlphanumericOrUnderscore(s[i]) {
|
||||
return false
|
||||
}
|
||||
continue
|
||||
}
|
||||
switch s[i] {
|
||||
case '_', '-':
|
||||
case '.':
|
||||
if kind == partNamespace {
|
||||
return false
|
||||
}
|
||||
case ':':
|
||||
if kind != partHost {
|
||||
return false
|
||||
}
|
||||
default:
|
||||
if !isAlphanumericOrUnderscore(s[i]) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func isValidNamespace(_ string) bool {
|
||||
return true // TODO: implement
|
||||
}
|
||||
|
||||
func isValidModel(_ string) bool {
|
||||
return true // TODO: implement
|
||||
}
|
||||
|
||||
func isValidTag(_ string) bool {
|
||||
return true // TODO: implement
|
||||
func isAlphanumericOrUnderscore(c byte) bool {
|
||||
return c >= 'A' && c <= 'Z' || c >= 'a' && c <= 'z' || c >= '0' && c <= '9' || c == '_'
|
||||
}
|
||||
|
||||
func (n Name) Host() string { return n.h }
|
||||
|
||||
@@ -81,15 +81,11 @@ func TestParseExtended(t *testing.T) {
|
||||
}
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.in, func(t *testing.T) {
|
||||
scheme, name, digest := ParseExtended(tt.in)
|
||||
if scheme != tt.wantScheme || name.Compare(tt.wantName) != 0 || digest != tt.wantDigest {
|
||||
scheme, name, digest := Split(tt.in)
|
||||
n := Parse(name)
|
||||
if scheme != tt.wantScheme || n.Compare(tt.wantName) != 0 || digest != tt.wantDigest {
|
||||
t.Errorf("ParseExtended(%q) = %q, %#v, %q, want %q, %#v, %q", tt.in, scheme, name, digest, tt.wantScheme, tt.wantName, tt.wantDigest)
|
||||
}
|
||||
|
||||
// Round trip
|
||||
if got := FormatExtended(scheme, name, digest); got != tt.in {
|
||||
t.Errorf("FormatExtended(%q, %q, %q) = %q", scheme, name, digest, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -150,3 +146,75 @@ func BenchmarkParseName(b *testing.B) {
|
||||
junkName = Parse("h/n/m:t")
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
part80 = "88888888888888888888888888888888888888888888888888888888888888888888888888888888"
|
||||
part350 = "33333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333"
|
||||
)
|
||||
|
||||
var testCases = map[string]bool{ // name -> valid
|
||||
"": false,
|
||||
|
||||
"_why/_the/_lucky:_stiff": true,
|
||||
|
||||
// minimal
|
||||
"h/n/m:t": true,
|
||||
|
||||
"host/namespace/model:tag": true,
|
||||
"host/namespace/model": true,
|
||||
"namespace/model": true,
|
||||
"model": true,
|
||||
|
||||
// long (but valid)
|
||||
part80 + "/" + part80 + "/" + part80 + ":" + part80: true,
|
||||
part350 + "/" + part80 + "/" + part80 + ":" + part80: true,
|
||||
|
||||
// too long
|
||||
part80 + "/" + part80 + "/" + part80 + ":" + part350: false,
|
||||
"x" + part350 + "/" + part80 + "/" + part80 + ":" + part80: false,
|
||||
|
||||
"h/nn/mm:t": true, // bare minimum part sizes
|
||||
|
||||
// unqualified
|
||||
"m": true,
|
||||
"n/m:": true,
|
||||
"h/n/m": true,
|
||||
"@t": false,
|
||||
"m@d": false,
|
||||
|
||||
// invalids
|
||||
"^": false,
|
||||
"mm:": true,
|
||||
"/nn/mm": true,
|
||||
"//": false, // empty model
|
||||
"//mm": true,
|
||||
"hh//": false, // empty model
|
||||
"//mm:@": false,
|
||||
"00@": false,
|
||||
"@": false,
|
||||
|
||||
// not starting with alphanum
|
||||
"-hh/nn/mm:tt": false,
|
||||
"hh/-nn/mm:tt": false,
|
||||
"hh/nn/-mm:tt": false,
|
||||
"hh/nn/mm:-tt": false,
|
||||
|
||||
// smells like a flag
|
||||
"-h": false,
|
||||
|
||||
// hosts
|
||||
"host:https/namespace/model:tag": true,
|
||||
|
||||
// colon in non-host part before tag
|
||||
"host/name:space/model:tag": false,
|
||||
}
|
||||
|
||||
func TestParseNameValidation(t *testing.T) {
|
||||
for s, valid := range testCases {
|
||||
got := Parse(s)
|
||||
if got.IsValid() != valid {
|
||||
t.Logf("got: %v", got)
|
||||
t.Errorf("Parse(%q).IsValid() = %v; want !%[2]v", s, got.IsValid())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -204,7 +204,7 @@ func (s *Local) handleDelete(_ http.ResponseWriter, r *http.Request) error {
|
||||
return err
|
||||
}
|
||||
if !ok {
|
||||
return &serverError{404, "manifest_not_found", "manifest not found"}
|
||||
return &serverError{404, "not_found", "model not found"}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -109,11 +109,8 @@ func TestServerDelete(t *testing.T) {
|
||||
got = s.send(t, "DELETE", "/api/delete", ``)
|
||||
checkErrorResponse(t, got, 400, "bad_request", "empty request body")
|
||||
|
||||
got = s.send(t, "DELETE", "/api/delete", `{"model": "!"}`)
|
||||
checkErrorResponse(t, got, 404, "manifest_not_found", "not found")
|
||||
|
||||
got = s.send(t, "DELETE", "/api/delete", `{"model": "://"}`)
|
||||
checkErrorResponse(t, got, 400, "bad_request", "invalid name")
|
||||
checkErrorResponse(t, got, 400, "bad_request", "invalid or missing name")
|
||||
|
||||
got = s.send(t, "DELETE", "/unknown_path", `{}`) // valid body
|
||||
checkErrorResponse(t, got, 404, "not_found", "not found")
|
||||
|
||||
Reference in New Issue
Block a user