fix tensor merge (#13053)

This commit is contained in:
Michael Yang
2025-11-13 15:32:34 -08:00
committed by GitHub
parent 333203d871
commit 12b174b10e
2 changed files with 66 additions and 0 deletions

View File

@@ -3,8 +3,10 @@ package convert
import (
"bytes"
"encoding/binary"
"fmt"
"io"
"iter"
"math/rand/v2"
"slices"
"strings"
"testing"
@@ -951,3 +953,45 @@ func TestMerge(t *testing.T) {
}
})
}
func TestMergeOrder(t *testing.T) {
for range 8 {
t.Run("", func(t *testing.T) {
tensors := make([]Tensor, 16)
for i := range tensors {
tensors[i] = &fakeTensor{
name: fmt.Sprintf("layer.%d.weight", i),
shape: []uint64{1},
data: []float32{float32(i)},
}
}
rand.Shuffle(len(tensors), func(i, j int) {
tensors[i], tensors[j] = tensors[j], tensors[i]
})
matched, unmatched := mergeTensors(tensors, merge{"layer.*.weight", "layer.weight"})
if len(unmatched) != 0 {
t.Error("expected no remaining tensors, got", len(unmatched))
}
if len(matched) != 1 {
t.Error("expected 1 merged tensor, got", len(matched))
}
var b bytes.Buffer
if _, err := matched[0].WriteTo(&b); err != nil {
t.Fatal(err)
}
var f32s [16]float32
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
t.Fatal(err)
}
if !slices.IsSorted(f32s[:]) {
t.Errorf("merged tensor data is not in order: %+v", f32s)
}
})
}
}