mirror of
https://github.com/likelovewant/ollama-for-amd.git
synced 2025-12-21 14:26:30 +00:00
use split activations when possible (#12293)
* use ggml_*_split activations when possible * forward qkv
This commit is contained in:
@@ -430,12 +430,13 @@ type Tensor interface {
|
|||||||
Sin(ctx Context) Tensor
|
Sin(ctx Context) Tensor
|
||||||
Cos(ctx Context) Tensor
|
Cos(ctx Context) Tensor
|
||||||
Tanh(ctx Context) Tensor
|
Tanh(ctx Context) Tensor
|
||||||
GELU(ctx Context) Tensor
|
GELU(ctx Context, up ...Tensor) Tensor
|
||||||
QuickGELU(ctx Context) Tensor
|
SILU(ctx Context, up ...Tensor) Tensor
|
||||||
SILU(ctx Context) Tensor
|
RELU(ctx Context, up ...Tensor) Tensor
|
||||||
RELU(ctx Context) Tensor
|
|
||||||
Sigmoid(ctx Context) Tensor
|
Sigmoid(ctx Context) Tensor
|
||||||
SwiGLU(ctx Context, up Tensor, alpha, limit float32) Tensor
|
|
||||||
|
// AlphaLimitSILU is a variant of SILU that clamps the input to the range [-limit, limit]
|
||||||
|
SILUAlphaLimit(ctx Context, up Tensor, alpha, limit float32) Tensor
|
||||||
|
|
||||||
Reshape(ctx Context, shape ...int) Tensor
|
Reshape(ctx Context, shape ...int) Tensor
|
||||||
View(ctx Context, offset int, shape ...int) Tensor
|
View(ctx Context, offset int, shape ...int) Tensor
|
||||||
|
|||||||
@@ -1431,35 +1431,46 @@ func (t *Tensor) IM2Col(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *Tensor) GELU(ctx ml.Context) ml.Tensor {
|
func (t *Tensor) GELU(ctx ml.Context, t2 ...ml.Tensor) ml.Tensor {
|
||||||
|
if len(t2) > 0 {
|
||||||
|
return &Tensor{
|
||||||
|
b: t.b,
|
||||||
|
t: C.ggml_geglu_split(ctx.(*Context).ctx, t.t, t2[0].(*Tensor).t),
|
||||||
|
}
|
||||||
|
}
|
||||||
return &Tensor{
|
return &Tensor{
|
||||||
b: t.b,
|
b: t.b,
|
||||||
t: C.ggml_gelu_inplace(ctx.(*Context).ctx, t.t),
|
t: C.ggml_gelu_inplace(ctx.(*Context).ctx, t.t),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *Tensor) QuickGELU(ctx ml.Context) ml.Tensor {
|
func (t *Tensor) SILU(ctx ml.Context, t2 ...ml.Tensor) ml.Tensor {
|
||||||
return &Tensor{
|
if len(t2) > 0 {
|
||||||
b: t.b,
|
return &Tensor{
|
||||||
t: C.ggml_gelu_quick_inplace(ctx.(*Context).ctx, t.t),
|
b: t.b,
|
||||||
|
t: C.ggml_swiglu_split(ctx.(*Context).ctx, t.t, t2[0].(*Tensor).t),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
func (t *Tensor) SILU(ctx ml.Context) ml.Tensor {
|
|
||||||
return &Tensor{
|
return &Tensor{
|
||||||
b: t.b,
|
b: t.b,
|
||||||
t: C.ggml_silu_inplace(ctx.(*Context).ctx, t.t),
|
t: C.ggml_silu_inplace(ctx.(*Context).ctx, t.t),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *Tensor) RELU(ctx ml.Context) ml.Tensor {
|
func (t *Tensor) RELU(ctx ml.Context, t2 ...ml.Tensor) ml.Tensor {
|
||||||
|
if len(t2) > 0 {
|
||||||
|
return &Tensor{
|
||||||
|
b: t.b,
|
||||||
|
t: C.ggml_reglu_split(ctx.(*Context).ctx, t.t, t2[0].(*Tensor).t),
|
||||||
|
}
|
||||||
|
}
|
||||||
return &Tensor{
|
return &Tensor{
|
||||||
b: t.b,
|
b: t.b,
|
||||||
t: C.ggml_relu_inplace(ctx.(*Context).ctx, t.t),
|
t: C.ggml_relu_inplace(ctx.(*Context).ctx, t.t),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *Tensor) SwiGLU(ctx ml.Context, up ml.Tensor, alpha, limit float32) ml.Tensor {
|
func (t *Tensor) SILUAlphaLimit(ctx ml.Context, up ml.Tensor, alpha, limit float32) ml.Tensor {
|
||||||
return &Tensor{
|
return &Tensor{
|
||||||
b: t.b,
|
b: t.b,
|
||||||
t: C.ggml_swiglu_oai(ctx.(*Context).ctx, t.t, up.(*Tensor).t, C.float(alpha), C.float(limit)),
|
t: C.ggml_swiglu_oai(ctx.(*Context).ctx, t.t, up.(*Tensor).t, C.float(alpha), C.float(limit)),
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ func Attention(ctx ml.Context, query, key, value ml.Tensor, scale float64, cache
|
|||||||
}
|
}
|
||||||
|
|
||||||
func AttentionWithSinks(ctx ml.Context, query, key, value, sinks ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
|
func AttentionWithSinks(ctx ml.Context, query, key, value, sinks ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
|
||||||
|
ctx.Forward(query)
|
||||||
if key != nil && value != nil {
|
if key != nil && value != nil {
|
||||||
if query.Dim(0) != key.Dim(0) {
|
if query.Dim(0) != key.Dim(0) {
|
||||||
panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0)))
|
panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0)))
|
||||||
@@ -39,6 +40,7 @@ func AttentionWithSinks(ctx ml.Context, query, key, value, sinks ml.Tensor, scal
|
|||||||
panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and value(%v)", key.Dim(2), value.Dim(2)))
|
panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and value(%v)", key.Dim(2), value.Dim(2)))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ctx.Forward(key, value)
|
||||||
if cache != nil {
|
if cache != nil {
|
||||||
cache.Put(ctx, key, value)
|
cache.Put(ctx, key, value)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -138,7 +138,7 @@ type MLP struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor {
|
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor {
|
||||||
hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
|
hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx, mlp.Up.Forward(ctx, hiddenState))
|
||||||
return mlp.Down.Forward(ctx, hiddenState)
|
return mlp.Down.Forward(ctx, hiddenState)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -123,7 +123,7 @@ type TextMLP struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextConfig) ml.Tensor {
|
func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextConfig) ml.Tensor {
|
||||||
hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
|
hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx, mlp.Up.Forward(ctx, hiddenState))
|
||||||
return mlp.Down.Forward(ctx, hiddenState)
|
return mlp.Down.Forward(ctx, hiddenState)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -170,8 +170,7 @@ func (d TextLayer) Forward(ctx ml.Context, hiddenStates, perLayerInput, position
|
|||||||
}
|
}
|
||||||
|
|
||||||
active = d.PerLayerInputGate.Forward(ctx, active)
|
active = d.PerLayerInputGate.Forward(ctx, active)
|
||||||
active = active.GELU(ctx)
|
active = active.GELU(ctx, perLayerInput)
|
||||||
active = active.Mul(ctx, perLayerInput)
|
|
||||||
|
|
||||||
active = d.PerLayerProjection.Forward(ctx, active)
|
active = d.PerLayerProjection.Forward(ctx, active)
|
||||||
active = d.PostPerLayerNorm.Forward(ctx, active, opts.eps)
|
active = d.PostPerLayerNorm.Forward(ctx, active, opts.eps)
|
||||||
@@ -292,7 +291,7 @@ func (mlp TextMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, activationSpa
|
|||||||
hiddenStates = hiddenStates.Sub(ctx, cutoff).RELU(ctx)
|
hiddenStates = hiddenStates.Sub(ctx, cutoff).RELU(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
hiddenStates = hiddenStates.GELU(ctx).Mul(ctx, upStates)
|
hiddenStates = hiddenStates.GELU(ctx, upStates)
|
||||||
hiddenStates = mlp.Down.Forward(ctx, hiddenStates)
|
hiddenStates = mlp.Down.Forward(ctx, hiddenStates)
|
||||||
return hiddenStates
|
return hiddenStates
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -210,7 +210,7 @@ func (mlp *MLPBlock) Forward(ctx ml.Context, hiddenStates, one ml.Tensor, opts *
|
|||||||
up = mlp.Up.Forward(ctx, hiddenStates, selectedExperts)
|
up = mlp.Up.Forward(ctx, hiddenStates, selectedExperts)
|
||||||
}
|
}
|
||||||
|
|
||||||
hiddenStates = gate.SwiGLU(ctx, up, 1.702, 7)
|
hiddenStates = gate.SILUAlphaLimit(ctx, up, 1.702, 7)
|
||||||
|
|
||||||
experts := mlp.Down.Forward(ctx, hiddenStates, selectedExperts)
|
experts := mlp.Down.Forward(ctx, hiddenStates, selectedExperts)
|
||||||
experts = experts.Mul(ctx, routingWeights)
|
experts = experts.Mul(ctx, routingWeights)
|
||||||
|
|||||||
@@ -118,7 +118,7 @@ type MLP struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor {
|
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor {
|
||||||
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
|
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx, mlp.Up.Forward(ctx, hiddenState))
|
||||||
return mlp.Down.Forward(ctx, hiddenState)
|
return mlp.Down.Forward(ctx, hiddenState)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -58,14 +58,14 @@ type TextMLP struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (mlp *TextMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor {
|
func (mlp *TextMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor {
|
||||||
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates))
|
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
|
||||||
return mlp.Down.Forward(ctx, hiddenStates)
|
return mlp.Down.Forward(ctx, hiddenStates)
|
||||||
}
|
}
|
||||||
|
|
||||||
type TextExperts struct {
|
type TextExperts struct {
|
||||||
Gate *nn.Linear `gguf:"ffn_gate_exps"`
|
Gate *nn.LinearBatch `gguf:"ffn_gate_exps"`
|
||||||
Up *nn.Linear `gguf:"ffn_up_exps"`
|
Up *nn.LinearBatch `gguf:"ffn_up_exps"`
|
||||||
Down *nn.Linear `gguf:"ffn_down_exps"`
|
Down *nn.LinearBatch `gguf:"ffn_down_exps"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *TextExperts) Forward(ctx ml.Context, hiddenStates, routerLogits ml.Tensor, opts *TextOptions) ml.Tensor {
|
func (e *TextExperts) Forward(ctx ml.Context, hiddenStates, routerLogits ml.Tensor, opts *TextOptions) ml.Tensor {
|
||||||
@@ -76,9 +76,9 @@ func (e *TextExperts) Forward(ctx ml.Context, hiddenStates, routerLogits ml.Tens
|
|||||||
hiddenStates = hiddenStates.Repeat(ctx, 1, opts.numExpertsUsed)
|
hiddenStates = hiddenStates.Repeat(ctx, 1, opts.numExpertsUsed)
|
||||||
hiddenStates = hiddenStates.Mul(ctx, scores)
|
hiddenStates = hiddenStates.Mul(ctx, scores)
|
||||||
|
|
||||||
upStates := e.Up.Weight.MulmatID(ctx, hiddenStates, experts)
|
upStates := e.Up.Forward(ctx, hiddenStates, experts)
|
||||||
gateStates := e.Gate.Weight.MulmatID(ctx, hiddenStates, experts)
|
gateStates := e.Gate.Forward(ctx, hiddenStates, experts)
|
||||||
downStates := e.Down.Weight.MulmatID(ctx, upStates.Mul(ctx, gateStates.SILU(ctx)), experts)
|
downStates := e.Down.Forward(ctx, upStates.Mul(ctx, gateStates.SILU(ctx)), experts)
|
||||||
|
|
||||||
nextStates := downStates.View(ctx, 0, hiddenStates.Dim(0), downStates.Stride(2), hiddenStates.Dim(2))
|
nextStates := downStates.View(ctx, 0, hiddenStates.Dim(0), downStates.Stride(2), hiddenStates.Dim(2))
|
||||||
for i := 1; i < opts.numExpertsUsed; i++ {
|
for i := 1; i < opts.numExpertsUsed; i++ {
|
||||||
@@ -96,7 +96,7 @@ type TextSharedExpert struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (mlp *TextSharedExpert) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor {
|
func (mlp *TextSharedExpert) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor {
|
||||||
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates))
|
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
|
||||||
return mlp.Down.Forward(ctx, hiddenStates)
|
return mlp.Down.Forward(ctx, hiddenStates)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -65,7 +65,7 @@ type MLP struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextOptions) ml.Tensor {
|
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextOptions) ml.Tensor {
|
||||||
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
|
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx, mlp.Up.Forward(ctx, hiddenState))
|
||||||
return mlp.Down.Forward(ctx, hiddenState)
|
return mlp.Down.Forward(ctx, hiddenState)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -51,7 +51,7 @@ type VisionMLP struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
||||||
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates))
|
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
|
||||||
return mlp.Down.Forward(ctx, hiddenStates)
|
return mlp.Down.Forward(ctx, hiddenStates)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -58,7 +58,7 @@ type TextMLP struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextModelOptions) ml.Tensor {
|
func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextModelOptions) ml.Tensor {
|
||||||
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
|
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx, mlp.Up.Forward(ctx, hiddenState))
|
||||||
return mlp.Down.Forward(ctx, hiddenState)
|
return mlp.Down.Forward(ctx, hiddenState)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -59,7 +59,7 @@ type MLP struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (mlp MLP) Forward(ctx ml.Context, hiddenStates ml.Tensor) ml.Tensor {
|
func (mlp MLP) Forward(ctx ml.Context, hiddenStates ml.Tensor) ml.Tensor {
|
||||||
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates))
|
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
|
||||||
return mlp.Down.Forward(ctx, hiddenStates)
|
return mlp.Down.Forward(ctx, hiddenStates)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -90,7 +90,7 @@ type MLP struct {
|
|||||||
|
|
||||||
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextOptions) ml.Tensor {
|
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextOptions) ml.Tensor {
|
||||||
// Apply SwiGLU activation gating
|
// Apply SwiGLU activation gating
|
||||||
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
|
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx, mlp.Up.Forward(ctx, hiddenState))
|
||||||
// Project back to hidden dimension
|
// Project back to hidden dimension
|
||||||
return mlp.Down.Forward(ctx, hiddenState)
|
return mlp.Down.Forward(ctx, hiddenState)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -100,8 +100,7 @@ type VisionMLP struct {
|
|||||||
func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
||||||
// Using activation as specified in config (likely GELU or SiLU/Swish)
|
// Using activation as specified in config (likely GELU or SiLU/Swish)
|
||||||
gateOutput := mlp.Gate.Forward(ctx, hiddenStates)
|
gateOutput := mlp.Gate.Forward(ctx, hiddenStates)
|
||||||
upOutput := mlp.Up.Forward(ctx, hiddenStates)
|
hiddenStates = gateOutput.SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
|
||||||
hiddenStates = gateOutput.SILU(ctx).Mul(ctx, upOutput)
|
|
||||||
|
|
||||||
return mlp.Down.Forward(ctx, hiddenStates)
|
return mlp.Down.Forward(ctx, hiddenStates)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -30,10 +30,10 @@ func (o Options) headDim() int {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type Attention struct {
|
type Attention struct {
|
||||||
QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"`
|
|
||||||
Query *nn.Linear `gguf:"attn_q"`
|
Query *nn.Linear `gguf:"attn_q"`
|
||||||
KeyNorm *nn.RMSNorm `gguf:"attn_k_norm"`
|
QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"`
|
||||||
Key *nn.Linear `gguf:"attn_k"`
|
Key *nn.Linear `gguf:"attn_k"`
|
||||||
|
KeyNorm *nn.RMSNorm `gguf:"attn_k_norm"`
|
||||||
Value *nn.Linear `gguf:"attn_v"`
|
Value *nn.Linear `gguf:"attn_v"`
|
||||||
Output *nn.Linear `gguf:"attn_output"`
|
Output *nn.Linear `gguf:"attn_output"`
|
||||||
}
|
}
|
||||||
@@ -65,10 +65,10 @@ type MLP interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type sparse struct {
|
type sparse struct {
|
||||||
Router *nn.Linear `gguf:"ffn_gate_inp"`
|
Router *nn.Linear `gguf:"ffn_gate_inp"`
|
||||||
Gate *nn.Linear `gguf:"ffn_gate_exps"`
|
Gate *nn.LinearBatch `gguf:"ffn_gate_exps"`
|
||||||
Up *nn.Linear `gguf:"ffn_up_exps"`
|
Up *nn.LinearBatch `gguf:"ffn_up_exps"`
|
||||||
Down *nn.Linear `gguf:"ffn_down_exps"`
|
Down *nn.LinearBatch `gguf:"ffn_down_exps"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mlp *sparse) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
|
func (mlp *sparse) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
|
||||||
@@ -87,13 +87,9 @@ func (mlp *sparse) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options
|
|||||||
|
|
||||||
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), 1, hiddenStates.Dim(1))
|
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), 1, hiddenStates.Dim(1))
|
||||||
|
|
||||||
upStates := mlp.Up.Weight.MulmatID(ctx, hiddenStates, selectedExperts)
|
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates, selectedExperts).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates, selectedExperts))
|
||||||
|
|
||||||
hiddenStates = mlp.Gate.Weight.MulmatID(ctx, hiddenStates, selectedExperts)
|
experts := mlp.Down.Forward(ctx, hiddenStates, selectedExperts)
|
||||||
hiddenStates = hiddenStates.SILU(ctx)
|
|
||||||
hiddenStates = hiddenStates.Mul(ctx, upStates)
|
|
||||||
|
|
||||||
experts := mlp.Down.Weight.MulmatID(ctx, hiddenStates, selectedExperts)
|
|
||||||
experts = experts.Mul(ctx, routingWeights)
|
experts = experts.Mul(ctx, routingWeights)
|
||||||
|
|
||||||
nextStates := experts.View(ctx, 0, experts.Dim(0), experts.Stride(2), experts.Dim(2))
|
nextStates := experts.View(ctx, 0, experts.Dim(0), experts.Stride(2), experts.Dim(2))
|
||||||
@@ -111,7 +107,8 @@ type dense struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (mlp *dense) Forward(ctx ml.Context, hiddenStates ml.Tensor, _ *Options) ml.Tensor {
|
func (mlp *dense) Forward(ctx ml.Context, hiddenStates ml.Tensor, _ *Options) ml.Tensor {
|
||||||
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates))
|
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).
|
||||||
|
SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
|
||||||
return mlp.Down.Forward(ctx, hiddenStates)
|
return mlp.Down.Forward(ctx, hiddenStates)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user