Add deepseek v3.1 (#13063)

* Add mla for flash attention
* Revert to using chunks
This commit is contained in:
Grace
2025-11-17 18:03:21 -08:00
committed by GitHub
parent 1fd4cb87b2
commit 584e2d646f
4 changed files with 67 additions and 24 deletions

View File

@@ -1625,7 +1625,7 @@ func (t *Tensor) AvgPool2D(ctx ml.Context, k, s int, p float32) ml.Tensor {
}
}
func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask, sinks ml.Tensor, scale float64) ml.Tensor {
func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask, sinks ml.Tensor, vmla ml.Tensor, scale float64) ml.Tensor {
var kqMask *C.struct_ggml_tensor
if mask != nil {
kqMask = mask.(*Tensor).t
@@ -1642,6 +1642,16 @@ func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask, sin
C.ggml_flash_attn_ext_add_sinks(kqv, sinks.(*Tensor).t)
}
C.ggml_flash_attn_ext_set_prec(kqv, C.GGML_PREC_F32)
if vmla != nil {
var cur ml.Tensor = &Tensor{b: t.b, t: kqv}
cur = cur.Permute(ctx, 0, 2, 1, 3)
cur = vmla.Mulmat(ctx, cur)
cur = cur.Permute(ctx, 0, 2, 1, 3)
cur = cur.Contiguous(ctx)
kqv = cur.(*Tensor).t
}
return &Tensor{b: t.b, t: kqv}
} else {
kq := key.MulmatFullPrec(ctx, query)
@@ -1654,6 +1664,10 @@ func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask, sin
}
kqv := value.Mulmat(ctx, kq)
if vmla != nil {
kqv = vmla.Mulmat(ctx, kqv)
}
return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
}
}