model: conversion and hyperparameter fixes for ministral and devstral (#13424)

This commit is contained in:
Jeffrey Morgan
2025-12-11 13:04:00 -08:00
committed by GitHub
parent 1c4e85b4df
commit a838421ea3
4 changed files with 250 additions and 12 deletions

View File

@@ -8,6 +8,7 @@ import (
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model/input"
)
@@ -17,10 +18,41 @@ type TextOptions struct {
eps, ropeBase, ropeScale float32
ropeOrigPosEmbeddings int
ropeScalingBeta float32
ropeType string
ropeExtrapolation float32
ropeBetaFast float32
ropeBetaSlow float32
ropeMscale float32
ropeMscaleAllDim float32
}
func (o TextOptions) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor) ml.Tensor {
return nn.RoPE(ctx, states, positions, o.ropeDim, o.ropeBase, 1./o.ropeScale)
var ropeOpts []func(*rope.Options)
if o.ropeType == "yarn" {
getMscale := func(scale, mscale float64) float64 {
if scale <= 1.0 {
return 1.0
}
return 0.1*mscale*math.Log(scale) + 1.0
}
var attnFactor float32
if o.ropeMscale != 0 && o.ropeMscaleAllDim != 0 {
attnFactor = float32(getMscale(float64(o.ropeScale), float64(o.ropeMscale)) / getMscale(float64(o.ropeScale), float64(o.ropeMscaleAllDim)))
} else {
attnFactor = float32(getMscale(float64(o.ropeScale), 1))
}
ropeOpts = append(ropeOpts,
rope.WithOriginalContextLength(o.ropeOrigPosEmbeddings),
rope.WithExtrapolationFactor(o.ropeExtrapolation),
rope.WithAttentionFactor(attnFactor),
rope.WithBetaFast(o.ropeBetaFast),
rope.WithBetaSlow(o.ropeBetaSlow),
)
}
return nn.RoPE(ctx, states, positions, o.ropeDim, o.ropeBase, 1./o.ropeScale, ropeOpts...)
}
type TextModel struct {
@@ -150,9 +182,15 @@ func newTextModel(c fs.Config) *TextModel {
ropeDim: int(c.Uint("rope.dimension_count")),
eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeBase: c.Float("rope.freq_base"),
ropeScale: c.Float("rope.scaling.factor", 1),
ropeScale: c.Float("rope.scaling.factor", 1.0),
ropeOrigPosEmbeddings: int(c.Uint("rope.scaling.original_context_length")),
ropeScalingBeta: c.Float("rope.scaling_beta"),
ropeScalingBeta: c.Float("rope.scaling_beta", 0.1),
ropeBetaFast: c.Float("rope.scaling.beta_fast", 32.0),
ropeBetaSlow: c.Float("rope.scaling.beta_slow", 1.0),
ropeType: c.String("rope.scaling.type"),
ropeMscale: c.Float("rope.scaling.mscale"),
ropeMscaleAllDim: c.Float("rope.scaling.mscale_all_dim"),
ropeExtrapolation: c.Float("rope.scaling.extrapolation_factor", 1),
},
}
}