convert: slice gate_up weight

This commit is contained in:
Michael Yang
2025-10-06 16:05:38 -07:00
committed by Michael Yang
parent c00fa9cc2b
commit 93085127f4

View File

@@ -108,17 +108,26 @@ func (m *gptossModel) Tensors(ts []Tensor) []*ggml.Tensor {
for name, mxfp4 := range mxfp4s { for name, mxfp4 := range mxfp4s {
dims := mxfp4.blocks.Shape() dims := mxfp4.blocks.Shape()
if strings.Contains(name, "ffn_down_exps") {
if !strings.HasSuffix(name, ".weight") { out = append(out, &ggml.Tensor{
name += ".weight" Name: name + ".weight",
Kind: uint32(ggml.TensorTypeMXFP4),
Shape: []uint64{dims[0], dims[1], dims[2] * dims[3] * 2},
WriterTo: mxfp4,
})
} else if strings.Contains(name, "ffn_gate_up_exps") {
out = append(out, &ggml.Tensor{
Name: strings.Replace(name, "gate_up", "gate", 1) + ".weight",
Kind: uint32(ggml.TensorTypeMXFP4),
Shape: []uint64{dims[0], dims[1] / 2, dims[2] * dims[3] * 2},
WriterTo: mxfp4.slice(1, 0, int(dims[1]), 2),
}, &ggml.Tensor{
Name: strings.Replace(name, "gate_up", "up", 1) + ".weight",
Kind: uint32(ggml.TensorTypeMXFP4),
Shape: []uint64{dims[0], dims[1] / 2, dims[2] * dims[3] * 2},
WriterTo: mxfp4.slice(1, 1, int(dims[1]), 2),
})
} }
out = append(out, &ggml.Tensor{
Name: name,
Kind: uint32(ggml.TensorTypeMXFP4),
Shape: []uint64{dims[0], dims[1], dims[2] * dims[3] * 2},
WriterTo: mxfp4,
})
} }
return out return out
@@ -169,9 +178,21 @@ func (m *gptossModel) Replacements() []string {
} }
type mxfp4 struct { type mxfp4 struct {
slices []tensor.Slice
blocks, scales Tensor blocks, scales Tensor
} }
func (m *mxfp4) slice(dim, start, end, step int) *mxfp4 {
slice := slices.Repeat([]tensor.Slice{nil}, len(m.blocks.Shape()))
slice[dim] = tensor.S(start, end, step)
return &mxfp4{
slices: slice,
blocks: m.blocks,
scales: m.scales,
}
}
func (m *mxfp4) WriteTo(w io.Writer) (int64, error) { func (m *mxfp4) WriteTo(w io.Writer) (int64, error) {
var b bytes.Buffer var b bytes.Buffer
if _, err := m.blocks.WriteTo(&b); err != nil { if _, err := m.blocks.WriteTo(&b); err != nil {
@@ -215,6 +236,13 @@ func (m *mxfp4) WriteTo(w io.Writer) (int64, error) {
return 0, err return 0, err
} }
if len(m.slices) > 0 {
out, err = out.Slice(m.slices...)
if err != nil {
return 0, err
}
}
out = tensor.Materialize(out) out = tensor.Materialize(out)
if err := out.Reshape(out.Shape().TotalSize()); err != nil { if err := out.Reshape(out.Shape().TotalSize()); err != nil {