mirror of
https://github.com/likelovewant/ollama-for-amd.git
synced 2025-12-21 22:33:56 +00:00
Merge branch 'ollama:main' into main
This commit is contained in:
@@ -122,9 +122,11 @@ if(CMAKE_HIP_COMPILER)
|
|||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-hip)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-hip)
|
||||||
|
|
||||||
if (WIN32)
|
if (WIN32)
|
||||||
target_compile_definitions(ggml-hip PRIVATE GGML_CUDA_NO_PEER_COPY=1)
|
target_compile_definitions(ggml-hip PRIVATE GGML_CUDA_NO_PEER_COPY)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
target_compile_definitions(ggml-hip PRIVATE GGML_HIP_NO_VMM)
|
||||||
|
|
||||||
set(OLLAMA_HIP_INSTALL_DIR ${OLLAMA_INSTALL_DIR}/rocm)
|
set(OLLAMA_HIP_INSTALL_DIR ${OLLAMA_INSTALL_DIR}/rocm)
|
||||||
install(TARGETS ggml-hip
|
install(TARGETS ggml-hip
|
||||||
RUNTIME_DEPENDENCIES
|
RUNTIME_DEPENDENCIES
|
||||||
|
|||||||
@@ -29,6 +29,17 @@ type Cache interface {
|
|||||||
// cache implementation used.
|
// cache implementation used.
|
||||||
Put(ctx ml.Context, key, value ml.Tensor)
|
Put(ctx ml.Context, key, value ml.Tensor)
|
||||||
|
|
||||||
|
// SetConfig controls optimizations (mostly backend-specific) that may transform
|
||||||
|
// the output of the cache to work better with specific kernels. If not called,
|
||||||
|
// the backend settings will be used. This works well when calling Attention.
|
||||||
|
//
|
||||||
|
// The config can be overridden by models, especially if they require vanilla
|
||||||
|
// output when implementing their own version of attention. To do this, pass
|
||||||
|
// an empty ml.CacheConfig.
|
||||||
|
//
|
||||||
|
// Most models will not need to use this.
|
||||||
|
SetConfig(ml.CacheConfig)
|
||||||
|
|
||||||
// ** cache management **
|
// ** cache management **
|
||||||
|
|
||||||
// Init sets up runtime parameters
|
// Init sets up runtime parameters
|
||||||
|
|||||||
@@ -22,6 +22,9 @@ type Causal struct {
|
|||||||
Capacity int32
|
Capacity int32
|
||||||
windowSize int32
|
windowSize int32
|
||||||
|
|
||||||
|
// config controls mostly backend-specific optimizations
|
||||||
|
config *ml.CacheConfig
|
||||||
|
|
||||||
// ** current forward pass **
|
// ** current forward pass **
|
||||||
|
|
||||||
// the active layer for Get and Put
|
// the active layer for Get and Put
|
||||||
@@ -75,14 +78,42 @@ func NewSWACache(windowSize int32, shift shiftFn) *Causal {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Causal) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
|
func (c *Causal) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
|
||||||
|
if c.config == nil {
|
||||||
|
var config ml.CacheConfig
|
||||||
|
if cc, ok := backend.(ml.BackendCacheConfig); ok {
|
||||||
|
config = cc.CacheConfig()
|
||||||
|
}
|
||||||
|
c.config = &config
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.config.CachePadding == 0 {
|
||||||
|
c.config.CachePadding = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.config.MaskBatchPadding == 0 {
|
||||||
|
c.config.MaskBatchPadding = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.config.MaskDType == ml.DTypeOther {
|
||||||
|
c.config.MaskDType = ml.DTypeF32
|
||||||
|
}
|
||||||
|
|
||||||
c.DType = dtype
|
c.DType = dtype
|
||||||
c.Capacity = capacity
|
c.Capacity = int32(roundUp(int(capacity), c.config.CachePadding))
|
||||||
c.cells = make([]cacheCell, capacity)
|
c.cells = make([]cacheCell, c.Capacity)
|
||||||
c.cellRanges = make(map[int]cellRange)
|
c.cellRanges = make(map[int]cellRange)
|
||||||
c.backend = backend
|
c.backend = backend
|
||||||
c.cacheCtx = backend.NewContext()
|
c.cacheCtx = backend.NewContext()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Causal) SetConfig(config ml.CacheConfig) {
|
||||||
|
if c.config != nil {
|
||||||
|
panic("config cannot be changed after being previously set, either by the model or backend")
|
||||||
|
}
|
||||||
|
|
||||||
|
c.config = &config
|
||||||
|
}
|
||||||
|
|
||||||
func (c *Causal) Close() {
|
func (c *Causal) Close() {
|
||||||
c.cacheCtx.Close()
|
c.cacheCtx.Close()
|
||||||
}
|
}
|
||||||
@@ -157,36 +188,91 @@ func (c *Causal) findStartLoc() (int, error) {
|
|||||||
return 0, fmt.Errorf("%w (length: %v)", ErrKvCacheFull, c.Capacity)
|
return 0, fmt.Errorf("%w (length: %v)", ErrKvCacheFull, c.Capacity)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func roundDown(length, pad int) int {
|
||||||
|
return (length / pad) * pad
|
||||||
|
}
|
||||||
|
|
||||||
|
func roundUp(length, pad int) int {
|
||||||
|
return ((length + pad - 1) / pad) * pad
|
||||||
|
}
|
||||||
|
|
||||||
// Builds a mask of history x batch indicating whether for each token in the batch the
|
// Builds a mask of history x batch indicating whether for each token in the batch the
|
||||||
// token in the history should apply. This is based on both the sequence and causality (the
|
// token in the history should apply. This is based on both the sequence and causality (the
|
||||||
// position of the history is not ahead of the token in the batch).
|
// position of the history is not ahead of the token in the batch).
|
||||||
func (c *Causal) buildMask(ctx ml.Context, positions []int32, seqs []int) (ml.Tensor, error) {
|
func (c *Causal) buildMask(ctx ml.Context, positions []int32, seqs []int) (ml.Tensor, error) {
|
||||||
// TODO(jessegross): This does not do padding, which is required for flash attention
|
// Align and pad the two dimensions as required by the backend
|
||||||
len := c.curCellRange.max - c.curCellRange.min + 1
|
batchSize := roundUp(c.curBatchSize, c.config.MaskBatchPadding)
|
||||||
mask := make([]float32, c.curBatchSize*len)
|
|
||||||
|
c.curCellRange.min = roundDown(c.curCellRange.min, c.config.CachePadding)
|
||||||
|
c.curCellRange.max = roundUp(c.curCellRange.max+1, c.config.CachePadding) - 1
|
||||||
|
|
||||||
|
length := c.curCellRange.max - c.curCellRange.min + 1
|
||||||
|
mask := make([]float32, batchSize*length)
|
||||||
|
|
||||||
for i := range c.curBatchSize {
|
for i := range c.curBatchSize {
|
||||||
for j := c.curCellRange.min; j <= c.curCellRange.max; j++ {
|
for j := c.curCellRange.min; j <= c.curCellRange.max; j++ {
|
||||||
if !slices.Contains(c.cells[j].sequences, seqs[i]) || c.cells[j].pos > positions[i] ||
|
if !slices.Contains(c.cells[j].sequences, seqs[i]) || c.cells[j].pos > positions[i] ||
|
||||||
c.cells[j].pos < positions[i]-c.windowSize {
|
c.cells[j].pos < positions[i]-c.windowSize {
|
||||||
mask[i*len+(j-c.curCellRange.min)] = float32(math.Inf(-1))
|
mask[i*length+(j-c.curCellRange.min)] = float32(math.Inf(-1))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return ctx.FromFloatSlice(mask, len, c.curBatchSize)
|
// Mask out any padding tokens we added. For padding that we added to the cache history, this
|
||||||
|
// has already been masked out because the sequence doesn't match.
|
||||||
|
for i := c.curBatchSize * length; i < len(mask); i++ {
|
||||||
|
mask[i] = float32(math.Inf(-1))
|
||||||
|
}
|
||||||
|
|
||||||
|
maskTensor, err := ctx.FromFloatSlice(mask, length, batchSize)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.config.MaskDType != ml.DTypeF32 {
|
||||||
|
out := ctx.Empty(c.config.MaskDType, maskTensor.Shape()...)
|
||||||
|
ctx.Forward(maskTensor.Copy(ctx, out))
|
||||||
|
maskTensor = out
|
||||||
|
}
|
||||||
|
|
||||||
|
return maskTensor, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func moveCell(ctx ml.Context, objs []ml.Tensor, src, dst, len int) {
|
func (c *Causal) moveCells(ctx ml.Context, src, dst, len int) {
|
||||||
for _, obj := range objs {
|
for i := range c.keys {
|
||||||
if obj == nil {
|
if c.keys[i] == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
srcView := obj.View(ctx, obj.Stride(2)*src, obj.Dim(0)*obj.Dim(1)*len)
|
key := c.keys[i]
|
||||||
dstView := obj.View(ctx, obj.Stride(2)*dst, obj.Dim(0)*obj.Dim(1)*len)
|
|
||||||
|
|
||||||
ctx.Forward(srcView.Copy(ctx, dstView))
|
kHeadDim := key.Dim(0)
|
||||||
|
numKVHeads := key.Dim(1)
|
||||||
|
rowSize := key.Stride(2)
|
||||||
|
|
||||||
|
kSrcView := key.View(ctx, rowSize*src, kHeadDim*numKVHeads*len)
|
||||||
|
kDstView := key.View(ctx, rowSize*dst, kHeadDim*numKVHeads*len)
|
||||||
|
|
||||||
|
value := c.values[i]
|
||||||
|
var vSrcView, vDstView ml.Tensor
|
||||||
|
if c.config.PermutedV {
|
||||||
|
vHeadDim := value.Dim(1)
|
||||||
|
elemSize := value.Stride(0)
|
||||||
|
|
||||||
|
vSrcView = value.View(ctx, elemSize*src, len, int(c.Capacity)*elemSize, vHeadDim*numKVHeads)
|
||||||
|
vDstView = value.View(ctx, elemSize*dst, len, int(c.Capacity)*elemSize, vHeadDim*numKVHeads)
|
||||||
|
} else {
|
||||||
|
vHeadDim := value.Dim(0)
|
||||||
|
rowSize := value.Stride(2)
|
||||||
|
|
||||||
|
vSrcView = value.View(ctx, rowSize*src, vHeadDim*numKVHeads*len)
|
||||||
|
vDstView = value.View(ctx, rowSize*dst, vHeadDim*numKVHeads*len)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx.Forward(
|
||||||
|
kSrcView.Copy(ctx, kDstView),
|
||||||
|
vSrcView.Copy(ctx, vDstView),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -238,8 +324,7 @@ func (c *Causal) defrag() {
|
|||||||
pendingLen++
|
pendingLen++
|
||||||
break
|
break
|
||||||
} else {
|
} else {
|
||||||
moveCell(ctx, c.keys, pendingSrc, pendingDst, pendingLen)
|
c.moveCells(ctx, pendingSrc, pendingDst, pendingLen)
|
||||||
moveCell(ctx, c.values, pendingSrc, pendingDst, pendingLen)
|
|
||||||
moves++
|
moves++
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -263,8 +348,7 @@ func (c *Causal) defrag() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if pendingLen > 0 {
|
if pendingLen > 0 {
|
||||||
moveCell(ctx, c.keys, pendingSrc, pendingDst, pendingLen)
|
c.moveCells(ctx, pendingSrc, pendingDst, pendingLen)
|
||||||
moveCell(ctx, c.values, pendingSrc, pendingDst, pendingLen)
|
|
||||||
moves++
|
moves++
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -305,35 +389,73 @@ func (c *Causal) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
|
|||||||
key := c.keys[c.curLayer]
|
key := c.keys[c.curLayer]
|
||||||
value := c.values[c.curLayer]
|
value := c.values[c.curLayer]
|
||||||
|
|
||||||
key = key.View(ctx, key.Stride(2)*c.curCellRange.min,
|
kHeadDim := key.Dim(0)
|
||||||
key.Dim(0), key.Stride(1),
|
numKVHeads := key.Dim(1)
|
||||||
key.Dim(1), key.Stride(2),
|
rowSize := key.Stride(2)
|
||||||
c.curMask.Dim(0),
|
cachedSize := c.curMask.Dim(0)
|
||||||
|
|
||||||
|
key = key.View(ctx, rowSize*c.curCellRange.min,
|
||||||
|
kHeadDim, key.Stride(1),
|
||||||
|
numKVHeads, key.Stride(2),
|
||||||
|
cachedSize,
|
||||||
)
|
)
|
||||||
|
|
||||||
value = value.View(ctx, key.Stride(2)*c.curCellRange.min,
|
if c.config.PermutedV {
|
||||||
value.Dim(0), value.Stride(1),
|
vHeadDim := value.Dim(1)
|
||||||
value.Dim(1), value.Stride(2),
|
elemSize := value.Stride(0)
|
||||||
c.curMask.Dim(0),
|
|
||||||
)
|
value = value.View(ctx, elemSize*c.curCellRange.min,
|
||||||
|
cachedSize, value.Stride(1),
|
||||||
|
vHeadDim, value.Stride(2),
|
||||||
|
numKVHeads,
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
vHeadDim := value.Dim(0)
|
||||||
|
rowSize := value.Stride(2)
|
||||||
|
|
||||||
|
value = value.View(ctx, rowSize*c.curCellRange.min,
|
||||||
|
vHeadDim, value.Stride(1),
|
||||||
|
numKVHeads, value.Stride(2),
|
||||||
|
cachedSize,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
return key, value, c.curMask
|
return key, value, c.curMask
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
|
func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
|
||||||
if c.curBatchSize != key.Dim(2) {
|
kHeadDim := key.Dim(0)
|
||||||
panic(fmt.Errorf("inconsistent batch sizes (layer: %v, batch size: %v layer batch size: %v)", c.curLayer, c.curBatchSize, key.Dim(2)))
|
vHeadDim := value.Dim(0)
|
||||||
|
numKVHeads := key.Dim(1)
|
||||||
|
batchSize := key.Dim(2)
|
||||||
|
|
||||||
|
if c.curBatchSize != batchSize {
|
||||||
|
panic(fmt.Errorf("inconsistent batch sizes (layer: %v, batch size: %v layer batch size: %v)", c.curLayer, c.curBatchSize, batchSize))
|
||||||
}
|
}
|
||||||
|
|
||||||
if c.keys[c.curLayer] == nil || c.values[c.curLayer] == nil {
|
if c.keys[c.curLayer] == nil || c.values[c.curLayer] == nil {
|
||||||
c.keys[c.curLayer] = c.cacheCtx.Zeros(c.DType, key.Dim(0), key.Dim(1), int(c.Capacity))
|
c.keys[c.curLayer] = c.cacheCtx.Zeros(c.DType, kHeadDim, numKVHeads, int(c.Capacity))
|
||||||
c.values[c.curLayer] = c.cacheCtx.Zeros(c.DType, value.Dim(0), value.Dim(1), int(c.Capacity))
|
|
||||||
|
if c.config.PermutedV {
|
||||||
|
c.values[c.curLayer] = c.cacheCtx.Zeros(c.DType, int(c.Capacity), vHeadDim, numKVHeads)
|
||||||
|
} else {
|
||||||
|
c.values[c.curLayer] = c.cacheCtx.Zeros(c.DType, vHeadDim, numKVHeads, int(c.Capacity))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx.Forward(
|
rowSize := c.keys[c.curLayer].Stride(2)
|
||||||
key.Copy(ctx, c.keys[c.curLayer].View(ctx, c.keys[c.curLayer].Stride(2)*c.curLoc, key.Dim(0)*key.Dim(1)*key.Dim(2))),
|
ctx.Forward(key.Copy(ctx, c.keys[c.curLayer].View(ctx, rowSize*c.curLoc, kHeadDim*numKVHeads*batchSize)))
|
||||||
value.Copy(ctx, c.values[c.curLayer].View(ctx, c.values[c.curLayer].Stride(2)*c.curLoc, value.Dim(0)*value.Dim(1)*value.Dim(2))),
|
|
||||||
)
|
if c.config.PermutedV {
|
||||||
|
elemSize := c.values[c.curLayer].Stride(0)
|
||||||
|
|
||||||
|
value = value.Permute(ctx, 1, 2, 0, 3)
|
||||||
|
ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, elemSize*c.curLoc, batchSize, int(c.Capacity)*elemSize, vHeadDim*numKVHeads)))
|
||||||
|
} else {
|
||||||
|
rowSize := c.values[c.curLayer].Stride(2)
|
||||||
|
|
||||||
|
ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, rowSize*c.curLoc, vHeadDim*numKVHeads*batchSize)))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Causal) CopyPrefix(srcSeq, dstSeq int, len int32) {
|
func (c *Causal) CopyPrefix(srcSeq, dstSeq int, len int32) {
|
||||||
@@ -389,9 +511,13 @@ func (c *Causal) shift(seq int, beginIndex, offset int32) error {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
key = key.View(ctx, key.Stride(2)*seqRange.min,
|
kHeadDim := key.Dim(0)
|
||||||
key.Dim(0), key.Stride(1),
|
numKVHeads := key.Dim(1)
|
||||||
key.Dim(1), key.Stride(2),
|
rowSize := key.Stride(2)
|
||||||
|
|
||||||
|
key = key.View(ctx, rowSize*seqRange.min,
|
||||||
|
kHeadDim, key.Stride(1),
|
||||||
|
numKVHeads, key.Stride(2),
|
||||||
size,
|
size,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -309,7 +309,7 @@ func (b *testBackend) SystemInfo() string {
|
|||||||
|
|
||||||
type testContext struct{}
|
type testContext struct{}
|
||||||
|
|
||||||
func (c *testContext) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
|
func (c *testContext) Empty(dtype ml.DType, shape ...int) ml.Tensor {
|
||||||
total := 0
|
total := 0
|
||||||
|
|
||||||
if len(shape) > 0 {
|
if len(shape) > 0 {
|
||||||
@@ -322,8 +322,12 @@ func (c *testContext) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
|
|||||||
return &testTensor{dtype: dtype, elementSize: 4, data: make([]float32, total), shape: shape}
|
return &testTensor{dtype: dtype, elementSize: 4, data: make([]float32, total), shape: shape}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *testContext) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
|
||||||
|
return c.Empty(dtype, shape...)
|
||||||
|
}
|
||||||
|
|
||||||
func (c *testContext) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) {
|
func (c *testContext) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) {
|
||||||
t := c.Zeros(ml.DTypeF32, shape...).(*testTensor)
|
t := c.Empty(ml.DTypeF32, shape...).(*testTensor)
|
||||||
|
|
||||||
copy(t.data, s)
|
copy(t.data, s)
|
||||||
|
|
||||||
@@ -391,7 +395,7 @@ func (t *testTensor) Floats() []float32 {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (t *testTensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
func (t *testTensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||||
out := ctx.Zeros(t.DType(), t.Shape()...).(*testTensor)
|
out := ctx.Empty(t.DType(), t.Shape()...).(*testTensor)
|
||||||
|
|
||||||
for i := range out.data {
|
for i := range out.data {
|
||||||
out.data[i] = t.data[i] + t2.(*testTensor).data[i]
|
out.data[i] = t.data[i] + t2.(*testTensor).data[i]
|
||||||
@@ -468,7 +472,7 @@ func (t *testTensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
|
|||||||
|
|
||||||
context := &testContext{}
|
context := &testContext{}
|
||||||
|
|
||||||
view := context.Zeros(t.dtype, s...).(*testTensor)
|
view := context.Empty(t.dtype, s...).(*testTensor)
|
||||||
view.data = t.data[offset : offset+len(view.data)]
|
view.data = t.data[offset : offset+len(view.data)]
|
||||||
|
|
||||||
return view
|
return view
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
package kvcache
|
package kvcache
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
"github.com/ollama/ollama/ml"
|
"github.com/ollama/ollama/ml"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -11,6 +13,9 @@ import (
|
|||||||
//
|
//
|
||||||
// Not currently safe for multiple sequences
|
// Not currently safe for multiple sequences
|
||||||
type EncoderCache struct {
|
type EncoderCache struct {
|
||||||
|
// config controls mostly backend-specific optimizations
|
||||||
|
config *ml.CacheConfig
|
||||||
|
|
||||||
// ** current forward pass **
|
// ** current forward pass **
|
||||||
|
|
||||||
// the active layer for Get and Put
|
// the active layer for Get and Put
|
||||||
@@ -40,9 +45,29 @@ func NewEncoderCache() *EncoderCache {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
|
func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
|
||||||
|
if c.config == nil {
|
||||||
|
var config ml.CacheConfig
|
||||||
|
if cc, ok := backend.(ml.BackendCacheConfig); ok {
|
||||||
|
config = cc.CacheConfig()
|
||||||
|
}
|
||||||
|
c.config = &config
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.config.CachePadding != 0 && c.config.CachePadding != 1 {
|
||||||
|
panic(fmt.Errorf("encoder cache is unable to enforce requested CachePadding (%v)", c.config.CachePadding))
|
||||||
|
}
|
||||||
|
|
||||||
c.cacheCtx = backend.NewContext()
|
c.cacheCtx = backend.NewContext()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *EncoderCache) SetConfig(config ml.CacheConfig) {
|
||||||
|
if c.config != nil {
|
||||||
|
panic("config cannot be changed after being previously set, either by the model or backend")
|
||||||
|
}
|
||||||
|
|
||||||
|
c.config = &config
|
||||||
|
}
|
||||||
|
|
||||||
func (c *EncoderCache) Close() {
|
func (c *EncoderCache) Close() {
|
||||||
c.cacheCtx.Close()
|
c.cacheCtx.Close()
|
||||||
}
|
}
|
||||||
@@ -75,9 +100,13 @@ func (c *EncoderCache) Put(ctx ml.Context, key, value ml.Tensor) {
|
|||||||
c.encoderPos = c.curPos
|
c.encoderPos = c.curPos
|
||||||
c.encoderCached = true
|
c.encoderCached = true
|
||||||
|
|
||||||
|
if c.config.PermutedV {
|
||||||
|
value = value.Permute(ctx, 1, 2, 0, 3)
|
||||||
|
}
|
||||||
|
|
||||||
if c.keys[c.curLayer] == nil || c.values[c.curLayer] == nil {
|
if c.keys[c.curLayer] == nil || c.values[c.curLayer] == nil {
|
||||||
c.keys[c.curLayer] = c.cacheCtx.Zeros(key.DType(), key.Shape()...)
|
c.keys[c.curLayer] = c.cacheCtx.Empty(key.DType(), key.Shape()...)
|
||||||
c.values[c.curLayer] = c.cacheCtx.Zeros(value.DType(), value.Shape()...)
|
c.values[c.curLayer] = c.cacheCtx.Empty(value.DType(), value.Shape()...)
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx.Forward(
|
ctx.Forward(
|
||||||
|
|||||||
@@ -28,6 +28,12 @@ func (c *WrapperCache) Init(backend ml.Backend, dtype ml.DType, capacity int32)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *WrapperCache) SetConfig(config ml.CacheConfig) {
|
||||||
|
for _, cache := range c.caches {
|
||||||
|
cache.SetConfig(config)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (c *WrapperCache) Close() {
|
func (c *WrapperCache) Close() {
|
||||||
for _, cache := range c.caches {
|
for _, cache := range c.caches {
|
||||||
cache.Close()
|
cache.Close()
|
||||||
|
|||||||
@@ -27,6 +27,35 @@ type Backend interface {
|
|||||||
SystemInfo() string
|
SystemInfo() string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// BackendCacheConfig should be implemented by backends that need special output
|
||||||
|
// from the cache to meet specific requirements. It is frequently implemented in
|
||||||
|
// conjunction with ScaledDotProductAttention.
|
||||||
|
type BackendCacheConfig interface {
|
||||||
|
CacheConfig() CacheConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
// CacheConfig controls optimizations (mostly backend-specific) that may transform
|
||||||
|
// the output the cache to work better with specific kernels.
|
||||||
|
type CacheConfig struct {
|
||||||
|
// CachePadding specifies the multiple for the number of tokens of cache history
|
||||||
|
// that will be returned from cache Get for k, v and mask. The capacity of the
|
||||||
|
// cache itself will also be increased to a multiple of this size if needed.
|
||||||
|
CachePadding int
|
||||||
|
|
||||||
|
// PermutedV performs Permute(ctx, 1, 2, 0, 3) on v tensors stored via Put
|
||||||
|
// and return the permuted version via Get. This uses the cache copy operation
|
||||||
|
// to avoid a Contiguous call on the permuted tensor.
|
||||||
|
PermutedV bool
|
||||||
|
|
||||||
|
// MaskDType specifies the data type for generating the mask. If unset it will
|
||||||
|
// default to DTypeF32.
|
||||||
|
MaskDType DType
|
||||||
|
|
||||||
|
// MaskBatchPadding specifies the multiple for the batch size dimension in the mask.
|
||||||
|
// Any position that does not correspond to an actual token will be filled with -Inf.
|
||||||
|
MaskBatchPadding int
|
||||||
|
}
|
||||||
|
|
||||||
// BackendParams controls how the backend loads and executes models
|
// BackendParams controls how the backend loads and executes models
|
||||||
type BackendParams struct {
|
type BackendParams struct {
|
||||||
// NumThreads sets the number of threads to use if running on the CPU
|
// NumThreads sets the number of threads to use if running on the CPU
|
||||||
@@ -40,6 +69,9 @@ type BackendParams struct {
|
|||||||
|
|
||||||
// TensorSplit is the fraction of the model to offload to each GPU
|
// TensorSplit is the fraction of the model to offload to each GPU
|
||||||
TensorSplit []float32
|
TensorSplit []float32
|
||||||
|
|
||||||
|
// FlashAttention indicates that we should use a fused flash attention kernel
|
||||||
|
FlashAttention bool
|
||||||
}
|
}
|
||||||
|
|
||||||
var backends = make(map[string]func(*os.File, BackendParams) (Backend, error))
|
var backends = make(map[string]func(*os.File, BackendParams) (Backend, error))
|
||||||
@@ -61,6 +93,7 @@ func NewBackend(f *os.File, params BackendParams) (Backend, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type Context interface {
|
type Context interface {
|
||||||
|
Empty(dtype DType, shape ...int) Tensor
|
||||||
Zeros(dtype DType, shape ...int) Tensor
|
Zeros(dtype DType, shape ...int) Tensor
|
||||||
FromFloatSlice(s []float32, shape ...int) (Tensor, error)
|
FromFloatSlice(s []float32, shape ...int) (Tensor, error)
|
||||||
FromIntSlice(s []int32, shape ...int) (Tensor, error)
|
FromIntSlice(s []int32, shape ...int) (Tensor, error)
|
||||||
@@ -116,6 +149,10 @@ type Tensor interface {
|
|||||||
// operation equivalent to following code on a tensor named
|
// operation equivalent to following code on a tensor named
|
||||||
// query:
|
// query:
|
||||||
//
|
//
|
||||||
|
// query = query.Permute(ctx, 0, 2, 1, 3)
|
||||||
|
// key = key.Permute(ctx, 0, 2, 1, 3)
|
||||||
|
// value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||||
|
//
|
||||||
// kq := key.MulmatFullPrec(ctx, query)
|
// kq := key.MulmatFullPrec(ctx, query)
|
||||||
//
|
//
|
||||||
// kq = kq.Scale(ctx, scale)
|
// kq = kq.Scale(ctx, scale)
|
||||||
@@ -170,7 +207,7 @@ func Dump(ctx Context, t Tensor, opts ...DumpOptions) string {
|
|||||||
return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32)
|
return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32)
|
||||||
})
|
})
|
||||||
case DTypeF16:
|
case DTypeF16:
|
||||||
f32 := ctx.Zeros(DTypeF32, t.Shape()...)
|
f32 := ctx.Empty(DTypeF32, t.Shape()...)
|
||||||
f32 = t.Copy(ctx, f32)
|
f32 = t.Copy(ctx, f32)
|
||||||
return dump[[]float32](ctx, f32, opts[0].Items, func(f float32) string {
|
return dump[[]float32](ctx, f32, opts[0].Items, func(f float32) string {
|
||||||
return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32)
|
return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32)
|
||||||
|
|||||||
@@ -79,6 +79,8 @@ var devices = sync.OnceValue(func() []device {
|
|||||||
})
|
})
|
||||||
|
|
||||||
type Backend struct {
|
type Backend struct {
|
||||||
|
flashAttention bool
|
||||||
|
|
||||||
meta *fs.GGML
|
meta *fs.GGML
|
||||||
cpus, gpus []Context
|
cpus, gpus []Context
|
||||||
tensors map[string]*Context
|
tensors map[string]*Context
|
||||||
@@ -192,9 +194,10 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
return &Backend{
|
return &Backend{
|
||||||
meta: meta,
|
flashAttention: params.FlashAttention,
|
||||||
cpus: cpus,
|
meta: meta,
|
||||||
gpus: gpus,
|
cpus: cpus,
|
||||||
|
gpus: gpus,
|
||||||
sched: C.ggml_backend_sched_new(
|
sched: C.ggml_backend_sched_new(
|
||||||
(*C.ggml_backend_t)(unsafe.Pointer(&backends[0])),
|
(*C.ggml_backend_t)(unsafe.Pointer(&backends[0])),
|
||||||
(*C.ggml_backend_buffer_type_t)(unsafe.Pointer(&bufts[0])),
|
(*C.ggml_backend_buffer_type_t)(unsafe.Pointer(&bufts[0])),
|
||||||
@@ -219,7 +222,7 @@ func (b *Backend) Get(name string) ml.Tensor {
|
|||||||
|
|
||||||
for _, c := range append(b.gpus, b.cpus...) {
|
for _, c := range append(b.gpus, b.cpus...) {
|
||||||
if t := C.ggml_get_tensor(c.ctx, cname); t != nil {
|
if t := C.ggml_get_tensor(c.ctx, cname); t != nil {
|
||||||
return &Tensor{t: t}
|
return &Tensor{b: b, t: t}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -247,6 +250,14 @@ func (b *Backend) NewContext() ml.Context {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (b *Backend) CacheConfig() ml.CacheConfig {
|
||||||
|
if b.flashAttention {
|
||||||
|
return ml.CacheConfig{CachePadding: 256, MaskDType: ml.DTypeF16, MaskBatchPadding: C.GGML_KQ_MASK_PAD}
|
||||||
|
} else {
|
||||||
|
return ml.CacheConfig{CachePadding: 32, PermutedV: true}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type Context struct {
|
type Context struct {
|
||||||
b *Backend
|
b *Backend
|
||||||
ctx *C.struct_ggml_context
|
ctx *C.struct_ggml_context
|
||||||
@@ -300,7 +311,7 @@ func shapeToGGML(shape []int) *C.int64_t {
|
|||||||
return &sh[0]
|
return &sh[0]
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
|
func newTensor(ctx Context, dtype ml.DType, zero bool, shape []int) ml.Tensor {
|
||||||
if len(shape) < 1 || len(shape) > 4 {
|
if len(shape) < 1 || len(shape) > 4 {
|
||||||
panic("unsupported number of dimensions")
|
panic("unsupported number of dimensions")
|
||||||
}
|
}
|
||||||
@@ -314,19 +325,29 @@ func (c Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
|
|||||||
var t *C.struct_ggml_tensor
|
var t *C.struct_ggml_tensor
|
||||||
switch dtype {
|
switch dtype {
|
||||||
case ml.DTypeF32:
|
case ml.DTypeF32:
|
||||||
t = C.ggml_new_tensor(c.ctx, C.GGML_TYPE_F32, C.int(len(shape)), shapeToGGML(shape))
|
t = C.ggml_new_tensor(ctx.ctx, C.GGML_TYPE_F32, C.int(len(shape)), shapeToGGML(shape))
|
||||||
case ml.DTypeF16:
|
case ml.DTypeF16:
|
||||||
t = C.ggml_new_tensor(c.ctx, C.GGML_TYPE_F16, C.int(len(shape)), shapeToGGML(shape))
|
t = C.ggml_new_tensor(ctx.ctx, C.GGML_TYPE_F16, C.int(len(shape)), shapeToGGML(shape))
|
||||||
case ml.DTypeI32:
|
case ml.DTypeI32:
|
||||||
t = C.ggml_new_tensor(c.ctx, C.GGML_TYPE_I32, C.int(len(shape)), shapeToGGML(shape))
|
t = C.ggml_new_tensor(ctx.ctx, C.GGML_TYPE_I32, C.int(len(shape)), shapeToGGML(shape))
|
||||||
default:
|
default:
|
||||||
panic("unsupported dtype")
|
panic("unsupported dtype")
|
||||||
}
|
}
|
||||||
|
|
||||||
b := C.ggml_backend_alloc_buffer(c.backend, C.ggml_nbytes(t))
|
b := C.ggml_backend_alloc_buffer(ctx.backend, C.ggml_nbytes(t))
|
||||||
C.ggml_backend_tensor_alloc(b, t, C.ggml_backend_buffer_get_base(b))
|
C.ggml_backend_tensor_alloc(b, t, C.ggml_backend_buffer_get_base(b))
|
||||||
C.ggml_set_zero(t)
|
if zero {
|
||||||
return &Tensor{t: t}
|
C.ggml_set_zero(t)
|
||||||
|
}
|
||||||
|
return &Tensor{b: ctx.b, t: t}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c Context) Empty(dtype ml.DType, shape ...int) ml.Tensor {
|
||||||
|
return newTensor(c, dtype, false, shape)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
|
||||||
|
return newTensor(c, dtype, true, shape)
|
||||||
}
|
}
|
||||||
|
|
||||||
func fromSlice[S ~[]E, E float32 | int32](ctx Context, s S, shape []int, dtype uint32) (ml.Tensor, error) {
|
func fromSlice[S ~[]E, E float32 | int32](ctx Context, s S, shape []int, dtype uint32) (ml.Tensor, error) {
|
||||||
@@ -335,7 +356,7 @@ func fromSlice[S ~[]E, E float32 | int32](ctx Context, s S, shape []int, dtype u
|
|||||||
if n == 0 {
|
if n == 0 {
|
||||||
var shape C.int64_t = 0
|
var shape C.int64_t = 0
|
||||||
t := C.ggml_new_tensor(ctx.ctx, dtype, 1, &shape)
|
t := C.ggml_new_tensor(ctx.ctx, dtype, 1, &shape)
|
||||||
return &Tensor{t: t}, nil
|
return &Tensor{b: ctx.b, t: t}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, v := range shape {
|
for _, v := range shape {
|
||||||
@@ -350,7 +371,7 @@ func fromSlice[S ~[]E, E float32 | int32](ctx Context, s S, shape []int, dtype u
|
|||||||
b := C.ggml_backend_alloc_buffer(ctx.backend, C.ggml_nbytes(t))
|
b := C.ggml_backend_alloc_buffer(ctx.backend, C.ggml_nbytes(t))
|
||||||
C.ggml_backend_tensor_alloc(b, t, C.ggml_backend_buffer_get_base(b))
|
C.ggml_backend_tensor_alloc(b, t, C.ggml_backend_buffer_get_base(b))
|
||||||
C.ggml_backend_tensor_set(t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t))
|
C.ggml_backend_tensor_set(t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t))
|
||||||
return &Tensor{t: t}, nil
|
return &Tensor{b: ctx.b, t: t}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c Context) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) {
|
func (c Context) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) {
|
||||||
@@ -368,6 +389,7 @@ func (c *Context) Close() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type Tensor struct {
|
type Tensor struct {
|
||||||
|
b *Backend
|
||||||
t *C.struct_ggml_tensor
|
t *C.struct_ggml_tensor
|
||||||
sync func()
|
sync func()
|
||||||
}
|
}
|
||||||
@@ -434,6 +456,7 @@ func (t *Tensor) DType() ml.DType {
|
|||||||
|
|
||||||
func (t *Tensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
func (t *Tensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||||
return &Tensor{
|
return &Tensor{
|
||||||
|
b: t.b,
|
||||||
t: C.ggml_add(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
|
t: C.ggml_add(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -448,24 +471,28 @@ func (t *Tensor) Stack(ctx ml.Context, dim int, s ...ml.Tensor) ml.Tensor {
|
|||||||
|
|
||||||
func (t *Tensor) Concat(ctx ml.Context, t2 ml.Tensor, dim int) ml.Tensor {
|
func (t *Tensor) Concat(ctx ml.Context, t2 ml.Tensor, dim int) ml.Tensor {
|
||||||
return &Tensor{
|
return &Tensor{
|
||||||
|
b: t.b,
|
||||||
t: C.ggml_concat(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.int(dim)),
|
t: C.ggml_concat(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.int(dim)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *Tensor) Contiguous(ctx ml.Context) ml.Tensor {
|
func (t *Tensor) Contiguous(ctx ml.Context) ml.Tensor {
|
||||||
return &Tensor{
|
return &Tensor{
|
||||||
|
b: t.b,
|
||||||
t: C.ggml_cont(ctx.(*Context).ctx, t.t),
|
t: C.ggml_cont(ctx.(*Context).ctx, t.t),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *Tensor) Mul(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
func (t *Tensor) Mul(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||||
return &Tensor{
|
return &Tensor{
|
||||||
|
b: t.b,
|
||||||
t: C.ggml_mul(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
|
t: C.ggml_mul(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *Tensor) Mulmat(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
func (t *Tensor) Mulmat(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||||
return &Tensor{
|
return &Tensor{
|
||||||
|
b: t.b,
|
||||||
t: C.ggml_mul_mat(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
|
t: C.ggml_mul_mat(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -475,12 +502,13 @@ func (t *Tensor) MulmatFullPrec(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
|||||||
C.ggml_mul_mat_set_prec(mul, C.GGML_PREC_F32)
|
C.ggml_mul_mat_set_prec(mul, C.GGML_PREC_F32)
|
||||||
|
|
||||||
return &Tensor{
|
return &Tensor{
|
||||||
|
b: t.b,
|
||||||
t: mul,
|
t: mul,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *Tensor) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tensor {
|
func (t *Tensor) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tensor {
|
||||||
tt := (&Tensor{t: C.ggml_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w)
|
tt := (&Tensor{b: t.b, t: C.ggml_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w)
|
||||||
if b != nil {
|
if b != nil {
|
||||||
tt = tt.Add(ctx, b)
|
tt = tt.Add(ctx, b)
|
||||||
}
|
}
|
||||||
@@ -489,7 +517,7 @@ func (t *Tensor) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tenso
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (t *Tensor) RMSNorm(ctx ml.Context, w ml.Tensor, eps float32) ml.Tensor {
|
func (t *Tensor) RMSNorm(ctx ml.Context, w ml.Tensor, eps float32) ml.Tensor {
|
||||||
return (&Tensor{t: C.ggml_rms_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w)
|
return (&Tensor{b: t.b, t: C.ggml_rms_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *Tensor) Pad(ctx ml.Context, shape ...int) ml.Tensor {
|
func (t *Tensor) Pad(ctx ml.Context, shape ...int) ml.Tensor {
|
||||||
@@ -498,6 +526,7 @@ func (t *Tensor) Pad(ctx ml.Context, shape ...int) ml.Tensor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
return &Tensor{
|
return &Tensor{
|
||||||
|
b: t.b,
|
||||||
t: C.ggml_pad(ctx.(*Context).ctx, t.t, C.int(shape[0]), C.int(shape[1]), C.int(shape[2]), C.int(shape[3])),
|
t: C.ggml_pad(ctx.(*Context).ctx, t.t, C.int(shape[0]), C.int(shape[1]), C.int(shape[2]), C.int(shape[3])),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -508,18 +537,21 @@ func (t *Tensor) Permute(ctx ml.Context, shape ...int) ml.Tensor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
return &Tensor{
|
return &Tensor{
|
||||||
|
b: t.b,
|
||||||
t: C.ggml_permute(ctx.(*Context).ctx, t.t, C.int(shape[0]), C.int(shape[1]), C.int(shape[2]), C.int(shape[3])),
|
t: C.ggml_permute(ctx.(*Context).ctx, t.t, C.int(shape[0]), C.int(shape[1]), C.int(shape[2]), C.int(shape[3])),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *Tensor) Rows(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
func (t *Tensor) Rows(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||||
return &Tensor{
|
return &Tensor{
|
||||||
|
b: t.b,
|
||||||
t: C.ggml_get_rows(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
|
t: C.ggml_get_rows(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *Tensor) Copy(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
func (t *Tensor) Copy(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||||
return &Tensor{
|
return &Tensor{
|
||||||
|
b: t.b,
|
||||||
t: C.ggml_cpy(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
|
t: C.ggml_cpy(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -528,18 +560,22 @@ func (t *Tensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor {
|
|||||||
switch len(shape) {
|
switch len(shape) {
|
||||||
case 1:
|
case 1:
|
||||||
return &Tensor{
|
return &Tensor{
|
||||||
|
b: t.b,
|
||||||
t: C.ggml_reshape_1d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0])),
|
t: C.ggml_reshape_1d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0])),
|
||||||
}
|
}
|
||||||
case 2:
|
case 2:
|
||||||
return &Tensor{
|
return &Tensor{
|
||||||
|
b: t.b,
|
||||||
t: C.ggml_reshape_2d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1])),
|
t: C.ggml_reshape_2d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1])),
|
||||||
}
|
}
|
||||||
case 3:
|
case 3:
|
||||||
return &Tensor{
|
return &Tensor{
|
||||||
|
b: t.b,
|
||||||
t: C.ggml_reshape_3d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1]), C.int64_t(shape[2])),
|
t: C.ggml_reshape_3d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1]), C.int64_t(shape[2])),
|
||||||
}
|
}
|
||||||
case 4:
|
case 4:
|
||||||
return &Tensor{
|
return &Tensor{
|
||||||
|
b: t.b,
|
||||||
t: C.ggml_reshape_4d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1]), C.int64_t(shape[2]), C.int64_t(shape[3])),
|
t: C.ggml_reshape_4d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1]), C.int64_t(shape[2]), C.int64_t(shape[3])),
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
@@ -549,18 +585,21 @@ func (t *Tensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor {
|
|||||||
|
|
||||||
func (t *Tensor) Scale(ctx ml.Context, s float64) ml.Tensor {
|
func (t *Tensor) Scale(ctx ml.Context, s float64) ml.Tensor {
|
||||||
return &Tensor{
|
return &Tensor{
|
||||||
|
b: t.b,
|
||||||
t: C.ggml_scale(ctx.(*Context).ctx, t.t, (C.float)(s)),
|
t: C.ggml_scale(ctx.(*Context).ctx, t.t, (C.float)(s)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *Tensor) Softmax(ctx ml.Context) ml.Tensor {
|
func (t *Tensor) Softmax(ctx ml.Context) ml.Tensor {
|
||||||
return &Tensor{
|
return &Tensor{
|
||||||
|
b: t.b,
|
||||||
t: C.ggml_soft_max(ctx.(*Context).ctx, t.t),
|
t: C.ggml_soft_max(ctx.(*Context).ctx, t.t),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *Tensor) Tanh(ctx ml.Context) ml.Tensor {
|
func (t *Tensor) Tanh(ctx ml.Context) ml.Tensor {
|
||||||
return &Tensor{
|
return &Tensor{
|
||||||
|
b: t.b,
|
||||||
t: C.ggml_tanh_inplace(ctx.(*Context).ctx, t.t),
|
t: C.ggml_tanh_inplace(ctx.(*Context).ctx, t.t),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -571,6 +610,7 @@ func (t *Tensor) Unpad(ctx ml.Context, shape ...int) ml.Tensor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
return &Tensor{
|
return &Tensor{
|
||||||
|
b: t.b,
|
||||||
t: C.ggml_unpad(ctx.(*Context).ctx, t.t, C.int(shape[0]), C.int(shape[1]), C.int(shape[2]), C.int(shape[3])),
|
t: C.ggml_unpad(ctx.(*Context).ctx, t.t, C.int(shape[0]), C.int(shape[1]), C.int(shape[2]), C.int(shape[3])),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -579,10 +619,12 @@ func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
|
|||||||
switch len(shape) {
|
switch len(shape) {
|
||||||
case 1:
|
case 1:
|
||||||
return &Tensor{
|
return &Tensor{
|
||||||
|
b: t.b,
|
||||||
t: C.ggml_view_1d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.size_t(offset)),
|
t: C.ggml_view_1d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.size_t(offset)),
|
||||||
}
|
}
|
||||||
case 3:
|
case 3:
|
||||||
return &Tensor{
|
return &Tensor{
|
||||||
|
b: t.b,
|
||||||
t: C.ggml_view_2d(ctx.(*Context).ctx, t.t,
|
t: C.ggml_view_2d(ctx.(*Context).ctx, t.t,
|
||||||
C.int64_t(shape[0]), C.int64_t(shape[2]),
|
C.int64_t(shape[0]), C.int64_t(shape[2]),
|
||||||
C.size_t(shape[1]),
|
C.size_t(shape[1]),
|
||||||
@@ -590,6 +632,7 @@ func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
|
|||||||
}
|
}
|
||||||
case 5:
|
case 5:
|
||||||
return &Tensor{
|
return &Tensor{
|
||||||
|
b: t.b,
|
||||||
t: C.ggml_view_3d(ctx.(*Context).ctx, t.t,
|
t: C.ggml_view_3d(ctx.(*Context).ctx, t.t,
|
||||||
C.int64_t(shape[0]), C.int64_t(shape[2]), C.int64_t(shape[4]),
|
C.int64_t(shape[0]), C.int64_t(shape[2]), C.int64_t(shape[4]),
|
||||||
C.size_t(shape[1]), C.size_t(shape[3]),
|
C.size_t(shape[1]), C.size_t(shape[3]),
|
||||||
@@ -597,6 +640,7 @@ func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
|
|||||||
}
|
}
|
||||||
case 7:
|
case 7:
|
||||||
return &Tensor{
|
return &Tensor{
|
||||||
|
b: t.b,
|
||||||
t: C.ggml_view_4d(ctx.(*Context).ctx, t.t,
|
t: C.ggml_view_4d(ctx.(*Context).ctx, t.t,
|
||||||
C.int64_t(shape[0]), C.int64_t(shape[2]), C.int64_t(shape[4]), C.int64_t(shape[6]),
|
C.int64_t(shape[0]), C.int64_t(shape[2]), C.int64_t(shape[4]), C.int64_t(shape[6]),
|
||||||
C.size_t(shape[1]), C.size_t(shape[3]), C.size_t(shape[5]),
|
C.size_t(shape[1]), C.size_t(shape[3]), C.size_t(shape[5]),
|
||||||
@@ -613,7 +657,7 @@ const (
|
|||||||
|
|
||||||
func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim uint32, ropeBase, ropeScale float32) ml.Tensor {
|
func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim uint32, ropeBase, ropeScale float32) ml.Tensor {
|
||||||
if ropeFactors == nil {
|
if ropeFactors == nil {
|
||||||
ropeFactors = &Tensor{}
|
ropeFactors = &Tensor{b: t.b}
|
||||||
}
|
}
|
||||||
|
|
||||||
dequant := t.t
|
dequant := t.t
|
||||||
@@ -622,6 +666,7 @@ func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDi
|
|||||||
}
|
}
|
||||||
|
|
||||||
return &Tensor{
|
return &Tensor{
|
||||||
|
b: t.b,
|
||||||
t: C.ggml_rope_ext(
|
t: C.ggml_rope_ext(
|
||||||
ctx.(*Context).ctx, dequant, positionIDs.(*Tensor).t, ropeFactors.(*Tensor).t,
|
ctx.(*Context).ctx, dequant, positionIDs.(*Tensor).t, ropeFactors.(*Tensor).t,
|
||||||
C.int(ropeDim),
|
C.int(ropeDim),
|
||||||
@@ -639,18 +684,21 @@ func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDi
|
|||||||
|
|
||||||
func (t *Tensor) GELU(ctx ml.Context) ml.Tensor {
|
func (t *Tensor) GELU(ctx ml.Context) ml.Tensor {
|
||||||
return &Tensor{
|
return &Tensor{
|
||||||
|
b: t.b,
|
||||||
t: C.ggml_gelu_inplace(ctx.(*Context).ctx, t.t),
|
t: C.ggml_gelu_inplace(ctx.(*Context).ctx, t.t),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *Tensor) SILU(ctx ml.Context) ml.Tensor {
|
func (t *Tensor) SILU(ctx ml.Context) ml.Tensor {
|
||||||
return &Tensor{
|
return &Tensor{
|
||||||
|
b: t.b,
|
||||||
t: C.ggml_silu_inplace(ctx.(*Context).ctx, t.t),
|
t: C.ggml_silu_inplace(ctx.(*Context).ctx, t.t),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *Tensor) Conv2D(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
|
func (t *Tensor) Conv2D(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
|
||||||
return &Tensor{
|
return &Tensor{
|
||||||
|
b: t.b,
|
||||||
t: C.ggml_conv_2d(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.int(s0), C.int(s1), C.int(p0), C.int(p1), C.int(d0), C.int(d1)),
|
t: C.ggml_conv_2d(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.int(s0), C.int(s1), C.int(p0), C.int(p1), C.int(d0), C.int(d1)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -661,13 +709,25 @@ func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask ml.T
|
|||||||
kqMask = mask.(*Tensor).t
|
kqMask = mask.(*Tensor).t
|
||||||
}
|
}
|
||||||
|
|
||||||
kq := key.MulmatFullPrec(ctx, t)
|
query := t.Permute(ctx, 0, 2, 1, 3)
|
||||||
kq = &Tensor{
|
key = key.Permute(ctx, 0, 2, 1, 3)
|
||||||
t: C.ggml_soft_max_ext(ctx.(*Context).ctx, kq.(*Tensor).t, kqMask, C.float(scale), 0),
|
|
||||||
}
|
|
||||||
|
|
||||||
kqv := value.Mulmat(ctx, kq)
|
if t.b.flashAttention {
|
||||||
return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
value = value.Permute(ctx, 0, 2, 1, 3)
|
||||||
|
|
||||||
|
kqv := C.ggml_flash_attn_ext(ctx.(*Context).ctx, query.(*Tensor).t, key.(*Tensor).t, value.(*Tensor).t, kqMask, C.float(scale), 0, 0)
|
||||||
|
C.ggml_flash_attn_ext_set_prec(kqv, C.GGML_PREC_F32)
|
||||||
|
return &Tensor{b: t.b, t: kqv}
|
||||||
|
} else {
|
||||||
|
kq := key.MulmatFullPrec(ctx, query)
|
||||||
|
kq = &Tensor{
|
||||||
|
b: t.b,
|
||||||
|
t: C.ggml_soft_max_ext(ctx.(*Context).ctx, kq.(*Tensor).t, kqMask, C.float(scale), 0),
|
||||||
|
}
|
||||||
|
|
||||||
|
kqv := value.Mulmat(ctx, kq)
|
||||||
|
return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Backend) SystemInfo() string {
|
func (b *Backend) SystemInfo() string {
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package nn
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/kvcache"
|
||||||
"github.com/ollama/ollama/ml"
|
"github.com/ollama/ollama/ml"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -11,40 +12,50 @@ import (
|
|||||||
//
|
//
|
||||||
// Parameters:
|
// Parameters:
|
||||||
// - ctx: Context for tensor operations
|
// - ctx: Context for tensor operations
|
||||||
// - query: Query tensor (Q) with shape [d_k, seq_len_q, heads]
|
// - query: Query tensor (Q) with shape [d_k, heads, seq_len_q]
|
||||||
// - key: Key tensor (K) with shape [d_k, seq_len_k, kv_heads]
|
// - key: Key tensor (K) with shape [d_k, kv_heads, seq_len_k], can be nil to read from cache only
|
||||||
// - value: Value tensor (V) with shape [seq_len_k, d_v, kv_heads]
|
// - value: Value tensor (V) with shape [d_v, kv_heads, seq_len_k], can be nil to read from cache only
|
||||||
// - mask: Optional attention mask that is added to the attention score. If
|
|
||||||
// provided, should broadcast to [seq_len_k, seq_len_q, heads]
|
|
||||||
// - scale: Scaling factor, typically 1/√d_k where d_k is the key dimension
|
// - scale: Scaling factor, typically 1/√d_k where d_k is the key dimension
|
||||||
|
// - cache: KV cache to store key/value and get past history, can be nil to only use provided key/value
|
||||||
//
|
//
|
||||||
// Returns:
|
// Returns:
|
||||||
//
|
//
|
||||||
// Attention output with shape [d_v, heads, seq_len_q]
|
// Attention output with shape [d_v, heads, seq_len_q]
|
||||||
func Attention(ctx ml.Context, query, key, value, mask ml.Tensor, scale float64) ml.Tensor {
|
func Attention(ctx ml.Context, query, key, value ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
|
||||||
if query.Dim(0) != key.Dim(0) {
|
if key != nil && value != nil {
|
||||||
panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0)))
|
if query.Dim(0) != key.Dim(0) {
|
||||||
|
panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0)))
|
||||||
|
}
|
||||||
|
|
||||||
|
if key.Dim(1) != value.Dim(1) {
|
||||||
|
panic(fmt.Errorf("kv_heads in attention operation does not match between key(%v) and value(%v)", key.Dim(1), value.Dim(1)))
|
||||||
|
}
|
||||||
|
|
||||||
|
if key.Dim(2) != value.Dim(2) {
|
||||||
|
panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and value(%v)", key.Dim(2), value.Dim(2)))
|
||||||
|
}
|
||||||
|
|
||||||
|
if cache != nil {
|
||||||
|
cache.Put(ctx, key, value)
|
||||||
|
}
|
||||||
|
} else if cache == nil {
|
||||||
|
panic("key & value tensors must be provided if cache is nil")
|
||||||
}
|
}
|
||||||
|
|
||||||
if mask != nil && query.Dim(1) != mask.Dim(1) {
|
var mask ml.Tensor
|
||||||
panic(fmt.Errorf("seq_len_q in attention operation does not match between query(%v) and mask(%v)", query.Dim(1), mask.Dim(1)))
|
if cache != nil {
|
||||||
|
key, value, mask = cache.Get(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
if key.Dim(1) != value.Dim(0) {
|
// Only use the fast SDPA implementation if we have a cache, since that's what
|
||||||
panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and value(%v)", key.Dim(1), value.Dim(0)))
|
// will do any expected backend-specific transformations for us
|
||||||
}
|
if sdpa, ok := query.(ml.ScaledDotProductAttention); ok && cache != nil {
|
||||||
|
|
||||||
if mask != nil && key.Dim(1) != mask.Dim(0) {
|
|
||||||
panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and mask(%v)", key.Dim(1), mask.Dim(0)))
|
|
||||||
}
|
|
||||||
|
|
||||||
if key.Dim(2) != value.Dim(2) {
|
|
||||||
panic(fmt.Errorf("kv_heads in attention operation does not match between key(%v) and value(%v)", key.Dim(2), value.Dim(2)))
|
|
||||||
}
|
|
||||||
|
|
||||||
if sdpa, ok := query.(ml.ScaledDotProductAttention); ok {
|
|
||||||
return sdpa.ScaledDotProductAttention(ctx, key, value, mask, scale)
|
return sdpa.ScaledDotProductAttention(ctx, key, value, mask, scale)
|
||||||
} else {
|
} else {
|
||||||
|
query = query.Permute(ctx, 0, 2, 1, 3)
|
||||||
|
key = key.Permute(ctx, 0, 2, 1, 3)
|
||||||
|
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||||
|
|
||||||
kq := key.MulmatFullPrec(ctx, query)
|
kq := key.MulmatFullPrec(ctx, query)
|
||||||
|
|
||||||
kq = kq.Scale(ctx, scale)
|
kq = kq.Scale(ctx, scale)
|
||||||
|
|||||||
@@ -81,15 +81,8 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
|||||||
v := sa.Value.Forward(ctx, hiddenState)
|
v := sa.Value.Forward(ctx, hiddenState)
|
||||||
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||||
|
|
||||||
cache.Put(ctx, k, v)
|
|
||||||
k, v, mask := cache.Get(ctx)
|
|
||||||
|
|
||||||
q = q.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
|
||||||
k = k.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
|
||||||
v = v.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
|
||||||
|
|
||||||
scaleFactor := 1.0 / math.Sqrt(float64(headDim))
|
scaleFactor := 1.0 / math.Sqrt(float64(headDim))
|
||||||
kqv := nn.Attention(ctx, q, k, v, mask, scaleFactor)
|
kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache)
|
||||||
kqv = kqv.Reshape(ctx, opts.hiddenSize, batchSize)
|
kqv = kqv.Reshape(ctx, opts.hiddenSize, batchSize)
|
||||||
|
|
||||||
return sa.Output.Forward(ctx, kqv)
|
return sa.Output.Forward(ctx, kqv)
|
||||||
|
|||||||
@@ -43,7 +43,9 @@ func New(c ml.Config) (model.Model, error) {
|
|||||||
TextModel: newTextModel(c),
|
TextModel: newTextModel(c),
|
||||||
}
|
}
|
||||||
|
|
||||||
m.Cache = kvcache.NewWrapperCache(kvcache.NewEncoderCache(), kvcache.NewCausalCache(m.TextModel.Shift))
|
encoderCache := kvcache.NewEncoderCache()
|
||||||
|
encoderCache.SetConfig(ml.CacheConfig{})
|
||||||
|
m.Cache = kvcache.NewWrapperCache(encoderCache, kvcache.NewCausalCache(m.TextModel.Shift))
|
||||||
|
|
||||||
return &m, nil
|
return &m, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -31,22 +31,15 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ m
|
|||||||
value := sa.Value.Forward(ctx, hiddenState)
|
value := sa.Value.Forward(ctx, hiddenState)
|
||||||
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||||
|
|
||||||
cache.Put(ctx, key, value)
|
|
||||||
key, value, mask := cache.Get(ctx)
|
|
||||||
|
|
||||||
query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
|
||||||
key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
|
||||||
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
|
||||||
|
|
||||||
scaleFactor := 1.0 / math.Sqrt(float64(headDim))
|
scaleFactor := 1.0 / math.Sqrt(float64(headDim))
|
||||||
attention := nn.Attention(ctx, query, key, value, mask, scaleFactor)
|
attention := nn.Attention(ctx, query, key, value, scaleFactor, cache)
|
||||||
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
|
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
|
||||||
|
|
||||||
return sa.Output.Forward(ctx, attention)
|
return sa.Output.Forward(ctx, attention)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||||
// This will only get called for layers in the cache, which are just the self attention layers
|
// This will only get called for layers in the causal cache, which are just the self attention layers
|
||||||
return key.RoPE(ctx, shift, m.RopeFactors, m.ropeDim, m.ropeBase, m.ropeScale), nil
|
return key.RoPE(ctx, shift, m.RopeFactors, m.ropeDim, m.ropeBase, m.ropeScale), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -107,7 +100,7 @@ func (ca *TextCrossAttention) Forward(ctx ml.Context, hiddenState, crossAttentio
|
|||||||
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
||||||
query = ca.QueryNorm.Forward(ctx, query, opts.eps)
|
query = ca.QueryNorm.Forward(ctx, query, opts.eps)
|
||||||
|
|
||||||
var key, value, mask ml.Tensor
|
var key, value ml.Tensor
|
||||||
if crossAttentionStates != nil {
|
if crossAttentionStates != nil {
|
||||||
numVisionTokens, numTiles := crossAttentionStates.Dim(1), crossAttentionStates.Dim(2)
|
numVisionTokens, numTiles := crossAttentionStates.Dim(1), crossAttentionStates.Dim(2)
|
||||||
|
|
||||||
@@ -119,16 +112,23 @@ func (ca *TextCrossAttention) Forward(ctx ml.Context, hiddenState, crossAttentio
|
|||||||
value = value.Reshape(ctx, headDim, opts.numKVHeads, numVisionTokens*numTiles)
|
value = value.Reshape(ctx, headDim, opts.numKVHeads, numVisionTokens*numTiles)
|
||||||
|
|
||||||
cache.Put(ctx, key, value)
|
cache.Put(ctx, key, value)
|
||||||
} else {
|
|
||||||
key, value, mask = cache.Get(ctx)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
key, value, _ = cache.Get(ctx)
|
||||||
key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
|
||||||
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
|
||||||
|
|
||||||
scaleFactor := 1.0 / math.Sqrt(float64(headDim))
|
scaleFactor := 1.0 / math.Sqrt(float64(headDim))
|
||||||
attention := nn.Attention(ctx, query, key, value, mask, scaleFactor)
|
|
||||||
|
query = query.Permute(ctx, 0, 2, 1, 3)
|
||||||
|
key = key.Permute(ctx, 0, 2, 1, 3)
|
||||||
|
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||||
|
|
||||||
|
kq := key.MulmatFullPrec(ctx, query)
|
||||||
|
|
||||||
|
kq = kq.Scale(ctx, scaleFactor)
|
||||||
|
kq = kq.Softmax(ctx)
|
||||||
|
|
||||||
|
kqv := value.Mulmat(ctx, kq)
|
||||||
|
attention := kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||||
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
|
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
|
||||||
|
|
||||||
return ca.Output.Forward(ctx, attention)
|
return ca.Output.Forward(ctx, attention)
|
||||||
|
|||||||
@@ -818,7 +818,7 @@ func Execute(args []string) error {
|
|||||||
batchSize := fs.Int("batch-size", 512, "Batch size")
|
batchSize := fs.Int("batch-size", 512, "Batch size")
|
||||||
numGPULayers := fs.Int("n-gpu-layers", 0, "Number of layers to offload to GPU")
|
numGPULayers := fs.Int("n-gpu-layers", 0, "Number of layers to offload to GPU")
|
||||||
mainGPU := fs.Int("main-gpu", 0, "Main GPU")
|
mainGPU := fs.Int("main-gpu", 0, "Main GPU")
|
||||||
_ = fs.Bool("flash-attn", false, "Enable flash attention")
|
flashAttention := fs.Bool("flash-attn", false, "Enable flash attention")
|
||||||
kvSize := fs.Int("ctx-size", 2048, "Context (or KV cache) size")
|
kvSize := fs.Int("ctx-size", 2048, "Context (or KV cache) size")
|
||||||
kvCacheType := fs.String("kv-cache-type", "", "quantization type for KV cache (default: f16)")
|
kvCacheType := fs.String("kv-cache-type", "", "quantization type for KV cache (default: f16)")
|
||||||
port := fs.Int("port", 8080, "Port to expose the server on")
|
port := fs.Int("port", 8080, "Port to expose the server on")
|
||||||
@@ -863,7 +863,6 @@ func Execute(args []string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// TODO(jessegross): Parameters that need to be implemented:
|
// TODO(jessegross): Parameters that need to be implemented:
|
||||||
// flash-attn
|
|
||||||
// no-mmap
|
// no-mmap
|
||||||
// mlock
|
// mlock
|
||||||
|
|
||||||
@@ -878,10 +877,11 @@ func Execute(args []string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
params := ml.BackendParams{
|
params := ml.BackendParams{
|
||||||
NumThreads: *threads,
|
NumThreads: *threads,
|
||||||
NumGPULayers: *numGPULayers,
|
NumGPULayers: *numGPULayers,
|
||||||
MainGPU: *mainGPU,
|
MainGPU: *mainGPU,
|
||||||
TensorSplit: tensorSplitFloats,
|
TensorSplit: tensorSplitFloats,
|
||||||
|
FlashAttention: *flashAttention,
|
||||||
}
|
}
|
||||||
|
|
||||||
server.ready.Add(1)
|
server.ready.Add(1)
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"slices"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
@@ -53,7 +54,7 @@ var (
|
|||||||
|
|
||||||
// ErrMissingModel is returned when the model part of a name is missing
|
// ErrMissingModel is returned when the model part of a name is missing
|
||||||
// or invalid.
|
// or invalid.
|
||||||
ErrNameInvalid = errors.New("invalid name; must be in the form {scheme://}{host/}{namespace/}[model]{:tag}{@digest}")
|
ErrNameInvalid = errors.New("invalid or missing name")
|
||||||
|
|
||||||
// ErrCached is passed to [Trace.PushUpdate] when a layer already
|
// ErrCached is passed to [Trace.PushUpdate] when a layer already
|
||||||
// exists. It is a non-fatal error and is never returned by [Registry.Push].
|
// exists. It is a non-fatal error and is never returned by [Registry.Push].
|
||||||
@@ -205,10 +206,18 @@ type Registry struct {
|
|||||||
// It is only used when a layer is larger than [MaxChunkingThreshold].
|
// It is only used when a layer is larger than [MaxChunkingThreshold].
|
||||||
MaxChunkSize int64
|
MaxChunkSize int64
|
||||||
|
|
||||||
// NameMask, if set, is the name used to convert non-fully qualified
|
// Mask, if set, is the name used to convert non-fully qualified
|
||||||
// names to fully qualified names. If empty, the default mask
|
// names to fully qualified names. If empty, the default mask
|
||||||
// ("registry.ollama.ai/library/_:latest") is used.
|
// ("registry.ollama.ai/library/_:latest") is used.
|
||||||
NameMask string
|
Mask string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Registry) completeName(name string) names.Name {
|
||||||
|
mask := defaultMask
|
||||||
|
if r.Mask != "" {
|
||||||
|
mask = names.Parse(r.Mask)
|
||||||
|
}
|
||||||
|
return names.Merge(names.Parse(name), mask)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DefaultRegistry returns a new Registry configured from the environment. The
|
// DefaultRegistry returns a new Registry configured from the environment. The
|
||||||
@@ -243,52 +252,6 @@ func DefaultRegistry() (*Registry, error) {
|
|||||||
return &rc, nil
|
return &rc, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type PushParams struct {
|
|
||||||
// From is an optional destination name for the model. If empty, the
|
|
||||||
// destination name is the same as the source name.
|
|
||||||
From string
|
|
||||||
}
|
|
||||||
|
|
||||||
// parseName parses name using [names.ParseExtended] and then merges the name with the
|
|
||||||
// default name, and checks that the name is fully qualified. If a digest is
|
|
||||||
// present, it parse and returns it with the other fields as their zero values.
|
|
||||||
//
|
|
||||||
// It returns an error if the name is not fully qualified, or if the digest, if
|
|
||||||
// any, is invalid.
|
|
||||||
//
|
|
||||||
// The scheme is returned as provided by [names.ParseExtended].
|
|
||||||
func parseName(s, mask string) (scheme string, n names.Name, d blob.Digest, err error) {
|
|
||||||
maskName := defaultMask
|
|
||||||
if mask != "" {
|
|
||||||
maskName = names.Parse(mask)
|
|
||||||
if !maskName.IsFullyQualified() {
|
|
||||||
return "", names.Name{}, blob.Digest{}, fmt.Errorf("invalid name mask: %s", mask)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
scheme, n, ds := names.ParseExtended(s)
|
|
||||||
if !n.IsValid() {
|
|
||||||
return "", names.Name{}, blob.Digest{}, fmt.Errorf("%w: %q", ErrNameInvalid, s)
|
|
||||||
}
|
|
||||||
n = names.Merge(n, maskName)
|
|
||||||
if ds != "" {
|
|
||||||
// Digest is present. Validate it.
|
|
||||||
d, err = blob.ParseDigest(ds)
|
|
||||||
if err != nil {
|
|
||||||
return "", names.Name{}, blob.Digest{}, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// The name check is deferred until after the digest check because we
|
|
||||||
// say that digests take precedence over names, and so should there
|
|
||||||
// errors when being parsed.
|
|
||||||
if !n.IsFullyQualified() {
|
|
||||||
return "", names.Name{}, blob.Digest{}, fmt.Errorf("%w: %q", ErrNameInvalid, s)
|
|
||||||
}
|
|
||||||
|
|
||||||
scheme = cmp.Or(scheme, "https")
|
|
||||||
return scheme, n, d, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *Registry) maxStreams() int {
|
func (r *Registry) maxStreams() int {
|
||||||
n := cmp.Or(r.MaxStreams, runtime.GOMAXPROCS(0))
|
n := cmp.Or(r.MaxStreams, runtime.GOMAXPROCS(0))
|
||||||
|
|
||||||
@@ -308,6 +271,12 @@ func (r *Registry) maxChunkSize() int64 {
|
|||||||
return cmp.Or(r.MaxChunkSize, DefaultMaxChunkSize)
|
return cmp.Or(r.MaxChunkSize, DefaultMaxChunkSize)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type PushParams struct {
|
||||||
|
// From is an optional destination name for the model. If empty, the
|
||||||
|
// destination name is the same as the source name.
|
||||||
|
From string
|
||||||
|
}
|
||||||
|
|
||||||
// Push pushes the model with the name in the cache to the remote registry.
|
// Push pushes the model with the name in the cache to the remote registry.
|
||||||
func (r *Registry) Push(ctx context.Context, c *blob.DiskCache, name string, p *PushParams) error {
|
func (r *Registry) Push(ctx context.Context, c *blob.DiskCache, name string, p *PushParams) error {
|
||||||
if p == nil {
|
if p == nil {
|
||||||
@@ -337,7 +306,7 @@ func (r *Registry) Push(ctx context.Context, c *blob.DiskCache, name string, p *
|
|||||||
|
|
||||||
t := traceFromContext(ctx)
|
t := traceFromContext(ctx)
|
||||||
|
|
||||||
scheme, n, _, err := parseName(name, r.NameMask)
|
scheme, n, _, err := parseName(name, r.Mask)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// This should never happen since ResolveLocal should have
|
// This should never happen since ResolveLocal should have
|
||||||
// already validated the name.
|
// already validated the name.
|
||||||
@@ -431,7 +400,7 @@ func canRetry(err error) bool {
|
|||||||
// typically slower than splitting the model up across layers, and is mostly
|
// typically slower than splitting the model up across layers, and is mostly
|
||||||
// utilized for layers of type equal to "application/vnd.ollama.image".
|
// utilized for layers of type equal to "application/vnd.ollama.image".
|
||||||
func (r *Registry) Pull(ctx context.Context, c *blob.DiskCache, name string) error {
|
func (r *Registry) Pull(ctx context.Context, c *blob.DiskCache, name string) error {
|
||||||
scheme, n, _, err := parseName(name, r.NameMask)
|
scheme, n, _, err := parseName(name, r.Mask)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -582,9 +551,9 @@ func (r *Registry) Pull(ctx context.Context, c *blob.DiskCache, name string) err
|
|||||||
// Unlink is like [blob.DiskCache.Unlink], but makes name fully qualified
|
// Unlink is like [blob.DiskCache.Unlink], but makes name fully qualified
|
||||||
// before attempting to unlink the model.
|
// before attempting to unlink the model.
|
||||||
func (r *Registry) Unlink(c *blob.DiskCache, name string) (ok bool, _ error) {
|
func (r *Registry) Unlink(c *blob.DiskCache, name string) (ok bool, _ error) {
|
||||||
_, n, _, err := parseName(name, r.NameMask)
|
n := r.completeName(name)
|
||||||
if err != nil {
|
if !n.IsFullyQualified() {
|
||||||
return false, err
|
return false, fmt.Errorf("%w: %q", ErrNameInvalid, name)
|
||||||
}
|
}
|
||||||
return c.Unlink(n.String())
|
return c.Unlink(n.String())
|
||||||
}
|
}
|
||||||
@@ -658,9 +627,9 @@ type Layer struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ResolveLocal resolves a name to a Manifest in the local cache. The name is
|
// ResolveLocal resolves a name to a Manifest in the local cache. The name is
|
||||||
// parsed using [names.ParseExtended] but the scheme is ignored.
|
// parsed using [names.Split] but the scheme is ignored.
|
||||||
func (r *Registry) ResolveLocal(c *blob.DiskCache, name string) (*Manifest, error) {
|
func (r *Registry) ResolveLocal(c *blob.DiskCache, name string) (*Manifest, error) {
|
||||||
_, n, d, err := parseName(name, r.NameMask)
|
_, n, d, err := parseName(name, r.Mask)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -686,7 +655,7 @@ func (r *Registry) ResolveLocal(c *blob.DiskCache, name string) (*Manifest, erro
|
|||||||
|
|
||||||
// Resolve resolves a name to a Manifest in the remote registry.
|
// Resolve resolves a name to a Manifest in the remote registry.
|
||||||
func (r *Registry) Resolve(ctx context.Context, name string) (*Manifest, error) {
|
func (r *Registry) Resolve(ctx context.Context, name string) (*Manifest, error) {
|
||||||
scheme, n, d, err := parseName(name, r.NameMask)
|
scheme, n, d, err := parseName(name, r.Mask)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -869,3 +838,69 @@ func maybeUnexpectedEOF(err error) error {
|
|||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type publicError struct {
|
||||||
|
wrapped error
|
||||||
|
message string
|
||||||
|
}
|
||||||
|
|
||||||
|
func withPublicMessagef(err error, message string, args ...any) error {
|
||||||
|
return publicError{wrapped: err, message: fmt.Sprintf(message, args...)}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e publicError) Error() string { return e.message }
|
||||||
|
func (e publicError) Unwrap() error { return e.wrapped }
|
||||||
|
|
||||||
|
var supportedSchemes = []string{
|
||||||
|
"http",
|
||||||
|
"https",
|
||||||
|
"https+insecure",
|
||||||
|
}
|
||||||
|
|
||||||
|
var supportedSchemesMessage = fmt.Sprintf("supported schemes are %v", strings.Join(supportedSchemes, ", "))
|
||||||
|
|
||||||
|
// parseName parses and validates an extended name, returning the scheme, name,
|
||||||
|
// and digest.
|
||||||
|
//
|
||||||
|
// If the scheme is empty, scheme will be "https". If an unsupported scheme is
|
||||||
|
// given, [ErrNameInvalid] wrapped with a display friendly message is returned.
|
||||||
|
//
|
||||||
|
// If the digest is invalid, [ErrNameInvalid] wrapped with a display friendly
|
||||||
|
// message is returned.
|
||||||
|
//
|
||||||
|
// If the name is not, once merged with the mask, fully qualified,
|
||||||
|
// [ErrNameInvalid] wrapped with a display friendly message is returned.
|
||||||
|
func parseName(s string, mask string) (scheme string, _ names.Name, _ blob.Digest, _ error) {
|
||||||
|
scheme, name, digest := names.Split(s)
|
||||||
|
scheme = cmp.Or(scheme, "https")
|
||||||
|
if !slices.Contains(supportedSchemes, scheme) {
|
||||||
|
err := withPublicMessagef(ErrNameInvalid, "unsupported scheme: %q: %s", scheme, supportedSchemesMessage)
|
||||||
|
return "", names.Name{}, blob.Digest{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var d blob.Digest
|
||||||
|
if digest != "" {
|
||||||
|
var err error
|
||||||
|
d, err = blob.ParseDigest(digest)
|
||||||
|
if err != nil {
|
||||||
|
err = withPublicMessagef(ErrNameInvalid, "invalid digest: %q", digest)
|
||||||
|
return "", names.Name{}, blob.Digest{}, err
|
||||||
|
}
|
||||||
|
if name == "" {
|
||||||
|
// We have can resolve a manifest from a digest only,
|
||||||
|
// so skip name validation and return the scheme and
|
||||||
|
// digest.
|
||||||
|
return scheme, names.Name{}, d, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
maskName := defaultMask
|
||||||
|
if mask != "" {
|
||||||
|
maskName = names.Parse(mask)
|
||||||
|
}
|
||||||
|
n := names.Merge(names.Parse(name), maskName)
|
||||||
|
if !n.IsFullyQualified() {
|
||||||
|
return "", names.Name{}, blob.Digest{}, fmt.Errorf("%w: %q", ErrNameInvalid, s)
|
||||||
|
}
|
||||||
|
return scheme, n, d, nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -84,14 +84,14 @@ func newClient(t *testing.T, h http.HandlerFunc) (*Registry, *blob.DiskCache) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
rc := &Registry{
|
r := &Registry{
|
||||||
HTTPClient: &http.Client{
|
HTTPClient: &http.Client{
|
||||||
Transport: recordRoundTripper(h),
|
Transport: recordRoundTripper(h),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
link := func(name string, manifest string) {
|
link := func(name string, manifest string) {
|
||||||
_, n, _, err := parseName(name, rc.NameMask)
|
_, n, _, err := parseName(name, r.Mask)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
@@ -122,7 +122,7 @@ func newClient(t *testing.T, h http.HandlerFunc) (*Registry, *blob.DiskCache) {
|
|||||||
commit("sizemismatch", mklayer("exists"), &Layer{Digest: blob.DigestFromBytes("present"), Size: 499})
|
commit("sizemismatch", mklayer("exists"), &Layer{Digest: blob.DigestFromBytes("present"), Size: 499})
|
||||||
link("invalid", "!!!!!")
|
link("invalid", "!!!!!")
|
||||||
|
|
||||||
return rc, c
|
return r, c
|
||||||
}
|
}
|
||||||
|
|
||||||
func okHandler(w http.ResponseWriter, r *http.Request) {
|
func okHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
@@ -145,29 +145,6 @@ func importBytes(t *testing.T, c *blob.DiskCache, data string) blob.Digest {
|
|||||||
return d
|
return d
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRegistryPushInvalidNames(t *testing.T) {
|
|
||||||
rc, c := newClient(t, nil)
|
|
||||||
|
|
||||||
cases := []struct {
|
|
||||||
name string
|
|
||||||
err error
|
|
||||||
}{
|
|
||||||
{"", ErrNameInvalid},
|
|
||||||
{"@", ErrNameInvalid},
|
|
||||||
{"@x", blob.ErrInvalidDigest},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range cases {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
// Create a new registry and push a new image.
|
|
||||||
err := rc.Push(t.Context(), c, tt.name, nil)
|
|
||||||
if !errors.Is(err, tt.err) {
|
|
||||||
t.Errorf("err = %v; want %v", err, tt.err)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func withTraceUnexpected(ctx context.Context) (context.Context, *Trace) {
|
func withTraceUnexpected(ctx context.Context) (context.Context, *Trace) {
|
||||||
t := &Trace{Update: func(*Layer, int64, error) { panic("unexpected") }}
|
t := &Trace{Update: func(*Layer, int64, error) { panic("unexpected") }}
|
||||||
return WithTrace(ctx, t), t
|
return WithTrace(ctx, t), t
|
||||||
@@ -622,7 +599,7 @@ func TestInsecureSkipVerify(t *testing.T) {
|
|||||||
}))
|
}))
|
||||||
defer s.Close()
|
defer s.Close()
|
||||||
|
|
||||||
const name = "ollama.com/library/insecure"
|
const name = "library/insecure"
|
||||||
|
|
||||||
var rc Registry
|
var rc Registry
|
||||||
url := fmt.Sprintf("https://%s/%s", s.Listener.Addr(), name)
|
url := fmt.Sprintf("https://%s/%s", s.Listener.Addr(), name)
|
||||||
@@ -724,3 +701,38 @@ func TestErrorUnmarshal(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestParseNameErrors tests that parseName returns errors messages with enough
|
||||||
|
// detail for users to debug naming issues they may encounter. Previous to this
|
||||||
|
// test, the error messages were not very helpful and each problem was reported
|
||||||
|
// as the same message.
|
||||||
|
//
|
||||||
|
// It is only for testing error messages, not that all invalids and valids are
|
||||||
|
// covered. Those are in other tests for names.Name and blob.Digest.
|
||||||
|
func TestParseNameErrors(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
err error
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{"x", nil, ""},
|
||||||
|
{"x@", nil, ""},
|
||||||
|
|
||||||
|
{"", ErrNameInvalid, `invalid or missing name: ""`},
|
||||||
|
{"://", ErrNameInvalid, `invalid or missing name: "://"`},
|
||||||
|
{"x://", ErrNameInvalid, `unsupported scheme: "x": supported schemes are http, https, https+insecure`},
|
||||||
|
|
||||||
|
{"@sha123-1234", ErrNameInvalid, `invalid digest: "sha123-1234"`},
|
||||||
|
{"x@sha123-1234", ErrNameInvalid, `invalid digest: "sha123-1234"`},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range cases {
|
||||||
|
_, _, _, err := parseName(tt.name, DefaultMask)
|
||||||
|
if !errors.Is(err, tt.err) {
|
||||||
|
t.Errorf("[%s]: err = %v; want %v", tt.name, err, tt.err)
|
||||||
|
}
|
||||||
|
if err != nil && !strings.Contains(err.Error(), tt.want) {
|
||||||
|
t.Errorf("[%s]: err =\n\t%v\nwant\n\t%v", tt.name, err, tt.want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import (
|
|||||||
"github.com/ollama/ollama/server/internal/internal/stringsx"
|
"github.com/ollama/ollama/server/internal/internal/stringsx"
|
||||||
)
|
)
|
||||||
|
|
||||||
const MaxNameLength = 50 + 1 + 50 + 1 + 50 // <namespace>/<model>:<tag>
|
const MaxNameLength = 350 + 1 + 80 + 1 + 80 + 1 + 80 // <host>/<namespace>/<model>:<tag>
|
||||||
|
|
||||||
type Name struct {
|
type Name struct {
|
||||||
// Make incomparable to enfoce use of Compare / Equal for
|
// Make incomparable to enfoce use of Compare / Equal for
|
||||||
@@ -25,19 +25,12 @@ type Name struct {
|
|||||||
// format of a valid name string is:
|
// format of a valid name string is:
|
||||||
//
|
//
|
||||||
// s:
|
// s:
|
||||||
// { host } "/" { namespace } "/" { model } ":" { tag } "@" { digest }
|
|
||||||
// { host } "/" { namespace } "/" { model } ":" { tag }
|
// { host } "/" { namespace } "/" { model } ":" { tag }
|
||||||
// { host } "/" { namespace } "/" { model } "@" { digest }
|
|
||||||
// { host } "/" { namespace } "/" { model }
|
// { host } "/" { namespace } "/" { model }
|
||||||
// { namespace } "/" { model } ":" { tag } "@" { digest }
|
|
||||||
// { namespace } "/" { model } ":" { tag }
|
// { namespace } "/" { model } ":" { tag }
|
||||||
// { namespace } "/" { model } "@" { digest }
|
|
||||||
// { namespace } "/" { model }
|
// { namespace } "/" { model }
|
||||||
// { model } ":" { tag } "@" { digest }
|
|
||||||
// { model } ":" { tag }
|
// { model } ":" { tag }
|
||||||
// { model } "@" { digest }
|
|
||||||
// { model }
|
// { model }
|
||||||
// "@" { digest }
|
|
||||||
// host:
|
// host:
|
||||||
// pattern: { alphanum | "_" } { alphanum | "_" | "-" | "." | ":" }*
|
// pattern: { alphanum | "_" } { alphanum | "_" | "-" | "." | ":" }*
|
||||||
// length: [1, 350]
|
// length: [1, 350]
|
||||||
@@ -50,9 +43,6 @@ type Name struct {
|
|||||||
// tag:
|
// tag:
|
||||||
// pattern: { alphanum | "_" } { alphanum | "-" | "_" | "." }*
|
// pattern: { alphanum | "_" } { alphanum | "-" | "_" | "." }*
|
||||||
// length: [1, 80]
|
// length: [1, 80]
|
||||||
// digest:
|
|
||||||
// pattern: { alphanum | "_" } { alphanum | "-" | ":" }*
|
|
||||||
// length: [1, 80]
|
|
||||||
//
|
//
|
||||||
// The name returned is not guaranteed to be valid. If it is not valid, the
|
// The name returned is not guaranteed to be valid. If it is not valid, the
|
||||||
// field values are left in an undefined state. Use [Name.IsValid] to check
|
// field values are left in an undefined state. Use [Name.IsValid] to check
|
||||||
@@ -82,23 +72,17 @@ func Parse(s string) Name {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ParseExtended parses and returns any scheme, Name, and digest from from s in
|
// Split splits an extended name string into its scheme, name, and digest
|
||||||
// the the form [scheme://][name][@digest]. All parts are optional.
|
// parts.
|
||||||
//
|
|
||||||
// If the scheme is present, it must be followed by "://". The digest is
|
|
||||||
// prefixed by "@" and comes after the name. The name is parsed using [Parse].
|
|
||||||
//
|
|
||||||
// The scheme and digest are stripped before the name is parsed by [Parse].
|
|
||||||
//
|
|
||||||
// For convience, the scheme is never empty. If the scheme is not present, the
|
|
||||||
// returned scheme is "https".
|
|
||||||
//
|
//
|
||||||
// Examples:
|
// Examples:
|
||||||
//
|
//
|
||||||
// http://ollama.com/bmizerany/smol:latest@digest
|
// http://ollama.com/bmizerany/smol:latest@digest
|
||||||
// https://ollama.com/bmizerany/smol:latest
|
// https://ollama.com/bmizerany/smol:latest
|
||||||
// ollama.com/bmizerany/smol:latest@digest // returns "https" scheme.
|
// ollama.com/bmizerany/smol:latest@digest // returns "https" scheme.
|
||||||
func ParseExtended(s string) (scheme string, _ Name, digest string) {
|
// model@digest
|
||||||
|
// @digest
|
||||||
|
func Split(s string) (scheme, name, digest string) {
|
||||||
i := strings.Index(s, "://")
|
i := strings.Index(s, "://")
|
||||||
if i >= 0 {
|
if i >= 0 {
|
||||||
scheme = s[:i]
|
scheme = s[:i]
|
||||||
@@ -109,21 +93,7 @@ func ParseExtended(s string) (scheme string, _ Name, digest string) {
|
|||||||
digest = s[i+1:]
|
digest = s[i+1:]
|
||||||
s = s[:i]
|
s = s[:i]
|
||||||
}
|
}
|
||||||
return scheme, Parse(s), digest
|
return scheme, s, digest
|
||||||
}
|
|
||||||
|
|
||||||
func FormatExtended(scheme string, n Name, digest string) string {
|
|
||||||
var b strings.Builder
|
|
||||||
if scheme != "" {
|
|
||||||
b.WriteString(scheme)
|
|
||||||
b.WriteString("://")
|
|
||||||
}
|
|
||||||
b.WriteString(n.String())
|
|
||||||
if digest != "" {
|
|
||||||
b.WriteByte('@')
|
|
||||||
b.WriteString(digest)
|
|
||||||
}
|
|
||||||
return b.String()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Merge merges two names into a single name. Non-empty host, namespace, and
|
// Merge merges two names into a single name. Non-empty host, namespace, and
|
||||||
@@ -141,39 +111,68 @@ func Merge(a, b Name) Name {
|
|||||||
|
|
||||||
// IsValid returns true if the name is valid.
|
// IsValid returns true if the name is valid.
|
||||||
func (n Name) IsValid() bool {
|
func (n Name) IsValid() bool {
|
||||||
if n.h != "" && !isValidHost(n.h) {
|
if n.h != "" && !isValidPart(partHost, n.h) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
if n.n != "" && !isValidNamespace(n.n) {
|
if n.n != "" && !isValidPart(partNamespace, n.n) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
if n.m != "" && !isValidModel(n.m) {
|
if n.t != "" && !isValidPart(partTag, n.t) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
if n.t != "" && !isValidTag(n.t) {
|
|
||||||
return false
|
// at bare minimum, model must be present and valid
|
||||||
}
|
return n.m != "" && isValidPart(partModel, n.m)
|
||||||
return true
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n Name) IsFullyQualified() bool {
|
func (n Name) IsFullyQualified() bool {
|
||||||
return n.IsValid() && n.h != "" && n.n != "" && n.m != "" && n.t != ""
|
return n.IsValid() && n.h != "" && n.n != "" && n.m != "" && n.t != ""
|
||||||
}
|
}
|
||||||
|
|
||||||
func isValidHost(_ string) bool {
|
const (
|
||||||
return true // TODO: implement
|
partHost = iota
|
||||||
|
partNamespace
|
||||||
|
partModel
|
||||||
|
partTag
|
||||||
|
)
|
||||||
|
|
||||||
|
func isValidPart(kind int, s string) bool {
|
||||||
|
maxlen := 80
|
||||||
|
if kind == partHost {
|
||||||
|
maxlen = 350
|
||||||
|
}
|
||||||
|
if len(s) > maxlen {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range s {
|
||||||
|
if i == 0 {
|
||||||
|
if !isAlphanumericOrUnderscore(s[i]) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
switch s[i] {
|
||||||
|
case '_', '-':
|
||||||
|
case '.':
|
||||||
|
if kind == partNamespace {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
case ':':
|
||||||
|
if kind != partHost {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
if !isAlphanumericOrUnderscore(s[i]) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func isValidNamespace(_ string) bool {
|
func isAlphanumericOrUnderscore(c byte) bool {
|
||||||
return true // TODO: implement
|
return c >= 'A' && c <= 'Z' || c >= 'a' && c <= 'z' || c >= '0' && c <= '9' || c == '_'
|
||||||
}
|
|
||||||
|
|
||||||
func isValidModel(_ string) bool {
|
|
||||||
return true // TODO: implement
|
|
||||||
}
|
|
||||||
|
|
||||||
func isValidTag(_ string) bool {
|
|
||||||
return true // TODO: implement
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n Name) Host() string { return n.h }
|
func (n Name) Host() string { return n.h }
|
||||||
|
|||||||
@@ -81,15 +81,11 @@ func TestParseExtended(t *testing.T) {
|
|||||||
}
|
}
|
||||||
for _, tt := range cases {
|
for _, tt := range cases {
|
||||||
t.Run(tt.in, func(t *testing.T) {
|
t.Run(tt.in, func(t *testing.T) {
|
||||||
scheme, name, digest := ParseExtended(tt.in)
|
scheme, name, digest := Split(tt.in)
|
||||||
if scheme != tt.wantScheme || name.Compare(tt.wantName) != 0 || digest != tt.wantDigest {
|
n := Parse(name)
|
||||||
|
if scheme != tt.wantScheme || n.Compare(tt.wantName) != 0 || digest != tt.wantDigest {
|
||||||
t.Errorf("ParseExtended(%q) = %q, %#v, %q, want %q, %#v, %q", tt.in, scheme, name, digest, tt.wantScheme, tt.wantName, tt.wantDigest)
|
t.Errorf("ParseExtended(%q) = %q, %#v, %q, want %q, %#v, %q", tt.in, scheme, name, digest, tt.wantScheme, tt.wantName, tt.wantDigest)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Round trip
|
|
||||||
if got := FormatExtended(scheme, name, digest); got != tt.in {
|
|
||||||
t.Errorf("FormatExtended(%q, %q, %q) = %q", scheme, name, digest, got)
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -150,3 +146,75 @@ func BenchmarkParseName(b *testing.B) {
|
|||||||
junkName = Parse("h/n/m:t")
|
junkName = Parse("h/n/m:t")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
part80 = "88888888888888888888888888888888888888888888888888888888888888888888888888888888"
|
||||||
|
part350 = "33333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333"
|
||||||
|
)
|
||||||
|
|
||||||
|
var testCases = map[string]bool{ // name -> valid
|
||||||
|
"": false,
|
||||||
|
|
||||||
|
"_why/_the/_lucky:_stiff": true,
|
||||||
|
|
||||||
|
// minimal
|
||||||
|
"h/n/m:t": true,
|
||||||
|
|
||||||
|
"host/namespace/model:tag": true,
|
||||||
|
"host/namespace/model": true,
|
||||||
|
"namespace/model": true,
|
||||||
|
"model": true,
|
||||||
|
|
||||||
|
// long (but valid)
|
||||||
|
part80 + "/" + part80 + "/" + part80 + ":" + part80: true,
|
||||||
|
part350 + "/" + part80 + "/" + part80 + ":" + part80: true,
|
||||||
|
|
||||||
|
// too long
|
||||||
|
part80 + "/" + part80 + "/" + part80 + ":" + part350: false,
|
||||||
|
"x" + part350 + "/" + part80 + "/" + part80 + ":" + part80: false,
|
||||||
|
|
||||||
|
"h/nn/mm:t": true, // bare minimum part sizes
|
||||||
|
|
||||||
|
// unqualified
|
||||||
|
"m": true,
|
||||||
|
"n/m:": true,
|
||||||
|
"h/n/m": true,
|
||||||
|
"@t": false,
|
||||||
|
"m@d": false,
|
||||||
|
|
||||||
|
// invalids
|
||||||
|
"^": false,
|
||||||
|
"mm:": true,
|
||||||
|
"/nn/mm": true,
|
||||||
|
"//": false, // empty model
|
||||||
|
"//mm": true,
|
||||||
|
"hh//": false, // empty model
|
||||||
|
"//mm:@": false,
|
||||||
|
"00@": false,
|
||||||
|
"@": false,
|
||||||
|
|
||||||
|
// not starting with alphanum
|
||||||
|
"-hh/nn/mm:tt": false,
|
||||||
|
"hh/-nn/mm:tt": false,
|
||||||
|
"hh/nn/-mm:tt": false,
|
||||||
|
"hh/nn/mm:-tt": false,
|
||||||
|
|
||||||
|
// smells like a flag
|
||||||
|
"-h": false,
|
||||||
|
|
||||||
|
// hosts
|
||||||
|
"host:https/namespace/model:tag": true,
|
||||||
|
|
||||||
|
// colon in non-host part before tag
|
||||||
|
"host/name:space/model:tag": false,
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseNameValidation(t *testing.T) {
|
||||||
|
for s, valid := range testCases {
|
||||||
|
got := Parse(s)
|
||||||
|
if got.IsValid() != valid {
|
||||||
|
t.Logf("got: %v", got)
|
||||||
|
t.Errorf("Parse(%q).IsValid() = %v; want !%[2]v", s, got.IsValid())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -204,7 +204,7 @@ func (s *Local) handleDelete(_ http.ResponseWriter, r *http.Request) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if !ok {
|
if !ok {
|
||||||
return &serverError{404, "manifest_not_found", "manifest not found"}
|
return &serverError{404, "not_found", "model not found"}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -109,11 +109,8 @@ func TestServerDelete(t *testing.T) {
|
|||||||
got = s.send(t, "DELETE", "/api/delete", ``)
|
got = s.send(t, "DELETE", "/api/delete", ``)
|
||||||
checkErrorResponse(t, got, 400, "bad_request", "empty request body")
|
checkErrorResponse(t, got, 400, "bad_request", "empty request body")
|
||||||
|
|
||||||
got = s.send(t, "DELETE", "/api/delete", `{"model": "!"}`)
|
|
||||||
checkErrorResponse(t, got, 404, "manifest_not_found", "not found")
|
|
||||||
|
|
||||||
got = s.send(t, "DELETE", "/api/delete", `{"model": "://"}`)
|
got = s.send(t, "DELETE", "/api/delete", `{"model": "://"}`)
|
||||||
checkErrorResponse(t, got, 400, "bad_request", "invalid name")
|
checkErrorResponse(t, got, 400, "bad_request", "invalid or missing name")
|
||||||
|
|
||||||
got = s.send(t, "DELETE", "/unknown_path", `{}`) // valid body
|
got = s.send(t, "DELETE", "/unknown_path", `{}`) // valid body
|
||||||
checkErrorResponse(t, got, 404, "not_found", "not found")
|
checkErrorResponse(t, got, 404, "not_found", "not found")
|
||||||
|
|||||||
Reference in New Issue
Block a user