diff --git a/convert/convert_gptoss.go b/convert/convert_gptoss.go index 478f29f4..c5ac0b37 100644 --- a/convert/convert_gptoss.go +++ b/convert/convert_gptoss.go @@ -108,17 +108,26 @@ func (m *gptossModel) Tensors(ts []Tensor) []*ggml.Tensor { for name, mxfp4 := range mxfp4s { dims := mxfp4.blocks.Shape() - - if !strings.HasSuffix(name, ".weight") { - name += ".weight" + if strings.Contains(name, "ffn_down_exps") { + out = append(out, &ggml.Tensor{ + Name: name + ".weight", + Kind: uint32(ggml.TensorTypeMXFP4), + Shape: []uint64{dims[0], dims[1], dims[2] * dims[3] * 2}, + WriterTo: mxfp4, + }) + } else if strings.Contains(name, "ffn_gate_up_exps") { + out = append(out, &ggml.Tensor{ + Name: strings.Replace(name, "gate_up", "gate", 1) + ".weight", + Kind: uint32(ggml.TensorTypeMXFP4), + Shape: []uint64{dims[0], dims[1] / 2, dims[2] * dims[3] * 2}, + WriterTo: mxfp4.slice(1, 0, int(dims[1]), 2), + }, &ggml.Tensor{ + Name: strings.Replace(name, "gate_up", "up", 1) + ".weight", + Kind: uint32(ggml.TensorTypeMXFP4), + Shape: []uint64{dims[0], dims[1] / 2, dims[2] * dims[3] * 2}, + WriterTo: mxfp4.slice(1, 1, int(dims[1]), 2), + }) } - - out = append(out, &ggml.Tensor{ - Name: name, - Kind: uint32(ggml.TensorTypeMXFP4), - Shape: []uint64{dims[0], dims[1], dims[2] * dims[3] * 2}, - WriterTo: mxfp4, - }) } return out @@ -169,9 +178,21 @@ func (m *gptossModel) Replacements() []string { } type mxfp4 struct { + slices []tensor.Slice + blocks, scales Tensor } +func (m *mxfp4) slice(dim, start, end, step int) *mxfp4 { + slice := slices.Repeat([]tensor.Slice{nil}, len(m.blocks.Shape())) + slice[dim] = tensor.S(start, end, step) + return &mxfp4{ + slices: slice, + blocks: m.blocks, + scales: m.scales, + } +} + func (m *mxfp4) WriteTo(w io.Writer) (int64, error) { var b bytes.Buffer if _, err := m.blocks.WriteTo(&b); err != nil { @@ -215,6 +236,13 @@ func (m *mxfp4) WriteTo(w io.Writer) (int64, error) { return 0, err } + if len(m.slices) > 0 { + out, err = out.Slice(m.slices...) + if err != nil { + return 0, err + } + } + out = tensor.Materialize(out) if err := out.Reshape(out.Shape().TotalSize()); err != nil {