feat(model): add qwen3vl (#12665)

This commit is contained in:
Michael Yang
2025-10-28 17:39:47 -07:00
committed by GitHub
parent 36d64fb531
commit 7d25b9e194
22 changed files with 1502 additions and 35 deletions

View File

@@ -198,6 +198,8 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
conv = &qwen2Model{}
case "Qwen2_5_VLForConditionalGeneration":
conv = &qwen25VLModel{}
case "Qwen3VLForConditionalGeneration", "Qwen3VLMoeForConditionalGeneration":
conv = &qwen3VLModel{}
case "BertModel":
conv = &bertModel{}
case "CohereForCausalLM":

157
convert/convert_qwen3.go Normal file
View File

@@ -0,0 +1,157 @@
package convert
import (
"slices"
"strings"
"github.com/ollama/ollama/fs/ggml"
"github.com/pdevine/tensor"
"github.com/pdevine/tensor/native"
)
type qwen3Model struct {
ModelParameters
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
HiddenSize uint32 `json:"hidden_size"`
HiddenLayers uint32 `json:"num_hidden_layers"`
IntermediateSize uint32 `json:"intermediate_size"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
HeadDim uint32 `json:"head_dim"`
NumExperts uint32 `json:"num_experts"`
NumExpertsPerToken uint32 `json:"num_experts_per_tok"`
NormTopkProb bool `json:"norm_topk_prob"`
RopeTheta float32 `json:"rope_theta"`
RopeScaling struct {
Type string `json:"type"`
Factor ropeFactor `json:"factor"`
OriginalMaxPositionEmbeddings uint32 `json:"original_max_position_embeddings"`
MropeSection []int32 `json:"mrope_section"`
} `json:"rope_scaling"`
RMSNormEPS float32 `json:"rms_norm_eps"`
}
// KV implements ModelConverter.
func (q *qwen3Model) KV(t *Tokenizer) ggml.KV {
arch := "qwen3"
if q.NumExperts > 0 {
arch += "moe"
}
kv := q.ModelParameters.KV(t)
kv["general.architecture"] = arch
kv["block_count"] = q.HiddenLayers
kv["context_length"] = q.MaxPositionEmbeddings
kv["embedding_length"] = q.HiddenSize
kv["feed_forward_length"] = q.IntermediateSize
kv["attention.head_count"] = q.NumAttentionHeads
kv["attention.head_count_kv"] = q.NumKeyValueHeads
kv["attention.key_length"] = q.HeadDim
kv["attention.value_length"] = q.HeadDim
if q.NumExperts > 0 {
kv["expert_count"] = q.NumExperts
kv["expert_used_count"] = q.NumExpertsPerToken
kv["norm_top_k_prob"] = q.NormTopkProb
}
kv["rope.freq_base"] = q.RopeTheta
kv["attention.layer_norm_rms_epsilon"] = q.RMSNormEPS
switch q.RopeScaling.Type {
case "":
// no scaling
case "yarn":
kv["rope.scaling.type"] = q.RopeScaling.Type
kv["rope.scaling.factor"] = q.RopeScaling.Factor
case "mrope", "default":
kv["rope.mrope_section"] = q.RopeScaling.MropeSection
default:
panic("unknown rope scaling type")
}
return kv
}
// Tensors implements ModelConverter.
func (q *qwen3Model) Tensors(ts []Tensor) []*ggml.Tensor {
var out []*ggml.Tensor
// TODO: handle split experts
for _, t := range ts {
switch {
case strings.Contains(t.Name(), "ffn_gate_up_exps"):
afterFunc := func(t tensor.Tensor) (tensor.Tensor, error) { return tensor.Transpose(t, 0, 2, 1) }
for t := range splitDim(t, 2,
split{Replacer: strings.NewReplacer("gate_up", "gate"), afterFunc: afterFunc},
split{Replacer: strings.NewReplacer("gate_up", "up"), afterFunc: afterFunc},
) {
t.Shape[1], t.Shape[2] = t.Shape[2], t.Shape[1]
out = append(out, t)
}
case strings.Contains(t.Name(), "ffn_down_exps"):
shape := slices.Clone(t.Shape())
shape[1], shape[2] = shape[2], shape[1]
t.SetRepacker(func(_ string, data []float32, shape []uint64) ([]float32, error) {
dims := make([]int, len(shape))
for i := range shape {
dims[i] = int(shape[i])
}
var tt tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
tt, err := tensor.Transpose(tt, 0, 2, 1)
if err != nil {
return nil, err
}
// flatten tensor so it can be written as a vector
if err := tt.Reshape(tt.Shape().TotalSize()); err != nil {
return nil, err
}
return native.VectorF32(tt.(*tensor.Dense))
})
out = append(out, &ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: shape,
WriterTo: t,
})
default:
out = append(out, &ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
}
}
return out
}
// Replacements implements ModelConverter.
func (q *qwen3Model) Replacements() []string {
return []string{
"lm_head", "output",
"model.embed_tokens", "token_embd",
"model.layers", "blk",
"input_layernorm", "attn_norm",
"self_attn.k_proj", "attn_k",
"self_attn.k_norm", "attn_k_norm",
"self_attn.v_proj", "attn_v",
"self_attn.q_proj", "attn_q",
"self_attn.q_norm", "attn_q_norm",
"self_attn.o_proj", "attn_output",
"mlp.down_proj", "ffn_down",
"mlp.gate_proj", "ffn_gate",
"mlp.up_proj", "ffn_up",
"mlp.gate.weight", "ffn_gate_inp.weight",
"mlp.experts.down_proj", "ffn_down_exps.weight",
"mlp.experts.gate_up_proj", "ffn_gate_up_exps.weight",
"post_attention_layernorm", "ffn_norm",
"model.norm", "output_norm",
}
}
var _ ModelConverter = (*qwen3Model)(nil)

116
convert/convert_qwen3vl.go Normal file
View File

@@ -0,0 +1,116 @@
package convert
import (
"cmp"
"encoding/json"
"io/fs"
"slices"
"strings"
"github.com/ollama/ollama/fs/ggml"
)
type qwen3VLModel struct {
qwen3Model `json:"text_config"`
VisionModel struct {
Depth uint32 `json:"depth"`
HiddenSize uint32 `json:"hidden_size"`
NumHeads uint32 `json:"num_heads"`
InChannels uint32 `json:"in_channels"`
PatchSize uint32 `json:"patch_size"`
SpatialMergeSize uint32 `json:"spatial_merge_size"`
WindowSize uint32 `json:"window_size"`
RMSNormEps float32 `json:"layer_norm_epsilon"`
RopeTheta float32 `json:"rope_theta"`
TemporalPatchSize uint32 `json:"temporal_patch_size"`
DeepstackVisualIndexes []int32 `json:"deepstack_visual_indexes"`
Size struct {
ShortestEdge uint32 `json:"shortest_edge"`
LongestEdge uint32 `json:"longest_edge"`
} `json:"size"`
ImageMean []float32 `json:"image_mean"`
ImageStd []float32 `json:"image_std"`
} `json:"vision_config"`
}
func (m *qwen3VLModel) parseMore(fsys fs.FS) error {
bts, err := fs.ReadFile(fsys, "preprocessor_config.json")
if err != nil {
return err
}
return json.Unmarshal(bts, &m.VisionModel)
}
func (m *qwen3VLModel) KV(t *Tokenizer) ggml.KV {
kv := m.qwen3Model.KV(t)
arch := "qwen3vl"
if m.NumExperts > 0 {
arch += "moe"
}
// override architecture
kv["general.architecture"] = arch
kv["vision.block_count"] = cmp.Or(m.VisionModel.Depth, 32)
kv["vision.embedding_length"] = m.VisionModel.HiddenSize
kv["vision.attention.head_count"] = cmp.Or(m.VisionModel.NumHeads, 16)
kv["vision.num_channels"] = m.VisionModel.InChannels
kv["vision.patch_size"] = cmp.Or(m.VisionModel.PatchSize, 14)
kv["vision.spatial_merge_size"] = cmp.Or(m.VisionModel.SpatialMergeSize, 2)
kv["vision.attention.layer_norm_epsilon"] = cmp.Or(m.VisionModel.RMSNormEps, 1e-6)
kv["vision.rope.freq_base"] = cmp.Or(m.VisionModel.RopeTheta, 1e4)
kv["vision.temporal_patch_size"] = cmp.Or(m.VisionModel.TemporalPatchSize, 2)
kv["vision.deepstack_visual_indexes"] = m.VisionModel.DeepstackVisualIndexes
kv["vision.shortest_edge"] = m.VisionModel.Size.ShortestEdge
kv["vision.longest_edge"] = m.VisionModel.Size.LongestEdge
kv["vision.image_mean"] = m.VisionModel.ImageMean
kv["vision.image_std"] = m.VisionModel.ImageStd
return kv
}
func (m *qwen3VLModel) Tensors(ts []Tensor) []*ggml.Tensor {
var rest []Tensor
var out []*ggml.Tensor
for _, t := range ts {
switch {
case strings.Contains(t.Name(), "attn_qkv"):
out = append(out, slices.Collect(splitDim(t, 0,
split{Replacer: strings.NewReplacer("attn_qkv", "attn_q")},
split{Replacer: strings.NewReplacer("attn_qkv", "attn_k")},
split{Replacer: strings.NewReplacer("attn_qkv", "attn_v")},
))...)
case strings.Contains(t.Name(), "patch_embed") && strings.HasSuffix(t.Name(), "weight"):
shape := t.Shape()
out = append(out, &ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: append([]uint64{shape[0] * shape[1]}, shape[2:]...),
WriterTo: t,
})
default:
rest = append(rest, t)
}
}
return append(m.qwen3Model.Tensors(rest), out...)
}
func (m *qwen3VLModel) Replacements() []string {
return append(
m.qwen3Model.Replacements(),
"model.language_", "",
"model.visual", "v",
"patch_embed.proj", "patch_embed",
"blocks", "blk",
"attn.qkv", "attn_qkv",
"attn.proj", "attn_out",
"deepstack_merger_list", "deepstack_merger",
)
}

View File

@@ -19,8 +19,8 @@ type split struct {
dim int
slices []tensor.Slice
// fn is an optional function to apply to the tensor after slicing
fn func(tensor.Tensor) (tensor.Tensor, error)
// afterFunc is an optional function to apply to the tensor after slicing
afterFunc func(tensor.Tensor) (tensor.Tensor, error)
}
// splitDim splits a tensor along a specified dimension into multiple tensors. The dimension
@@ -54,8 +54,8 @@ func splitDim(t Tensor, dim int, splits ...split) iter.Seq[*ggml.Tensor] {
tt = tensor.Materialize(tt)
if split.fn != nil {
tt, err = split.fn(tt)
if split.afterFunc != nil {
tt, err = split.afterFunc(tt)
if err != nil {
return nil, err
}

View File

@@ -432,7 +432,7 @@ func TestSplitDim(t *testing.T) {
t.Run("split with transpose", func(t *testing.T) {
next, stop := iter.Pull(splitDim(&r, 1,
split{Replacer: strings.NewReplacer("a", "x")},
split{Replacer: strings.NewReplacer("b", "y"), fn: func(tt tensor.Tensor) (tensor.Tensor, error) {
split{Replacer: strings.NewReplacer("b", "y"), afterFunc: func(tt tensor.Tensor) (tensor.Tensor, error) {
return tensor.Transpose(tt, 1, 0)
}},
))