mirror of
https://github.com/likelovewant/ollama-for-amd.git
synced 2025-12-21 22:33:56 +00:00
use nn.Linear in place of ml.Tensor (#11049)
while nn.Linear.Forward isn't applicable for sparse MLP, it's still a nice container for the tensors
This commit is contained in:
@@ -63,9 +63,9 @@ func (mlp *TextMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOp
|
|||||||
}
|
}
|
||||||
|
|
||||||
type TextExperts struct {
|
type TextExperts struct {
|
||||||
Gate ml.Tensor `gguf:"ffn_gate_exps.weight"`
|
Gate *nn.Linear `gguf:"ffn_gate_exps"`
|
||||||
Up ml.Tensor `gguf:"ffn_up_exps.weight"`
|
Up *nn.Linear `gguf:"ffn_up_exps"`
|
||||||
Down ml.Tensor `gguf:"ffn_down_exps.weight"`
|
Down *nn.Linear `gguf:"ffn_down_exps"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *TextExperts) Forward(ctx ml.Context, hiddenStates, routerLogits ml.Tensor, opts *TextOptions) ml.Tensor {
|
func (e *TextExperts) Forward(ctx ml.Context, hiddenStates, routerLogits ml.Tensor, opts *TextOptions) ml.Tensor {
|
||||||
@@ -76,9 +76,9 @@ func (e *TextExperts) Forward(ctx ml.Context, hiddenStates, routerLogits ml.Tens
|
|||||||
hiddenStates = hiddenStates.Repeat(ctx, 1, opts.numExpertsUsed)
|
hiddenStates = hiddenStates.Repeat(ctx, 1, opts.numExpertsUsed)
|
||||||
hiddenStates = hiddenStates.Mul(ctx, scores)
|
hiddenStates = hiddenStates.Mul(ctx, scores)
|
||||||
|
|
||||||
upStates := e.Up.MulmatID(ctx, hiddenStates, experts)
|
upStates := e.Up.Weight.MulmatID(ctx, hiddenStates, experts)
|
||||||
gateStates := e.Gate.MulmatID(ctx, hiddenStates, experts)
|
gateStates := e.Gate.Weight.MulmatID(ctx, hiddenStates, experts)
|
||||||
downStates := e.Down.MulmatID(ctx, upStates.Mul(ctx, gateStates.SILU(ctx)), experts)
|
downStates := e.Down.Weight.MulmatID(ctx, upStates.Mul(ctx, gateStates.SILU(ctx)), experts)
|
||||||
|
|
||||||
nextStates := downStates.View(ctx, 0, hiddenStates.Dim(0), downStates.Stride(2), hiddenStates.Dim(2))
|
nextStates := downStates.View(ctx, 0, hiddenStates.Dim(0), downStates.Stride(2), hiddenStates.Dim(2))
|
||||||
for i := 1; i < opts.numExpertsUsed; i++ {
|
for i := 1; i < opts.numExpertsUsed; i++ {
|
||||||
|
|||||||
@@ -66,9 +66,9 @@ type MLP interface {
|
|||||||
|
|
||||||
type sparse struct {
|
type sparse struct {
|
||||||
Router *nn.Linear `gguf:"ffn_gate_inp"`
|
Router *nn.Linear `gguf:"ffn_gate_inp"`
|
||||||
Gate ml.Tensor `gguf:"ffn_gate_exps.weight"`
|
Gate *nn.Linear `gguf:"ffn_gate_exps"`
|
||||||
Up ml.Tensor `gguf:"ffn_up_exps.weight"`
|
Up *nn.Linear `gguf:"ffn_up_exps"`
|
||||||
Down ml.Tensor `gguf:"ffn_down_exps.weight"`
|
Down *nn.Linear `gguf:"ffn_down_exps"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mlp *sparse) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
|
func (mlp *sparse) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
|
||||||
@@ -87,13 +87,13 @@ func (mlp *sparse) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options
|
|||||||
|
|
||||||
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), 1, hiddenStates.Dim(1))
|
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), 1, hiddenStates.Dim(1))
|
||||||
|
|
||||||
upStates := mlp.Up.MulmatID(ctx, hiddenStates, selectedExperts)
|
upStates := mlp.Up.Weight.MulmatID(ctx, hiddenStates, selectedExperts)
|
||||||
|
|
||||||
hiddenStates = mlp.Gate.MulmatID(ctx, hiddenStates, selectedExperts)
|
hiddenStates = mlp.Gate.Weight.MulmatID(ctx, hiddenStates, selectedExperts)
|
||||||
hiddenStates = hiddenStates.SILU(ctx)
|
hiddenStates = hiddenStates.SILU(ctx)
|
||||||
hiddenStates = hiddenStates.Mul(ctx, upStates)
|
hiddenStates = hiddenStates.Mul(ctx, upStates)
|
||||||
|
|
||||||
experts := mlp.Down.MulmatID(ctx, hiddenStates, selectedExperts)
|
experts := mlp.Down.Weight.MulmatID(ctx, hiddenStates, selectedExperts)
|
||||||
experts = experts.Mul(ctx, routingWeights)
|
experts = experts.Mul(ctx, routingWeights)
|
||||||
|
|
||||||
nextStates := experts.View(ctx, 0, experts.Dim(0), experts.Stride(2), experts.Dim(2))
|
nextStates := experts.View(ctx, 0, experts.Dim(0), experts.Stride(2), experts.Dim(2))
|
||||||
|
|||||||
Reference in New Issue
Block a user