mirror of
https://github.com/likelovewant/ollama-for-amd.git
synced 2025-12-21 14:26:30 +00:00
fix tensor merge (#13053)
This commit is contained in:
@@ -3,8 +3,10 @@ package convert
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"iter"
|
||||
"math/rand/v2"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
@@ -951,3 +953,45 @@ func TestMerge(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestMergeOrder(t *testing.T) {
|
||||
for range 8 {
|
||||
t.Run("", func(t *testing.T) {
|
||||
tensors := make([]Tensor, 16)
|
||||
for i := range tensors {
|
||||
tensors[i] = &fakeTensor{
|
||||
name: fmt.Sprintf("layer.%d.weight", i),
|
||||
shape: []uint64{1},
|
||||
data: []float32{float32(i)},
|
||||
}
|
||||
}
|
||||
|
||||
rand.Shuffle(len(tensors), func(i, j int) {
|
||||
tensors[i], tensors[j] = tensors[j], tensors[i]
|
||||
})
|
||||
|
||||
matched, unmatched := mergeTensors(tensors, merge{"layer.*.weight", "layer.weight"})
|
||||
if len(unmatched) != 0 {
|
||||
t.Error("expected no remaining tensors, got", len(unmatched))
|
||||
}
|
||||
|
||||
if len(matched) != 1 {
|
||||
t.Error("expected 1 merged tensor, got", len(matched))
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if _, err := matched[0].WriteTo(&b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var f32s [16]float32
|
||||
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !slices.IsSorted(f32s[:]) {
|
||||
t.Errorf("merged tensor data is not in order: %+v", f32s)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user