convert(gptoss): mxfp4 to ggml layout to avoid jit conversion (#12018)

* convert: return bytes written

* ggml flavor mxfp4

* simplify jit conversion

* comment
This commit is contained in:
Michael Yang
2025-08-26 16:41:02 -07:00
committed by GitHub
parent 86834a2797
commit 59412fbb43
6 changed files with 49 additions and 58 deletions

View File

@@ -172,7 +172,20 @@ func (m *mxfp4) WriteTo(w io.Writer) (int64, error) {
blocksDims[i] = int(d) blocksDims[i] = int(d)
} }
var blocks tensor.Tensor = tensor.New(tensor.WithShape(blocksDims...), tensor.WithBacking(b.Bytes())) bts := b.Bytes()
var tmp [16]byte
for i := 0; i < b.Len(); i += 16 {
for j := range 8 {
// transform a1b2c3 ... x7y8z9 -> 71xa82yb93zc
a, b := bts[i+j], bts[i+j+8]
tmp[2*j+0] = (a & 0x0F) | (b << 4)
tmp[2*j+1] = (a >> 4) | (b & 0xF0)
}
copy(bts[i:i+16], tmp[:])
}
var blocks tensor.Tensor = tensor.New(tensor.WithShape(blocksDims...), tensor.WithBacking(bts))
var s bytes.Buffer var s bytes.Buffer
if _, err := m.scales.WriteTo(&s); err != nil { if _, err := m.scales.WriteTo(&s); err != nil {
@@ -206,5 +219,5 @@ func (m *mxfp4) WriteTo(w io.Writer) (int64, error) {
return 0, err return 0, err
} }
return 0, nil return int64(len(u8s)), nil
} }

View File

@@ -33,8 +33,8 @@ func (t tensorBase) Shape() []uint64 {
const ( const (
tensorKindFP32 uint32 = iota tensorKindFP32 uint32 = iota
tensorKindFP16 tensorKindFP16
tensorKindMXFP4 = 4
tensorKindBF16 = 30 tensorKindBF16 = 30
tensorKindMXFP4 = 39
) )
func (t tensorBase) Kind() uint32 { func (t tensorBase) Kind() uint32 {

View File

@@ -188,17 +188,17 @@ func (st safetensor) WriteTo(w io.Writer) (int64, error) {
switch st.Kind() { switch st.Kind() {
case tensorKindFP32: case tensorKindFP32:
return 0, binary.Write(w, binary.LittleEndian, f32s) return int64(len(f32s) * 4), binary.Write(w, binary.LittleEndian, f32s)
case tensorKindFP16: case tensorKindFP16:
f16s := make([]uint16, len(f32s)) f16s := make([]uint16, len(f32s))
for i := range f32s { for i := range f32s {
f16s[i] = float16.Fromfloat32(f32s[i]).Bits() f16s[i] = float16.Fromfloat32(f32s[i]).Bits()
} }
return 0, binary.Write(w, binary.LittleEndian, f16s) return int64(len(f16s) * 2), binary.Write(w, binary.LittleEndian, f16s)
case tensorKindBF16: case tensorKindBF16:
u8s := bfloat16.EncodeFloat32(f32s) u8s := bfloat16.EncodeFloat32(f32s)
return 0, binary.Write(w, binary.LittleEndian, u8s) return int64(len(u8s)), binary.Write(w, binary.LittleEndian, u8s)
default: default:
return 0, fmt.Errorf("unknown storage type: %d", st.Kind()) return 0, fmt.Errorf("unknown storage type: %d", st.Kind())
} }

View File

@@ -290,24 +290,24 @@ func (t Tensor) blockSize() uint64 {
func (t TensorType) BlockSize() uint64 { func (t TensorType) BlockSize() uint64 {
switch t { switch t {
case case
0, // F32 TensorTypeF32,
1, // F16 TensorTypeF16,
24, // I8 TensorTypeI8,
25, // I16 TensorTypeI16,
26, // I32 TensorTypeI32,
27, // I64 TensorTypeI64,
28, // F64 TensorTypeF64,
30: // BF16 TensorTypeBF16:
return 1 return 1
case case
2, // Q4_0 TensorTypeQ4_0,
3, // Q4_1 TensorTypeQ4_1,
4, // MXFP4 TensorTypeQ5_0,
6, // Q5_0 TensorTypeQ5_1,
7, // Q5_1 TensorTypeQ8_0,
8, // Q8_0 TensorTypeQ8_1,
9, // Q8_1 tensorTypeIQ4_NL,
20: // IQ4_NL 4, TensorTypeMXFP4:
return 32 return 32
default: default:
return 256 return 256
@@ -330,8 +330,6 @@ func (t TensorType) TypeSize() uint64 {
return 2 + blockSize/2 return 2 + blockSize/2
case TensorTypeQ4_1: case TensorTypeQ4_1:
return 2 + 2 + blockSize/2 return 2 + 2 + blockSize/2
case TensorTypeMXFP4, 39:
return 1 + blockSize/2
case TensorTypeQ5_0: case TensorTypeQ5_0:
return 2 + 4 + blockSize/2 return 2 + 4 + blockSize/2
case TensorTypeQ5_1: case TensorTypeQ5_1:
@@ -382,6 +380,8 @@ func (t TensorType) TypeSize() uint64 {
return blockSize/8 + blockSize/16 + blockSize/32 return blockSize/8 + blockSize/16 + blockSize/32
case TensorTypeBF16: case TensorTypeBF16:
return 2 return 2
case 4, TensorTypeMXFP4:
return 1 + blockSize/2
default: default:
return 0 return 0
} }

View File

@@ -146,8 +146,6 @@ func (ftype FileType) ToTensorType() TensorType {
return TensorTypeQ4_0 return TensorTypeQ4_0
case fileTypeQ4_1: case fileTypeQ4_1:
return TensorTypeQ4_1 return TensorTypeQ4_1
case fileTypeMXFP4:
return TensorTypeMXFP4 // Formerly unused tensorTypeQ4_2
case FileTypeQ8_0: case FileTypeQ8_0:
return TensorTypeQ8_0 return TensorTypeQ8_0
case fileTypeQ5_0: case fileTypeQ5_0:
@@ -176,6 +174,8 @@ func (ftype FileType) ToTensorType() TensorType {
return TensorTypeQ2_K return TensorTypeQ2_K
case FileTypeBF16: case FileTypeBF16:
return TensorTypeBF16 return TensorTypeBF16
case fileTypeMXFP4:
return TensorTypeMXFP4
default: default:
slog.Warn("unsupported file type", "type", ftype) slog.Warn("unsupported file type", "type", ftype)
return 0 // F32 return 0 // F32
@@ -191,7 +191,7 @@ const (
TensorTypeF16 TensorTypeF16
TensorTypeQ4_0 TensorTypeQ4_0
TensorTypeQ4_1 TensorTypeQ4_1
TensorTypeMXFP4 // Formerly unused tensorTypeQ4_2 tensorTypeQ4_2
tensorTypeQ4_3 // unused by GGML tensorTypeQ4_3 // unused by GGML
TensorTypeQ5_0 TensorTypeQ5_0
TensorTypeQ5_1 TensorTypeQ5_1
@@ -226,6 +226,7 @@ const (
tensorTypeIQ4_NL_4_4 // unused by GGML tensorTypeIQ4_NL_4_4 // unused by GGML
tensorTypeIQ4_NL_4_8 // unused by GGML tensorTypeIQ4_NL_4_8 // unused by GGML
tensorTypeIQ4_NL_8_8 // unused by GGML tensorTypeIQ4_NL_8_8 // unused by GGML
TensorTypeMXFP4
) )
// ParseFileType parses the provided GGUF file type // ParseFileType parses the provided GGUF file type
@@ -318,7 +319,7 @@ func (t TensorType) String() string {
return "F64" return "F64"
case TensorTypeBF16: case TensorTypeBF16:
return "BF16" return "BF16"
case TensorTypeMXFP4: case 4, TensorTypeMXFP4:
return "MXFP4" return "MXFP4"
default: default:
return "unknown" return "unknown"

View File

@@ -535,6 +535,7 @@ func (b *Backend) Load(ctx context.Context, progress func(float32)) error {
const BS = 17 // MXFP4 block size const BS = 17 // MXFP4 block size
bts := make([]byte, 8*BS*format.KibiByte) // ~128k block aligned bts := make([]byte, 8*BS*format.KibiByte) // ~128k block aligned
var s uint64 var s uint64
var tmp [16]byte
for s < t.Size() { for s < t.Size() {
// Stop if either the parent context has been canceled or if any of the other tensors returned an error // Stop if either the parent context has been canceled or if any of the other tensors returned an error
if err := ctx.Err(); err != nil { if err := ctx.Err(); err != nil {
@@ -546,37 +547,13 @@ func (b *Backend) Load(ctx context.Context, progress func(float32)) error {
return err return err
} }
for j := range n / BS { for j := range n / BS {
for i := 1; i < BS; i++ {
// swap nibbles
t_lo := bts[j*BS+i] & 0x0F
t_hi := bts[j*BS+i] & 0xF0
bts[j*BS+i] = (t_lo << 4) | (t_hi >> 4)
}
// transform aaaa...bbbb... to abababab...
oi := 0
tmp := [16]byte{}
for i := 1; i < 9; i++ { for i := 1; i < 9; i++ {
blk_a0 := bts[j*BS+i] & 0xF0 // transform a1b2c3 ... x7y8z9 -> 71xa82yb93zc
blk_a1 := bts[j*BS+i] << 4 a, b := bts[j*BS+i], bts[j*BS+i+8]
blk_b0 := bts[j*BS+i+8] >> 4 tmp[2*(i-1)] = (a & 0x0F) | (b << 4)
blk_b1 := bts[j*BS+i+8] & 0x0F tmp[2*(i-1)+1] = (a >> 4) | (b & 0xF0)
// swap once more
out0 := blk_a0 | blk_b0
out1 := blk_a1 | blk_b1
out_h0 := out0 & 0xF0
out_l0 := out0 & 0x0F
out_h1 := out1 & 0xF0
out_l1 := out1 & 0x0F
out0 = (out_h0 >> 4) | (out_l0 << 4)
out1 = (out_h1 >> 4) | (out_l1 << 4)
tmp[oi] = out0
oi++
tmp[oi] = out1
oi++
}
for i := range tmp {
bts[j*BS+i+1] = tmp[i]
} }
copy(bts[j*BS+1:j*BS+17], tmp[:])
} }
for _, tt := range tts { for _, tt := range tts {